• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/broadcast_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/broadcast_node.h"
2
#include "sdfg/analysis/scope_analysis.h"
3
#include "sdfg/builder/structured_sdfg_builder.h"
4
#include "sdfg/structured_control_flow/for.h"
5

6
namespace sdfg {
7
namespace math {
8
namespace tensor {
9

10
BroadcastNode::BroadcastNode(
11
    size_t element_id,
12
    const DebugInfo& debug_info,
13
    const graph::Vertex vertex,
14
    data_flow::DataFlowGraph& parent,
15
    const std::vector<symbolic::Expression>& input_shape,
16
    const std::vector<symbolic::Expression>& output_shape
17
)
18
    : TensorNode(
×
19
          element_id,
×
20
          debug_info,
×
21
          vertex,
×
22
          parent,
×
23
          LibraryNodeType_Broadcast,
×
24
          {"Y"},
×
25
          {"X"},
×
26
          data_flow::ImplementationType_NONE
×
27
      ),
×
28
      input_shape_(input_shape), output_shape_(output_shape) {}
×
29

NEW
30
void BroadcastNode::validate(const Function& function) const {
×
NEW
31
    TensorNode::validate(function);
×
32

NEW
33
    auto& graph = this->get_parent();
×
34

NEW
35
    auto& iedge = *graph.in_edges(*this).begin();
×
NEW
36
    auto& shape = static_cast<const types::Tensor&>(iedge.base_type());
×
NEW
37
    if (!shape.is_scalar()) {
×
NEW
38
        if (shape.shape().size() != this->input_shape_.size()) {
×
NEW
39
            throw InvalidSDFGException(
×
NEW
40
                "Library Node: Tensor shape must match node shape. Tensor shape: " +
×
NEW
41
                std::to_string(shape.shape().size()) + " Node shape: " + std::to_string(this->input_shape_.size())
×
NEW
42
            );
×
NEW
43
        }
×
NEW
44
        for (size_t i = 0; i < this->input_shape_.size(); ++i) {
×
NEW
45
            if (!symbolic::eq(shape.shape().at(i), this->input_shape_.at(i))) {
×
NEW
46
                throw InvalidSDFGException(
×
NEW
47
                    "Library Node: Tensor shape does not match expected shape. Tensor shape: " +
×
NEW
48
                    shape.shape().at(i)->__str__() + " Expected shape: " + this->input_shape_.at(i)->__str__()
×
NEW
49
                );
×
NEW
50
            }
×
NEW
51
        }
×
NEW
52
    }
×
53

NEW
54
    auto& oedge = *graph.out_edges(*this).begin();
×
NEW
55
    auto& output_shape = static_cast<const types::Tensor&>(oedge.base_type());
×
NEW
56
    if (output_shape.shape().size() != this->output_shape_.size()) {
×
NEW
57
        throw InvalidSDFGException(
×
NEW
58
            "Library Node: Output tensor shape must match node shape. Output tensor shape: " +
×
NEW
59
            std::to_string(output_shape.shape().size()) + " Node shape: " + std::to_string(this->output_shape_.size())
×
NEW
60
        );
×
NEW
61
    }
×
62

NEW
63
    for (size_t i = 0; i < this->output_shape_.size(); ++i) {
×
NEW
64
        if (!symbolic::eq(output_shape.shape().at(i), this->output_shape_.at(i))) {
×
NEW
65
            throw InvalidSDFGException(
×
NEW
66
                "Library Node: Output tensor shape does not match expected shape. Output tensor shape: " +
×
NEW
67
                output_shape.shape().at(i)->__str__() + " Expected shape: " + this->output_shape_.at(i)->__str__()
×
NEW
68
            );
×
NEW
69
        }
×
NEW
70
    }
×
NEW
71
}
×
72

73
symbolic::SymbolSet BroadcastNode::symbols() const {
×
74
    symbolic::SymbolSet syms;
×
75
    for (const auto& dim : input_shape_) {
×
76
        for (auto& atom : symbolic::atoms(dim)) {
×
77
            syms.insert(atom);
×
78
        }
×
79
    }
×
80
    for (const auto& dim : output_shape_) {
×
81
        for (auto& atom : symbolic::atoms(dim)) {
×
82
            syms.insert(atom);
×
83
        }
×
84
    }
×
85
    return syms;
×
86
}
×
87

88
void BroadcastNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
89
    for (auto& dim : input_shape_) {
×
90
        dim = symbolic::subs(dim, old_expression, new_expression);
×
91
    }
×
92
    for (auto& dim : output_shape_) {
×
93
        dim = symbolic::subs(dim, old_expression, new_expression);
×
94
    }
×
95
}
×
96

97
bool BroadcastNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
98
    auto& dataflow = this->get_parent();
×
99
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
100

101
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
102
        return false;
×
103
    }
×
104

105
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
106
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
107

108
    auto& in_edge = *dataflow.in_edges(*this).begin();
×
109
    auto& out_edge = *dataflow.out_edges(*this).begin();
×
110
    auto& in_node = static_cast<data_flow::AccessNode&>(in_edge.src());
×
111
    auto& out_node = static_cast<data_flow::AccessNode&>(out_edge.dst());
×
112

NEW
113
    symbolic::MultiExpression loop_vars;
×
114
    structured_control_flow::Sequence* inner_scope = nullptr;
×
115

116
    for (size_t i = 0; i < output_shape_.size(); ++i) {
×
117
        std::string var_name = builder.find_new_name("_i" + std::to_string(i));
×
118
        builder.add_container(var_name, types::Scalar(types::PrimitiveType::Int64));
×
119

120
        auto sym_var = symbolic::symbol(var_name);
×
121
        auto condition = symbolic::Lt(sym_var, output_shape_[i]);
×
122
        auto init = symbolic::zero();
×
123
        auto update = symbolic::add(sym_var, symbolic::one());
×
124

125
        if (i == 0) {
×
126
            auto& loop = builder.add_map_before(
×
127
                parent,
×
128
                block,
×
129
                sym_var,
×
130
                condition,
×
131
                init,
×
132
                update,
×
133
                structured_control_flow::ScheduleType_Sequential::create(),
×
134
                {},
×
135
                this->debug_info()
×
136
            );
×
137
            inner_scope = &loop.root();
×
138
        } else {
×
139
            auto& loop = builder.add_map(
×
140
                *inner_scope,
×
141
                sym_var,
×
142
                condition,
×
143
                init,
×
144
                update,
×
145
                structured_control_flow::ScheduleType_Sequential::create(),
×
146
                {},
×
147
                this->debug_info()
×
148
            );
×
149
            inner_scope = &loop.root();
×
150
        }
×
151
        loop_vars.push_back(sym_var);
×
152
    }
×
153

154
    auto& tasklet_block = builder.add_block(*inner_scope, {}, this->debug_info());
×
155

156
    auto& in_acc = builder.add_access(tasklet_block, in_node.data());
×
157
    auto& out_acc = builder.add_access(tasklet_block, out_node.data());
×
158

NEW
159
    symbolic::MultiExpression input_subset = {};
×
160
    for (size_t i = 0; i < input_shape_.size(); ++i) {
×
161
        if (!symbolic::eq(input_shape_[i], symbolic::one())) {
×
NEW
162
            input_subset.push_back(loop_vars[i]);
×
NEW
163
        } else {
×
NEW
164
            input_subset.push_back(symbolic::zero());
×
165
        }
×
166
    }
×
NEW
167
    auto& iedge_tensor = static_cast<const types::Tensor&>(in_edge.base_type());
×
NEW
168
    if (iedge_tensor.is_scalar()) {
×
NEW
169
        input_subset = {};
×
UNCOV
170
    }
×
171

NEW
172
    auto& tasklet =
×
NEW
173
        builder.add_tasklet(tasklet_block, data_flow::TaskletCode::assign, "_out", {"_in"}, this->debug_info());
×
174

175
    builder.add_computational_memlet(
×
176
        tasklet_block, in_acc, tasklet, "_in", input_subset, in_edge.base_type(), this->debug_info()
×
177
    );
×
178
    builder.add_computational_memlet(
×
NEW
179
        tasklet_block, tasklet, "_out", out_acc, loop_vars, out_edge.base_type(), this->debug_info()
×
180
    );
×
181

182
    builder.remove_memlet(block, in_edge);
×
183
    builder.remove_memlet(block, out_edge);
×
184
    builder.remove_node(block, in_node);
×
185
    builder.remove_node(block, out_node);
×
186
    builder.remove_node(block, *this);
×
187

188
    int index = parent.index(block);
×
189
    builder.remove_child(parent, index);
×
190

191
    return true;
×
192
}
×
193

194
std::unique_ptr<data_flow::DataFlowNode> BroadcastNode::
195
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
196
    return std::unique_ptr<data_flow::DataFlowNode>(
×
197
        new BroadcastNode(element_id, this->debug_info(), vertex, parent, input_shape_, output_shape_)
×
198
    );
×
199
}
×
200

201
} // namespace tensor
202
} // namespace math
203
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc