• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20621841097

31 Dec 2025 03:18PM UTC coverage: 39.655% (-0.06%) from 39.712%
20621841097

Pull #421

github

web-flow
Merge 7662c1b88 into 3b72c335e
Pull Request #421: Extend tensor library nodes with primitive type support and refactor CMathNode to use enums

14996 of 49220 branches covered (30.47%)

Branch coverage included in aggregate %.

247 of 608 new or added lines in 52 files covered. (40.63%)

38 existing lines in 5 files now uncovered.

12874 of 21062 relevant lines covered (61.12%)

89.39 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

46.24
/src/data_flow/library_nodes/math/tensor/tensor_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/tensor_node.h"
2

3
#include "sdfg/types/pointer.h"
4
#include "sdfg/types/scalar.h"
5

6
namespace sdfg {
7
namespace math {
8
namespace tensor {
9

10
TensorNode::TensorNode(
108✔
11
    size_t element_id,
12
    const DebugInfo& debug_info,
13
    const graph::Vertex vertex,
14
    data_flow::DataFlowGraph& parent,
15
    const data_flow::LibraryNodeCode& code,
16
    const std::vector<std::string>& outputs,
17
    const std::vector<std::string>& inputs,
18
    data_flow::ImplementationType impl_type
19
)
20
    : MathNode(element_id, debug_info, vertex, parent, code, outputs, inputs, impl_type) {}
108✔
21

22
void TensorNode::validate(const Function& function) const {
24✔
23
    auto& graph = this->get_parent();
24✔
24

25
    // Check that all input memlets are scalar or pointer of scalar
26
    for (auto& iedge : graph.in_edges(*this)) {
58✔
27
        if (iedge.base_type().type_id() != types::TypeID::Scalar &&
34!
28
            iedge.base_type().type_id() != types::TypeID::Pointer) {
26✔
29
            throw InvalidSDFGException(
1!
NEW
30
                "TensorNode: Input memlet must be of scalar or pointer type. Found type: " + iedge.base_type().print()
×
31
            );
32
        }
33
        if (iedge.base_type().type_id() == types::TypeID::Pointer) {
34✔
34
            auto& ptr_type = static_cast<const types::Pointer&>(iedge.base_type());
26✔
35
            if (ptr_type.pointee_type().type_id() != types::TypeID::Scalar) {
26!
NEW
36
                throw InvalidSDFGException(
×
NEW
37
                    "TensorNode: Input memlet pointer must be flat (pointer to scalar). Found type: " +
×
NEW
38
                    ptr_type.pointee_type().print()
×
39
                );
40
            }
41
            if (!iedge.subset().empty()) {
26!
NEW
42
                throw InvalidSDFGException("TensorNode: Input memlet pointer must not be dereferenced.");
×
43
            }
44
        }
26✔
45
    }
46

47
    // Check that all output memlets are scalar or pointer of scalar
48
    for (auto& oedge : graph.out_edges(*this)) {
48✔
49
        if (oedge.base_type().type_id() != types::TypeID::Scalar &&
24!
50
            oedge.base_type().type_id() != types::TypeID::Pointer) {
17✔
NEW
51
            throw InvalidSDFGException(
×
NEW
52
                "TensorNode: Output memlet must be of scalar or pointer type. Found type: " + oedge.base_type().print()
×
53
            );
54
        }
55
        if (oedge.base_type().type_id() == types::TypeID::Pointer) {
24✔
56
            auto& ptr_type = static_cast<const types::Pointer&>(oedge.base_type());
17✔
57
            if (ptr_type.pointee_type().type_id() != types::TypeID::Scalar) {
17!
NEW
58
                throw InvalidSDFGException(
×
NEW
59
                    "TensorNode: Output memlet pointer must be flat (pointer to scalar). Found type: " +
×
NEW
60
                    ptr_type.pointee_type().print()
×
61
                );
62
            }
63
            if (!oedge.subset().empty()) {
17!
NEW
64
                throw InvalidSDFGException("TensorNode: Output memlet pointer must not be dereferenced.");
×
65
            }
66
        }
17✔
67
    }
68

69
    // Validate that all memlets have the same primitive type
70
    types::PrimitiveType prim_type = primitive_type(graph);
24✔
71

72
    // Check if this operation supports integer types
73
    if (!supports_integer_types() && types::is_integer(prim_type)) {
24✔
74
        throw InvalidSDFGException(
2!
75
            "TensorNode: This operation does not support integer types. Found type: " +
1!
76
            std::string(types::primitive_type_to_string(prim_type))
1!
77
        );
78
    }
79
}
24✔
80

81
types::PrimitiveType TensorNode::primitive_type(const data_flow::DataFlowGraph& graph) const {
26✔
82
    types::PrimitiveType result_type = types::PrimitiveType::Void;
26✔
83
    bool first = true;
26✔
84

85
    // Check all input edges
86
    for (auto& iedge : graph.in_edges(*this)) {
61✔
87
        types::PrimitiveType edge_type;
88
        if (iedge.base_type().type_id() == types::TypeID::Scalar) {
36✔
89
            auto& scalar_type = static_cast<const types::Scalar&>(iedge.base_type());
8✔
90
            edge_type = scalar_type.primitive_type();
8✔
91
        } else if (iedge.base_type().type_id() == types::TypeID::Pointer) {
36!
92
            auto& ptr_type = static_cast<const types::Pointer&>(iedge.base_type());
28✔
93
            auto& pointee = ptr_type.pointee_type();
28✔
94
            if (pointee.type_id() == types::TypeID::Scalar) {
28!
95
                auto& scalar_type = static_cast<const types::Scalar&>(pointee);
28✔
96
                edge_type = scalar_type.primitive_type();
28✔
97
            } else {
28✔
98
                throw InvalidSDFGException("TensorNode: Pointer must point to scalar type");
1!
99
            }
100
        } else {
28✔
NEW
101
            throw InvalidSDFGException("TensorNode: Edge must be scalar or pointer type");
×
102
        }
103

104
        if (first) {
36✔
105
            result_type = edge_type;
26✔
106
            first = false;
26✔
107
        } else if (result_type != edge_type) {
36✔
108
            throw InvalidSDFGException(
2!
109
                "TensorNode: All memlets must have the same primitive type. Found " +
1!
110
                std::string(types::primitive_type_to_string(result_type)) + " and " +
3!
111
                std::string(types::primitive_type_to_string(edge_type))
1!
112
            );
113
        }
114
    }
115

116
    // Check all output edges
117
    for (auto& oedge : graph.out_edges(*this)) {
50✔
118
        types::PrimitiveType edge_type;
119
        if (oedge.base_type().type_id() == types::TypeID::Scalar) {
25✔
120
            auto& scalar_type = static_cast<const types::Scalar&>(oedge.base_type());
9✔
121
            edge_type = scalar_type.primitive_type();
9✔
122
        } else if (oedge.base_type().type_id() == types::TypeID::Pointer) {
25!
123
            auto& ptr_type = static_cast<const types::Pointer&>(oedge.base_type());
16✔
124
            auto& pointee = ptr_type.pointee_type();
16✔
125
            if (pointee.type_id() == types::TypeID::Scalar) {
16!
126
                auto& scalar_type = static_cast<const types::Scalar&>(pointee);
16✔
127
                edge_type = scalar_type.primitive_type();
16✔
128
            } else {
16✔
NEW
129
                throw InvalidSDFGException("TensorNode: Pointer must point to scalar type");
×
130
            }
131
        } else {
16✔
NEW
132
            throw InvalidSDFGException("TensorNode: Edge must be scalar or pointer type");
×
133
        }
134

135
        if (first) {
25!
NEW
136
            result_type = edge_type;
×
NEW
137
            first = false;
×
138
        } else if (result_type != edge_type) {
25!
NEW
139
            throw InvalidSDFGException(
×
NEW
140
                "TensorNode: All memlets must have the same primitive type. Found " +
×
NEW
141
                std::string(types::primitive_type_to_string(result_type)) + " and " +
×
NEW
142
                std::string(types::primitive_type_to_string(edge_type))
×
143
            );
144
        }
145
    }
146

147
    if (first) {
25!
NEW
148
        throw InvalidSDFGException("TensorNode: No edges found to determine primitive type");
×
149
    }
150

151
    return result_type;
25✔
152
}
1✔
153

154
data_flow::TaskletCode TensorNode::get_integer_minmax_tasklet(types::PrimitiveType prim_type, bool is_max) {
4✔
155
    bool is_signed = types::is_signed(prim_type);
4✔
156
    if (is_max) {
4!
157
        return is_signed ? data_flow::TaskletCode::int_smax : data_flow::TaskletCode::int_umax;
4✔
158
    } else {
NEW
159
        return is_signed ? data_flow::TaskletCode::int_smin : data_flow::TaskletCode::int_umin;
×
160
    }
161
}
4✔
162

163
} // namespace tensor
164
} // namespace math
165
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc