• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22023884668

14 Feb 2026 08:36PM UTC coverage: 64.903% (-1.4%) from 66.315%
22023884668

Pull #525

github

web-flow
Merge 1d47f8bf2 into 9d01cacd5
Pull Request #525: Step 3 (Native Tensor Support): Refactor Python Frontend

2522 of 3435 new or added lines in 32 files covered. (73.42%)

320 existing lines in 15 files now uncovered.

23204 of 35752 relevant lines covered (64.9%)

370.03 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/transpose_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/transpose_node.h"
2

3
#include "sdfg/analysis/scope_analysis.h"
4
#include "sdfg/builder/structured_sdfg_builder.h"
5

6
namespace sdfg {
7
namespace math {
8
namespace tensor {
9

10
TransposeNode::TransposeNode(
11
    size_t element_id,
12
    const DebugInfo& debug_info,
13
    const graph::Vertex vertex,
14
    data_flow::DataFlowGraph& parent,
15
    const std::vector<symbolic::Expression>& shape,
16
    const std::vector<int64_t>& perm
17
)
18
    : TensorNode(
×
19
          element_id,
×
20
          debug_info,
×
21
          vertex,
×
22
          parent,
×
23
          LibraryNodeType_Transpose,
×
24
          {"Y"},
×
25
          {"X"},
×
26
          data_flow::ImplementationType_NONE
×
27
      ),
×
28
      shape_(shape), perm_(perm) {
×
29
    if (perm_.empty()) {
×
30
        // Default permutation: reverse
31
        for (size_t i = 0; i < shape.size(); ++i) {
×
32
            perm_.push_back(shape.size() - 1 - i);
×
33
        }
×
34
    } else {
×
35
        if (perm_.size() != shape_.size()) {
×
36
            throw std::invalid_argument("Permutation rank must match shape rank");
×
37
        }
×
38
    }
×
39
}
×
40

NEW
41
void TransposeNode::validate(const Function& function) const {
×
NEW
42
    TensorNode::validate(function);
×
43

NEW
44
    auto& graph = this->get_parent();
×
45

NEW
46
    auto& iedge = *graph.in_edges(*this).begin();
×
NEW
47
    auto& shape = static_cast<const types::Tensor&>(iedge.base_type());
×
NEW
48
    if (shape.shape().size() != this->shape_.size()) {
×
NEW
49
        throw InvalidSDFGException(
×
NEW
50
            "Library Node: Tensor shape must match node shape. Tensor shape: " + std::to_string(shape.shape().size()) +
×
NEW
51
            " Node shape: " + std::to_string(this->shape_.size())
×
NEW
52
        );
×
NEW
53
    }
×
NEW
54
    for (size_t i = 0; i < this->shape_.size(); ++i) {
×
NEW
55
        if (!symbolic::eq(shape.shape().at(i), this->shape_.at(i))) {
×
NEW
56
            throw InvalidSDFGException(
×
NEW
57
                "Library Node: Tensor shape does not match expected shape. Tensor shape: " +
×
NEW
58
                shape.shape().at(i)->__str__() + " Expected shape: " + this->shape_.at(i)->__str__()
×
NEW
59
            );
×
NEW
60
        }
×
NEW
61
    }
×
62

NEW
63
    auto& oedge = *graph.out_edges(*this).begin();
×
NEW
64
    auto& output_shape = static_cast<const types::Tensor&>(oedge.base_type());
×
NEW
65
    if (output_shape.shape().size() != this->shape_.size()) {
×
NEW
66
        throw InvalidSDFGException(
×
NEW
67
            "Library Node: Output tensor shape must match node shape. Output tensor shape: " +
×
NEW
68
            std::to_string(output_shape.shape().size()) + " Node shape: " + std::to_string(this->shape_.size())
×
NEW
69
        );
×
NEW
70
    }
×
71

NEW
72
    for (size_t i = 0; i < this->shape_.size(); ++i) {
×
NEW
73
        if (!symbolic::eq(output_shape.shape().at(i), this->shape_.at(perm_.at(i)))) {
×
NEW
74
            throw InvalidSDFGException(
×
NEW
75
                "Library Node: Output tensor shape does not match expected shape. Output tensor shape: " +
×
NEW
76
                output_shape.shape().at(i)->__str__() + " Expected shape: " + this->shape_.at(perm_[i])->__str__()
×
NEW
77
            );
×
NEW
78
        }
×
NEW
79
    }
×
NEW
80
}
×
81

82
symbolic::SymbolSet TransposeNode::symbols() const {
×
83
    symbolic::SymbolSet syms;
×
84
    for (const auto& dim : shape_) {
×
85
        for (auto& atom : symbolic::atoms(dim)) {
×
86
            syms.insert(atom);
×
87
        }
×
88
    }
×
89
    return syms;
×
90
}
×
91

92
void TransposeNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
93
    for (auto& dim : shape_) {
×
94
        dim = symbolic::subs(dim, old_expression, new_expression);
×
95
    }
×
96
}
×
97

98
bool TransposeNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
99
    auto& dataflow = this->get_parent();
×
100
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
101
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
102
        return false;
×
103
    }
×
104

105
    auto& scope_analysis = analysis_manager.get<analysis::ScopeAnalysis>();
×
106
    auto& parent = static_cast<structured_control_flow::Sequence&>(*scope_analysis.parent_scope(&block));
×
107
    int index = parent.index(block);
×
108
    auto& transition = parent.at(index).second;
×
109

110
    auto& input = this->inputs_.at(0);
×
111
    auto& output = this->outputs_.at(0);
×
112

113
    auto& iedge = *dataflow.in_edges(*this).begin();
×
114
    auto& oedge = *dataflow.out_edges(*this).begin();
×
115

116
    // Checks if legal
117
    auto& input_node = static_cast<data_flow::AccessNode&>(iedge.src());
×
118
    auto& output_node = static_cast<data_flow::AccessNode&>(oedge.dst());
×
119
    if (dataflow.in_degree(input_node) != 0 || dataflow.out_degree(output_node) != 0) {
×
120
        return false;
×
121
    }
×
122

123
    // Add new graph after the current block
124
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
×
125

126
    // Add maps
127
    structured_control_flow::Sequence* last_scope = &new_sequence;
×
128
    structured_control_flow::Map* last_map = nullptr;
×
129
    std::vector<symbolic::Expression> loop_vars;
×
130

131
    for (size_t i = 0; i < this->shape_.size(); i++) {
×
132
        std::string indvar_str = builder.find_new_name("_i");
×
133
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
×
134

135
        auto indvar = symbolic::symbol(indvar_str);
×
136
        auto init = symbolic::zero();
×
137
        auto update = symbolic::add(indvar, symbolic::one());
×
138
        auto condition = symbolic::Lt(indvar, this->shape_[i]);
×
139
        last_map = &builder.add_map(
×
140
            *last_scope,
×
141
            indvar,
×
142
            condition,
×
143
            init,
×
144
            update,
×
145
            structured_control_flow::ScheduleType_Sequential::create(),
×
146
            {},
×
147
            block.debug_info()
×
148
        );
×
149
        last_scope = &last_map->root();
×
150

151
        loop_vars.push_back(indvar);
×
152
    }
×
153

UNCOV
154
    auto& body = builder.add_block(*last_scope, {}, block.debug_info());
×
155

156
    // Determine output shape
157
    std::vector<symbolic::Expression> output_shape(shape_.size());
×
158
    std::vector<symbolic::Expression> output_indices(shape_.size());
×
159
    for (size_t i = 0; i < shape_.size(); ++i) {
×
160
        output_shape[i] = shape_[perm_[i]];
×
161
        output_indices[i] = loop_vars[perm_[i]];
×
162
    }
×
163

164
    // Read Input
165
    auto& x_access = builder.add_access(body, input_node.data(), debug_info());
×
166
    auto& y_access = builder.add_access(body, output_node.data(), debug_info());
×
167

168
    auto& tasklet = builder.add_tasklet(body, data_flow::assign, "_out", {"_in"}, debug_info());
×
169

170
    // Access memlets
NEW
171
    builder.add_computational_memlet(body, x_access, tasklet, "_in", loop_vars, iedge.base_type(), debug_info());
×
172

NEW
173
    builder.add_computational_memlet(body, tasklet, "_out", y_access, output_indices, oedge.base_type(), debug_info());
×
174

175
    // Remove the original node
176
    builder.remove_memlet(block, iedge);
×
177
    builder.remove_memlet(block, oedge);
×
178
    builder.remove_node(block, input_node);
×
179
    builder.remove_node(block, output_node);
×
180
    builder.remove_node(block, *this);
×
181
    builder.remove_child(parent, index + 1);
×
182

183
    return true;
×
184
}
×
185

186
nlohmann::json TransposeNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
187
    const TransposeNode& transpose_node = static_cast<const TransposeNode&>(library_node);
×
188
    nlohmann::json j;
×
189

190
    j["code"] = transpose_node.code().value();
×
191

192
    serializer::JSONSerializer serializer;
×
193

194
    j["shape"] = nlohmann::json::array();
×
195
    for (auto& dim : transpose_node.shape()) {
×
196
        j["shape"].push_back(serializer.expression(dim));
×
197
    }
×
198

199
    j["perm"] = nlohmann::json::array();
×
200
    for (auto& dim : transpose_node.perm()) {
×
201
        j["perm"].push_back(dim);
×
202
    }
×
203

204
    return j;
×
205
}
×
206

207
data_flow::LibraryNode& TransposeNodeSerializer::deserialize(
208
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
209
) {
×
210
    assert(j.contains("element_id"));
×
211
    assert(j.contains("code"));
×
212
    assert(j.contains("debug_info"));
×
213
    assert(j.contains("shape"));
×
214
    assert(j.contains("perm"));
×
215

216
    std::vector<symbolic::Expression> shape;
×
217
    if (j.contains("shape")) {
×
218
        for (const auto& dim : j["shape"]) {
×
219
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
220
        }
×
221
    }
×
222

223
    std::vector<int64_t> perm;
×
224
    for (const auto& dim : j["perm"]) {
×
225
        perm.push_back(dim.get<int64_t>());
×
226
    }
×
227

228
    sdfg::serializer::JSONSerializer serializer;
×
229
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
230

231
    return builder.add_library_node<TransposeNode>(parent, debug_info, shape, perm);
×
232
}
×
233

234
} // namespace tensor
235
} // namespace math
236
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc