• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / sdfglib / 21112613465

18 Jan 2026 01:35PM UTC coverage: 64.355% (+0.2%) from 64.154%
21112613465

Pull #462

github

web-flow
Merge a4377317c into 92e9cbdc3
Pull Request #462: adds syntax support for multi-assignments and np.empty_like

45 of 52 new or added lines in 3 files covered. (86.54%)

1 existing line in 1 file now uncovered.

19555 of 30386 relevant lines covered (64.36%)

387.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

61.49
/python/docc/linear_algebra.py
1
import ast
3✔
2
from ._sdfg import Scalar, PrimitiveType
3✔
3

4

5
class LinearAlgebraHandler:
3✔
6
    def __init__(self, builder, array_info, symbol_table, expr_visitor):
3✔
7
        self.builder = builder
3✔
8
        self.array_info = array_info
3✔
9
        self.symbol_table = symbol_table
3✔
10
        self.expr_visitor = expr_visitor
3✔
11
        self._unique_counter = 0
3✔
12

13
    def _get_unique_id(self):
3✔
14
        self._unique_counter += 1
3✔
15
        return self._unique_counter
3✔
16

17
    def _parse_expr(self, node):
3✔
18
        return self.expr_visitor.visit(node)
3✔
19

20
    def is_gemm(self, node):
3✔
21
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
3✔
22
            return True
×
23
        if isinstance(node, ast.Call):
3✔
24
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
3✔
25
                return True
×
26
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
3✔
27
                return True
×
28
            if isinstance(node.func, ast.Attribute) and node.func.attr == "matmul":
3✔
29
                return True
×
30
            if isinstance(node.func, ast.Name) and node.func.id == "matmul":
3✔
31
                return True
×
32
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
3✔
33
            return self.is_gemm(node.left) or self.is_gemm(node.right)
3✔
34
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
3✔
35
            return self.is_gemm(node.left) or self.is_gemm(node.right)
3✔
36
        return False
3✔
37

38
    def parse_arg(self, node):
3✔
39
        if isinstance(node, ast.Name):
3✔
40
            if node.id in self.array_info:
3✔
41
                return node.id, [], self.array_info[node.id]["shapes"], []
3✔
42
        elif isinstance(node, ast.Subscript):
3✔
43
            if isinstance(node.value, ast.Name) and node.value.id in self.array_info:
3✔
44
                name = node.value.id
3✔
45
                indices = []
3✔
46
                if isinstance(node.slice, ast.Tuple):
3✔
47
                    indices = node.slice.elts
3✔
48
                else:
49
                    indices = [node.slice]
3✔
50

51
                start_indices = []
3✔
52
                slice_shape = []
3✔
53

54
                for i, idx in enumerate(indices):
3✔
55
                    if isinstance(idx, ast.Slice):
3✔
56
                        start = "0"
3✔
57
                        if idx.lower:
3✔
58
                            start = self._parse_expr(idx.lower)
×
59
                        start_indices.append(start)
3✔
60

61
                        shapes = self.array_info[name]["shapes"]
3✔
62
                        dim_size = (
3✔
63
                            shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
64
                        )
65
                        stop = dim_size
3✔
66
                        if idx.upper:
3✔
67
                            stop = self._parse_expr(idx.upper)
3✔
68

69
                        size = f"({stop} - {start})"
3✔
70
                        slice_shape.append(size)
3✔
71
                    else:
72
                        val = self._parse_expr(idx)
3✔
73
                        start_indices.append(val)
3✔
74

75
                return name, start_indices, slice_shape, indices
3✔
76

77
        return None, None, None, None
×
78

79
    def flatten_subset(self, name, start_indices):
3✔
80
        if not start_indices:
3✔
81
            return []
3✔
82
        info = self.array_info[name]
3✔
83
        shapes = info["shapes"]
3✔
84
        ndim = info["ndim"]
3✔
85

86
        if len(start_indices) != ndim:
3✔
87
            return start_indices
×
88

89
        strides = []
3✔
90
        current_stride = "1"
3✔
91
        strides.append(current_stride)
3✔
92
        for i in range(ndim - 1, 0, -1):
3✔
93
            dim_size = shapes[i]
3✔
94
            if current_stride == "1":
3✔
95
                current_stride = str(dim_size)
3✔
96
            else:
97
                current_stride = f"({current_stride} * {dim_size})"
3✔
98
            strides.append(current_stride)
3✔
99
        strides = list(reversed(strides))
3✔
100

101
        offset = "0"
3✔
102
        for i in range(ndim):
3✔
103
            idx = start_indices[i]
3✔
104
            stride = strides[i]
3✔
105
            term = f"({idx} * {stride})" if stride != "1" else idx
3✔
106
            if offset == "0":
3✔
107
                offset = term
3✔
108
            else:
109
                offset = f"({offset} + {term})"
3✔
110

111
        return [offset]
3✔
112

113
    def handle_gemm(self, target, value_node):
3✔
114
        target_name = None
3✔
115
        target_subset = []
3✔
116

117
        if isinstance(target, str):
3✔
118
            target_name = target
3✔
119
        elif isinstance(target, ast.Name):
3✔
120
            target_name = target.id
×
121
        elif isinstance(target, ast.Subscript):
3✔
122
            if isinstance(target.value, ast.Name):
3✔
123
                # Handle target slice
124
                res = self.parse_arg(target)
3✔
125
                if res[0]:
3✔
126
                    target_name = res[0]
3✔
127
                    target_subset = self.flatten_subset(target_name, res[1])
3✔
128
                else:
129
                    target_name = target.value.id
×
130

131
        if not target_name or target_name not in self.array_info:
3✔
132
            return False
×
133

134
        alpha = "1.0"
3✔
135
        beta = "0.0"
3✔
136
        A = None
3✔
137
        B = None
3✔
138

139
        def extract_factor(node):
3✔
140
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
3✔
141
                if self.is_gemm(node.left):
×
142
                    return node.left, self._parse_expr(node.right)
×
143
                if self.is_gemm(node.right):
×
144
                    return node.right, self._parse_expr(node.left)
×
145

146
                res = self.parse_arg(node.left)
×
147
                if res[0]:
×
148
                    return node.left, self._parse_expr(node.right)
×
149
                res = self.parse_arg(node.right)
×
150
                if res[0]:
×
151
                    return node.right, self._parse_expr(node.left)
×
152
            return node, "1.0"
3✔
153

154
        def parse_term(node):
3✔
155
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
3✔
156
                l, l_f = extract_factor(node.left)
3✔
157
                r, r_f = extract_factor(node.right)
3✔
158
                f = "1.0"
3✔
159
                if l_f != "1.0":
3✔
160
                    f = l_f
×
161
                if r_f != "1.0":
3✔
162
                    if f == "1.0":
×
163
                        f = r_f
×
164
                    else:
165
                        f = f"({f} * {r_f})"
×
166
                return l, r, f
3✔
167

168
            if isinstance(node, ast.Call):
×
169
                is_gemm_call = False
×
170
                if isinstance(node.func, ast.Attribute) and node.func.attr in [
×
171
                    "dot",
172
                    "matmul",
173
                ]:
174
                    is_gemm_call = True
×
175
                if isinstance(node.func, ast.Name) and node.func.id in [
×
176
                    "dot",
177
                    "matmul",
178
                ]:
179
                    is_gemm_call = True
×
180

181
                if is_gemm_call and len(node.args) == 2:
×
182
                    return node.args[0], node.args[1], "1.0"
×
183

184
            if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
×
185
                l, r, a = parse_term(node.left)
×
186
                if l:
×
187
                    return l, r, self._parse_expr(node.right)
×
188
                l, r, a = parse_term(node.right)
×
189
                if l:
×
190
                    return l, r, self._parse_expr(node.left)
×
191

192
            return None, None, None
×
193

194
        if isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
3✔
195
            l, r, a = parse_term(value_node.left)
×
196
            if l:
×
197
                A = l
×
198
                B = r
×
199
                alpha = a
×
200
                if isinstance(value_node.right, ast.BinOp) and isinstance(
×
201
                    value_node.right.op, ast.Mult
202
                ):
203
                    if self._is_target(value_node.right.left, target_name):
×
204
                        beta = self._parse_expr(value_node.right.right)
×
205
                    elif self._is_target(value_node.right.right, target_name):
×
206
                        beta = self._parse_expr(value_node.right.left)
×
207
                elif self._is_target(value_node.right, target_name):
×
208
                    beta = "1.0"
×
209
            else:
210
                l, r, a = parse_term(value_node.right)
×
211
                if l:
×
212
                    A = l
×
213
                    B = r
×
214
                    alpha = a
×
215
                    if isinstance(value_node.left, ast.BinOp) and isinstance(
×
216
                        value_node.left.op, ast.Mult
217
                    ):
218
                        if self._is_target(value_node.left.left, target_name):
×
219
                            beta = self._parse_expr(value_node.left.right)
×
220
                        elif self._is_target(value_node.left.right, target_name):
×
221
                            beta = self._parse_expr(value_node.left.left)
×
222
                    elif self._is_target(value_node.left, target_name):
×
223
                        beta = "1.0"
×
224
        else:
225
            l, r, a = parse_term(value_node)
3✔
226
            if l:
3✔
227
                A = l
3✔
228
                B = r
3✔
229
                alpha = a
3✔
230

231
        if A is None or B is None:
3✔
232
            return False
×
233

234
        def get_name_and_trans(node):
3✔
235
            if isinstance(node, ast.Attribute) and node.attr == "T":
3✔
236
                return node.value, True
×
237
            return node, False
3✔
238

239
        A_node, trans_a = get_name_and_trans(A)
3✔
240
        B_node, trans_b = get_name_and_trans(B)
3✔
241

242
        if self.is_gemm(A_node):
3✔
243
            tmp_name = self.expr_visitor.visit(A_node)
×
244
            A_node = ast.Name(id=tmp_name)
×
245

246
        if self.is_gemm(B_node):
3✔
247
            tmp_name = self.expr_visitor.visit(B_node)
×
248
            B_node = ast.Name(id=tmp_name)
×
249

250
        res_a = self.parse_arg(A_node)
3✔
251
        res_b = self.parse_arg(B_node)
3✔
252

253
        if not res_a[0] or not res_b[0]:
3✔
254
            return False
×
255

256
        A_name, subset_a, shape_a, indices_a = res_a
3✔
257
        B_name, subset_b, shape_b, indices_b = res_b
3✔
258

259
        flat_subset_a = self.flatten_subset(A_name, subset_a)
3✔
260
        flat_subset_b = self.flatten_subset(B_name, subset_b)
3✔
261

262
        def get_ndim(name):
3✔
263
            if name not in self.array_info:
3✔
264
                return 1
×
265
            return self.array_info[name]["ndim"]
3✔
266

267
        if len(shape_a) == 2:
3✔
268
            if not trans_a:
3✔
269
                m = shape_a[0]
3✔
270
                k = shape_a[1]
3✔
271
            else:
272
                m = shape_a[1]
×
273
                k = shape_a[0]
×
274
        else:
275
            # 1D array A(K) -> (1, K)
276
            m = "1"
×
277
            k = shape_a[0]
×
278
            if self._is_stride_1(A_name, indices_a):
×
279
                if get_ndim(A_name) == 1:
×
280
                    trans_a = True
×
281
                else:
282
                    trans_a = False
×
283
            else:
284
                trans_a = True
×
285

286
        if len(shape_b) == 2:
3✔
287
            if not trans_b:
3✔
288
                n = shape_b[1]
3✔
289
            else:
290
                n = shape_b[0]
×
291
        else:
292
            # 1D array B(K) -> (K, 1)
293
            n = "1"
3✔
294
            if self._is_stride_1(B_name, indices_b):
3✔
295
                if get_ndim(B_name) == 1:
3✔
296
                    trans_b = False
3✔
297
                else:
298
                    trans_b = True
×
299
            else:
300
                trans_b = False
×
301

302
        def get_ld(name):
3✔
303
            if name not in self.array_info:
3✔
304
                return ""
×
305
            shapes = self.array_info[name]["shapes"]
3✔
306
            if len(shapes) >= 2:
3✔
307
                return str(shapes[1])
3✔
308
            return "1"
3✔
309

310
        lda = get_ld(A_name)
3✔
311
        ldb = get_ld(B_name)
3✔
312

313
        ldc = ""
3✔
314
        if target_name:
3✔
315
            if get_ndim(target_name) == 1 and m == "1":
3✔
316
                ldc = n
×
317
            else:
318
                ldc = get_ld(target_name)
3✔
319

320
        self.builder.add_gemm(
3✔
321
            A_name,
322
            B_name,
323
            target_name,
324
            alpha,
325
            beta,
326
            m,
327
            n,
328
            k,
329
            trans_a,
330
            trans_b,
331
            flat_subset_a,
332
            flat_subset_b,
333
            target_subset,
334
            lda,
335
            ldb,
336
            ldc,
337
        )
338
        return True
3✔
339

340
    def _is_stride_1(self, name, indices):
3✔
341
        if name not in self.array_info:
3✔
342
            return True
×
343
        info = self.array_info[name]
3✔
344
        ndim = info["ndim"]
3✔
345

346
        if not indices:
3✔
347
            return True
3✔
348

349
        sliced_dim = -1
×
350
        for i, idx in enumerate(indices):
×
351
            if isinstance(idx, ast.Slice):
×
352
                sliced_dim = i
×
353
                break
×
354

355
        if sliced_dim == -1:
×
356
            if len(indices) < ndim:
×
357
                sliced_dim = ndim - 1
×
358
            else:
359
                return True
×
360

361
        return sliced_dim == ndim - 1
×
362

363
    def _is_target(self, node, target_name):
3✔
364
        if isinstance(target_name, ast.AST):
×
365
            return self._parse_expr(node) == self._parse_expr(target_name)
×
366

367
        if isinstance(node, ast.Name) and node.id == target_name:
×
368
            return True
×
369
        if isinstance(node, ast.Subscript):
×
370
            if isinstance(node.value, ast.Name) and node.value.id == target_name:
×
371
                return True
×
372
        return False
×
373

374
    def _is_dot_call(self, node):
3✔
375
        if isinstance(node, ast.Call):
3✔
376
            if isinstance(node.func, ast.Attribute) and node.func.attr == "dot":
×
377
                return True
×
378
            if isinstance(node.func, ast.Name) and node.func.id == "dot":
×
379
                return True
×
380
        if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult):
3✔
381
            return True
3✔
UNCOV
382
        return False
×
383

384
    def handle_dot(self, target, value_node):
3✔
385
        dot_node = None
3✔
386
        is_accumulate = False
3✔
387

388
        if self._is_dot_call(value_node):
3✔
389
            dot_node = value_node
3✔
390
        elif isinstance(value_node, ast.BinOp) and isinstance(value_node.op, ast.Add):
×
391
            if self._is_dot_call(value_node.left):
×
392
                dot_node = value_node.left
×
393
                if self._is_target(value_node.right, target):
×
394
                    is_accumulate = True
×
395
            elif self._is_dot_call(value_node.right):
×
396
                dot_node = value_node.right
×
397
                if self._is_target(value_node.left, target):
×
398
                    is_accumulate = True
×
399

400
        if not dot_node:
3✔
401
            return False
×
402

403
        arg0 = None
3✔
404
        arg1 = None
3✔
405

406
        if isinstance(dot_node, ast.Call):
3✔
NEW
407
            args = dot_node.args
×
NEW
408
            if len(args) != 2:
×
NEW
409
                return False
×
NEW
410
            arg0 = args[0]
×
NEW
411
            arg1 = args[1]
×
412
        elif isinstance(dot_node, ast.BinOp) and isinstance(dot_node.op, ast.MatMult):
3✔
413
            arg0 = dot_node.left
3✔
414
            arg1 = dot_node.right
3✔
415

416
        res_a = self.parse_arg(arg0)
3✔
417
        res_b = self.parse_arg(arg1)
3✔
418

419
        if not res_a[0] or not res_b[0]:
3✔
420
            return False
×
421

422
        name_a, subset_a, shape_a, indices_a = res_a
3✔
423
        name_b, subset_b, shape_b, indices_b = res_b
3✔
424

425
        if len(shape_a) != 1 or len(shape_b) != 1:
3✔
426
            return False
×
427

428
        n = shape_a[0]
3✔
429

430
        def get_stride(name, indices):
3✔
431
            if not indices:
3✔
432
                return "1"
3✔
433
            info = self.array_info[name]
3✔
434
            shapes = info["shapes"]
3✔
435
            ndim = info["ndim"]
3✔
436

437
            sliced_dim = -1
3✔
438
            for i, idx in enumerate(indices):
3✔
439
                if isinstance(idx, ast.Slice):
3✔
440
                    sliced_dim = i
3✔
441
                    break
3✔
442

443
            if sliced_dim == -1:
3✔
444
                return "1"
×
445

446
            stride = "1"
3✔
447
            for i in range(sliced_dim + 1, ndim):
3✔
448
                dim_size = shapes[i] if i < len(shapes) else f"_{name}_shape_{i}"
×
449
                if stride == "1":
×
450
                    stride = str(dim_size)
×
451
                else:
452
                    stride = f"({stride} * {dim_size})"
×
453
            return stride
3✔
454

455
        incx = get_stride(name_a, indices_a)
3✔
456
        incy = get_stride(name_b, indices_b)
3✔
457

458
        flat_subset_a = self.flatten_subset(name_a, subset_a)
3✔
459
        flat_subset_b = self.flatten_subset(name_b, subset_b)
3✔
460

461
        tmp_res = f"_dot_res_{self._get_unique_id()}"
3✔
462
        self.builder.add_container(tmp_res, Scalar(PrimitiveType.Double), False)
3✔
463
        self.symbol_table[tmp_res] = Scalar(PrimitiveType.Double)
3✔
464

465
        self.builder.add_dot(
3✔
466
            name_a, name_b, tmp_res, n, incx, incy, flat_subset_a, flat_subset_b
467
        )
468

469
        target_str = target if isinstance(target, str) else self._parse_expr(target)
3✔
470

471
        if is_accumulate:
3✔
472
            self.builder.add_assignment(target_str, f"{target_str} + {tmp_res}")
×
473
        else:
474
            self.builder.add_assignment(target_str, tmp_res)
3✔
475

476
        return True
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc