• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/library_nodes/math/tensor/broadcast_node.h"
5
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/div_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/exp_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sub_node.h"
8
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/max_node.h"
9
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/sum_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
11
#include "sdfg/structured_control_flow/block.h"
12
#include "sdfg/structured_control_flow/for.h"
13
#include "sdfg/types/pointer.h"
14
#include "sdfg/types/scalar.h"
15
#include "sdfg/types/utils.h"
16

17
namespace sdfg {
18
namespace math {
19
namespace tensor {
20

21
SoftmaxNode::SoftmaxNode(
22
    size_t element_id,
23
    const DebugInfo& debug_info,
24
    const graph::Vertex vertex,
25
    data_flow::DataFlowGraph& parent,
26
    const std::vector<symbolic::Expression>& shape,
27
    const std::vector<int64_t>& axes,
28
    bool keepdims
29
)
30
    : ReduceNode(element_id, debug_info, vertex, parent, LibraryNodeType_Softmax, shape, axes, keepdims) {
×
31
    if (keepdims) {
×
32
        throw InvalidSDFGException("Unsupported attribute on library node: softmax");
×
33
    }
×
34
}
×
35

NEW
36
void SoftmaxNode::validate(const Function& function) const {}
×
37

38
bool SoftmaxNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
39
    auto& dataflow = this->get_parent();
×
40
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
41

42
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
43
        return false;
×
44
    }
×
45

46
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
47
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
48
    int index = parent.index(block);
×
49

50
    auto& in_edge = *dataflow.in_edges(*this).begin();
×
51
    auto& out_edge = *dataflow.out_edges(*this).begin();
×
52
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
×
53
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
×
54

55
    // Calculate reduced shape (for Max and Sum)
56
    std::vector<symbolic::Expression> reduced_shape;
×
57
    std::vector<int64_t> sorted_axes = axes_;
×
58
    // Normalize negative axes
NEW
59
    for (auto& axis : sorted_axes) {
×
NEW
60
        if (axis < 0) {
×
NEW
61
            axis = static_cast<int64_t>(shape_.size()) + axis;
×
NEW
62
        }
×
63
        // Validate axis is in bounds
NEW
64
        if (axis < 0 || axis >= static_cast<int64_t>(shape_.size())) {
×
NEW
65
            throw InvalidSDFGException(
×
NEW
66
                "Library Node: Axis value out of bounds. Axis: " + std::to_string(axis) +
×
NEW
67
                " Shape size: " + std::to_string(shape_.size())
×
NEW
68
            );
×
NEW
69
        }
×
NEW
70
    }
×
UNCOV
71
    std::sort(sorted_axes.begin(), sorted_axes.end());
×
72

73
    for (size_t i = 0; i < shape_.size(); ++i) {
×
74
        bool is_axis = false;
×
75
        for (auto axis : sorted_axes) {
×
76
            if (axis == (int64_t) i) {
×
77
                is_axis = true;
×
78
                break;
×
79
            }
×
80
        }
×
81

82
        if (is_axis) {
×
83
            reduced_shape.push_back(symbolic::one());
×
84
        } else {
×
85
            reduced_shape.push_back(shape_[i]);
×
86
        }
×
87
    }
×
88

89
    types::Scalar element_type(this->primitive_type(dataflow));
×
90
    types::Pointer pointer_type(element_type);
×
91

92
    // Temporary buffers
93
    std::string tmp_max_name = builder.find_new_name("_softmax_max");
×
94
    builder.add_container(tmp_max_name, pointer_type);
×
95

96
    std::string tmp_max_bcast_name = builder.find_new_name("_softmax_max_bcast");
×
97
    builder.add_container(tmp_max_bcast_name, pointer_type);
×
98

99
    std::string tmp_sub_name = builder.find_new_name("_softmax_sub");
×
100
    builder.add_container(tmp_sub_name, pointer_type);
×
101

102
    std::string tmp_exp_name = builder.find_new_name("_softmax_exp");
×
103
    builder.add_container(tmp_exp_name, pointer_type);
×
104

105
    std::string tmp_sum_name = builder.find_new_name("_softmax_sum");
×
106
    builder.add_container(tmp_sum_name, pointer_type);
×
107

108
    std::string tmp_sum_bcast_name = builder.find_new_name("_softmax_sum_bcast");
×
109
    builder.add_container(tmp_sum_bcast_name, pointer_type);
×
110

111
    // Mallocs
NEW
112
    {
×
113
        symbolic::Expression bytes_elem = types::get_type_size(element_type, false);
×
114

115
        symbolic::Expression bytes_full = bytes_elem;
×
116
        for (auto& dim : this->shape_) {
×
117
            bytes_full = symbolic::mul(dim, bytes_full);
×
118
        }
×
119

120
        symbolic::Expression bytes_reduced = bytes_elem;
×
121
        for (auto& dim : reduced_shape) {
×
122
            bytes_reduced = symbolic::mul(dim, bytes_reduced);
×
123
        }
×
124

125
        auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
126

127
        auto malloc_helper = [&](const std::string& name, const symbolic::Expression& size) {
×
128
            auto& access = builder.add_access(alloc_block, name);
×
129
            auto& malloc_node = builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), size);
×
NEW
130
            builder
×
NEW
131
                .add_computational_memlet(alloc_block, malloc_node, "_ret", access, {}, pointer_type, this->debug_info());
×
UNCOV
132
        };
×
133

134
        malloc_helper(tmp_max_name, bytes_reduced);
×
135
        malloc_helper(tmp_max_bcast_name, bytes_full);
×
136
        malloc_helper(tmp_sub_name, bytes_full);
×
137
        malloc_helper(tmp_exp_name, bytes_full);
×
138
        malloc_helper(tmp_sum_name, bytes_reduced);
×
139
        malloc_helper(tmp_sum_bcast_name, bytes_full);
×
140
    }
×
141

142
    // 1. Max(X) -> TmpMax
143
    {
×
144
        auto& max_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
145
        auto& max_node =
×
146
            builder.add_library_node<MaxNode>(max_block, this->debug_info(), this->shape_, this->axes_, true);
×
147

148
        auto& in_access = builder.add_access(max_block, in_node.data());
×
149
        auto& out_access = builder.add_access(max_block, tmp_max_name);
×
150

151
        builder
×
152
            .add_computational_memlet(max_block, in_access, max_node, "X", {}, in_edge.base_type(), this->debug_info());
×
NEW
153
        builder
×
NEW
154
            .add_computational_memlet(max_block, max_node, "Y", out_access, {}, out_edge.base_type(), this->debug_info());
×
UNCOV
155
    }
×
156

157
    // 1.5 Broadcast Max -> TmpMaxBcast
158
    {
×
159
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
160
        auto& bcast_node =
×
161
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
162

163
        auto& in_access = builder.add_access(bcast_block, tmp_max_name);
×
164
        auto& out_access = builder.add_access(bcast_block, tmp_max_bcast_name);
×
165

NEW
166
        builder.add_computational_memlet(
×
NEW
167
            bcast_block, in_access, bcast_node, "X", {}, out_edge.base_type(), this->debug_info()
×
NEW
168
        );
×
NEW
169
        builder.add_computational_memlet(
×
NEW
170
            bcast_block, bcast_node, "Y", out_access, {}, out_edge.base_type(), this->debug_info()
×
NEW
171
        );
×
UNCOV
172
    }
×
173

174
    // 2. Sub(X, TmpMaxBcast) -> TmpSub
175
    {
×
176
        auto& sub_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
177
        auto& sub_node = builder.add_library_node<SubNode>(sub_block, this->debug_info(), this->shape_);
×
178

179
        auto& in1_access = builder.add_access(sub_block, in_node.data());
×
180
        auto& in2_access = builder.add_access(sub_block, tmp_max_bcast_name);
×
181
        auto& out_access = builder.add_access(sub_block, tmp_sub_name);
×
182

183
        builder
×
184
            .add_computational_memlet(sub_block, in1_access, sub_node, "A", {}, in_edge.base_type(), this->debug_info());
×
NEW
185
        builder
×
NEW
186
            .add_computational_memlet(sub_block, in2_access, sub_node, "B", {}, out_edge.base_type(), this->debug_info());
×
NEW
187
        builder
×
NEW
188
            .add_computational_memlet(sub_block, sub_node, "C", out_access, {}, out_edge.base_type(), this->debug_info());
×
UNCOV
189
    }
×
190

191
    // 3. Exp(TmpSub) -> TmpExp
192
    {
×
193
        auto& exp_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
194
        auto& exp_node = builder.add_library_node<ExpNode>(exp_block, this->debug_info(), this->shape_);
×
195

196
        auto& in_access = builder.add_access(exp_block, tmp_sub_name);
×
197
        auto& out_access = builder.add_access(exp_block, tmp_exp_name);
×
198

NEW
199
        builder
×
NEW
200
            .add_computational_memlet(exp_block, in_access, exp_node, "X", {}, out_edge.base_type(), this->debug_info());
×
NEW
201
        builder
×
NEW
202
            .add_computational_memlet(exp_block, exp_node, "Y", out_access, {}, out_edge.base_type(), this->debug_info());
×
UNCOV
203
    }
×
204

205
    // 4. Sum(TmpExp) -> TmpSum
206
    {
×
207
        auto& sum_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
208
        auto& sum_node =
×
209
            builder.add_library_node<SumNode>(sum_block, this->debug_info(), this->shape_, this->axes_, true);
×
210

211
        auto& in_access = builder.add_access(sum_block, tmp_exp_name);
×
212
        auto& out_access = builder.add_access(sum_block, tmp_sum_name);
×
213

NEW
214
        builder
×
NEW
215
            .add_computational_memlet(sum_block, in_access, sum_node, "X", {}, out_edge.base_type(), this->debug_info());
×
NEW
216
        builder
×
NEW
217
            .add_computational_memlet(sum_block, sum_node, "Y", out_access, {}, out_edge.base_type(), this->debug_info());
×
UNCOV
218
    }
×
219

220
    // 4.5 Broadcast Sum -> TmpSumBcast
221
    {
×
222
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
223
        auto& bcast_node =
×
224
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
225

226
        auto& in_access = builder.add_access(bcast_block, tmp_sum_name);
×
227
        auto& out_access = builder.add_access(bcast_block, tmp_sum_bcast_name);
×
228

NEW
229
        builder.add_computational_memlet(
×
NEW
230
            bcast_block, in_access, bcast_node, "X", {}, out_edge.base_type(), this->debug_info()
×
NEW
231
        );
×
NEW
232
        builder.add_computational_memlet(
×
NEW
233
            bcast_block, bcast_node, "Y", out_access, {}, out_edge.base_type(), this->debug_info()
×
NEW
234
        );
×
UNCOV
235
    }
×
236

237
    // 5. Div(TmpExp, TmpSumBcast) -> Output
238
    {
×
239
        auto& div_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
240
        auto& div_node = builder.add_library_node<DivNode>(div_block, this->debug_info(), this->shape_);
×
241

242
        auto& in1_access = builder.add_access(div_block, tmp_exp_name);
×
243
        auto& in2_access = builder.add_access(div_block, tmp_sum_bcast_name);
×
244
        auto& out_access = builder.add_access(div_block, out_node.data());
×
245

NEW
246
        builder
×
NEW
247
            .add_computational_memlet(div_block, in1_access, div_node, "A", {}, out_edge.base_type(), this->debug_info());
×
NEW
248
        builder
×
NEW
249
            .add_computational_memlet(div_block, in2_access, div_node, "B", {}, out_edge.base_type(), this->debug_info());
×
250
        builder
×
251
            .add_computational_memlet(div_block, div_node, "C", out_access, {}, out_edge.base_type(), this->debug_info());
×
252
    }
×
253

254
    // Cleanup
255
    builder.remove_memlet(block, in_edge);
×
256
    builder.remove_memlet(block, out_edge);
×
257
    builder.remove_node(block, in_node);
×
258
    builder.remove_node(block, out_node);
×
259
    builder.remove_node(block, *this);
×
260

261
    int last_index = parent.index(block);
×
262
    builder.remove_child(parent, last_index);
×
263

264
    return true;
×
265
}
×
266

267
std::unique_ptr<data_flow::DataFlowNode> SoftmaxNode::
268
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
269
    return std::unique_ptr<
×
270
        data_flow::DataFlowNode>(new SoftmaxNode(element_id, this->debug_info(), vertex, parent, this->shape_, this->axes_)
×
271
    );
×
272
}
×
273

274
} // namespace tensor
275
} // namespace math
276
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc