• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

daisytuner / docc / 22851518700

09 Mar 2026 11:34AM UTC coverage: 64.448% (+0.03%) from 64.415%
22851518700

Pull #556

github

web-flow
Merge 523df0a24 into 64dd02640
Pull Request #556: [MLIR] add support for linalg broadcast

24493 of 38004 relevant lines covered (64.45%)

386.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.01
/python/docc/compiler/compiled_sdfg.py
1
import ctypes
4✔
2
from docc.sdfg import Scalar, Array, Pointer, Structure, PrimitiveType
4✔
3

4
import numpy as np
4✔
5
import ml_dtypes
4✔
6

7

8
def idiv(a, b):
4✔
9
    """Integer division (floor division for positive numbers)."""
10
    return int(a) // int(b)
4✔
11

12

13
# Evaluation context for shape expressions
14
_EVAL_GLOBALS = {"idiv": idiv}
4✔
15

16
_CTYPES_MAP = {
4✔
17
    PrimitiveType.Bool: ctypes.c_bool,
18
    PrimitiveType.Int8: ctypes.c_int8,
19
    PrimitiveType.Int16: ctypes.c_int16,
20
    PrimitiveType.Int32: ctypes.c_int32,
21
    PrimitiveType.Int64: ctypes.c_int64,
22
    PrimitiveType.UInt8: ctypes.c_uint8,
23
    PrimitiveType.UInt16: ctypes.c_uint16,
24
    PrimitiveType.UInt32: ctypes.c_uint32,
25
    PrimitiveType.UInt64: ctypes.c_uint64,
26
    PrimitiveType.Float: ctypes.c_float,
27
    PrimitiveType.Double: ctypes.c_double,
28
    # Half and BFloat are 2 bytes, use c_uint16 for raw storage
29
    PrimitiveType.Half: ctypes.c_uint16,
30
    PrimitiveType.BFloat: ctypes.c_uint16,
31
}
32

33

34
class CompiledSDFG:
4✔
35
    def __init__(
4✔
36
        self,
37
        lib_path,
38
        sdfg,
39
        shape_sources=None,
40
        structure_member_info=None,
41
        output_args=None,
42
        output_shapes=None,
43
        output_strides=None,
44
    ):
45
        self.lib_path = lib_path
4✔
46
        self.sdfg = sdfg
4✔
47
        self.shape_sources = shape_sources or []
4✔
48
        self.structure_member_info = structure_member_info or {}
4✔
49
        self.lib = ctypes.CDLL(lib_path)
4✔
50
        self.func = getattr(self.lib, sdfg.name)
4✔
51

52
        # Check for output args
53
        self.output_args = output_args or []
4✔
54
        if not self.output_args and hasattr(sdfg, "metadata"):
4✔
55
            out_args_str = sdfg.metadata("output_args")
4✔
56
            if out_args_str:
4✔
57
                self.output_args = out_args_str.split(",")
×
58

59
        self.output_shapes = output_shapes or {}
4✔
60
        self.output_strides = output_strides or {}
4✔
61

62
        # Cache for ctypes structure definitions
63
        self._ctypes_structures = {}
4✔
64

65
        # Set up argument types
66
        self.arg_names = sdfg.arguments
4✔
67
        self.arg_types = []
4✔
68
        self.arg_sdfg_types = []  # Keep track of original sdfg types
4✔
69
        for arg_name in sdfg.arguments:
4✔
70
            arg_type = sdfg.type(arg_name)
4✔
71
            self.arg_sdfg_types.append(arg_type)
4✔
72
            ct_type = self._get_ctypes_type(arg_type)
4✔
73
            self.arg_types.append(ct_type)
4✔
74

75
        self.func.argtypes = self.arg_types
4✔
76

77
        # Set up return type
78
        self.func.restype = self._get_ctypes_type(sdfg.return_type)
4✔
79

80
    def _convert_to_python_syntax(self, expr_str):
4✔
81
        import re
4✔
82

83
        result = expr_str
4✔
84

85
        while True:
4✔
86
            pattern = r"([a-zA-Z_][a-zA-Z0-9_]*)\(([^()]+)\)"
4✔
87
            match = re.search(pattern, result)
4✔
88
            if not match:
4✔
89
                break
4✔
90

91
            name = match.group(1)
4✔
92
            index = match.group(2)
4✔
93

94
            # Skip known function names
95
            known_functions = {
4✔
96
                "int",
97
                "float",
98
                "abs",
99
                "min",
100
                "max",
101
                "sum",
102
                "len",
103
                "idiv",
104
            }
105
            if name.lower() in known_functions:
4✔
106
                # Use unique delimiters that won't appear in expressions
107
                placeholder = f"@@@FUNC@@@{name}@@@{index}@@@END@@@"
4✔
108
                result = result[: match.start()] + placeholder + result[match.end() :]
4✔
109
            else:
110
                result = (
×
111
                    result[: match.start()] + f"{name}[{index}]" + result[match.end() :]
112
                )
113

114
        result = re.sub(
4✔
115
            r"@@@FUNC@@@([a-zA-Z_][a-zA-Z0-9_]*)@@@(.+?)@@@END@@@", r"\1(\2)", result
116
        )
117

118
        return result
4✔
119

120
    def _create_ctypes_structure(self, struct_name):
4✔
121
        """Create a ctypes Structure class for the given structure name."""
122
        if struct_name in self._ctypes_structures:
4✔
123
            return self._ctypes_structures[struct_name]
×
124

125
        if struct_name not in self.structure_member_info:
4✔
126
            raise ValueError(f"Structure '{struct_name}' not found in member info")
×
127

128
        # Get member info: {member_name: (index, type)}
129
        members = self.structure_member_info[struct_name]
4✔
130
        # Sort by index to get correct order
131
        sorted_members = sorted(members.items(), key=lambda x: x[1][0])
4✔
132

133
        # Build _fields_ for ctypes.Structure
134
        fields = []
4✔
135
        for member_name, (index, member_type) in sorted_members:
4✔
136
            ct_type = self._get_ctypes_type(member_type)
4✔
137
            fields.append((member_name, ct_type))
4✔
138

139
        # Create the ctypes Structure class dynamically
140
        class CStructure(ctypes.Structure):
4✔
141
            _fields_ = fields
4✔
142

143
        self._ctypes_structures[struct_name] = CStructure
4✔
144
        return CStructure
4✔
145

146
    def _get_ctypes_type(self, sdfg_type):
4✔
147
        if isinstance(sdfg_type, Scalar):
4✔
148
            return _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
4✔
149
        elif isinstance(sdfg_type, Array):
4✔
150
            # Arrays are passed as pointers
151
            elem_type = _CTYPES_MAP.get(sdfg_type.primitive_type, ctypes.c_void_p)
×
152
            return ctypes.POINTER(elem_type)
×
153
        elif isinstance(sdfg_type, Pointer):
4✔
154
            # Check if pointee is a Structure
155
            # Note: has_pointee_type() is guaranteed to exist on Pointer instances from C++ bindings
156
            if sdfg_type.has_pointee_type():
4✔
157
                pointee = sdfg_type.pointee_type
4✔
158
                if isinstance(pointee, Structure):
4✔
159
                    # Create ctypes structure and return pointer to it
160
                    struct_class = self._create_ctypes_structure(pointee.name)
4✔
161
                    return ctypes.POINTER(struct_class)
4✔
162
                elif isinstance(pointee, Scalar):
4✔
163
                    elem_type = _CTYPES_MAP.get(pointee.primitive_type, ctypes.c_void_p)
4✔
164
                    return ctypes.POINTER(elem_type)
4✔
165
            return ctypes.c_void_p
×
166
        return ctypes.c_void_p
×
167

168
    def _convert_return_value(self, func_result, shape_symbol_values):
4✔
169
        return_type = self.sdfg.return_type
4✔
170

171
        if isinstance(return_type, Pointer):
4✔
172
            if return_type.has_pointee_type():
×
173
                pointee = return_type.pointee_type
×
174
                if isinstance(pointee, Scalar):
×
175
                    # Pointer to scalar element type - need to determine array size
176
                    # Get return shape from metadata if available
177
                    return_shape_str = self.sdfg.metadata("return_shape")
×
178
                    if return_shape_str:
×
179
                        # Strip brackets (metadata may be "[10,10]" format)
180
                        return_shape_str = return_shape_str.strip("[]")
×
181
                        shape = []
×
182
                        for dim_str in return_shape_str.split(","):
×
183
                            try:
×
184
                                eval_str = self._convert_to_python_syntax(str(dim_str))
×
185
                                val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
×
186
                                shape.append(int(val))
×
187
                            except Exception:
×
188
                                # Can't evaluate shape, return raw pointer
189
                                return func_result
×
190

191
                        # Determine numpy dtype from primitive type
192
                        dtype_map = {
×
193
                            PrimitiveType.Float: np.float32,
194
                            PrimitiveType.Double: np.float64,
195
                            PrimitiveType.Int8: np.int8,
196
                            PrimitiveType.Int16: np.int16,
197
                            PrimitiveType.Int32: np.int32,
198
                            PrimitiveType.Int64: np.int64,
199
                            PrimitiveType.UInt8: np.uint8,
200
                            PrimitiveType.UInt16: np.uint16,
201
                            PrimitiveType.UInt32: np.uint32,
202
                            PrimitiveType.UInt64: np.uint64,
203
                            PrimitiveType.Bool: np.bool_,
204
                            PrimitiveType.Half: np.float16,
205
                            PrimitiveType.BFloat: ml_dtypes.bfloat16,
206
                        }
207
                        dtype = dtype_map.get(pointee.primitive_type, np.float64)
×
208

209
                        # Calculate total size
210
                        total_size = 1
×
211
                        for dim in shape:
×
212
                            total_size *= dim
×
213

214
                        # Create numpy array from pointer
215
                        ct_type = _CTYPES_MAP.get(
×
216
                            pointee.primitive_type, ctypes.c_double
217
                        )
218
                        arr_type = ct_type * total_size
×
219
                        # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
220
                        # doesn't support these types (PEP 3118 buffer format limitation)
221
                        if pointee.primitive_type in (
×
222
                            PrimitiveType.Half,
223
                            PrimitiveType.BFloat,
224
                        ):
225
                            byte_size = total_size * 2  # Half and BFloat are 2 bytes
×
226
                            arr = np.frombuffer(
×
227
                                (ctypes.c_char * byte_size).from_address(
228
                                    ctypes.cast(func_result, ctypes.c_void_p).value
229
                                ),
230
                                dtype=dtype,
231
                            ).copy()
232
                        else:
233
                            arr = np.ctypeslib.as_array(
×
234
                                ctypes.cast(
235
                                    func_result, ctypes.POINTER(arr_type)
236
                                ).contents
237
                            )
238
                        return arr.reshape(shape)
×
239
                    else:
240
                        # No shape info - try to infer from input shapes
241
                        # For identity-like operations, the output shape matches input
242
                        if len(self.shape_sources) > 0 and len(shape_symbol_values) > 0:
×
243
                            # Use first input's shape as a fallback
244
                            shape = []
×
245
                            for i in range(len(self.shape_sources)):
×
246
                                if f"_s{i}" in shape_symbol_values:
×
247
                                    shape.append(shape_symbol_values[f"_s{i}"])
×
248

249
                            if shape:
×
250
                                dtype_map = {
×
251
                                    PrimitiveType.Float: np.float32,
252
                                    PrimitiveType.Double: np.float64,
253
                                    PrimitiveType.Int8: np.int8,
254
                                    PrimitiveType.Int16: np.int16,
255
                                    PrimitiveType.Int32: np.int32,
256
                                    PrimitiveType.Int64: np.int64,
257
                                    PrimitiveType.UInt8: np.uint8,
258
                                    PrimitiveType.UInt16: np.uint16,
259
                                    PrimitiveType.UInt32: np.uint32,
260
                                    PrimitiveType.UInt64: np.uint64,
261
                                    PrimitiveType.Bool: np.bool_,
262
                                    PrimitiveType.Half: np.float16,
263
                                    PrimitiveType.BFloat: ml_dtypes.bfloat16,
264
                                }
265
                                dtype = dtype_map.get(
×
266
                                    pointee.primitive_type, np.float64
267
                                )
268

269
                                total_size = 1
×
270
                                for dim in shape:
×
271
                                    total_size *= dim
×
272

273
                                ct_type = _CTYPES_MAP.get(
×
274
                                    pointee.primitive_type, ctypes.c_double
275
                                )
276
                                arr_type = ct_type * total_size
×
277
                                # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
278
                                # doesn't support these types (PEP 3118 buffer format limitation)
279
                                if pointee.primitive_type in (
×
280
                                    PrimitiveType.Half,
281
                                    PrimitiveType.BFloat,
282
                                ):
283
                                    byte_size = (
×
284
                                        total_size * 2
285
                                    )  # Half and BFloat are 2 bytes
286
                                    arr = np.frombuffer(
×
287
                                        (ctypes.c_char * byte_size).from_address(
288
                                            ctypes.cast(
289
                                                func_result, ctypes.c_void_p
290
                                            ).value
291
                                        ),
292
                                        dtype=dtype,
293
                                    ).copy()
294
                                else:
295
                                    arr = np.ctypeslib.as_array(
×
296
                                        ctypes.cast(
297
                                            func_result, ctypes.POINTER(arr_type)
298
                                        ).contents
299
                                    )
300
                                return arr.reshape(shape)
×
301

302
                        # Can't determine shape, return raw pointer
303
                        return func_result
×
304
        elif isinstance(return_type, Scalar):
4✔
305
            return func_result
4✔
306

307
        return func_result
×
308

309
    def __call__(self, *args):
4✔
310
        # Identify user arguments vs implicit arguments (shapes, return values)
311

312
        # 1. Compute shape symbol values from user args input
313
        shape_symbol_values = {}
4✔
314
        for u_idx, dim_idx in self.shape_sources:
4✔
315
            if u_idx < len(args):
4✔
316
                val = args[u_idx].shape[dim_idx]
4✔
317
                s_idx = self.shape_sources.index((u_idx, dim_idx))
4✔
318
                shape_symbol_values[f"_s{s_idx}"] = val
4✔
319

320
        # Add input arrays to the shape context for expressions with indirect access
321
        # This allows evaluating expressions like A_row[0] at runtime
322
        user_arg_idx = 0
4✔
323
        for name in self.arg_names:
4✔
324
            if name in self.output_args:
4✔
325
                continue
4✔
326
            if name.startswith("_s") and name[2:].isdigit():
4✔
327
                continue
4✔
328

329
            # Must be a user parameter - add it to shape context if it's an array
330
            if user_arg_idx < len(args):
4✔
331
                val = args[user_arg_idx]
4✔
332
                if isinstance(val, (int, float, np.integer, np.floating)):
4✔
333
                    shape_symbol_values[name] = val
4✔
334
                elif np is not None and isinstance(val, np.ndarray):
4✔
335
                    # Add numpy arrays to context for indirect access shape evaluation
336
                    shape_symbol_values[name] = val
4✔
337
                user_arg_idx += 1
4✔
338

339
        param_arg_idx = 0
4✔
340
        for name in self.arg_names:
4✔
341
            if name in self.output_args:
4✔
342
                continue
4✔
343
            if name.startswith("_s") and name[2:].isdigit():
4✔
344
                continue
4✔
345

346
            # Must be a user parameter
347
            if param_arg_idx < len(args):
4✔
348
                val = args[param_arg_idx]
4✔
349
                if isinstance(val, (int, float, np.integer, np.floating)):
4✔
350
                    shape_symbol_values[name] = val
4✔
351
                param_arg_idx += 1
4✔
352

353
        converted_args = []
4✔
354
        structure_refs = []
4✔
355
        return_buffers = {}
4✔
356

357
        next_user_arg_idx = 0
4✔
358

359
        for i, arg_name in enumerate(self.arg_names):
4✔
360
            target_type = self.arg_types[i]
4✔
361

362
            if arg_name in self.output_args:
4✔
363
                base_type = target_type._type_
4✔
364

365
                # If array (pointer type) and we have shape info, we need to allocate array.
366
                # If not in output_shapes, assume scalar return (pointer to single value).
367
                if arg_name in self.output_shapes:
4✔
368
                    size = 1
4✔
369
                    dims = self.output_shapes[arg_name]
4✔
370
                    # Evaluate
371
                    for dim_str in dims:
4✔
372
                        try:
4✔
373
                            # Convert SDFG parentheses notation to Python bracket notation
374
                            # e.g., "A_row(0)" -> "A_row[0]"
375
                            eval_str = self._convert_to_python_syntax(str(dim_str))
4✔
376
                            val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
4✔
377
                            size *= int(val)
4✔
378
                        except Exception as e:
×
379
                            raise RuntimeError(
×
380
                                f"Could not evaluate shape {dim_str} for {arg_name}: {e}"
381
                            )
382

383
                    buf_type = base_type * size
4✔
384
                    buf = buf_type()
4✔
385
                    # Store sdfg_type for proper dtype conversion (needed for Half/BFloat)
386
                    sdfg_type = self.arg_sdfg_types[i]
4✔
387
                    return_buffers[arg_name] = (buf, size, dims, sdfg_type)
4✔
388
                    converted_args.append(
4✔
389
                        ctypes.cast(ctypes.addressof(buf), target_type)
390
                    )
391
                    continue
4✔
392

393
                # Scalar Return (Pointer(Scalar))
394
                buf = base_type()
4✔
395
                sdfg_type = self.arg_sdfg_types[i]
4✔
396
                return_buffers[arg_name] = (buf, 1, None, sdfg_type)
4✔
397
                converted_args.append(ctypes.byref(buf))
4✔
398
                continue
4✔
399

400
            if arg_name.startswith("_s") and arg_name[2:].isdigit():
4✔
401
                s_idx = int(arg_name[2:])
4✔
402
                if f"_s{s_idx}" in shape_symbol_values:
4✔
403
                    val = shape_symbol_values[f"_s{s_idx}"]
4✔
404
                    converted_args.append(ctypes.c_int64(val))
4✔
405
                else:
406
                    converted_args.append(ctypes.c_int64(0))
×
407
                continue
4✔
408

409
            # User Argument
410
            if next_user_arg_idx >= len(args):
4✔
411
                raise ValueError("Not enough arguments provided")
×
412

413
            arg = args[next_user_arg_idx]
4✔
414
            next_user_arg_idx += 1
4✔
415

416
            # ... Conversion logic (numpy to ctypes) ...
417
            sdfg_type = self.arg_sdfg_types[i]
4✔
418

419
            if np is not None and isinstance(arg, np.ndarray):
4✔
420
                if hasattr(target_type, "contents"):
4✔
421
                    converted_args.append(arg.ctypes.data_as(target_type))
4✔
422
                else:
423
                    converted_args.append(arg)
×
424
            elif (
4✔
425
                sdfg_type
426
                and isinstance(sdfg_type, Pointer)
427
                and sdfg_type.has_pointee_type()
428
                and isinstance(sdfg_type.pointee_type, Structure)
429
            ):
430
                # Struct logic
431
                struct_name = sdfg_type.pointee_type.name
4✔
432
                struct_class = self._ctypes_structures.get(struct_name)
4✔
433
                members = self.structure_member_info[struct_name]
4✔
434
                sorted_members = sorted(members.items(), key=lambda x: x[1][0])
4✔
435
                struct_values = {}
4✔
436
                for member_name, (index, member_type) in sorted_members:
4✔
437
                    if hasattr(arg, member_name):
4✔
438
                        struct_values[member_name] = getattr(arg, member_name)
4✔
439
                c_struct = struct_class(**struct_values)
4✔
440
                structure_refs.append(c_struct)
4✔
441
                converted_args.append(ctypes.pointer(c_struct))
4✔
442
            else:
443
                converted_args.append(
4✔
444
                    target_type(arg)
445
                )  # Explicit cast to ensure int stays int
446

447
        func_result = self.func(*converted_args)
4✔
448

449
        # Process returns
450
        results = []
4✔
451
        sorted_ret_names = sorted(
4✔
452
            return_buffers.keys(), key=lambda x: int(x.split("_")[-1])
453
        )
454

455
        for name in sorted_ret_names:
4✔
456
            buf, size, dims, sdfg_type = return_buffers[name]
4✔
457
            if size == 1 and dims is None:
4✔
458
                # Scalar
459
                # buf is c_double / c_int instance
460
                results.append(buf.value)
4✔
461
            else:
462
                # Array
463
                # buf is (c_double * size) instance.
464
                # Convert to numpy
465
                if np is not None:
4✔
466
                    # Determine the target dtype from SDFG type
467
                    target_dtype = None
4✔
468
                    primitive_type = None
4✔
469
                    if isinstance(sdfg_type, Pointer) and sdfg_type.has_pointee_type():
4✔
470
                        pointee = sdfg_type.pointee_type
4✔
471
                        if isinstance(pointee, Scalar):
4✔
472
                            primitive_type = pointee.primitive_type
4✔
473
                            dtype_map = {
4✔
474
                                PrimitiveType.Float: np.float32,
475
                                PrimitiveType.Double: np.float64,
476
                                PrimitiveType.Int8: np.int8,
477
                                PrimitiveType.Int16: np.int16,
478
                                PrimitiveType.Int32: np.int32,
479
                                PrimitiveType.Int64: np.int64,
480
                                PrimitiveType.UInt8: np.uint8,
481
                                PrimitiveType.UInt16: np.uint16,
482
                                PrimitiveType.UInt32: np.uint32,
483
                                PrimitiveType.UInt64: np.uint64,
484
                                PrimitiveType.Bool: np.bool_,
485
                                PrimitiveType.Half: np.float16,
486
                                PrimitiveType.BFloat: ml_dtypes.bfloat16,
487
                            }
488
                            target_dtype = dtype_map.get(primitive_type)
4✔
489

490
                    # For Half/BFloat, use np.frombuffer since np.ctypeslib.as_array
491
                    # doesn't support these types (PEP 3118 buffer format limitation)
492
                    if primitive_type in (PrimitiveType.Half, PrimitiveType.BFloat):
4✔
493
                        byte_size = size * 2  # Half and BFloat are 2 bytes
4✔
494
                        arr = np.frombuffer(
4✔
495
                            (ctypes.c_char * byte_size).from_address(
496
                                ctypes.addressof(buf)
497
                            ),
498
                            dtype=target_dtype,
499
                        ).copy()
500
                    else:
501
                        # Create numpy array from buffer
502
                        arr = np.ctypeslib.as_array(buf)  # 1D
4✔
503
                    if dims:
4✔
504
                        # Reshape
505
                        try:
4✔
506
                            shape = []
4✔
507
                            for dim_str in dims:
4✔
508
                                eval_str = self._convert_to_python_syntax(str(dim_str))
4✔
509
                                val = eval(eval_str, _EVAL_GLOBALS, shape_symbol_values)
4✔
510
                                shape.append(int(val))
4✔
511

512
                            # Use strides directly if available
513
                            if name in self.output_strides:
4✔
514
                                stride_strs = self.output_strides[name]
4✔
515
                                try:
4✔
516
                                    # Evaluate stride expressions and convert to byte strides
517
                                    itemsize = arr.itemsize
4✔
518
                                    byte_strides = tuple(
4✔
519
                                        int(
520
                                            eval(
521
                                                self._convert_to_python_syntax(str(s)),
522
                                                {},
523
                                                shape_symbol_values,
524
                                            )
525
                                        )
526
                                        * itemsize
527
                                        for s in stride_strs
528
                                    )
529
                                    arr = np.lib.stride_tricks.as_strided(
4✔
530
                                        arr, shape=shape, strides=byte_strides
531
                                    )
532
                                except:
4✔
533
                                    arr = arr.reshape(shape)
4✔
534
                            else:
535
                                arr = arr.reshape(shape)
×
536
                        except:
×
537
                            pass
×
538
                    results.append(arr)
4✔
539
                else:
540
                    # fallback list
541
                    results.append(list(buf))
×
542

543
        if len(results) == 1:
4✔
544
            return results[0]
4✔
545
        elif len(results) > 1:
4✔
546
            return tuple(results)
4✔
547

548
        # No output args - check if function has a non-void return type
549
        if func_result is not None:
4✔
550
            return self._convert_return_value(func_result, shape_symbol_values)
4✔
551

552
        return None
4✔
553

554
    def get_return_shape(self, *args):
4✔
555
        shape_str = self.sdfg.metadata("return_shape")
4✔
556
        if not shape_str:
4✔
557
            return None
4✔
558

559
        shape_exprs = shape_str.split(",")
×
560

561
        # Reconstruct shape values
562
        shape_values = {}
×
563
        for i, (arg_idx, dim_idx) in enumerate(self.shape_sources):
×
564
            arg = args[arg_idx]
×
565
            if np is not None and isinstance(arg, np.ndarray):
×
566
                val = arg.shape[dim_idx]
×
567
                shape_values[f"_s{i}"] = val
×
568

569
        # Add scalar arguments to shape_values
570
        # We assume the first len(args) arguments in sdfg.arguments correspond to the user arguments
571
        if hasattr(self.sdfg, "arguments"):
×
572
            for arg_name, arg_val in zip(self.sdfg.arguments, args):
×
573
                if isinstance(arg_val, (int, np.integer)):
×
574
                    shape_values[arg_name] = int(arg_val)
×
575

576
        evaluated_shape = []
×
577
        for expr in shape_exprs:
×
578
            # Simple evaluation using eval with shape_values
579
            # Warning: eval is unsafe, but here expressions come from our compiler
580
            try:
×
581
                val = eval(expr, _EVAL_GLOBALS, shape_values)
×
582
                evaluated_shape.append(int(val))
×
583
            except Exception:
×
584
                return None
×
585

586
        return tuple(evaluated_shape)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc