• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27007027060

05 Jun 2026 09:28AM UTC coverage: 61.275% (-0.02%) from 61.292%
27007027060

push

github

web-flow
Improve Quantization support on TensorNodes (#736)

* Added DataFlowGraph.find_standalone_exit() following the pattern of find_standalone_entry() to abstract away edge types.
* LibNodeDispatcher allows no missing inputs.
  ConvNode explicitly is configured whether it has a bias or not to solve for this.
* Fixed elementwise CMath node toStr()

---------

Co-authored-by: Moritz Timmer <25349452+Moehre2@users.noreply.github.com>

10 of 43 new or added lines in 8 files covered. (23.26%)

1 existing line in 1 file now uncovered.

35592 of 58086 relevant lines covered (61.27%)

11015.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.11
/sdfg/src/data_flow/library_nodes/math/tensor/elementwise_ops/cmath_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/cmath_node.h"
2

3
#include <cstddef>
4
#include <memory>
5
#include <nlohmann/json_fwd.hpp>
6
#include <sstream>
7
#include <string>
8
#include <unordered_map>
9
#include <unordered_set>
10
#include <vector>
11

12
#include "sdfg/analysis/analysis.h"
13
#include "sdfg/analysis/scope_analysis.h"
14
#include "sdfg/builder/structured_sdfg_builder.h"
15
#include "sdfg/data_flow/access_node.h"
16
#include "sdfg/data_flow/data_flow_graph.h"
17
#include "sdfg/data_flow/data_flow_node.h"
18
#include "sdfg/data_flow/library_node.h"
19
#include "sdfg/data_flow/library_nodes/math/cmath/cmath_node.h"
20
#include "sdfg/data_flow/library_nodes/math/tensor/tensor_node.h"
21
#include "sdfg/element.h"
22
#include "sdfg/exceptions.h"
23
#include "sdfg/graph/graph.h"
24
#include "sdfg/serializer/json_serializer.h"
25
#include "sdfg/structured_control_flow/block.h"
26
#include "sdfg/structured_control_flow/sequence.h"
27
#include "sdfg/symbolic/symbolic.h"
28
#include "sdfg/types/scalar.h"
29
#include "sdfg/types/tensor.h"
30

31
namespace sdfg {
32
namespace math {
33
namespace tensor {
34

35
CMathTensorNode::CMathTensorNode(
36
    size_t element_id,
37
    const DebugInfo& debug_info,
38
    const graph::Vertex vertex,
39
    data_flow::DataFlowGraph& parent,
40
    const cmath::CMathFunction cmath_function,
41
    const std::string& modified_tensor_conn,
42
    const std::vector<std::string>& tensor_inputs,
43
    const std::vector<symbolic::Expression>& shape,
44
    QuantizationType quantization,
45
    const data_flow::ImplementationType& impl_type
46
)
47
    : ElementWiseDataflowTensorNode(
228✔
48
          element_id,
228✔
49
          debug_info,
228✔
50
          vertex,
228✔
51
          parent,
228✔
52
          LibraryNodeType_TensorCMath,
228✔
53
          shape,
228✔
54
          modified_tensor_conn,
228✔
55
          tensor_inputs,
228✔
56
          quantization,
228✔
57
          impl_type
228✔
58
      ),
228✔
59
      cmath_function_(cmath_function) {}
228✔
60

61
void CMathTensorNode::validate(const Function& function) const {
228✔
62
    auto& graph = this->get_parent();
228✔
63

64
    validate_target_tensor(graph);
228✔
65
    validate_all_input_tensors(graph);
228✔
66

67
    auto actual_inputs = this->inputs().size() - 1;
228✔
68
    // Validate: inputs match arity
69
    if (cmath::cmath_function_to_arity(this->cmath_function()) != actual_inputs) {
228✔
70
        throw InvalidSDFGException(
×
71
            "CMathTensorNode (Code: " + std::string(cmath::cmath_function_to_stem(this->cmath_function())) +
×
72
            "): Invalid number of inputs. Expected " +
×
73
            std::to_string(cmath::cmath_function_to_arity(this->cmath_function())) + ", got " +
×
74
            std::to_string(actual_inputs)
×
75
        );
×
76
    }
×
77
}
228✔
78

79
cmath::CMathFunction CMathTensorNode::cmath_function() const { return this->cmath_function_; }
684✔
80

81
ElementWiseDataflowTensorNode::ElementOutput CMathTensorNode::expand_operation_dataflow(
82
    builder::StructuredSDFGBuilder& builder,
83
    analysis::AnalysisManager& analysis_manager,
84
    Block& block,
85
    std::vector<ElementInput>& needed_inputs,
86
    types::PrimitiveType expected_type
87
) {
228✔
88
    if (cmath::cmath_function_to_arity(this->cmath_function()) > needed_inputs.size()) {
228✔
89
        return {}; // not mappable, probably invalid
×
90
    }
×
91

92
    auto prim_type = needed_inputs.at(0).required_type;
228✔
93

94
    auto& libnode = builder.add_library_node<cmath::CMathNode>(block, debug_info_, this->cmath_function(), prim_type);
228✔
95
    auto& inputs = libnode.inputs();
228✔
96
    for (size_t i = 0; i < inputs.size(); i++) {
516✔
97
        auto& tensor_input = needed_inputs.at(i);
288✔
98
        tensor_input.consumer = &libnode;
288✔
99
        tensor_input.input_conn_index = i;
288✔
100
    }
288✔
101

102
    // validate that expected_type is also output by cmath function
103

104
    return {.producer = &libnode, .output_conn_index = 0, .type = expected_type};
228✔
105
}
228✔
106

107
bool CMathTensorNode::supports_integer_types() const {
×
108
    return this->cmath_function() == cmath::CMathFunction::lrint ||
×
109
           this->cmath_function() == cmath::CMathFunction::llrint ||
×
110
           this->cmath_function() == cmath::CMathFunction::lround ||
×
111
           this->cmath_function() == cmath::CMathFunction::llround;
×
112
}
×
113

114
std::unique_ptr<data_flow::DataFlowNode> CMathTensorNode::
115
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
116
    return std::unique_ptr<data_flow::DataFlowNode>(new CMathTensorNode(
×
117
        element_id,
×
118
        this->debug_info(),
×
119
        vertex,
×
120
        parent,
×
121
        this->cmath_function(),
×
122
        inputs_.at(0),
×
123
        std::vector<std::string>(inputs_.cbegin() + 1, inputs_.cend()),
×
124
        this->shape(),
×
125
        fixed_quantization_,
×
126
        implementation_type_
×
127
    ));
×
128
}
×
129

130
std::string CMathTensorNode::toStr() const {
×
131
    std::stringstream stream;
×
132

NEW
133
    const auto* iedge = this->get_parent().in_edge_for_connector(*this, this->input(0));
×
NEW
134
    stream << this->code().value() << "("
×
NEW
135
           << cmath::get_cmath_intrinsic_name(this->cmath_function(), iedge->base_type().primitive_type()) << ")";
×
136

137
    return stream.str();
×
138
}
×
139

140
nlohmann::json CMathTensorNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
141
    const auto& elem_node = static_cast<const CMathTensorNode&>(library_node);
×
142
    nlohmann::json j = BaseElementWiseDataflowTensorNodeSerializer::serialize(library_node);
×
143

144
    auto input_arr = nlohmann::json::array();
×
145
    for (auto& input : elem_node.inputs()) {
×
146
        input_arr.push_back(input);
×
147
    }
×
148
    j["inputs"] = input_arr;
×
149

150
    j["cmath_function"] = cmath::cmath_function_to_stem(elem_node.cmath_function());
×
151

152
    return j;
×
153
}
×
154

155
data_flow::LibraryNode& CMathTensorNodeSerializer::deserialize(
156
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
157
) {
×
158
    auto base = deserialize_base_values(j);
×
159

160
    // Assertions for required fields
161
    assert(j.contains("inputs"));
×
162
    assert(j.contains("cmath_function"));
×
163

164
    std::vector<std::string> inputs;
×
165
    for (const auto& input : j["inputs"]) {
×
166
        inputs.push_back(input.get<std::string>());
×
167
    }
×
168

169
    auto cmath_function = cmath::string_to_cmath_function(j["cmath_function"].get<std::string>());
×
170

171
    std::vector<std::string> tensor_inputs(inputs.cbegin() + 1, inputs.cend());
×
172

173
    return static_cast<CMathTensorNode&>(builder.add_library_node<CMathTensorNode>(
×
174
        parent, base.debug_info, cmath_function, inputs.at(0), tensor_inputs, base.shape, base.quantization
×
175
    ));
×
176
}
×
177

178
} // namespace tensor
179
} // namespace math
180
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc