• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27007027060

05 Jun 2026 09:28AM UTC coverage: 61.275% (-0.02%) from 61.292%
27007027060

push

github

web-flow
Improve Quantization support on TensorNodes (#736)

* Added DataFlowGraph.find_standalone_exit() following the pattern of find_standalone_entry() to abstract away edge types.
* LibNodeDispatcher allows no missing inputs.
  ConvNode explicitly is configured whether it has a bias or not to solve for this.
* Fixed elementwise CMath node toStr()

---------

Co-authored-by: Moritz Timmer <25349452+Moehre2@users.noreply.github.com>

10 of 43 new or added lines in 8 files covered. (23.26%)

1 existing line in 1 file now uncovered.

35592 of 58086 relevant lines covered (61.27%)

11015.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

58.52
/sdfg/src/data_flow/library_nodes/math/tensor/elementwise_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_node.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5
#include "sdfg/data_flow/tasklet.h"
6
#include "sdfg/types/type.h"
7

8
#include "sdfg/analysis/scope_analysis.h"
9

10
namespace sdfg {
11
namespace math {
12
namespace tensor {
13

14
ElementWiseDataflowTensorNode::ElementWiseDataflowTensorNode(
15
    size_t element_id,
16
    const DebugInfo& debug_info,
17
    const graph::Vertex vertex,
18
    data_flow::DataFlowGraph& parent,
19
    const data_flow::LibraryNodeCode& code,
20
    const std::vector<symbolic::Expression>& shape,
21
    const std::string& modified_tensor_conn,
22
    const std::vector<std::string>& tensor_inputs,
23
    QuantizationType quantization,
24
    const data_flow::ImplementationType& impl_type
25
)
26
    : TensorNode(
550✔
27
          element_id,
550✔
28
          debug_info,
550✔
29
          vertex,
550✔
30
          parent,
550✔
31
          code,
550✔
32
          {},
550✔
33
          build_input_conns(modified_tensor_conn, tensor_inputs),
550✔
34
          impl_type
550✔
35
      ),
550✔
36
      fixed_quantization_(quantization), shape_(shape) {}
550✔
37

38
std::vector<std::string> ElementWiseDataflowTensorNode::
39
    build_input_conns(const std::string& modified_tensor_conn, const std::vector<std::string>& tensor_inputs) {
550✔
40
    std::vector<std::string> input_conns;
550✔
41
    input_conns.reserve(1 + input_conns.size());
550✔
42
    input_conns.push_back(modified_tensor_conn);
550✔
43
    input_conns.insert(input_conns.end(), tensor_inputs.begin(), tensor_inputs.end());
550✔
44
    return input_conns;
550✔
45
}
550✔
46

47
types::PrimitiveType ElementWiseDataflowTensorNode::fixed_quantization() const { return fixed_quantization_; }
×
48

NEW
49
void ElementWiseDataflowTensorNode::set_fixed_quantization(const QuantizationType quant) {
×
NEW
50
    fixed_quantization_ = quant;
×
NEW
51
}
×
52

53
types::PrimitiveType ElementWiseDataflowTensorNode::quantization(const data_flow::DataFlowGraph& data_flow_graph
54
) const {
×
55
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
×
56
        return fixed_quantization_;
×
57
    } else {
×
58
        return this->primitive_type(data_flow_graph);
×
59
    }
×
60
}
×
61

62
std::optional<types::PrimitiveType> ElementWiseDataflowTensorNode::uniform_quantization(const data_flow::DataFlowGraph&
63
                                                                                            data_flow_graph) const {
×
64
    if (fixed_quantization_ != QUANTIZATION_MATCH_INPUTS) {
×
65
        auto inferred = this->primitive_type(data_flow_graph);
×
66
        if (inferred == fixed_quantization_) {
×
67
            return fixed_quantization_;
×
68
        } else {
×
69
            return std::nullopt;
×
70
        }
×
71
    } else {
×
72
        return this->primitive_type(data_flow_graph);
×
73
    }
×
74
}
×
75

76
void ElementWiseDataflowTensorNode::validate_target_tensor(const data_flow::DataFlowGraph& graph) const {
556✔
77
    auto* target_ptr_edge = graph.in_edge_for_connector(*this, inputs_.at(0));
556✔
78
    auto& tensor_output = static_cast<const types::Tensor&>(target_ptr_edge->base_type());
556✔
79

80
    validate_shape_matches(shape_, tensor_output.layout(), "output tensor");
556✔
81
}
556✔
82

83
void ElementWiseDataflowTensorNode::validate_all_input_tensors(const data_flow::DataFlowGraph& graph) const {
584✔
84
    for (int i = 1; i < tensor_input_count(); ++i) {
1,457✔
85
        auto* iedge = graph.in_edge_for_connector(*this, inputs_.at(i));
873✔
86
        if (!iedge) {
873✔
87
            throw InvalidSDFGException(
×
88
                "On libNode #" + std::to_string(element_id()) + ": input " + inputs_.at(i) + " is not connected"
×
89
            );
×
90
        }
×
91
        if (iedge->base_type().type_id() == types::TypeID::Scalar) {
873✔
92
            continue;
×
93
        }
×
94
        auto& tensor_input = static_cast<const types::Tensor&>(iedge->base_type());
873✔
95
        // Case 1: Scalar input is allowed as secondary input
96
        if (tensor_input.is_scalar()) {
873✔
97
            continue;
7✔
98
        }
7✔
99

100
        // currently no arbitrary broadcast support! but could be added
101
        validate_shape_matches(shape_, tensor_input.layout(), "input " + inputs_.at(i));
866✔
102
    }
866✔
103
}
584✔
104

105
void ElementWiseDataflowTensorNode::validate_non_tensor_inputs(const data_flow::DataFlowGraph& graph) const {
328✔
106
    for (int i = tensor_input_count(); i < inputs_.size(); ++i) {
328✔
107
        auto* iedge = graph.in_edge_for_connector(*this, inputs_.at(i));
×
108
        if (!iedge) {
×
109
            if (i < mandatory_input_count()) {
×
110
                throw InvalidSDFGException(
×
111
                    "On libNode #" + std::to_string(element_id()) + ": input " + inputs_.at(i) + " is not connected"
×
112
                );
×
113
            } else {
×
114
                continue;
×
115
            }
×
116
        }
×
117
        if (iedge->base_type().type_id() != types::TypeID::Scalar) {
×
118
            throw InvalidSDFGException(
×
119
                "On libNode #" + std::to_string(element_id()) + ": input " + inputs_.at(i) + " is not scalar"
×
120
            );
×
121
        }
×
122
    }
×
123
}
328✔
124

125
void ElementWiseDataflowTensorNode::validate(const Function& function) const {
330✔
126
    TensorNode::validate(function);
330✔
127

128
    auto& graph = this->get_parent();
330✔
129

130
    validate_target_tensor(graph);
330✔
131

132
    validate_all_input_tensors(graph);
330✔
133

134
    validate_non_tensor_inputs(graph);
330✔
135
}
330✔
136

137
symbolic::SymbolSet ElementWiseDataflowTensorNode::symbols() const {
17✔
138
    symbolic::SymbolSet syms;
17✔
139
    for (const auto& dim : shape_) {
68✔
140
        for (auto& atom : symbolic::atoms(dim)) {
68✔
141
            syms.insert(atom);
×
142
        }
×
143
    }
68✔
144
    return syms;
17✔
145
}
17✔
146

147
void ElementWiseDataflowTensorNode::
148
    replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
149
    for (auto& dim : shape_) {
×
150
        dim = symbolic::subs(dim, old_expression, new_expression);
×
151
    }
×
152
}
×
153

154
std::pair<structured_control_flow::Sequence*, std::vector<symbolic::Expression>> ElementWiseDataflowTensorNode::
155
    add_eltwise_scope(
156
        builder::StructuredSDFGBuilder& builder,
157
        const DebugInfo& scope_deb_info,
158
        Sequence& parent,
159
        const std::vector<symbolic::Expression>& shape
160
    ) {
521✔
161
    // Add maps
162
    data_flow::Subset new_subset;
521✔
163
    std::vector<symbolic::Expression> loop_vars;
521✔
164
    structured_control_flow::Sequence* last_scope = &parent;
521✔
165
    structured_control_flow::Map* last_map = nullptr;
521✔
166

167
    for (size_t i = 0; i < shape.size(); i++) {
1,798✔
168
        std::string indvar_str = builder.find_new_name("_i");
1,277✔
169
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
1,277✔
170

171
        auto indvar = symbolic::symbol(indvar_str);
1,277✔
172
        auto init = symbolic::zero();
1,277✔
173
        auto update = symbolic::add(indvar, symbolic::one());
1,277✔
174
        auto condition = symbolic::Lt(indvar, shape.at(i));
1,277✔
175
        last_map = &builder.add_map(
1,277✔
176
            *last_scope,
1,277✔
177
            indvar,
1,277✔
178
            condition,
1,277✔
179
            init,
1,277✔
180
            update,
1,277✔
181
            structured_control_flow::ScheduleType_Sequential::create(),
1,277✔
182
            {},
1,277✔
183
            scope_deb_info
1,277✔
184
        );
1,277✔
185
        last_scope = &last_map->root();
1,277✔
186

187
        loop_vars.push_back(indvar);
1,277✔
188
    }
1,277✔
189
    return {last_scope, loop_vars};
521✔
190
}
521✔
191

192
std::unique_ptr<types::IType> ElementWiseDataflowTensorNode::access_type(const std::pair<
193
                                                                         types::PrimitiveType,
194
                                                                         const TensorLayout*>& pair) {
812✔
195
    if (pair.second) {
812✔
196
        return std::make_unique<types::Tensor>(pair.first, *pair.second);
812✔
197
    } else {
812✔
198
        return std::make_unique<types::Scalar>(pair.first);
×
199
    }
×
200
}
812✔
201

202
bool ElementWiseDataflowTensorNode::create_input(
203
    builder::StructuredSDFGBuilder& builder,
204
    structured_control_flow::Block& block,
205
    const data_flow::AccessNode& org_src,
206
    const std::pair<types::PrimitiveType, const TensorLayout*>& src_type,
207
    const ElementInput& needed_input,
208
    const std::vector<symbolic::Expression>& eltwise_subset,
209
    std::unordered_map<const data_flow::AccessNode*, data_flow::AccessNode*>& new_node_mapping
210
) {
812✔
211
    auto* new_consumer = needed_input.consumer;
812✔
212
    if (new_consumer) {
812✔
213
        if (src_type.first != needed_input.required_type) {
812✔
214
            throw InvalidSDFGException(
×
215
                "Input " + std::to_string(needed_input.input_conn_index) + " on node #" +
×
216
                std::to_string(new_consumer->element_id()) + " is required as " +
×
217
                types::primitive_type_to_string(needed_input.required_type) + " but provided as " +
×
218
                types::primitive_type_to_string(src_type.first)
×
219
            );
×
220
        }
×
221
        auto existing_input_it = new_node_mapping.find(&org_src);
812✔
222
        data_flow::AccessNode* input_node;
812✔
223
        std::vector<symbolic::Expression> empty_subset;
812✔
224
        const std::vector<symbolic::Expression>* memlet_subset;
812✔
225
        if (src_type.second && !src_type.second->is_scalar()) {
812✔
226
            memlet_subset = &eltwise_subset;
812✔
227
        } else {
812✔
228
            memlet_subset = &empty_subset;
×
229
        }
×
230
        auto new_type = access_type(src_type);
812✔
231
        if (existing_input_it != new_node_mapping.end()) {
812✔
232
            input_node = existing_input_it->second;
×
233
        } else {
812✔
234
            if (org_src.is_constant()) {
812✔
235
                types::Scalar const_type(src_type.first);
×
236
                input_node = &builder.add_constant(block, org_src.data(), const_type);
×
237
            } else {
812✔
238
                input_node = &builder.add_access(block, org_src.data());
812✔
239
            }
812✔
240
            new_node_mapping.emplace(&org_src, input_node);
812✔
241
        }
812✔
242

243
        builder.add_computational_memlet(
812✔
244
            block,
812✔
245
            *input_node,
812✔
246
            *new_consumer,
812✔
247
            new_consumer->input(needed_input.input_conn_index),
812✔
248
            *memlet_subset,
812✔
249
            *new_type
812✔
250
        );
812✔
251
        return true;
812✔
252
    } else {
812✔
253
        return false;
×
254
    }
×
255
}
812✔
256

257
void ElementWiseDataflowTensorNode::create_output(
258
    builder::StructuredSDFGBuilder& builder,
259
    structured_control_flow::Block& block,
260
    const data_flow::AccessNode& org_dst,
261
    const types::Tensor& dst_type,
262
    const ElementOutput& provided_output,
263
    const std::vector<symbolic::Expression>& eltwise_subset
264
) {
521✔
265
    auto* producer = provided_output.producer;
521✔
266
    if (dst_type.primitive_type() != provided_output.type) {
521✔
267
        throw InvalidSDFGException(
×
268
            "Output " + std::to_string(provided_output.output_conn_index) + " on node #" +
×
269
            std::to_string(producer->element_id()) + " is provided as " +
×
270
            types::primitive_type_to_string(provided_output.type) + " but required as " +
×
271
            types::primitive_type_to_string(dst_type.primitive_type())
×
272
        );
×
273
    }
×
274
    auto& output_node = builder.add_access(block, org_dst.data());
521✔
275
    builder.add_computational_memlet(
521✔
276
        block, *producer, producer->output(provided_output.output_conn_index), output_node, eltwise_subset, dst_type
521✔
277
    );
521✔
278
}
521✔
279

280
bool ElementWiseDataflowTensorNode::
281
    expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
521✔
282
    auto& dataflow = this->get_parent();
521✔
283
    auto& org_block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
521✔
284

285
    auto* output_tensor_iedge = dataflow.in_edge_for_connector(*this, inputs_.at(0));
521✔
286
    if (!output_tensor_iedge) {
521✔
287
        return false;
×
288
    }
×
289
    auto& target_tensor = static_cast<const types::Tensor&>(output_tensor_iedge->base_type());
521✔
290
    std::vector<const data_flow::Memlet*> iedges;
521✔
291
    std::vector<const data_flow::AccessNode*> inputs_sa;
521✔
292
    std::vector<std::pair<types::PrimitiveType, const TensorLayout*>> input_types;
521✔
293
    iedges.reserve(inputs_.size() - 1);
521✔
294
    for (int i = 1; i < this->inputs_.size(); ++i) {
1,333✔
295
        auto* iedge = dataflow.in_edge_for_connector(*this, inputs_.at(i));
812✔
296
        if (!iedge) {
812✔
297
            if (i < mandatory_input_count()) {
×
298
                return false;
×
299
            } else {
×
300
                continue;
×
301
            }
×
302
        }
×
303
        iedges.push_back(iedge);
812✔
304
        auto* input_sa = dataflow.find_standalone_entry(iedge);
812✔
305
        if (!input_sa) {
812✔
306
            return false;
×
307
        }
×
308
        inputs_sa.push_back(input_sa);
812✔
309
        auto& input_type = iedge->base_type();
812✔
310
        if (input_type.type_id() == types::TypeID::Scalar) {
812✔
311
            input_types.emplace_back(input_type.primitive_type(), nullptr);
×
312
        } else {
812✔
313
            auto& tensor_type = static_cast<const types::Tensor&>(iedge->base_type());
812✔
314
            input_types.emplace_back(input_type.primitive_type(), &tensor_type.layout());
812✔
315
        }
812✔
316
    }
812✔
317

318
    auto* output_tensor_sa = dataflow.find_standalone_entry(output_tensor_iedge);
521✔
319
    if (!output_tensor_sa) {
521✔
320
        return false;
×
321
    }
×
322

323
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
521✔
324
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&org_block));
521✔
325
    int index = parent.index(org_block);
521✔
326
    auto& transition = parent.at(index).second;
521✔
327

328
    // Add new graph after the current block
329
    auto& new_sequence =
521✔
330
        builder.add_sequence_before(parent, org_block, transition.assignments(), org_block.debug_info());
521✔
331

332
    auto [eltw_scope, loop_vars] = add_eltwise_scope(builder, org_block.debug_info(), new_sequence, shape_);
521✔
333

334
    std::vector<tensor::ElementWiseDataflowTensorNode::ElementInput> eltwise_inputs;
521✔
335
    eltwise_inputs.reserve(inputs_.size() - 1);
521✔
336
    for (int i = 0; i < input_types.size(); ++i) {
1,333✔
337
        eltwise_inputs.push_back({.required_type = input_types.at(i).first});
812✔
338
    }
812✔
339

340
    auto& new_block = builder.add_block(*eltw_scope);
521✔
341

342
    auto produced_output =
521✔
343
        expand_operation_dataflow(builder, analysis_manager, new_block, eltwise_inputs, target_tensor.primitive_type());
521✔
344
    if (!produced_output.producer) {
521✔
345
        return false;
×
346
    }
×
347

348
    std::unordered_map<const data_flow::AccessNode*, data_flow::AccessNode*> new_node_mapping;
521✔
349

350
    // for all old input edge, remove old, create new
351
    for (int i = 0; i < iedges.size(); ++i) {
1,333✔
352
        create_input(
812✔
353
            builder, new_block, *inputs_sa.at(i), input_types.at(i), eltwise_inputs.at(i), loop_vars, new_node_mapping
812✔
354
        );
812✔
355
    }
812✔
356
    create_output(builder, new_block, *output_tensor_sa, target_tensor, produced_output, loop_vars);
521✔
357
    builder.clear_code_node_legacy(org_block, *this);
521✔
358
    // WARNING: this has been deallocated at this point!!
359
    builder.remove_child(parent, index + 1);
521✔
360

361
    return true;
521✔
362
}
521✔
363

364
data_flow::PointerAccessType ElementWiseDataflowTensorNode::pointer_access_type(int input_idx) const {
×
365
    if (input_idx == 0) {
×
366
        return data_flow::PointerAccessMeta::create_full_write_only(symbolic::__nullptr__(), true);
×
367
    } else if (input_idx < tensor_input_count()) {
×
368
        return data_flow::PointerAccessMeta::create_read_only(symbolic::__nullptr__(), true);
×
369
    } else {
×
370
        return TensorNode::pointer_access_type(input_idx);
×
371
    }
×
372
}
×
373

374
data_flow::AccessNode& ElementWiseDataflowTensorNode::create_tmp_access_node(
375
    builder::StructuredSDFGBuilder& builder,
376
    structured_control_flow::Block& block,
377
    const std::string& prefix,
378
    const types::IType& type
379
) const {
12✔
380
    auto cont = builder.find_new_name(prefix);
12✔
381
    builder.add_container(cont, type);
12✔
382
    auto& output_node_add = builder.add_access(block, cont);
12✔
383
    return output_node_add;
12✔
384
}
12✔
385

386
nlohmann::json BaseElementWiseDataflowTensorNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
387
    const ElementWiseDataflowTensorNode& elem_node = static_cast<const ElementWiseDataflowTensorNode&>(library_node);
×
388
    nlohmann::json j;
×
389

390
    j["code"] = elem_node.code().value();
×
391

392
    serializer::JSONSerializer serializer;
×
393
    j["shape"] = nlohmann::json::array();
×
394
    for (auto& dim : elem_node.shape()) {
×
395
        j["shape"].push_back(serializer.expression(dim));
×
396
    }
×
397

398
    j["result_quant"] = elem_node.fixed_quantization();
×
399

400
    return j;
×
401
}
×
402

403
BaseElementWiseDataflowTensorNodeSerializer::BaseDeser BaseElementWiseDataflowTensorNodeSerializer::
404
    deserialize_base_values(const nlohmann::json& j) {
×
405
    assert(j.contains("element_id"));
×
406
    assert(j.contains("code"));
×
407
    assert(j.contains("debug_info"));
×
408

409
    std::vector<symbolic::Expression> shape;
×
410
    if (j.contains("shape")) {
×
411
        for (const auto& dim : j["shape"]) {
×
412
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
413
        }
×
414
    }
×
415

416
    serializer::JSONSerializer serializer;
×
417
    auto debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
418
    return {
×
419
        .shape = shape,
×
420
        .quantization = deserialize_quantization(j, "result_quant", QUANTIZATION_MATCH_INPUTS),
×
421
        .debug_info = debug_info
×
422
    };
×
423
}
×
424

425
} // namespace tensor
426
} // namespace math
427
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc