• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 17111174344

20 Aug 2025 09:50PM UTC coverage: 60.99% (-1.5%) from 62.483%
17111174344

push

github

web-flow
Merge pull request #208 from daisytuner/reduce-ml

adds ML reduction nodes

6 of 372 new or added lines in 7 files covered. (1.61%)

9251 of 15168 relevant lines covered (60.99%)

118.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/ml/log_softmax.cpp
1
#include "sdfg/data_flow/library_nodes/math/ml/softmax.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/analysis/scope_analysis.h"
5
#include "sdfg/builder/structured_sdfg_builder.h"
6

7
namespace sdfg {
8
namespace math {
9
namespace ml {
10

NEW
11
SoftmaxNode::SoftmaxNode(
×
12
    size_t element_id,
13
    const DebugInfo &debug_info,
14
    const graph::Vertex vertex,
15
    data_flow::DataFlowGraph &parent,
16
    int axis
17
)
NEW
18
    : MathNode(
×
NEW
19
          element_id, debug_info, vertex, parent, LibraryNodeType_Softmax, {"output"}, {"input"}, data_flow::ImplementationType_NONE
×
20
      ),
NEW
21
      axis_(axis) {}
×
22

NEW
23
void SoftmaxNode::validate(const Function &) const { /* TODO */ }
×
24

NEW
25
bool SoftmaxNode::expand(builder::StructuredSDFGBuilder &builder, analysis::AnalysisManager &analysis_manager) {
×
NEW
26
    auto &dataflow = this->get_parent();
×
NEW
27
    auto &block = static_cast<structured_control_flow::Block &>(*dataflow.get_parent());
×
28

NEW
29
    auto &scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
NEW
30
    auto &parent = static_cast<structured_control_flow::Sequence &>(*scope_analysis.parent_scope(&block));
×
31

32
    // Locate edges
NEW
33
    const data_flow::Memlet *iedge_input = nullptr;
×
NEW
34
    const data_flow::Memlet *oedge_output = nullptr;
×
NEW
35
    for (const auto &edge : dataflow.in_edges(*this)) {
×
NEW
36
        if (edge.dst_conn() == "input") {
×
NEW
37
            iedge_input = &edge;
×
NEW
38
        }
×
39
    }
NEW
40
    for (const auto &edge : dataflow.out_edges(*this)) {
×
NEW
41
        if (edge.src_conn() == "output") {
×
NEW
42
            oedge_output = &edge;
×
NEW
43
        }
×
44
    }
NEW
45
    if (!iedge_input || !oedge_output) return false;
×
46

NEW
47
    std::string input_name = static_cast<const data_flow::AccessNode &>(iedge_input->src()).data();
×
NEW
48
    std::string output_name = static_cast<const data_flow::AccessNode &>(oedge_output->dst()).data();
×
49

50
    // Create new sequence before
NEW
51
    auto &new_sequence = builder.add_sequence_before(parent, block, block.debug_info()).first;
×
NEW
52
    structured_control_flow::Sequence *last_scope = &new_sequence;
×
53

54
    // Create maps over output subset dims (parallel dims)
NEW
55
    data_flow::Subset domain_begin = oedge_output->begin_subset();
×
NEW
56
    data_flow::Subset domain_end = oedge_output->end_subset();
×
57

NEW
58
    std::vector<symbolic::Expression> loop_syms;
×
NEW
59
    structured_control_flow::Map *last_map = nullptr;
×
NEW
60
    for (size_t d = 0; d < domain_begin.size(); ++d) {
×
NEW
61
        std::string indvar_str = builder.find_new_name("_i");
×
NEW
62
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
63
        auto indvar = symbolic::symbol(indvar_str);
×
NEW
64
        auto init = domain_begin[d];
×
NEW
65
        auto update = symbolic::add(indvar, symbolic::one());
×
NEW
66
        auto cond = symbolic::Lt(indvar, symbolic::add(domain_end[d], symbolic::one()));
×
NEW
67
        last_map = &builder.add_map(
×
NEW
68
            *last_scope,
×
69
            indvar,
70
            cond,
71
            init,
72
            update,
73
            structured_control_flow::ScheduleType_Sequential,
NEW
74
            {},
×
NEW
75
            block.debug_info()
×
76
        );
NEW
77
        last_scope = &last_map->root();
×
NEW
78
        loop_syms.push_back(indvar);
×
NEW
79
    }
×
80
    
81
    // Initialize temp variable to zero
NEW
82
    std::string temp_name = builder.find_new_name("_tmp");
×
NEW
83
    std::string temp_name2 = builder.find_new_name("_tmp");
×
NEW
84
    types::Scalar temp_type(types::PrimitiveType::Float);
×
NEW
85
    builder.add_container(temp_name, temp_type);
×
NEW
86
    builder.add_container(temp_name2, temp_type);
×
87
    
NEW
88
    auto &init_block = builder.add_block(*last_scope);
×
NEW
89
    auto &init_tasklet = builder.add_tasklet(init_block, data_flow::TaskletCode::assign, "_out", {"0.0f"});
×
NEW
90
    auto &tmp_access_init = builder.add_access(init_block, temp_name);
×
NEW
91
    builder.add_computational_memlet(init_block, init_tasklet, "_out", tmp_access_init, {}, temp_type);
×
92

93
    // add reduction for loop
NEW
94
    symbolic::Expression red_begin;
×
NEW
95
    symbolic::Expression red_end;
×
NEW
96
    if (axis_ >= 0) {
×
NEW
97
        red_begin = iedge_input->begin_subset()[axis_];
×
NEW
98
        red_end = iedge_input->end_subset()[axis_];
×
NEW
99
    } else {
×
NEW
100
        red_begin = iedge_input->begin_subset().back();
×
NEW
101
        red_end = iedge_input->end_subset().back();
×
102
    }
NEW
103
    std::string red_name = builder.find_new_name("_i");
×
NEW
104
    builder.add_container(red_name, types::Scalar(types::PrimitiveType::UInt64));
×
NEW
105
    auto red_indvar = symbolic::symbol(red_name);
×
NEW
106
    auto red_init = red_begin;
×
NEW
107
    auto red_update = symbolic::add(red_indvar, symbolic::one());
×
NEW
108
    auto red_cond = symbolic::Lt(red_indvar, symbolic::add(red_end, symbolic::one()));
×
NEW
109
    auto red_map = &builder.add_for(
×
NEW
110
        *last_scope,
×
111
        red_indvar,
112
        red_cond,
113
        red_init,
114
        red_update,
NEW
115
        {},
×
NEW
116
        block.debug_info()
×
117
    );
118

119
    // Create innermost block
NEW
120
    auto &code_block = builder.add_block(red_map->root());
×
121
    
122
    // Create access nodes for input and output
NEW
123
    auto &input_access = builder.add_access(code_block, input_name);
×
NEW
124
    auto &tmp2_access = builder.add_access(code_block, temp_name2);
×
NEW
125
    auto &tmp_access_out = builder.add_access(code_block, temp_name);
×
NEW
126
    auto &tmp_access_in = builder.add_access(code_block, temp_name2);
×
127
    
128
    // Create index expressions for input and output
NEW
129
    std::vector<symbolic::Expression> input_subset = loop_syms;
×
130
    
131
    // Replace the reduction axis index with the reduction variable for input
NEW
132
    if (axis_ >= 0 && axis_ < static_cast<int>(input_subset.size())) {
×
NEW
133
        input_subset.insert(input_subset.begin() + axis_, red_indvar);
×
NEW
134
    } else if (axis_ < 0) {
×
NEW
135
        input_subset.push_back(red_indvar);
×
NEW
136
    }
×
137
    
138
    // Compute exponential
NEW
139
    auto &exp_tasklet = builder.add_tasklet(code_block, data_flow::TaskletCode::expf, "_out", {"_in"});
×
NEW
140
    builder.add_computational_memlet(code_block, input_access, exp_tasklet, "_in", input_subset, iedge_input->base_type());
×
NEW
141
    builder.add_computational_memlet(code_block, exp_tasklet, "_out", tmp2_access, {}, temp_type);
×
142
    
143
    // Add to temp (reduction)
NEW
144
    auto &add_tasklet = builder.add_tasklet(code_block, data_flow::TaskletCode::add, "_out", {"_in1", "_in2"});
×
NEW
145
    builder.add_computational_memlet(code_block, tmp_access_in, add_tasklet, "_in1", {}, temp_type);
×
NEW
146
    builder.add_computational_memlet(code_block, tmp2_access, add_tasklet, "_in2", {}, temp_type);
×
NEW
147
    builder.add_computational_memlet(code_block, add_tasklet, "_out", tmp_access_out, {}, temp_type);
×
148

149
    // Create writeback - assign the accumulated sum to output
NEW
150
    auto &writeback_block = builder.add_block(*last_scope);
×
NEW
151
    auto &tmp_access_wb = builder.add_access(writeback_block, temp_name);
×
NEW
152
    auto &output_access_wb = builder.add_access(writeback_block, output_name);
×
NEW
153
    auto &writeback_tasklet = builder.add_tasklet(writeback_block, data_flow::TaskletCode::assign, "_out", {"_in"});
×
NEW
154
    builder.add_computational_memlet(writeback_block, tmp_access_wb, writeback_tasklet, "_in", {}, temp_type);
×
NEW
155
    builder.add_computational_memlet(writeback_block, writeback_tasklet, "_out", output_access_wb, loop_syms, oedge_output->base_type());
×
156

157
    // Cleanup old block
NEW
158
    builder.remove_memlet(block, *iedge_input);
×
NEW
159
    builder.remove_memlet(block, *oedge_output);
×
NEW
160
    builder.remove_node(block, *this);
×
NEW
161
    builder.remove_child(parent, block);
×
162

NEW
163
    return true;
×
NEW
164
}
×
165

166
std::unique_ptr<data_flow::DataFlowNode> SoftmaxNode::
NEW
167
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph &parent) const {
×
NEW
168
    return std::unique_ptr<data_flow::DataFlowNode>(new SoftmaxNode(
×
NEW
169
        element_id, this->debug_info(), vertex, parent, axis_
×
170
    ));
NEW
171
}
×
172

NEW
173
nlohmann::json SoftmaxNodeSerializer::serialize(const data_flow::LibraryNode &library_node) {
×
NEW
174
    const SoftmaxNode &node = static_cast<const SoftmaxNode &>(library_node);
×
NEW
175
    nlohmann::json j;
×
176

NEW
177
    j["code"] = node.code().value();
×
NEW
178
    j["axis"] = node.axis();
×
179

NEW
180
    return j;
×
NEW
181
}
×
182

NEW
183
data_flow::LibraryNode &SoftmaxNodeSerializer::deserialize(
×
184
    const nlohmann::json &j, builder::StructuredSDFGBuilder &builder, structured_control_flow::Block &parent
185
) {
NEW
186
    auto code = j["code"].get<std::string>();
×
NEW
187
    if (code != LibraryNodeType_Softmax.value()) {
×
NEW
188
        throw std::runtime_error("Invalid library node code");
×
189
    }
190

NEW
191
    sdfg::serializer::JSONSerializer serializer;
×
NEW
192
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
193

NEW
194
    auto axis = j["axis"].get<int>();
×
195

NEW
196
    return builder.add_library_node<SoftmaxNode>(parent, debug_info, axis);
×
NEW
197
}
×
198

199
} // namespace ml
200
} // namespace math
201
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc