• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 20650731423

01 Jan 2026 09:02AM UTC coverage: 39.569% (-0.07%) from 39.635%
20650731423

push

github

web-flow
Merge pull request #422 from daisytuner/copilot/add-tensor-cast-unary-node

Add CastNode for tensor type casting between primitive types

15022 of 49410 branches covered (30.4%)

Branch coverage included in aggregate %.

30 of 74 new or added lines in 2 files covered. (40.54%)

22 existing lines in 2 files now uncovered.

12897 of 21148 relevant lines covered (60.98%)

91.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

22.87
/src/data_flow/library_nodes/math/tensor/elementwise_ops/cast_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/elementwise_ops/cast_node.h"
2

3
#include "sdfg/analysis/analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
#include "sdfg/analysis/scope_analysis.h"
7

8
namespace sdfg {
9
namespace math {
10
namespace tensor {
11

12
CastNode::CastNode(
28✔
13
    size_t element_id,
14
    const DebugInfo& debug_info,
15
    const graph::Vertex vertex,
16
    data_flow::DataFlowGraph& parent,
17
    const std::vector<symbolic::Expression>& shape,
18
    types::PrimitiveType target_type
19
)
20
    : ElementWiseUnaryNode(element_id, debug_info, vertex, parent, LibraryNodeType_Cast, shape),
28✔
21
      target_type_(target_type) {}
28✔
22

23
bool CastNode::expand_operation(
28✔
24
    builder::StructuredSDFGBuilder& builder,
25
    analysis::AnalysisManager& analysis_manager,
26
    structured_control_flow::Sequence& body,
27
    const std::string& input_name,
28
    const std::string& output_name,
29
    const types::IType& input_type,
30
    const types::IType& output_type,
31
    const data_flow::Subset& subset
32
) {
33
    // Add code block
34
    auto& code_block = builder.add_block(body);
28!
35

36
    auto& input_node_new = builder.add_access(code_block, input_name);
28!
37
    auto& output_node_new = builder.add_access(code_block, output_name);
28!
38

39
    // Use assign tasklet which handles type casting when input and output types differ
40
    auto& tasklet = builder.add_tasklet(code_block, data_flow::TaskletCode::assign, "_out", {"_in"});
28!
41
    builder.add_computational_memlet(code_block, input_node_new, tasklet, "_in", subset, input_type);
28!
42
    builder.add_computational_memlet(code_block, tasklet, "_out", output_node_new, subset, output_type);
28!
43

44
    return true;
28✔
NEW
45
}
×
46

47
void CastNode::validate(const Function& function) const {
4✔
48
    auto& graph = this->get_parent();
4✔
49

50
    // Check that all input memlets are scalar or pointer of scalar
51
    for (auto& iedge : graph.in_edges(*this)) {
8✔
52
        if (iedge.base_type().type_id() != types::TypeID::Scalar &&
4!
53
            iedge.base_type().type_id() != types::TypeID::Pointer) {
4✔
NEW
54
            throw InvalidSDFGException(
×
NEW
55
                "CastNode: Input memlet must be of scalar or pointer type. Found type: " + iedge.base_type().print()
×
56
            );
57
        }
58
        if (iedge.base_type().type_id() == types::TypeID::Pointer) {
4!
59
            auto& ptr_type = static_cast<const types::Pointer&>(iedge.base_type());
4✔
60
            if (ptr_type.pointee_type().type_id() != types::TypeID::Scalar) {
4!
NEW
61
                throw InvalidSDFGException(
×
NEW
62
                    "CastNode: Input memlet pointer must be flat (pointer to scalar). Found type: " +
×
NEW
63
                    ptr_type.pointee_type().print()
×
64
                );
65
            }
66
            if (!iedge.subset().empty()) {
4!
NEW
67
                throw InvalidSDFGException("CastNode: Input memlet pointer must not be dereferenced.");
×
68
            }
69
        }
4✔
70
    }
71

72
    // Check that all output memlets are scalar or pointer of scalar
73
    for (auto& oedge : graph.out_edges(*this)) {
8✔
74
        if (oedge.base_type().type_id() != types::TypeID::Scalar &&
4!
75
            oedge.base_type().type_id() != types::TypeID::Pointer) {
4✔
NEW
76
            throw InvalidSDFGException(
×
NEW
77
                "CastNode: Output memlet must be of scalar or pointer type. Found type: " + oedge.base_type().print()
×
78
            );
79
        }
80
        if (oedge.base_type().type_id() == types::TypeID::Pointer) {
4!
81
            auto& ptr_type = static_cast<const types::Pointer&>(oedge.base_type());
4✔
82
            if (ptr_type.pointee_type().type_id() != types::TypeID::Scalar) {
4!
NEW
83
                throw InvalidSDFGException(
×
NEW
84
                    "CastNode: Output memlet pointer must be flat (pointer to scalar). Found type: " +
×
NEW
85
                    ptr_type.pointee_type().print()
×
86
                );
87
            }
88
            if (!oedge.subset().empty()) {
4!
NEW
89
                throw InvalidSDFGException("CastNode: Output memlet pointer must not be dereferenced.");
×
90
            }
91
        }
4✔
92
    }
93

94
    // For CastNode, we DON'T check that all memlets have the same primitive type
95
    // because the whole point of casting is to convert between types
96
}
4✔
97

98
std::unique_ptr<data_flow::DataFlowNode> CastNode::
NEW
99
    clone(size_t element_id, const graph::Vertex vertex, data_flow::DataFlowGraph& parent) const {
×
NEW
100
    return std::unique_ptr<data_flow::DataFlowNode>(
×
NEW
101
        new CastNode(element_id, this->debug_info(), vertex, parent, this->shape_, this->target_type_)
×
102
    );
NEW
103
}
×
104

NEW
105
nlohmann::json CastNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
NEW
106
    const CastNode& cast_node = static_cast<const CastNode&>(library_node);
×
NEW
107
    nlohmann::json j;
×
108

NEW
109
    j["code"] = cast_node.code().value();
×
110

NEW
111
    serializer::JSONSerializer serializer;
×
NEW
112
    j["shape"] = nlohmann::json::array();
×
NEW
113
    for (auto& dim : cast_node.shape()) {
×
NEW
114
        j["shape"].push_back(serializer.expression(dim));
×
115
    }
116

NEW
117
    j["target_type"] = static_cast<int>(cast_node.target_type());
×
118

NEW
119
    return j;
×
NEW
120
}
×
121

NEW
122
data_flow::LibraryNode& CastNodeSerializer::deserialize(
×
123
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
124
) {
125
    // Assertions for required fields
NEW
126
    assert(j.contains("element_id"));
×
NEW
127
    assert(j.contains("code"));
×
NEW
128
    assert(j.contains("debug_info"));
×
NEW
129
    assert(j.contains("shape"));
×
NEW
130
    assert(j.contains("target_type"));
×
131

NEW
132
    std::vector<symbolic::Expression> shape;
×
NEW
133
    for (const auto& dim : j["shape"]) {
×
NEW
134
        shape.push_back(symbolic::parse(dim.get<std::string>()));
×
135
    }
136

NEW
137
    types::PrimitiveType target_type = static_cast<types::PrimitiveType>(j["target_type"].get<int>());
×
138

139
    // Extract debug info using JSONSerializer
NEW
140
    sdfg::serializer::JSONSerializer serializer;
×
NEW
141
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
142

NEW
143
    return static_cast<CastNode&>(builder.add_library_node<CastNode>(parent, debug_info, shape, target_type));
×
NEW
144
}
×
145

146
} // namespace tensor
147
} // namespace math
148
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc