• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 17506072291

05 Sep 2025 10:42PM UTC coverage: 83.068% (+0.09%) from 82.98%
17506072291

Pull #783

github

web-flow
Another visualise
Pull Request #783: Diagonal assembly of matrices

60 of 61 new or added lines in 4 files covered. (98.36%)

31 existing lines in 4 files now uncovered.

3699 of 4453 relevant lines covered (83.07%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.61
/ffcx/codegeneration/integral_generator.py
1
# Copyright (C) 2015-2024 Martin Sandve Alnæs, Michal Habera, Igor Baratta, Chris Richardson
2
#
3
# Modified by Jørgen S. Dokken, 2024
4
#
5
# This file is part of FFCx. (https://www.fenicsproject.org)
6
#
7
# SPDX-License-Identifier:    LGPL-3.0-or-later
8
"""Integral generator."""
9

10
import collections
1✔
11
import logging
1✔
12
from numbers import Integral
1✔
13
from typing import Any
1✔
14

15
import basix
1✔
16
import ufl
1✔
17

18
import ffcx.codegeneration.lnodes as L
1✔
19
from ffcx.codegeneration import geometry
1✔
20
from ffcx.codegeneration.definitions import create_dof_index, create_quadrature_index
1✔
21
from ffcx.codegeneration.optimizer import optimize
1✔
22
from ffcx.ir.elementtables import piecewise_ttypes
1✔
23
from ffcx.ir.integral import BlockDataT
1✔
24
from ffcx.ir.representationutils import QuadratureRule
1✔
25

26
logger = logging.getLogger("ffcx")
1✔
27

28

29
def extract_dtype(v, vops: list[Any]):
1✔
30
    """Extract dtype from ufl expression v and its operands."""
31
    dtypes = []
1✔
32
    for op in vops:
1✔
33
        if hasattr(op, "dtype"):
1✔
34
            dtypes.append(op.dtype)
1✔
35
        elif hasattr(op, "symbol"):
1✔
36
            dtypes.append(op.symbol.dtype)
×
37
        elif isinstance(op, Integral):
1✔
38
            dtypes.append(L.DataType.INT)
1✔
39
        else:
40
            raise RuntimeError(f"Not expecting this type of operand {type(op)}")
×
41
    is_cond = isinstance(v, ufl.classes.Condition)
1✔
42
    return L.DataType.BOOL if is_cond else L.merge_dtypes(dtypes)
1✔
43

44

45
class IntegralGenerator:
1✔
46
    """Integral generator."""
47

48
    def __init__(self, ir, backend):
1✔
49
        """Initialise."""
50
        # Store ir
51
        self.ir = ir
1✔
52

53
        # Backend specific plugin with attributes
54
        # - symbols: for translating ufl operators to target language
55
        # - definitions: for defining backend specific variables
56
        # - access: for accessing backend specific variables
57
        self.backend = backend
1✔
58

59
        # Set of operator names code has been generated for, used in the
60
        # end for selecting necessary includes
61
        self._ufl_names = set()
1✔
62

63
        # Initialize lookup tables for variable scopes
64
        self.init_scopes()
1✔
65

66
        # Cache
67
        self.temp_symbols = {}
1✔
68

69
        # Set of counters used for assigning names to intermediate
70
        # variables
71
        self.symbol_counters = collections.defaultdict(int)
1✔
72

73
    def init_scopes(self):
1✔
74
        """Initialize variable scope dicts."""
75
        # Reset variables, separate sets for each quadrature rule
76
        self.scopes = {
1✔
77
            quadrature_rule: {} for quadrature_rule in self.ir.expression.integrand.keys()
78
        }
79
        self.scopes[(None, None)] = {}
1✔
80

81
    def set_var(self, quadrature_rule, domain, v, vaccess):
1✔
82
        """Set a new variable in variable scope dicts.
83

84
        Scope is determined by quadrature_rule which identifies the
85
        quadrature loop scope or None if outside quadrature loops.
86

87
        Args:
88
            quadrature_rule: Quadrature rule
89
            domain: The domain of the integral
90
            v: the ufl expression
91
            vaccess: the LNodes expression to access the value in the code
92
        """
93
        self.scopes[(domain, quadrature_rule)][v] = vaccess
1✔
94

95
    def get_var(self, quadrature_rule, domain, v):
1✔
96
        """Lookup ufl expression v in variable scope dicts.
97

98
        Scope is determined by quadrature rule which identifies the
99
        quadrature loop scope or None if outside quadrature loops.
100

101
        If v is not found in quadrature loop scope, the piecewise
102
        scope (None) is checked.
103

104
        Returns the LNodes expression to access the value in the code.
105
        """
106
        if v._ufl_is_literal_:
1✔
107
            return L.ufl_to_lnodes(v)
1✔
108

109
        # quadrature loop scope
110
        f = self.scopes[(domain, quadrature_rule)].get(v)
1✔
111

112
        # piecewise scope
113
        if f is None:
1✔
114
            f = self.scopes[(None, None)].get(v)
1✔
115
        return f
1✔
116

117
    def new_temp_symbol(self, basename):
1✔
118
        """Create a new code symbol named basename + running counter."""
119
        name = f"{basename}{self.symbol_counters[basename]:d}"
1✔
120
        self.symbol_counters[basename] += 1
1✔
121
        return L.Symbol(name, dtype=L.DataType.SCALAR)
1✔
122

123
    def get_temp_symbol(self, tempname, key):
1✔
124
        """Get a temporary symbol."""
125
        key = (tempname,) + key
1✔
126
        s = self.temp_symbols.get(key)
1✔
127
        defined = s is not None
1✔
128
        if not defined:
1✔
129
            s = self.new_temp_symbol(tempname)
1✔
130
            self.temp_symbols[key] = s
1✔
131
        return s, defined
1✔
132

133
    def generate(self, domain: basix.CellType):
1✔
134
        """Generate entire tabulate_tensor body.
135

136
        Assumes that the code returned from here will be wrapped in a
137
        context that matches a suitable version of the UFC
138
        tabulate_tensor signatures.
139
        """
140
        # Assert that scopes are empty: expecting this to be called only
141
        # once
142
        assert not any(d for d in self.scopes.values())
1✔
143

144
        parts = []
1✔
145

146
        # Generate the tables of quadrature points and weights
147
        parts += self.generate_quadrature_tables(domain)
1✔
148

149
        # Generate the tables of basis function values and
150
        # pre-integrated blocks
151
        parts += self.generate_element_tables(domain)
1✔
152

153
        # Generate the tables of geometry data that are needed
154
        parts += self.generate_geometry_tables()
1✔
155

156
        # Loop generation code will produce parts to go before
157
        # quadloops, to define the quadloops, and to go after the
158
        # quadloops
159
        all_preparts = []
1✔
160
        all_quadparts = []
1✔
161

162
        # Pre-definitions are collected across all quadrature loops to
163
        # improve re-use and avoid name clashes
164
        for cell, rule in self.ir.expression.integrand.keys():
1✔
165
            if domain == cell:
1✔
166
                # Generate code to compute piecewise constant scalar factors
167
                all_preparts += self.generate_piecewise_partition(rule, cell)
1✔
168

169
                # Generate code to integrate reusable blocks of final
170
                # element tensor
171
                all_quadparts += self.generate_quadrature_loop(rule, cell)
1✔
172

173
        # Collect parts before, during, and after quadrature loops
174
        parts += all_preparts
1✔
175
        parts += all_quadparts
1✔
176

177
        return L.StatementList(parts)
1✔
178

179
    def generate_quadrature_tables(self, domain: basix.CellType):
1✔
180
        """Generate static tables of quadrature points and weights."""
181
        parts: list[L.LNode] = []
1✔
182

183
        # No quadrature tables for custom (given argument)
184
        skip = ufl.custom_integral_types
1✔
185
        if self.ir.expression.integral_type in skip:
1✔
186
            return parts
×
187

188
        # Loop over quadrature rules
189
        for (cell, quadrature_rule), _ in self.ir.expression.integrand.items():
1✔
190
            if domain == cell:
1✔
191
                # Generate quadrature weights array
192
                wsym = self.backend.symbols.weights_table(quadrature_rule)
1✔
193
                parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)]
1✔
194

195
        # Add leading comment if there are any tables
196
        parts = L.commented_code_list(parts, "Quadrature rules")
1✔
197
        return parts
1✔
198

199
    def generate_geometry_tables(self):
1✔
200
        """Generate static tables of geometry data."""
201
        ufl_geometry = {
1✔
202
            ufl.geometry.FacetEdgeVectors: "facet_edge_vertices",
203
            ufl.geometry.CellFacetJacobian: "cell_facet_jacobian",
204
            ufl.geometry.CellRidgeJacobian: "cell_ridge_jacobian",
205
            ufl.geometry.ReferenceCellVolume: "reference_cell_volume",
206
            ufl.geometry.ReferenceFacetVolume: "reference_facet_volume",
207
            ufl.geometry.ReferenceCellEdgeVectors: "reference_cell_edge_vectors",
208
            ufl.geometry.ReferenceFacetEdgeVectors: "reference_facet_edge_vectors",
209
            ufl.geometry.ReferenceNormal: "reference_normals",
210
            ufl.geometry.FacetOrientation: "facet_orientation",
211
        }
212
        cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()}  # type: ignore
1✔
213

214
        for integrand in self.ir.expression.integrand.values():
1✔
215
            for attr in integrand["factorization"].nodes.values():
1✔
216
                mt = attr.get("mt")
1✔
217
                if mt is not None:
1✔
218
                    t = type(mt.terminal)
1✔
219
                    if t in ufl_geometry:
1✔
220
                        cells[t].add(
1✔
221
                            ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
222
                        )
223

224
        parts = []
1✔
225
        for i, cell_list in cells.items():
1✔
226
            for c in cell_list:
1✔
227
                parts.append(geometry.write_table(ufl_geometry[i], c))
1✔
228

229
        return parts
1✔
230

231
    def generate_element_tables(self, domain: basix.CellType):
1✔
232
        """Generate static tables.
233

234
        With precomputed element basis function values in quadrature points.
235
        """
236
        parts = []
1✔
237
        tables = self.ir.expression.unique_tables[domain]
1✔
238
        table_types = self.ir.expression.unique_table_types[domain]
1✔
239
        if self.ir.expression.integral_type in ufl.custom_integral_types:
1✔
240
            # Define only piecewise tables
241
            table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes]
×
242
        else:
243
            # Define all tables
244
            table_names = sorted(tables)
1✔
245

246
        for name in table_names:
1✔
247
            table = tables[name]
1✔
248
            parts += self.declare_table(name, table)
1✔
249

250
        # Add leading comment if there are any tables
251
        parts = L.commented_code_list(
1✔
252
            parts,
253
            [
254
                "Precomputed values of basis functions and precomputations",
255
                "FE* dimensions: [permutation][entities][points][dofs]",
256
            ],
257
        )
258
        return parts
1✔
259

260
    def declare_table(self, name, table):
1✔
261
        """Declare a table.
262

263
        If the dof dimensions of the table have dof rotations, apply
264
        these rotations.
265

266
        """
267
        table_symbol = L.Symbol(name, dtype=L.DataType.REAL)
1✔
268
        self.backend.symbols.element_tables[name] = table_symbol
1✔
269
        return [L.ArrayDecl(table_symbol, values=table, const=True)]
1✔
270

271
    def generate_quadrature_loop(self, quadrature_rule: QuadratureRule, domain: basix.CellType):
1✔
272
        """Generate quadrature loop with for this quadrature_rule."""
273
        # Generate varying partition
274
        definitions, intermediates_0 = self.generate_varying_partition(quadrature_rule, domain)
1✔
275

276
        # Generate dofblock parts, some of this will be placed before or after quadloop
277
        tensor_comp, intermediates_fw = self.generate_dofblock_partition(quadrature_rule, domain)
1✔
278
        assert all([isinstance(tc, L.Section) for tc in tensor_comp])
1✔
279

280
        # Check if we only have Section objects
281
        inputs = []
1✔
282
        for definition in definitions:
1✔
283
            assert isinstance(definition, L.Section)
1✔
284
            inputs += definition.output
1✔
285

286
        # Create intermediates section
287
        output = []
1✔
288
        declarations = []
1✔
289
        for fw in intermediates_fw:
1✔
290
            assert isinstance(fw, L.VariableDecl)
1✔
291
            output += [fw.symbol]
1✔
292
            declarations += [L.VariableDecl(fw.symbol, 0)]
1✔
293
            intermediates_0 += [L.Assign(fw.symbol, fw.value)]
1✔
294
        intermediates = [L.Section("Intermediates", intermediates_0, declarations, inputs, output)]
1✔
295

296
        iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
297
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
298

299
        code = definitions + intermediates + tensor_comp
1✔
300
        code = optimize(code, quadrature_rule)
1✔
301

302
        return [L.create_nested_for_loops([iq], code)]
1✔
303

304
    def generate_piecewise_partition(self, quadrature_rule, domain: basix.CellType):
1✔
305
        """Generate a piecewise partition."""
306
        # Get annotated graph of factorisation
307
        F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
308
        arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
309
        return self.generate_partition(arraysymbol, F, "piecewise", None, None)
1✔
310

311
    def generate_varying_partition(self, quadrature_rule, domain: basix.CellType):
1✔
312
        """Generate a varying partition."""
313
        # Get annotated graph of factorisation
314
        F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
315
        arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
316
        return self.generate_partition(arraysymbol, F, "varying", quadrature_rule, domain)
1✔
317

318
    def generate_partition(self, symbol, F, mode, quadrature_rule, domain):
1✔
319
        """Generate a partition."""
320
        definitions = []
1✔
321
        intermediates = []
1✔
322

323
        for i, attr in F.nodes.items():
1✔
324
            if attr["status"] != mode:
1✔
325
                continue
1✔
326
            v = attr["expression"]
1✔
327

328
            # Generate code only if the expression is not already in cache
329
            if not self.get_var(quadrature_rule, domain, v):
1✔
330
                if v._ufl_is_literal_:
1✔
331
                    vaccess = L.ufl_to_lnodes(v)
×
332
                elif mt := attr.get("mt"):
1✔
333
                    tabledata = attr.get("tr")
1✔
334

335
                    # Backend specific modified terminal translation
336
                    vaccess = self.backend.access.get(mt, tabledata, quadrature_rule)
1✔
337
                    vdef = self.backend.definitions.get(mt, tabledata, quadrature_rule, vaccess)
1✔
338

339
                    if vdef:
1✔
340
                        assert isinstance(vdef, L.Section)
1✔
341
                    # Only add if definition is unique.
342
                    # This can happen when using sub-meshes
343
                    if vdef not in definitions:
1✔
344
                        definitions += [vdef]
1✔
345
                else:
346
                    # Get previously visited operands
347
                    vops = [self.get_var(quadrature_rule, domain, op) for op in v.ufl_operands]
1✔
348
                    dtype = extract_dtype(v, vops)
1✔
349

350
                    # Mapping UFL operator to target language
351
                    self._ufl_names.add(v._ufl_handler_name_)
1✔
352
                    vexpr = L.ufl_to_lnodes(v, *vops)
1✔
353

354
                    j = len(intermediates)
1✔
355
                    vaccess = L.Symbol(f"{symbol.name}_{j}", dtype=dtype)
1✔
356
                    intermediates.append(L.VariableDecl(vaccess, vexpr))
1✔
357

358
                # Store access node for future reference
359
                self.set_var(quadrature_rule, domain, v, vaccess)
1✔
360

361
        # Optimize definitions
362
        definitions = optimize(definitions, quadrature_rule)
1✔
363
        return definitions, intermediates
1✔
364

365
    def generate_dofblock_partition(self, quadrature_rule: QuadratureRule, domain: basix.CellType):
1✔
366
        """Generate a dofblock partition."""
367
        block_contributions = self.ir.expression.integrand[(domain, quadrature_rule)][
1✔
368
            "block_contributions"
369
        ]
370
        quadparts = []
1✔
371
        blocks = [
1✔
372
            (blockmap, blockdata)
373
            for blockmap, contributions in sorted(block_contributions.items())
374
            for blockdata in contributions
375
        ]
376

377
        block_groups = collections.defaultdict(list)
1✔
378

379
        # Group loops by blockmap, in Vector elements each component has
380
        # a different blockmap
381
        for blockmap, blockdata in blocks:
1✔
382
            scalar_blockmap = []
1✔
383
            assert len(blockdata.ma_data) == len(blockmap)
1✔
384
            for i, b in enumerate(blockmap):
1✔
385
                bs = blockdata.ma_data[i].tabledata.block_size
1✔
386
                offset = blockdata.ma_data[i].tabledata.offset
1✔
387
                b = tuple([(idx - offset) // bs for idx in b])
1✔
388
                scalar_blockmap.append(b)
1✔
389
            block_groups[tuple(scalar_blockmap)].append(blockdata)
1✔
390

391
        intermediates = []
1✔
392
        for blockmap in block_groups:
1✔
393
            block_quadparts, intermediate = self.generate_block_parts(
1✔
394
                quadrature_rule, domain, blockmap, block_groups[blockmap]
395
            )
396
            intermediates += intermediate
1✔
397

398
            # Add computations
399
            quadparts.extend(block_quadparts)
1✔
400

401
        return quadparts, intermediates
1✔
402

403
    def get_arg_factors(self, blockdata, block_rank, quadrature_rule, domain, iq, indices):
1✔
404
        """Get arg factors."""
405
        arg_factors = []
1✔
406
        tables = []
1✔
407
        for i in range(block_rank):
1✔
408
            mad = blockdata.ma_data[i]
1✔
409
            td = mad.tabledata
1✔
410
            scope = self.ir.expression.integrand[(domain, quadrature_rule)]["modified_arguments"]
1✔
411
            mt = scope[mad.ma_index]
1✔
412
            arg_tables = []
1✔
413

414
            # Translate modified terminal to code
415
            # TODO: Move element table access out of backend?
416
            #       Not using self.backend.access.argument() here
417
            #       now because it assumes too much about indices.
418

419
            assert td.ttype != "zeros"
1✔
420

421
            if td.ttype == "ones":
1✔
422
                arg_factor = 1
1✔
423
            else:
424
                # Assuming B sparsity follows element table sparsity
425
                arg_factor, arg_tables = self.backend.access.table_access(
1✔
426
                    td, self.ir.expression.entity_type, mt.restriction, iq, indices[i]
427
                )
428

429
            tables += arg_tables
1✔
430
            arg_factors.append(arg_factor)
1✔
431

432
        return arg_factors, tables
1✔
433

434
    def generate_block_parts(
1✔
435
        self,
436
        quadrature_rule: QuadratureRule,
437
        domain: basix.CellType,
438
        blockmap: tuple,
439
        blocklist: list[BlockDataT],
440
    ):
441
        """Generate and return code parts for a given block.
442

443
        Returns parts occurring before, inside, and after the quadrature
444
        loop identified by the quadrature rule.
445

446
        Should be called with quadrature_rule=None for
447
        quadloop-independent blocks.
448
        """
449
        # The parts to return
450
        quadparts: list[L.LNode] = []
1✔
451
        intermediates: list[L.LNode] = []
1✔
452
        tables = []
1✔
453
        vars = []
1✔
454

455
        # RHS expressions grouped by LHS "dofmap"
456
        rhs_expressions = collections.defaultdict(list)
1✔
457

458
        block_rank = len(blockmap)
1✔
459
        iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
460
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
461

462
        A_shape = self.ir.expression.tensor_shape
1✔
463

464
        for blockdata in blocklist:
1✔
465
            B_indices = []
1✔
466
            for i in range(block_rank):
1✔
467
                table_ref = blockdata.ma_data[i].tabledata
1✔
468
                symbol = self.backend.symbols.argument_loop_index(i)
1✔
469
                index = create_dof_index(table_ref, symbol)
1✔
470
                B_indices.append(index)
1✔
471

472
            diagonalise = len(A_shape) == 1 and block_rank == 2
1✔
473
            if diagonalise:
1✔
474
                B_indices = [B_indices[0], B_indices[0]]
1✔
475

476
            ttypes = blockdata.ttypes
1✔
477
            if "zeros" in ttypes:
1✔
UNCOV
478
                raise RuntimeError(
×
479
                    "Not expecting zero arguments to be left in dofblock generation."
480
                )
481

482
            if len(blockdata.factor_indices_comp_indices) > 1:
1✔
UNCOV
483
                raise RuntimeError("Code generation for non-scalar integrals unsupported")
×
484

485
            # We have scalar integrand here, take just the factor index
486
            factor_index = blockdata.factor_indices_comp_indices[0][0]
1✔
487

488
            # Get factor expression
489
            F = self.ir.expression.integrand[(domain, quadrature_rule)]["factorization"]
1✔
490

491
            v = F.nodes[factor_index]["expression"]
1✔
492
            f = self.get_var(quadrature_rule, domain, v)
1✔
493

494
            # Quadrature weight was removed in representation, add it back now
495
            if self.ir.expression.integral_type in ufl.custom_integral_types:
1✔
UNCOV
496
                weights = self.backend.symbols.custom_weights_table
×
UNCOV
497
                weight = weights[iq.global_index]
×
498
            else:
499
                weights = self.backend.symbols.weights_table(quadrature_rule)
1✔
500
                weight = weights[iq.global_index]
1✔
501

502
            # Define fw = f * weight
503
            fw_rhs = L.float_product([f, weight])
1✔
504
            if not isinstance(fw_rhs, L.Product):
1✔
UNCOV
505
                fw = fw_rhs
×
506
            else:
507
                # Define and cache scalar temp variable
508
                key = (quadrature_rule, factor_index, blockdata.all_factors_piecewise)
1✔
509
                fw, defined = self.get_temp_symbol("fw", key)
1✔
510
                if not defined:
1✔
511
                    input = [f, weight]
1✔
512
                    # filter only L.Symbol in input
513
                    input = [i for i in input if isinstance(i, L.Symbol)]
1✔
514
                    output = [fw]
1✔
515

516
                    # assert input and output are Symbol objects
517
                    assert all(isinstance(i, L.Symbol) for i in input)
1✔
518
                    assert all(isinstance(o, L.Symbol) for o in output)
1✔
519

520
                    intermediates += [L.VariableDecl(fw, fw_rhs)]
1✔
521

522
            var = fw if isinstance(fw, L.Symbol) else fw.array
1✔
523
            vars += [var]
1✔
524
            assert not blockdata.transposed, "Not handled yet"
1✔
525

526
            # Fetch code to access modified arguments
527
            arg_factors, table = self.get_arg_factors(
1✔
528
                blockdata, block_rank, quadrature_rule, domain, iq, B_indices
529
            )
530
            tables += table
1✔
531
            # Define B_rhs = fw * arg_factors
532
            insert_rank = block_rank
1✔
533
            if diagonalise:
1✔
534
                insert_rank = 1
1✔
535
                B_indices = [B_indices[0]]
1✔
536
            B_rhs = L.float_product([fw] + arg_factors)
1✔
537

538
            A_indices = []
1✔
539
            for i in range(insert_rank):
1✔
540
                index = B_indices[i]
1✔
541
                tabledata = blockdata.ma_data[i].tabledata
1✔
542
                offset = tabledata.offset
1✔
543
                if len(blockmap[i]) == 1:
1✔
544
                    A_indices.append(index.global_index + offset)
1✔
545
                else:
546
                    block_size = blockdata.ma_data[i].tabledata.block_size
1✔
547
                    A_indices.append(block_size * index.global_index + offset)
1✔
548
            rhs_expressions[tuple(A_indices)].append(B_rhs)
1✔
549

550
        # List of statements to keep in the inner loop
551
        keep = collections.defaultdict(list)
1✔
552

553
        for indices in rhs_expressions:
1✔
554
            keep[indices] = rhs_expressions[indices]
1✔
555

556
        body: list[L.LNode] = []
1✔
557

558
        A = self.backend.symbols.element_tensor
1✔
559
        for indices in keep:
1✔
560
            multi_index = L.MultiIndex(list(indices), A_shape)
1✔
561
            for expression in keep[indices]:
1✔
562
                body.append(L.AssignAdd(A[multi_index], expression))
1✔
563

564
        # reverse B_indices
565
        B_indices = B_indices[::-1]
1✔
566
        body = [L.create_nested_for_loops(B_indices, body)]
1✔
567
        input = [*vars, *tables]
1✔
568
        output = [A]
1✔
569

570
        # Make sure we don't have repeated symbols in input
571
        input = list(set(input))
1✔
572

573
        # assert input and output are Symbol objects
574
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
575
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
576

577
        annotations = []
1✔
578
        if len(B_indices) > 1:
1✔
579
            annotations.append(L.Annotation.licm)
1✔
580

581
        quadparts += [L.Section("Tensor Computation", body, [], input, output, annotations)]
1✔
582

583
        return quadparts, intermediates
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc