• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 27981272983

22 Jun 2026 08:18PM UTC coverage: 61.754% (-0.03%) from 61.782%
27981272983

Pull #781

github

web-flow
Merge bddaa3724 into fe87d162b
Pull Request #781: Extend Segformer benchmarks setup

987 of 1432 new or added lines in 62 files covered. (68.92%)

9 existing lines in 7 files now uncovered.

38121 of 61730 relevant lines covered (61.75%)

993.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/sdfg/src/data_flow/library_nodes/math/tensor/transpose_node.cpp
1
#include "sdfg/data_flow/library_nodes/math/tensor/transpose_node.h"
2

3
#include "sdfg/builder/structured_sdfg_builder.h"
4

5
namespace sdfg {
6
namespace math {
7
namespace tensor {
8

9
TransposeNode::TransposeNode(
10
    size_t element_id,
11
    const DebugInfo& debug_info,
12
    const graph::Vertex vertex,
13
    data_flow::DataFlowGraph& parent,
14
    const std::vector<symbolic::Expression>& shape,
15
    const std::vector<int64_t>& perm
16
)
17
    : TensorNode(
×
18
          element_id,
×
19
          debug_info,
×
20
          vertex,
×
21
          parent,
×
22
          LibraryNodeType_Transpose,
×
23
          {"Y"},
×
24
          {"X"},
×
25
          data_flow::ImplementationType_NONE
×
26
      ),
×
27
      shape_(shape), perm_(perm) {
×
28
    if (perm_.empty()) {
×
29
        // Default permutation: reverse
30
        for (size_t i = 0; i < shape.size(); ++i) {
×
31
            perm_.push_back(shape.size() - 1 - i);
×
32
        }
×
33
    } else {
×
34
        if (perm_.size() != shape_.size()) {
×
35
            throw std::invalid_argument("Permutation rank must match shape rank");
×
36
        }
×
37
    }
×
38
}
×
39

40
void TransposeNode::validate(const Function& function) const {
×
41
    TensorNode::validate(function);
×
42

43
    auto& graph = this->get_parent();
×
44

45
    auto& iedge = *graph.in_edges(*this).begin();
×
46
    auto& shape = static_cast<const types::Tensor&>(iedge.base_type());
×
47
    if (shape.shape().size() != this->shape_.size()) {
×
48
        throw InvalidSDFGException(
×
49
            "Library Node: Tensor shape must match node shape. Tensor shape: " + std::to_string(shape.shape().size()) +
×
50
            " Node shape: " + std::to_string(this->shape_.size())
×
51
        );
×
52
    }
×
53
    for (size_t i = 0; i < this->shape_.size(); ++i) {
×
54
        if (!symbolic::eq(shape.shape().at(i), this->shape_.at(i))) {
×
55
            throw InvalidSDFGException(
×
56
                "Library Node: Tensor shape does not match expected shape. Tensor shape: " +
×
57
                shape.shape().at(i)->__str__() + " Expected shape: " + this->shape_.at(i)->__str__()
×
58
            );
×
59
        }
×
60
    }
×
61

62
    auto& oedge = *graph.out_edges(*this).begin();
×
63
    auto& output_shape = static_cast<const types::Tensor&>(oedge.base_type());
×
64
    if (output_shape.shape().size() != this->shape_.size()) {
×
65
        throw InvalidSDFGException(
×
66
            "Library Node: Output tensor shape must match node shape. Output tensor shape: " +
×
67
            std::to_string(output_shape.shape().size()) + " Node shape: " + std::to_string(this->shape_.size())
×
68
        );
×
69
    }
×
70

71
    for (size_t i = 0; i < this->shape_.size(); ++i) {
×
72
        if (!symbolic::eq(output_shape.shape().at(i), this->shape_.at(perm_.at(i)))) {
×
73
            throw InvalidSDFGException(
×
74
                "Library Node: Output tensor shape does not match expected shape. Output tensor shape: " +
×
75
                output_shape.shape().at(i)->__str__() + " Expected shape: " + this->shape_.at(perm_[i])->__str__()
×
76
            );
×
77
        }
×
78
    }
×
79
}
×
80

81
symbolic::SymbolSet TransposeNode::symbols() const {
×
82
    symbolic::SymbolSet syms;
×
83
    for (const auto& dim : shape_) {
×
84
        for (auto& atom : symbolic::atoms(dim)) {
×
85
            syms.insert(atom);
×
86
        }
×
87
    }
×
88
    return syms;
×
89
}
×
90

91
void TransposeNode::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) {
×
92
    for (auto& dim : shape_) {
×
93
        dim = symbolic::subs(dim, old_expression, new_expression);
×
94
    }
×
95
}
×
96

NEW
97
void TransposeNode::replace(const symbolic::ExpressionMapping& replacements) {
×
NEW
98
    for (auto& dim : shape_) {
×
NEW
99
        dim = symbolic::subs(dim, replacements);
×
NEW
100
    }
×
NEW
101
}
×
102

103
bool TransposeNode::expand(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
×
104
    auto& dataflow = this->get_parent();
×
105
    auto& block = static_cast<structured_control_flow::Block&>(*dataflow.get_parent());
×
106
    if (dataflow.in_degree(*this) != 1 || dataflow.out_degree(*this) != 1) {
×
107
        return false;
×
108
    }
×
109

110
    auto& parent = static_cast<structured_control_flow::Sequence&>(*block.get_parent());
×
111
    int index = parent.index(block);
×
112
    auto& transition = parent.at(index).second;
×
113

114
    auto& input = this->inputs_.at(0);
×
115
    auto& output = this->outputs_.at(0);
×
116

117
    auto& iedge = *dataflow.in_edges(*this).begin();
×
118
    auto& oedge = *dataflow.out_edges(*this).begin();
×
119

120
    // Checks if legal
121
    auto& input_node = static_cast<data_flow::AccessNode&>(iedge.src());
×
122
    auto& output_node = static_cast<data_flow::AccessNode&>(oedge.dst());
×
123
    if (dataflow.in_degree(input_node) != 0 || dataflow.out_degree(output_node) != 0) {
×
124
        return false;
×
125
    }
×
126

127
    // Add new graph after the current block
128
    auto& new_sequence = builder.add_sequence_before(parent, block, transition.assignments(), block.debug_info());
×
129

130
    // Add maps
131
    structured_control_flow::Sequence* last_scope = &new_sequence;
×
132
    structured_control_flow::Map* last_map = nullptr;
×
133
    std::vector<symbolic::Expression> loop_vars;
×
134

135
    for (size_t i = 0; i < this->shape_.size(); i++) {
×
136
        std::string indvar_str = builder.find_new_name("_i");
×
137
        builder.add_container(indvar_str, types::Scalar(types::PrimitiveType::UInt64));
×
138

139
        auto indvar = symbolic::symbol(indvar_str);
×
140
        auto init = symbolic::zero();
×
141
        auto update = symbolic::add(indvar, symbolic::one());
×
142
        auto condition = symbolic::Lt(indvar, this->shape_[i]);
×
143
        last_map = &builder.add_map(
×
144
            *last_scope,
×
145
            indvar,
×
146
            condition,
×
147
            init,
×
148
            update,
×
149
            structured_control_flow::ScheduleType_Sequential::create(),
×
150
            {},
×
151
            block.debug_info()
×
152
        );
×
153
        last_scope = &last_map->root();
×
154

155
        loop_vars.push_back(indvar);
×
156
    }
×
157

158
    auto& body = builder.add_block(*last_scope, {}, block.debug_info());
×
159

160
    // Determine output shape
161
    std::vector<symbolic::Expression> output_shape(shape_.size());
×
162
    std::vector<symbolic::Expression> output_indices(shape_.size());
×
163
    for (size_t i = 0; i < shape_.size(); ++i) {
×
164
        output_shape[i] = shape_[perm_[i]];
×
165
        output_indices[i] = loop_vars[perm_[i]];
×
166
    }
×
167

168
    // Read Input
169
    auto& x_access = builder.add_access(body, input_node.data(), debug_info());
×
170
    auto& y_access = builder.add_access(body, output_node.data(), debug_info());
×
171

172
    auto& tasklet = builder.add_tasklet(body, data_flow::assign, "_out", {"_in"}, debug_info());
×
173

174
    // Access memlets
175
    builder.add_computational_memlet(body, x_access, tasklet, "_in", loop_vars, iedge.base_type(), debug_info());
×
176

177
    builder.add_computational_memlet(body, tasklet, "_out", y_access, output_indices, oedge.base_type(), debug_info());
×
178

179
    // Remove the original node
180
    builder.remove_memlet(block, iedge);
×
181
    builder.remove_memlet(block, oedge);
×
182
    builder.remove_node(block, input_node);
×
183
    builder.remove_node(block, output_node);
×
184
    builder.remove_node(block, *this);
×
185
    builder.remove_child(parent, index + 1);
×
186

187
    return true;
×
188
}
×
189

190
nlohmann::json TransposeNodeSerializer::serialize(const data_flow::LibraryNode& library_node) {
×
191
    const TransposeNode& transpose_node = static_cast<const TransposeNode&>(library_node);
×
192
    nlohmann::json j;
×
193

194
    j["code"] = transpose_node.code().value();
×
195

196
    serializer::JSONSerializer serializer;
×
197

198
    j["shape"] = nlohmann::json::array();
×
199
    for (auto& dim : transpose_node.shape()) {
×
200
        j["shape"].push_back(serializer.expression(dim));
×
201
    }
×
202

203
    j["perm"] = nlohmann::json::array();
×
204
    for (auto& dim : transpose_node.perm()) {
×
205
        j["perm"].push_back(dim);
×
206
    }
×
207

208
    return j;
×
209
}
×
210

211
data_flow::LibraryNode& TransposeNodeSerializer::deserialize(
212
    const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, structured_control_flow::Block& parent
213
) {
×
214
    assert(j.contains("element_id"));
×
215
    assert(j.contains("code"));
×
216
    assert(j.contains("debug_info"));
×
217
    assert(j.contains("shape"));
×
218
    assert(j.contains("perm"));
×
219

220
    std::vector<symbolic::Expression> shape;
×
221
    if (j.contains("shape")) {
×
222
        for (const auto& dim : j["shape"]) {
×
223
            shape.push_back(symbolic::parse(dim.get<std::string>()));
×
224
        }
×
225
    }
×
226

227
    std::vector<int64_t> perm;
×
228
    for (const auto& dim : j["perm"]) {
×
229
        perm.push_back(dim.get<int64_t>());
×
230
    }
×
231

232
    sdfg::serializer::JSONSerializer serializer;
×
233
    DebugInfo debug_info = serializer.json_to_debug_info(j["debug_info"]);
×
234

235
    return builder.add_library_node<TransposeNode>(parent, debug_info, shape, perm);
×
236
}
×
237

238
} // namespace tensor
239
} // namespace math
240
} // namespace sdfg
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc