• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20585279183

29 Dec 2025 11:48PM UTC coverage: 39.581% (-0.8%) from 40.359%
20585279183

push

github

web-flow
Merge pull request #412 from daisytuner/mean-std-nodes

adds mean/std library nodes

14647 of 48066 branches covered (30.47%)

Branch coverage included in aggregate %.

225 of 622 new or added lines in 14 files covered. (36.17%)

41 existing lines in 1 file now uncovered.

12489 of 20493 relevant lines covered (60.94%)

87.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/library_nodes/math/tensor/broadcast_node.h"
5
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/div_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/exp_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sub_node.h"
8
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/max_node.h"
9
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/sum_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
11
#include "sdfg/structured_control_flow/block.h"
12
#include "sdfg/structured_control_flow/for.h"
13
#include "sdfg/types/pointer.h"
14
#include "sdfg/types/scalar.h"
15
#include "sdfg/types/utils.h"
16

17
namespace sdfg {
18
namespace math {
19
namespace tensor {
20

NEW
21
SoftmaxNode::SoftmaxNode(
×
22
    size_t element_id,
23
    const DebugInfo& debug_info,
24
    const graph::Vertex vertex,
25
    data_flow::DataFlowGraph& parent,
26
    const std::vector<symbolic::Expression>& shape,
27
    const std::vector<int64_t>& axes,
28
    bool keepdims
29
)
NEW
30
    : ReduceNode(element_id, debug_info, vertex, parent, LibraryNodeType_Softmax, shape, axes, keepdims) {
×
NEW
31
    if (keepdims) {
×
NEW
32
        throw InvalidSDFGException("Unsupported attribute on library node: softmax");
×
33
    }
NEW
34
}
×
35

NEW
36
bool SoftmaxNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
NEW
37
    auto& dataflow = this->get_parent();
×
NEW
38
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
39

NEW
40
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
NEW
41
        return false;
×
42
    }
43

NEW
44
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
NEW
45
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
NEW
46
    int index = parent.index(block);
×
47

NEW
48
    auto& in_edge = *dataflow.in_edges(*this).begin();
×
NEW
49
    auto& out_edge = *dataflow.out_edges(*this).begin();
×
NEW
50
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
×
NEW
51
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
×
52

53
    // Calculate reduced shape (for Max and Sum)
NEW
54
    std::vector<symbolic::Expression> reduced_shape;
×
NEW
55
    std::vector<int64_t> sorted_axes = axes_;
×
NEW
56
    std::sort(sorted_axes.begin(), sorted_axes.end());
×
57

NEW
58
    for (size_t i = 0; i < shape_.size(); ++i) {
×
NEW
59
        bool is_axis = false;
×
NEW
60
        for (auto axis : sorted_axes) {
×
NEW
61
            if (axis == (int64_t) i) {
×
NEW
62
                is_axis = true;
×
NEW
63
                break;
×
64
            }
65
        }
66

NEW
67
        if (is_axis) {
×
NEW
68
            reduced_shape.push_back(symbolic::one());
×
NEW
69
        } else {
×
NEW
70
            reduced_shape.push_back(shape_[i]);
×
71
        }
NEW
72
    }
×
73

74
    // Determine element type
NEW
75
    const types::IType* element_type = &in_edge.base_type();
×
NEW
76
    if (in_edge.base_type().type_id() == types::TypeID::Pointer) {
×
NEW
77
        element_type = &static_cast<const types::Pointer&>(in_edge.base_type()).pointee_type();
×
NEW
78
    }
×
79

NEW
80
    types::Pointer intermediate_type(*element_type);
×
81

82
    // Temporary buffers
NEW
83
    std::string tmp_max_name = builder.find_new_name("_softmax_max");
×
NEW
84
    builder.add_container(tmp_max_name, intermediate_type);
×
85

NEW
86
    std::string tmp_max_bcast_name = builder.find_new_name("_softmax_max_bcast");
×
NEW
87
    builder.add_container(tmp_max_bcast_name, intermediate_type);
×
88

NEW
89
    std::string tmp_sub_name = builder.find_new_name("_softmax_sub");
×
NEW
90
    builder.add_container(tmp_sub_name, intermediate_type);
×
91

NEW
92
    std::string tmp_exp_name = builder.find_new_name("_softmax_exp");
×
NEW
93
    builder.add_container(tmp_exp_name, intermediate_type);
×
94

NEW
95
    std::string tmp_sum_name = builder.find_new_name("_softmax_sum");
×
NEW
96
    builder.add_container(tmp_sum_name, intermediate_type);
×
97

NEW
98
    std::string tmp_sum_bcast_name = builder.find_new_name("_softmax_sum_bcast");
×
NEW
99
    builder.add_container(tmp_sum_bcast_name, intermediate_type);
×
100

101
    // Mallocs
NEW
102
    if (in_edge.base_type().type_id() == types::TypeID::Pointer) {
×
NEW
103
        symbolic::Expression bytes_elem = types::get_type_size(*element_type, false);
×
104

NEW
105
        symbolic::Expression bytes_full = bytes_elem;
×
NEW
106
        for (auto& dim : this->shape_) {
×
NEW
107
            bytes_full = symbolic::mul(dim, bytes_full);
×
108
        }
109

NEW
110
        symbolic::Expression bytes_reduced = bytes_elem;
×
NEW
111
        for (auto& dim : reduced_shape) {
×
NEW
112
            bytes_reduced = symbolic::mul(dim, bytes_reduced);
×
113
        }
114

NEW
115
        auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
116

NEW
117
        auto malloc_helper = [&](const std::string& name, const symbolic::Expression& size) {
×
NEW
118
            auto& access = builder.add_access(alloc_block, name);
×
NEW
119
            auto& malloc_node = builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), size);
×
NEW
120
            builder.add_computational_memlet(
×
NEW
121
                alloc_block, malloc_node, "_ret", access, {}, in_edge.base_type(), this->debug_info()
×
122
            );
NEW
123
        };
×
124

NEW
125
        malloc_helper(tmp_max_name, bytes_reduced);
×
NEW
126
        malloc_helper(tmp_max_bcast_name, bytes_full);
×
NEW
127
        malloc_helper(tmp_sub_name, bytes_full);
×
NEW
128
        malloc_helper(tmp_exp_name, bytes_full);
×
NEW
129
        malloc_helper(tmp_sum_name, bytes_reduced);
×
NEW
130
        malloc_helper(tmp_sum_bcast_name, bytes_full);
×
NEW
131
    }
×
132

133
    // 1. Max(X) -> TmpMax
134
    {
NEW
135
        auto& max_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
136
        auto& max_node =
×
NEW
137
            builder.add_library_node<MaxNode>(max_block, this->debug_info(), this->shape_, this->axes_, true);
×
138

NEW
139
        auto& in_access = builder.add_access(max_block, in_node.data());
×
NEW
140
        auto& out_access = builder.add_access(max_block, tmp_max_name);
×
141

NEW
142
        builder
×
NEW
143
            .add_computational_memlet(max_block, in_access, max_node, "X", {}, in_edge.base_type(), this->debug_info());
×
NEW
144
        builder
×
NEW
145
            .add_computational_memlet(max_block, max_node, "Y", out_access, {}, intermediate_type, this->debug_info());
×
146
    }
147

148
    // 1.5 Broadcast Max -> TmpMaxBcast
149
    {
NEW
150
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
151
        auto& bcast_node =
×
NEW
152
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
153

NEW
154
        auto& in_access = builder.add_access(bcast_block, tmp_max_name);
×
NEW
155
        auto& out_access = builder.add_access(bcast_block, tmp_max_bcast_name);
×
156

NEW
157
        builder
×
NEW
158
            .add_computational_memlet(bcast_block, in_access, bcast_node, "X", {}, intermediate_type, this->debug_info());
×
NEW
159
        builder
×
NEW
160
            .add_computational_memlet(bcast_block, bcast_node, "Y", out_access, {}, intermediate_type, this->debug_info());
×
161
    }
162

163
    // 2. Sub(X, TmpMaxBcast) -> TmpSub
164
    {
NEW
165
        auto& sub_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
166
        auto& sub_node = builder.add_library_node<SubNode>(sub_block, this->debug_info(), this->shape_);
×
167

NEW
168
        auto& in1_access = builder.add_access(sub_block, in_node.data());
×
NEW
169
        auto& in2_access = builder.add_access(sub_block, tmp_max_bcast_name);
×
NEW
170
        auto& out_access = builder.add_access(sub_block, tmp_sub_name);
×
171

NEW
172
        builder
×
NEW
173
            .add_computational_memlet(sub_block, in1_access, sub_node, "A", {}, in_edge.base_type(), this->debug_info());
×
NEW
174
        builder
×
NEW
175
            .add_computational_memlet(sub_block, in2_access, sub_node, "B", {}, intermediate_type, this->debug_info());
×
NEW
176
        builder
×
NEW
177
            .add_computational_memlet(sub_block, sub_node, "C", out_access, {}, intermediate_type, this->debug_info());
×
178
    }
179

180
    // 3. Exp(TmpSub) -> TmpExp
181
    {
NEW
182
        auto& exp_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
183
        auto& exp_node = builder.add_library_node<ExpNode>(exp_block, this->debug_info(), this->shape_);
×
184

NEW
185
        auto& in_access = builder.add_access(exp_block, tmp_sub_name);
×
NEW
186
        auto& out_access = builder.add_access(exp_block, tmp_exp_name);
×
187

NEW
188
        builder.add_computational_memlet(exp_block, in_access, exp_node, "X", {}, intermediate_type, this->debug_info());
×
NEW
189
        builder
×
NEW
190
            .add_computational_memlet(exp_block, exp_node, "Y", out_access, {}, intermediate_type, this->debug_info());
×
191
    }
192

193
    // 4. Sum(TmpExp) -> TmpSum
194
    {
NEW
195
        auto& sum_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
196
        auto& sum_node =
×
NEW
197
            builder.add_library_node<SumNode>(sum_block, this->debug_info(), this->shape_, this->axes_, true);
×
198

NEW
199
        auto& in_access = builder.add_access(sum_block, tmp_exp_name);
×
NEW
200
        auto& out_access = builder.add_access(sum_block, tmp_sum_name);
×
201

NEW
202
        builder.add_computational_memlet(sum_block, in_access, sum_node, "X", {}, intermediate_type, this->debug_info());
×
NEW
203
        builder
×
NEW
204
            .add_computational_memlet(sum_block, sum_node, "Y", out_access, {}, intermediate_type, this->debug_info());
×
205
    }
206

207
    // 4.5 Broadcast Sum -> TmpSumBcast
208
    {
NEW
209
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
210
        auto& bcast_node =
×
NEW
211
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
212

NEW
213
        auto& in_access = builder.add_access(bcast_block, tmp_sum_name);
×
NEW
214
        auto& out_access = builder.add_access(bcast_block, tmp_sum_bcast_name);
×
215

NEW
216
        builder
×
NEW
217
            .add_computational_memlet(bcast_block, in_access, bcast_node, "X", {}, intermediate_type, this->debug_info());
×
NEW
218
        builder
×
NEW
219
            .add_computational_memlet(bcast_block, bcast_node, "Y", out_access, {}, intermediate_type, this->debug_info());
×
220
    }
221

222
    // 5. Div(TmpExp, TmpSumBcast) -> Output
223
    {
NEW
224
        auto& div_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
NEW
225
        auto& div_node = builder.add_library_node<DivNode>(div_block, this->debug_info(), this->shape_);
×
226

NEW
227
        auto& in1_access = builder.add_access(div_block, tmp_exp_name);
×
NEW
228
        auto& in2_access = builder.add_access(div_block, tmp_sum_bcast_name);
×
NEW
229
        auto& out_access = builder.add_access(div_block, out_node.data());
×
230

NEW
231
        builder
×
NEW
232
            .add_computational_memlet(div_block, in1_access, div_node, "A", {}, intermediate_type, this->debug_info());
×
NEW
233
        builder
×
NEW
234
            .add_computational_memlet(div_block, in2_access, div_node, "B", {}, intermediate_type, this->debug_info());
×
NEW
235
        builder
×
NEW
236
            .add_computational_memlet(div_block, div_node, "C", out_access, {}, out_edge.base_type(), this->debug_info());
×
237
    }
238

239
    // Cleanup
NEW
240
    builder.remove_memlet(block, in_edge);
×
NEW
241
    builder.remove_memlet(block, out_edge);
×
NEW
242
    builder.remove_node(block, in_node);
×
NEW
243
    builder.remove_node(block, out_node);
×
NEW
244
    builder.remove_node(block, *this);
×
245

NEW
246
    int last_index = parent.index(block);
×
NEW
247
    builder.remove_child(parent, last_index);
×
248

NEW
249
    return true;
×
NEW
250
}
×
251

252
std::unique_ptr<data_flow::DataFlowNode> SoftmaxNode::
NEW
253
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
254
    return std::unique_ptr<
×
NEW
255
        data_flow::DataFlowNode>(new SoftmaxNode(element_id, this->debug_info(), vertex, parent, this->shape_, this->axes_)
×
256
    );
NEW
257
}
×
258

259
} // namespace tensor
260
} // namespace math
261
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc