• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 17512342502

06 Sep 2025 08:37AM UTC coverage: 82.98%. Remained the same
17512342502

Pull #783

github

jorgensd
Generalize test a bit
Pull Request #783: Diagonal assembly of matrices

47 of 56 new or added lines in 3 files covered. (83.93%)

22 existing lines in 2 files now uncovered.

3715 of 4477 relevant lines covered (82.98%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.66
/ffcx/codegeneration/integral_generator.py
1
# Copyright (C) 2015-2024 Martin Sandve Alnæs, Michal Habera, Igor Baratta, Chris Richardson
2
#
3
# Modified by Jørgen S. Dokken, 2024
4
#
5
# This file is part of FFCx. (https://www.fenicsproject.org)
6
#
7
# SPDX-License-Identifier:    LGPL-3.0-or-later
8
"""Integral generator."""
9

10
import collections
1✔
11
import logging
1✔
12
from numbers import Integral
1✔
13
from typing import Any
1✔
14

15
import basix
1✔
16
import ufl
1✔
17

18
import ffcx.codegeneration.lnodes as L
1✔
19
from ffcx.codegeneration import geometry
1✔
20
from ffcx.codegeneration.definitions import create_dof_index, create_quadrature_index
1✔
21
from ffcx.codegeneration.optimizer import optimize
1✔
22
from ffcx.ir.elementtables import piecewise_ttypes
1✔
23
from ffcx.ir.integral import BlockDataT, TensorPart
1✔
24
from ffcx.ir.representationutils import QuadratureRule
1✔
25

26
logger = logging.getLogger("ffcx")
1✔
27

28

29
def extract_dtype(v, vops: list[Any]):
1✔
30
    """Extract dtype from ufl expression v and its operands."""
31
    dtypes = []
1✔
32
    for op in vops:
1✔
33
        if hasattr(op, "dtype"):
1✔
34
            dtypes.append(op.dtype)
1✔
35
        elif hasattr(op, "symbol"):
1✔
36
            dtypes.append(op.symbol.dtype)
×
37
        elif isinstance(op, Integral):
1✔
38
            dtypes.append(L.DataType.INT)
1✔
39
        else:
40
            raise RuntimeError(f"Not expecting this type of operand {type(op)}")
×
41
    is_cond = isinstance(v, ufl.classes.Condition)
1✔
42
    return L.DataType.BOOL if is_cond else L.merge_dtypes(dtypes)
1✔
43

44

45
class IntegralGenerator:
1✔
46
    """Integral generator."""
47

48
    def __init__(self, ir, backend):
1✔
49
        """Initialise."""
50
        # Store ir
51
        self.ir = ir
1✔
52

53
        # Backend specific plugin with attributes
54
        # - symbols: for translating ufl operators to target language
55
        # - definitions: for defining backend specific variables
56
        # - access: for accessing backend specific variables
57
        self.backend = backend
1✔
58

59
        # Set of operator names code has been generated for, used in the
60
        # end for selecting necessary includes
61
        self._ufl_names = set()
1✔
62

63
        # Initialize lookup tables for variable scopes
64
        self.init_scopes()
1✔
65

66
        # Cache
67
        self.temp_symbols = {}
1✔
68

69
        # Set of counters used for assigning names to intermediate
70
        # variables
71
        self.symbol_counters = collections.defaultdict(int)
1✔
72

73
    def init_scopes(self):
1✔
74
        """Initialize variable scope dicts."""
75
        # Reset variables, separate sets for each quadrature rule
76
        self.scopes = {
1✔
77
            quadrature_rule: {} for quadrature_rule in self.ir.expression.integrand.keys()
78
        }
79
        self.scopes[(None, None)] = {}
1✔
80

81
    def set_var(self, quadrature_rule, domain, v, vaccess):
1✔
82
        """Set a new variable in variable scope dicts.
83

84
        Scope is determined by quadrature_rule which identifies the
85
        quadrature loop scope or None if outside quadrature loops.
86

87
        Args:
88
            quadrature_rule: Quadrature rule
89
            domain: The domain of the integral
90
            v: the ufl expression
91
            vaccess: the LNodes expression to access the value in the code
92
        """
93
        self.scopes[(domain, quadrature_rule)][v] = vaccess
1✔
94

95
    def get_var(self, quadrature_rule, domain, v):
1✔
96
        """Lookup ufl expression v in variable scope dicts.
97

98
        Scope is determined by quadrature rule which identifies the
99
        quadrature loop scope or None if outside quadrature loops.
100

101
        If v is not found in quadrature loop scope, the piecewise
102
        scope (None) is checked.
103

104
        Returns the LNodes expression to access the value in the code.
105
        """
106
        if v._ufl_is_literal_:
1✔
107
            return L.ufl_to_lnodes(v)
1✔
108

109
        # quadrature loop scope
110
        f = self.scopes[(domain, quadrature_rule)].get(v)
1✔
111

112
        # piecewise scope
113
        if f is None:
1✔
114
            f = self.scopes[(None, None)].get(v)
1✔
115
        return f
1✔
116

117
    def new_temp_symbol(self, basename):
1✔
118
        """Create a new code symbol named basename + running counter."""
119
        name = f"{basename}{self.symbol_counters[basename]:d}"
1✔
120
        self.symbol_counters[basename] += 1
1✔
121
        return L.Symbol(name, dtype=L.DataType.SCALAR)
1✔
122

123
    def get_temp_symbol(self, tempname, key):
1✔
124
        """Get a temporary symbol."""
125
        key = (tempname,) + key
1✔
126
        s = self.temp_symbols.get(key)
1✔
127
        defined = s is not None
1✔
128
        if not defined:
1✔
129
            s = self.new_temp_symbol(tempname)
1✔
130
            self.temp_symbols[key] = s
1✔
131
        return s, defined
1✔
132

133
    def generate(self, domain: basix.CellType):
1✔
134
        """Generate entire tabulate_tensor body.
135

136
        Assumes that the code returned from here will be wrapped in a
137
        context that matches a suitable version of the UFC
138
        tabulate_tensor signatures.
139
        """
140
        # Assert that scopes are empty: expecting this to be called only
141
        # once
142
        assert not any(d for d in self.scopes.values())
1✔
143

144
        parts = []
1✔
145

146
        # Generate the tables of quadrature points and weights
147
        parts += self.generate_quadrature_tables(domain)
1✔
148

149
        # Generate the tables of basis function values and
150
        # pre-integrated blocks
151
        parts += self.generate_element_tables(domain)
1✔
152

153
        # Generate the tables of geometry data that are needed
154
        parts += self.generate_geometry_tables()
1✔
155

156
        # Loop generation code will produce parts to go before
157
        # quadloops, to define the quadloops, and to go after the
158
        # quadloops
159
        all_preparts = []
1✔
160
        all_quadparts = []
1✔
161

162
        # Pre-definitions are collected across all quadrature loops to
163
        # improve re-use and avoid name clashes
164
        for cell, rule in self.ir.expression.integrand.keys():
1✔
165
            if domain == cell:
1✔
166
                # Generate code to compute piecewise constant scalar factors
167
                all_preparts += self.generate_piecewise_partition(rule, cell)
1✔
168

169
                # Generate code to integrate reusable blocks of final
170
                # element tensor
171
                all_quadparts += self.generate_quadrature_loop(rule, cell)
1✔
172

173
        # Collect parts before, during, and after quadrature loops
174
        parts += all_preparts
1✔
175
        parts += all_quadparts
1✔
176

177
        return L.StatementList(parts)
1✔
178

179
    def generate_quadrature_tables(self, domain: basix.CellType):
1✔
180
        """Generate static tables of quadrature points and weights."""
181
        parts: list[L.LNode] = []
1✔
182

183
        # No quadrature tables for custom (given argument)
184
        skip = ufl.custom_integral_types
1✔
185
        if self.ir.expression.integral_type in skip:
1✔
186
            return parts
×
187

188
        # Loop over quadrature rules
189
        for (cell, quadrature_rule), _ in self.ir.expression.integrand.items():
1✔
190
            if domain == cell:
1✔
191
                # Generate quadrature weights array
192
                wsym = self.backend.symbols.weights_table(quadrature_rule)
1✔
193
                parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)]
1✔
194

195
        # Add leading comment if there are any tables
196
        parts = L.commented_code_list(parts, "Quadrature rules")
1✔
197
        return parts
1✔
198

199
    def generate_geometry_tables(self):
1✔
200
        """Generate static tables of geometry data."""
201
        ufl_geometry = {
1✔
202
            ufl.geometry.FacetEdgeVectors: "facet_edge_vertices",
203
            ufl.geometry.CellFacetJacobian: "cell_facet_jacobian",
204
            ufl.geometry.CellRidgeJacobian: "cell_ridge_jacobian",
205
            ufl.geometry.ReferenceCellVolume: "reference_cell_volume",
206
            ufl.geometry.ReferenceFacetVolume: "reference_facet_volume",
207
            ufl.geometry.ReferenceCellEdgeVectors: "reference_cell_edge_vectors",
208
            ufl.geometry.ReferenceFacetEdgeVectors: "reference_facet_edge_vectors",
209
            ufl.geometry.ReferenceNormal: "reference_normals",
210
            ufl.geometry.FacetOrientation: "facet_orientation",
211
        }
212
        cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()}  # type: ignore
1✔
213

214
        for integrand in self.ir.expression.integrand.values():
1✔
215
            for attr in integrand["factorization"].nodes.values():
1✔
216
                mt = attr.get("mt")
1✔
217
                if mt is not None:
1✔
218
                    t = type(mt.terminal)
1✔
219
                    if t in ufl_geometry:
1✔
220
                        cells[t].add(
1✔
221
                            ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
222
                        )
223

224
        parts = []
1✔
225
        for i, cell_list in cells.items():
1✔
226
            for c in cell_list:
1✔
227
                parts.append(geometry.write_table(ufl_geometry[i], c))
1✔
228

229
        return parts
1✔
230

231
    def generate_element_tables(self, domain: basix.CellType):
1✔
232
        """Generate static tables.
233

234
        With precomputed element basis function values in quadrature points.
235
        """
236
        parts = []
1✔
237
        tables = self.ir.expression.unique_tables[domain]
1✔
238
        table_types = self.ir.expression.unique_table_types[domain]
1✔
239
        if self.ir.expression.integral_type in ufl.custom_integral_types:
1✔
240
            # Define only piecewise tables
241
            table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes]
×
242
        else:
243
            # Define all tables
244
            table_names = sorted(tables)
1✔
245

246
        for name in table_names:
1✔
247
            table = tables[name]
1✔
248
            parts += self.declare_table(name, table)
1✔
249

250
        # Add leading comment if there are any tables
251
        parts = L.commented_code_list(
1✔
252
            parts,
253
            [
254
                "Precomputed values of basis functions and precomputations",
255
                "FE* dimensions: [permutation][entities][points][dofs]",
256
            ],
257
        )
258
        return parts
1✔
259

260
    def declare_table(self, name, table):
1✔
261
        """Declare a table.
262

263
        If the dof dimensions of the table have dof rotations, apply
264
        these rotations.
265

266
        """
267
        table_symbol = L.Symbol(name, dtype=L.DataType.REAL)
1✔
268
        self.backend.symbols.element_tables[name] = table_symbol
1✔
269
        return [L.ArrayDecl(table_symbol, values=table, const=True)]
1✔
270

271
    def generate_quadrature_loop(self, quadrature_rule: QuadratureRule, domain: basix.CellType):
1✔
272
        """Generate quadrature loop with for this quadrature_rule."""
273
        # Generate varying partition
274
        definitions, intermediates_0 = self.generate_varying_partition(quadrature_rule, domain)
1✔
275

276
        # Generate dofblock parts, some of this will be placed before or after quadloop
277
        tensor_comp, intermediates_fw = self.generate_dofblock_partition(quadrature_rule, domain)
1✔
278
        assert all([isinstance(tc, L.Section) for tc in tensor_comp])
1✔
279

280
        # Check if we only have Section objects
281
        inputs = []
1✔
282
        for definition in definitions:
1✔
283
            assert isinstance(definition, L.Section)
1✔
284
            inputs += definition.output
1✔
285

286
        # Create intermediates section
287
        output = []
1✔
288
        declarations = []
1✔
289
        for fw in intermediates_fw:
1✔
290
            assert isinstance(fw, L.VariableDecl)
1✔
291
            output += [fw.symbol]
1✔
292
            declarations += [L.VariableDecl(fw.symbol, 0)]
1✔
293
            intermediates_0 += [L.Assign(fw.symbol, fw.value)]
1✔
294
        intermediates = [L.Section("Intermediates", intermediates_0, declarations, inputs, output)]
1✔
295

296
        iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
297
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
298

299
        code = definitions + intermediates + tensor_comp
1✔
300
        code = optimize(code, quadrature_rule)
1✔
301

302
        return [L.create_nested_for_loops([iq], code)]
1✔
303

304
    def generate_piecewise_partition(self, quadrature_rule, domain: basix.CellType):
1✔
305
        """Generate a piecewise partition."""
306
        # Get annotated graph of factorisation
307
        F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
308
        arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
309
        return self.generate_partition(arraysymbol, F, "piecewise", None, None)
1✔
310

311
    def generate_varying_partition(self, quadrature_rule, domain: basix.CellType):
1✔
312
        """Generate a varying partition."""
313
        # Get annotated graph of factorisation
314
        F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
315
        arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
316
        return self.generate_partition(arraysymbol, F, "varying", quadrature_rule, domain)
1✔
317

318
    def generate_partition(self, symbol, F, mode, quadrature_rule, domain):
1✔
319
        """Generate a partition."""
320
        definitions = []
1✔
321
        intermediates = []
1✔
322

323
        for i, attr in F.nodes.items():
1✔
324
            if attr["status"] != mode:
1✔
325
                continue
1✔
326
            v = attr["expression"]
1✔
327

328
            # Generate code only if the expression is not already in cache
329
            if not self.get_var(quadrature_rule, domain, v):
1✔
330
                if v._ufl_is_literal_:
1✔
331
                    vaccess = L.ufl_to_lnodes(v)
×
332
                elif mt := attr.get("mt"):
1✔
333
                    tabledata = attr.get("tr")
1✔
334

335
                    # Backend specific modified terminal translation
336
                    vaccess = self.backend.access.get(mt, tabledata, quadrature_rule)
1✔
337
                    vdef = self.backend.definitions.get(mt, tabledata, quadrature_rule, vaccess)
1✔
338

339
                    if vdef:
1✔
340
                        assert isinstance(vdef, L.Section)
1✔
341
                    # Only add if definition is unique.
342
                    # This can happen when using sub-meshes
343
                    if vdef not in definitions:
1✔
344
                        definitions += [vdef]
1✔
345
                else:
346
                    # Get previously visited operands
347
                    vops = [self.get_var(quadrature_rule, domain, op) for op in v.ufl_operands]
1✔
348
                    dtype = extract_dtype(v, vops)
1✔
349

350
                    # Mapping UFL operator to target language
351
                    self._ufl_names.add(v._ufl_handler_name_)
1✔
352
                    vexpr = L.ufl_to_lnodes(v, *vops)
1✔
353

354
                    j = len(intermediates)
1✔
355
                    vaccess = L.Symbol(f"{symbol.name}_{j}", dtype=dtype)
1✔
356
                    intermediates.append(L.VariableDecl(vaccess, vexpr))
1✔
357

358
                # Store access node for future reference
359
                self.set_var(quadrature_rule, domain, v, vaccess)
1✔
360

361
        # Optimize definitions
362
        definitions = optimize(definitions, quadrature_rule)
1✔
363
        return definitions, intermediates
1✔
364

365
    def generate_dofblock_partition(
1✔
366
        self,
367
        quadrature_rule: QuadratureRule,
368
        domain: basix.CellType,
369
    ):
370
        """Generate a dofblock partition."""
371
        block_contributions = self.ir.expression.integrand[(domain, quadrature_rule)][
1✔
372
            "block_contributions"
373
        ]
374
        quadparts = []
1✔
375
        blocks = [
1✔
376
            (blockmap, blockdata)
377
            for blockmap, contributions in sorted(block_contributions.items())
378
            for blockdata in contributions
379
        ]
380

381
        block_groups = collections.defaultdict(list)
1✔
382

383
        # Group loops by blockmap, in Vector elements each component has
384
        # a different blockmap
385
        for blockmap, blockdata in blocks:
1✔
386
            scalar_blockmap = []
1✔
387
            assert len(blockdata.ma_data) == len(blockmap)
1✔
388
            for i, b in enumerate(blockmap):
1✔
389
                bs = blockdata.ma_data[i].tabledata.block_size
1✔
390
                offset = blockdata.ma_data[i].tabledata.offset
1✔
391
                b = tuple([(idx - offset) // bs for idx in b])
1✔
392
                scalar_blockmap.append(b)
1✔
393
            block_groups[tuple(scalar_blockmap)].append(blockdata)
1✔
394

395
        intermediates = []
1✔
396
        for blockmap in block_groups:
1✔
397
            block_quadparts, intermediate = self.generate_block_parts(
1✔
398
                quadrature_rule,
399
                domain,
400
                blockmap,
401
                block_groups[blockmap],
402
            )
403
            intermediates += intermediate
1✔
404

405
            # Add computations
406
            quadparts.extend(block_quadparts)
1✔
407

408
        return quadparts, intermediates
1✔
409

410
    def get_arg_factors(self, blockdata, block_rank, quadrature_rule, domain, iq, indices):
1✔
411
        """Get arg factors."""
412
        arg_factors = []
1✔
413
        tables = []
1✔
414
        for i in range(block_rank):
1✔
415
            mad = blockdata.ma_data[i]
1✔
416
            td = mad.tabledata
1✔
417
            scope = self.ir.expression.integrand[(domain, quadrature_rule)]["modified_arguments"]
1✔
418
            mt = scope[mad.ma_index]
1✔
419
            arg_tables = []
1✔
420

421
            # Translate modified terminal to code
422
            # TODO: Move element table access out of backend?
423
            #       Not using self.backend.access.argument() here
424
            #       now because it assumes too much about indices.
425

426
            assert td.ttype != "zeros"
1✔
427

428
            if td.ttype == "ones":
1✔
429
                arg_factor = 1
1✔
430
            else:
431
                # Assuming B sparsity follows element table sparsity
432
                arg_factor, arg_tables = self.backend.access.table_access(
1✔
433
                    td, self.ir.expression.entity_type, mt.restriction, iq, indices[i]
434
                )
435

436
            tables += arg_tables
1✔
437
            arg_factors.append(arg_factor)
1✔
438

439
        return arg_factors, tables
1✔
440

441
    def generate_block_parts(
1✔
442
        self,
443
        quadrature_rule: QuadratureRule,
444
        domain: basix.CellType,
445
        blockmap: tuple,
446
        blocklist: list[BlockDataT],
447
    ):
448
        """Generate and return code parts for a given block.
449

450
        Returns parts occurring before, inside, and after the quadrature
451
        loop identified by the quadrature rule.
452

453
        Should be called with quadrature_rule=None for
454
        quadloop-independent blocks.
455
        """
456
        # The parts to return
457
        quadparts: list[L.LNode] = []
1✔
458
        intermediates: list[L.LNode] = []
1✔
459
        tables = []
1✔
460
        vars = []
1✔
461

462
        # RHS expressions grouped by LHS "dofmap"
463
        rhs_expressions = collections.defaultdict(list)
1✔
464

465
        block_rank = len(blockmap)
1✔
466
        iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
467
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
468

469
        A_shape = self.ir.expression.tensor_shape
1✔
470

471
        for blockdata in blocklist:
1✔
472
            B_indices = []
1✔
473
            for i in range(block_rank):
1✔
474
                table_ref = blockdata.ma_data[i].tabledata
1✔
475
                symbol = self.backend.symbols.argument_loop_index(i)
1✔
476
                index = create_dof_index(table_ref, symbol)
1✔
477
                B_indices.append(index)
1✔
478

479
            if self.ir.part == TensorPart.diagonal and block_rank == 2:
1✔
480
                assert len(A_shape) == 1
1✔
481
                # If off-diagonal in mixed element, skip contribution
482
                if B_indices[0].size() != B_indices[1].size():
1✔
483
                    B_indices = []
1✔
484
                else:
485
                    B_indices = [B_indices[0], B_indices[0]]
1✔
486

487
                # If on an off-diagonal in mixed element matrix, skip contribution
488
                if len(B_indices) == 0 and block_rank == 2:
1✔
489
                    continue
1✔
490

491
            ttypes = blockdata.ttypes
1✔
492
            if "zeros" in ttypes:
1✔
UNCOV
493
                raise RuntimeError(
×
494
                    "Not expecting zero arguments to be left in dofblock generation."
495
                )
496

497
            if len(blockdata.factor_indices_comp_indices) > 1:
1✔
UNCOV
498
                raise RuntimeError("Code generation for non-scalar integrals unsupported")
×
499

500
            # We have scalar integrand here, take just the factor index
501
            factor_index = blockdata.factor_indices_comp_indices[0][0]
1✔
502

503
            # Get factor expression
504
            F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
505

506
            v = F.nodes[factor_index]["expression"]
1✔
507
            f = self.get_var(quadrature_rule, domain, v)
1✔
508

509
            # Quadrature weight was removed in representation, add it back now
510
            if self.ir.expression.integral_type in ufl.custom_integral_types:
1✔
UNCOV
511
                weights = self.backend.symbols.custom_weights_table
×
512
                weight = weights[iq.global_index]
×
513
            else:
514
                weights = self.backend.symbols.weights_table(quadrature_rule)
1✔
515
                weight = weights[iq.global_index]
1✔
516

517
            # Define fw = f * weight
518
            fw_rhs = L.float_product([f, weight])
1✔
519
            if not isinstance(fw_rhs, L.Product):
1✔
UNCOV
520
                fw = fw_rhs
×
521
            else:
522
                # Define and cache scalar temp variable
523
                key = (quadrature_rule, factor_index, blockdata.all_factors_piecewise)
1✔
524
                fw, defined = self.get_temp_symbol("fw", key)
1✔
525
                if not defined:
1✔
526
                    input = [f, weight]
1✔
527
                    # filter only L.Symbol in input
528
                    input = [i for i in input if isinstance(i, L.Symbol)]
1✔
529
                    output = [fw]
1✔
530

531
                    # assert input and output are Symbol objects
532
                    assert all(isinstance(i, L.Symbol) for i in input)
1✔
533
                    assert all(isinstance(o, L.Symbol) for o in output)
1✔
534

535
                    intermediates += [L.VariableDecl(fw, fw_rhs)]
1✔
536

537
            var = fw if isinstance(fw, L.Symbol) else fw.array
1✔
538
            vars += [var]
1✔
539
            assert not blockdata.transposed, "Not handled yet"
1✔
540

541
            # Fetch code to access modified arguments
542
            arg_factors, table = self.get_arg_factors(
1✔
543
                blockdata, block_rank, quadrature_rule, domain, iq, B_indices
544
            )
545
            tables += table
1✔
546
            # Define B_rhs = fw * arg_factors
547
            insert_rank = block_rank
1✔
548
            if self.ir.part == TensorPart.diagonal:
1✔
549
                insert_rank = 1
1✔
550
                B_indices = [B_indices[0]]
1✔
551
            B_rhs = L.float_product([fw] + arg_factors)
1✔
552

553
            A_indices = []
1✔
554
            for i in range(insert_rank):
1✔
555
                index = B_indices[i]
1✔
556
                tabledata = blockdata.ma_data[i].tabledata
1✔
557
                offset = tabledata.offset
1✔
558
                if len(blockmap[i]) == 1:
1✔
559
                    A_indices.append(index.global_index + offset)
1✔
560
                else:
561
                    block_size = blockdata.ma_data[i].tabledata.block_size
1✔
562
                    A_indices.append(block_size * index.global_index + offset)
1✔
563
            rhs_expressions[tuple(A_indices)].append(B_rhs)
1✔
564

565
        # List of statements to keep in the inner loop
566
        keep = collections.defaultdict(list)
1✔
567

568
        for indices in rhs_expressions:
1✔
569
            keep[indices] = rhs_expressions[indices]
1✔
570

571
        body: list[L.LNode] = []
1✔
572

573
        A = self.backend.symbols.element_tensor
1✔
574
        for indices in keep:
1✔
575
            multi_index = L.MultiIndex(list(indices), A_shape)
1✔
576
            for expression in keep[indices]:
1✔
577
                body.append(L.AssignAdd(A[multi_index], expression))
1✔
578

579
        # reverse B_indices
580
        B_indices = B_indices[::-1]
1✔
581
        body = [L.create_nested_for_loops(B_indices, body)]
1✔
582
        input = [*vars, *tables]
1✔
583
        output = [A]
1✔
584

585
        # Make sure we don't have repeated symbols in input
586
        input = list(set(input))
1✔
587

588
        # assert input and output are Symbol objects
589
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
590
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
591

592
        annotations = []
1✔
593
        if len(B_indices) > 1:
1✔
594
            annotations.append(L.Annotation.licm)
1✔
595

596
        quadparts += [L.Section("Tensor Computation", body, [], input, output, annotations)]
1✔
597

598
        return quadparts, intermediates
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc