• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20621841097

31 Dec 2025 03:18PM UTC coverage: 39.655% (-0.06%) from 39.712%
20621841097

Pull #421

github

web-flow
Merge 7662c1b88 into 3b72c335e
Pull Request #421: Extend tensor library nodes with primitive type support and refactor CMathNode to use enums

14996 of 49220 branches covered (30.47%)

Branch coverage included in aggregate %.

247 of 608 new or added lines in 52 files covered. (40.63%)

38 existing lines in 5 files now uncovered.

12874 of 21062 relevant lines covered (61.12%)

89.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/src/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/softmax_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/data_flow/library_nodes/math/tensor/broadcast_node.h"
5
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/div_node.h"
6
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/exp_node.h"
7
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/sub_node.h"
8
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/max_node.h"
9
#include "sdfg/data_flow/library_nodes/math/tensor/reduce_ops/sum_node.h"
10
#include "sdfg/data_flow/library_nodes/stdlib/malloc.h"
11
#include "sdfg/structured_control_flow/block.h"
12
#include "sdfg/structured_control_flow/for.h"
13
#include "sdfg/types/pointer.h"
14
#include "sdfg/types/scalar.h"
15
#include "sdfg/types/utils.h"
16

17
namespace sdfg {
18
namespace math {
19
namespace tensor {
20

21
SoftmaxNode::SoftmaxNode(
×
22
    size_t element_id,
23
    const DebugInfo& debug_info,
24
    const graph::Vertex vertex,
25
    data_flow::DataFlowGraph& parent,
26
    const std::vector<symbolic::Expression>& shape,
27
    const std::vector<int64_t>& axes,
28
    bool keepdims
29
)
30
    : ReduceNode(element_id, debug_info, vertex, parent, LibraryNodeType_Softmax, shape, axes, keepdims) {
×
31
    if (keepdims) {
×
32
        throw InvalidSDFGException("Unsupported attribute on library node: softmax");
×
33
    }
34
}
×
35

36
bool SoftmaxNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
37
    auto& dataflow = this->get_parent();
×
38
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
39

40
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
41
        return false;
×
42
    }
43

44
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
45
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
46
    int index = parent.index(block);
×
47

48
    auto& in_edge = *dataflow.in_edges(*this).begin();
×
49
    auto& out_edge = *dataflow.out_edges(*this).begin();
×
50
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
×
51
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
×
52

53
    // Calculate reduced shape (for Max and Sum)
54
    std::vector<symbolic::Expression> reduced_shape;
×
55
    std::vector<int64_t> sorted_axes = axes_;
×
56
    std::sort(sorted_axes.begin(), sorted_axes.end());
×
57

58
    for (size_t i = 0; i < shape_.size(); ++i) {
×
59
        bool is_axis = false;
×
60
        for (auto axis : sorted_axes) {
×
61
            if (axis == (int64_t) i) {
×
62
                is_axis = true;
×
63
                break;
×
64
            }
65
        }
66

67
        if (is_axis) {
×
68
            reduced_shape.push_back(symbolic::one());
×
69
        } else {
×
70
            reduced_shape.push_back(shape_[i]);
×
71
        }
72
    }
×
73

NEW
74
    types::Scalar element_type(this->primitive_type(dataflow));
×
NEW
75
    types::Pointer pointer_type(element_type);
×
76

77
    // Temporary buffers
78
    std::string tmp_max_name = builder.find_new_name("_softmax_max");
×
NEW
79
    builder.add_container(tmp_max_name, pointer_type);
×
80

81
    std::string tmp_max_bcast_name = builder.find_new_name("_softmax_max_bcast");
×
NEW
82
    builder.add_container(tmp_max_bcast_name, pointer_type);
×
83

84
    std::string tmp_sub_name = builder.find_new_name("_softmax_sub");
×
NEW
85
    builder.add_container(tmp_sub_name, pointer_type);
×
86

87
    std::string tmp_exp_name = builder.find_new_name("_softmax_exp");
×
NEW
88
    builder.add_container(tmp_exp_name, pointer_type);
×
89

90
    std::string tmp_sum_name = builder.find_new_name("_softmax_sum");
×
NEW
91
    builder.add_container(tmp_sum_name, pointer_type);
×
92

93
    std::string tmp_sum_bcast_name = builder.find_new_name("_softmax_sum_bcast");
×
NEW
94
    builder.add_container(tmp_sum_bcast_name, pointer_type);
×
95

96
    // Mallocs
97
    if (in_edge.base_type().type_id() == types::TypeID::Pointer) {
×
NEW
98
        symbolic::Expression bytes_elem = types::get_type_size(element_type, false);
×
99

100
        symbolic::Expression bytes_full = bytes_elem;
×
101
        for (auto& dim : this->shape_) {
×
102
            bytes_full = symbolic::mul(dim, bytes_full);
×
103
        }
104

105
        symbolic::Expression bytes_reduced = bytes_elem;
×
106
        for (auto& dim : reduced_shape) {
×
107
            bytes_reduced = symbolic::mul(dim, bytes_reduced);
×
108
        }
109

110
        auto& alloc_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
111

112
        auto malloc_helper = [&](const std::string& name, const symbolic::Expression& size) {
×
113
            auto& access = builder.add_access(alloc_block, name);
×
114
            auto& malloc_node = builder.add_library_node<stdlib::MallocNode>(alloc_block, this->debug_info(), size);
×
115
            builder.add_computational_memlet(
×
116
                alloc_block, malloc_node, "_ret", access, {}, in_edge.base_type(), this->debug_info()
×
117
            );
118
        };
×
119

120
        malloc_helper(tmp_max_name, bytes_reduced);
×
121
        malloc_helper(tmp_max_bcast_name, bytes_full);
×
122
        malloc_helper(tmp_sub_name, bytes_full);
×
123
        malloc_helper(tmp_exp_name, bytes_full);
×
124
        malloc_helper(tmp_sum_name, bytes_reduced);
×
125
        malloc_helper(tmp_sum_bcast_name, bytes_full);
×
126
    }
×
127

128
    // 1. Max(X) -> TmpMax
129
    {
130
        auto& max_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
131
        auto& max_node =
×
132
            builder.add_library_node<MaxNode>(max_block, this->debug_info(), this->shape_, this->axes_, true);
×
133

134
        auto& in_access = builder.add_access(max_block, in_node.data());
×
135
        auto& out_access = builder.add_access(max_block, tmp_max_name);
×
136

137
        builder
×
138
            .add_computational_memlet(max_block, in_access, max_node, "X", {}, in_edge.base_type(), this->debug_info());
×
NEW
139
        builder.add_computational_memlet(max_block, max_node, "Y", out_access, {}, pointer_type, this->debug_info());
×
140
    }
141

142
    // 1.5 Broadcast Max -> TmpMaxBcast
143
    {
144
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
145
        auto& bcast_node =
×
146
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
147

148
        auto& in_access = builder.add_access(bcast_block, tmp_max_name);
×
149
        auto& out_access = builder.add_access(bcast_block, tmp_max_bcast_name);
×
150

NEW
151
        builder.add_computational_memlet(bcast_block, in_access, bcast_node, "X", {}, pointer_type, this->debug_info());
×
NEW
152
        builder.add_computational_memlet(bcast_block, bcast_node, "Y", out_access, {}, pointer_type, this->debug_info());
×
153
    }
154

155
    // 2. Sub(X, TmpMaxBcast) -> TmpSub
156
    {
157
        auto& sub_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
158
        auto& sub_node = builder.add_library_node<SubNode>(sub_block, this->debug_info(), this->shape_);
×
159

160
        auto& in1_access = builder.add_access(sub_block, in_node.data());
×
161
        auto& in2_access = builder.add_access(sub_block, tmp_max_bcast_name);
×
162
        auto& out_access = builder.add_access(sub_block, tmp_sub_name);
×
163

164
        builder
×
165
            .add_computational_memlet(sub_block, in1_access, sub_node, "A", {}, in_edge.base_type(), this->debug_info());
×
NEW
166
        builder.add_computational_memlet(sub_block, in2_access, sub_node, "B", {}, pointer_type, this->debug_info());
×
NEW
167
        builder.add_computational_memlet(sub_block, sub_node, "C", out_access, {}, pointer_type, this->debug_info());
×
168
    }
169

170
    // 3. Exp(TmpSub) -> TmpExp
171
    {
172
        auto& exp_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
173
        auto& exp_node = builder.add_library_node<ExpNode>(exp_block, this->debug_info(), this->shape_);
×
174

175
        auto& in_access = builder.add_access(exp_block, tmp_sub_name);
×
176
        auto& out_access = builder.add_access(exp_block, tmp_exp_name);
×
177

NEW
178
        builder.add_computational_memlet(exp_block, in_access, exp_node, "X", {}, pointer_type, this->debug_info());
×
NEW
179
        builder.add_computational_memlet(exp_block, exp_node, "Y", out_access, {}, pointer_type, this->debug_info());
×
180
    }
181

182
    // 4. Sum(TmpExp) -> TmpSum
183
    {
184
        auto& sum_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
185
        auto& sum_node =
×
186
            builder.add_library_node<SumNode>(sum_block, this->debug_info(), this->shape_, this->axes_, true);
×
187

188
        auto& in_access = builder.add_access(sum_block, tmp_exp_name);
×
189
        auto& out_access = builder.add_access(sum_block, tmp_sum_name);
×
190

NEW
191
        builder.add_computational_memlet(sum_block, in_access, sum_node, "X", {}, pointer_type, this->debug_info());
×
NEW
192
        builder.add_computational_memlet(sum_block, sum_node, "Y", out_access, {}, pointer_type, this->debug_info());
×
193
    }
194

195
    // 4.5 Broadcast Sum -> TmpSumBcast
196
    {
197
        auto& bcast_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
198
        auto& bcast_node =
×
199
            builder.add_library_node<BroadcastNode>(bcast_block, this->debug_info(), reduced_shape, this->shape_);
×
200

201
        auto& in_access = builder.add_access(bcast_block, tmp_sum_name);
×
202
        auto& out_access = builder.add_access(bcast_block, tmp_sum_bcast_name);
×
203

NEW
204
        builder.add_computational_memlet(bcast_block, in_access, bcast_node, "X", {}, pointer_type, this->debug_info());
×
NEW
205
        builder.add_computational_memlet(bcast_block, bcast_node, "Y", out_access, {}, pointer_type, this->debug_info());
×
206
    }
207

208
    // 5. Div(TmpExp, TmpSumBcast) -> Output
209
    {
210
        auto& div_block = builder.add_block_before(parent, block, {}, this->debug_info());
×
211
        auto& div_node = builder.add_library_node<DivNode>(div_block, this->debug_info(), this->shape_);
×
212

213
        auto& in1_access = builder.add_access(div_block, tmp_exp_name);
×
214
        auto& in2_access = builder.add_access(div_block, tmp_sum_bcast_name);
×
215
        auto& out_access = builder.add_access(div_block, out_node.data());
×
216

NEW
217
        builder.add_computational_memlet(div_block, in1_access, div_node, "A", {}, pointer_type, this->debug_info());
×
NEW
218
        builder.add_computational_memlet(div_block, in2_access, div_node, "B", {}, pointer_type, this->debug_info());
×
219
        builder
×
220
            .add_computational_memlet(div_block, div_node, "C", out_access, {}, out_edge.base_type(), this->debug_info());
×
221
    }
222

223
    // Cleanup
224
    builder.remove_memlet(block, in_edge);
×
225
    builder.remove_memlet(block, out_edge);
×
226
    builder.remove_node(block, in_node);
×
227
    builder.remove_node(block, out_node);
×
228
    builder.remove_node(block, *this);
×
229

230
    int last_index = parent.index(block);
×
231
    builder.remove_child(parent, last_index);
×
232

233
    return true;
×
234
}
×
235

236
std::unique_ptr<data_flow::DataFlowNode> SoftmaxNode::
237
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
238
    return std::unique_ptr<
×
239
        data_flow::DataFlowNode>(new SoftmaxNode(element_id, this->debug_info(), vertex, parent, this->shape_, this->axes_)
×
240
    );
241
}
×
242

243
} // namespace tensor
244
} // namespace math
245
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc