• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 7299320904

22 Dec 2023 11:18AM UTC coverage: 80.35% (+0.4%) from 79.915%
7299320904

push

github

web-flow
[Sum Factorization - 1] Transition to 1D basis evaluation matrices for hexes and quads (#642)

* add multi index and tensor tables

* update parameters

* Updates to MultiIndex

* Add a simple formatting rule in C

* Replace FlattenedArray

* Minor fixes

* minor fix to pass the tests

* hopefully fix some tests

* update quadrature rule

* Fix MultiIndex when zero size

* Minor tweak

* minor update to documentaion

* update quadrature

* update entity dofs

* fix tensor factors

* update code

* add mass action

* update mass action

* create quadrature rule with tensor factors

* fix tablse

* user tensor structure for matrices as well

* improve tests

* minor improvements

* Fix error

* Simplify and fix mypy

* Fix options parser for bool types

* update quadrature generation

* fixes for peicewise tabledata

* add precendece to MultiIndex

* fix quadrature permutation

* update loop hoisting

* Try with np.prod

* add test

* update test

* update test

* update test for hexes

* use property for dim

* use argument_loop_index

* remove comments

* add doc strings

* add license

* fix documentation

* remove extra if

---------

Co-authored-by: Chris Richardson <chris@bpi.cam.ac.uk>

157 of 170 new or added lines in 9 files covered. (92.35%)

4 existing lines in 1 file now uncovered.

3676 of 4575 relevant lines covered (80.35%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.01
/ffcx/codegeneration/integral_generator.py
1
# Copyright (C) 2015-2023 Martin Sandve Alnæs, Michal Habera, Igor Baratta, Chris Richardson
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6

7
import collections
1✔
8
import logging
1✔
9
from typing import Any, Dict, List, Set, Tuple
1✔
10

11
import ufl
1✔
12
from ffcx.codegeneration import geometry
1✔
13
from ffcx.ir.elementtables import piecewise_ttypes
1✔
14
from ffcx.ir.integral import BlockDataT
1✔
15
import ffcx.codegeneration.lnodes as L
1✔
16
from ffcx.codegeneration.lnodes import LNode, BinOp
1✔
17
from ffcx.ir.representationutils import QuadratureRule
1✔
18
from ffcx.codegeneration.definitions import create_quadrature_index, create_nested_for_loops, create_dof_index
1✔
19

20
logger = logging.getLogger("ffcx")
1✔
21

22

23
class IntegralGenerator(object):
1✔
24
    def __init__(self, ir, backend):
1✔
25
        # Store ir
26
        self.ir = ir
1✔
27

28
        # Backend specific plugin with attributes
29
        # - symbols: for translating ufl operators to target language
30
        # - definitions: for defining backend specific variables
31
        # - access: for accessing backend specific variables
32
        self.backend = backend
1✔
33

34
        # Set of operator names code has been generated for, used in the
35
        # end for selecting necessary includes
36
        self._ufl_names = set()
1✔
37

38
        # Initialize lookup tables for variable scopes
39
        self.init_scopes()
1✔
40

41
        # Cache
42
        self.shared_symbols = {}
1✔
43

44
        # Set of counters used for assigning names to intermediate
45
        # variables
46
        self.symbol_counters = collections.defaultdict(int)
1✔
47

48
    def init_scopes(self):
1✔
49
        """Initialize variable scope dicts."""
50
        # Reset variables, separate sets for each quadrature rule
51
        self.scopes = {quadrature_rule: {} for quadrature_rule in self.ir.integrand.keys()}
1✔
52
        self.scopes[None] = {}
1✔
53

54
    def set_var(self, quadrature_rule, v, vaccess):
1✔
55
        """Set a new variable in variable scope dicts.
56

57
        Scope is determined by quadrature_rule which identifies the
58
        quadrature loop scope or None if outside quadrature loops.
59

60
        v is the ufl expression and vaccess is the LNodes
61
        expression to access the value in the code.
62

63
        """
64
        self.scopes[quadrature_rule][v] = vaccess
1✔
65

66
    def get_var(self, quadrature_rule, v):
1✔
67
        """Lookup ufl expression v in variable scope dicts.
68

69
        Scope is determined by quadrature rule which identifies the
70
        quadrature loop scope or None if outside quadrature loops.
71

72
        If v is not found in quadrature loop scope, the piecewise
73
        scope (None) is checked.
74

75
        Returns the LNodes expression to access the value in the code.
76
        """
77
        if v._ufl_is_literal_:
1✔
78
            return L.ufl_to_lnodes(v)
1✔
79
        f = self.scopes[quadrature_rule].get(v)
1✔
80
        if f is None:
1✔
81
            f = self.scopes[None].get(v)
1✔
82
        return f
1✔
83

84
    def new_temp_symbol(self, basename):
1✔
85
        """Create a new code symbol named basename + running counter."""
86
        name = "%s%d" % (basename, self.symbol_counters[basename])
1✔
87
        self.symbol_counters[basename] += 1
1✔
88
        return L.Symbol(name, dtype=L.DataType.SCALAR)
1✔
89

90
    def get_temp_symbol(self, tempname, key):
1✔
91
        key = (tempname,) + key
1✔
92
        s = self.shared_symbols.get(key)
1✔
93
        defined = s is not None
1✔
94
        if not defined:
1✔
95
            s = self.new_temp_symbol(tempname)
1✔
96
            self.shared_symbols[key] = s
1✔
97
        return s, defined
1✔
98

99
    def generate(self):
1✔
100
        """Generate entire tabulate_tensor body.
101

102
        Assumes that the code returned from here will be wrapped in a
103
        context that matches a suitable version of the UFC
104
        tabulate_tensor signatures.
105
        """
106
        # Assert that scopes are empty: expecting this to be called only
107
        # once
108
        assert not any(d for d in self.scopes.values())
1✔
109

110
        parts = []
1✔
111

112
        # Generate the tables of quadrature points and weights
113
        parts += self.generate_quadrature_tables()
1✔
114

115
        # Generate the tables of basis function values and
116
        # pre-integrated blocks
117
        parts += self.generate_element_tables()
1✔
118

119
        # Generate the tables of geometry data that are needed
120
        parts += self.generate_geometry_tables()
1✔
121

122
        # Loop generation code will produce parts to go before
123
        # quadloops, to define the quadloops, and to go after the
124
        # quadloops
125
        all_preparts = []
1✔
126
        all_quadparts = []
1✔
127

128
        # Pre-definitions are collected across all quadrature loops to
129
        # improve re-use and avoid name clashes
130
        all_predefinitions = dict()
1✔
131
        for rule in self.ir.integrand.keys():
1✔
132
            # Generate code to compute piecewise constant scalar factors
133
            all_preparts += self.generate_piecewise_partition(rule)
1✔
134

135
            # Generate code to integrate reusable blocks of final
136
            # element tensor
137
            pre_definitions, preparts, quadparts = self.generate_quadrature_loop(rule)
1✔
138
            all_preparts += preparts
1✔
139
            all_quadparts += quadparts
1✔
140
            all_predefinitions.update(pre_definitions)
1✔
141

142
        parts += L.commented_code_list(self.fuse_loops(all_predefinitions),
1✔
143
                                       "Pre-definitions of modified terminals to enable unit-stride access")
144

145
        # Collect parts before, during, and after quadrature loops
146
        parts += all_preparts
1✔
147
        parts += all_quadparts
1✔
148

149
        return L.StatementList(parts)
1✔
150

151
    def generate_quadrature_tables(self):
1✔
152
        """Generate static tables of quadrature points and weights."""
153
        parts = []
1✔
154

155
        # No quadrature tables for custom (given argument) or point
156
        # (evaluation in single vertex)
157
        skip = ufl.custom_integral_types + ufl.measure.point_integral_types
1✔
158
        if self.ir.integral_type in skip:
1✔
159
            return parts
×
160

161
        # Loop over quadrature rules
162
        for quadrature_rule, integrand in self.ir.integrand.items():
1✔
163
            # Generate quadrature weights array
164
            wsym = self.backend.symbols.weights_table(quadrature_rule)
1✔
165
            parts += [L.ArrayDecl(wsym, values=quadrature_rule.weights, const=True)]
1✔
166

167
        # Add leading comment if there are any tables
168
        parts = L.commented_code_list(parts, "Quadrature rules")
1✔
169
        return parts
1✔
170

171
    def generate_geometry_tables(self):
1✔
172
        """Generate static tables of geometry data."""
173
        ufl_geometry = {
1✔
174
            ufl.geometry.FacetEdgeVectors: "facet_edge_vertices",
175
            ufl.geometry.CellFacetJacobian: "reference_facet_jacobian",
176
            ufl.geometry.ReferenceCellVolume: "reference_cell_volume",
177
            ufl.geometry.ReferenceFacetVolume: "reference_facet_volume",
178
            ufl.geometry.ReferenceCellEdgeVectors: "reference_edge_vectors",
179
            ufl.geometry.ReferenceFacetEdgeVectors: "facet_reference_edge_vectors",
180
            ufl.geometry.ReferenceNormal: "reference_facet_normals",
181
            ufl.geometry.FacetOrientation: "facet_orientation"
182
        }
183
        cells: Dict[Any, Set[Any]] = {t: set() for t in ufl_geometry.keys()}
1✔
184

185
        for integrand in self.ir.integrand.values():
1✔
186
            for attr in integrand["factorization"].nodes.values():
1✔
187
                mt = attr.get("mt")
1✔
188
                if mt is not None:
1✔
189
                    t = type(mt.terminal)
1✔
190
                    if t in ufl_geometry:
1✔
191
                        cells[t].add(ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname())
1✔
192

193
        parts = []
1✔
194
        for i, cell_list in cells.items():
1✔
195
            for c in cell_list:
1✔
196
                parts.append(geometry.write_table(ufl_geometry[i], c))
1✔
197

198
        return parts
1✔
199

200
    def generate_element_tables(self):
1✔
201
        """Generate static tables with precomputed element basisfunction values in quadrature points."""
202
        parts = []
1✔
203
        tables = self.ir.unique_tables
1✔
204
        table_types = self.ir.unique_table_types
1✔
205
        if self.ir.integral_type in ufl.custom_integral_types:
1✔
206
            # Define only piecewise tables
207
            table_names = [name for name in sorted(tables) if table_types[name] in piecewise_ttypes]
×
208
        else:
209
            # Define all tables
210
            table_names = sorted(tables)
1✔
211

212
        for name in table_names:
1✔
213
            table = tables[name]
1✔
214
            parts += self.declare_table(name, table)
1✔
215

216
        # Add leading comment if there are any tables
217
        parts = L.commented_code_list(parts, [
1✔
218
            "Precomputed values of basis functions and precomputations",
219
            "FE* dimensions: [permutation][entities][points][dofs]"])
220
        return parts
1✔
221

222
    def declare_table(self, name, table):
1✔
223
        """Declare a table.
224

225
        If the dof dimensions of the table have dof rotations, apply
226
        these rotations.
227

228
        """
229
        table_symbol = L.Symbol(name, dtype=L.DataType.REAL)
1✔
230
        self.backend.symbols.element_tables[name] = table_symbol
1✔
231
        return [L.ArrayDecl(table_symbol, values=table, const=True)]
1✔
232

233
    def generate_quadrature_loop(self, quadrature_rule: QuadratureRule):
1✔
234
        """Generate quadrature loop with for this quadrature_rule."""
235
        # Generate varying partition
236
        pre_definitions, body = self.generate_varying_partition(quadrature_rule)
1✔
237

238
        body = L.commented_code_list(body, f"Quadrature loop body setup for quadrature rule {quadrature_rule.id()}")
1✔
239

240
        # Generate dofblock parts, some of this will be placed before or
241
        # after quadloop
242
        preparts, quadparts = self.generate_dofblock_partition(quadrature_rule)
1✔
243
        body += quadparts
1✔
244

245
        # Wrap body in loop or scope
246
        if not body:
1✔
247
            # Could happen for integral with everything zero and
248
            # optimized away
249
            quadparts = []
×
250
        else:
251
            iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
252
            iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
253
            quadparts = [create_nested_for_loops([iq], body)]
1✔
254

255
        return pre_definitions, preparts, quadparts
1✔
256

257
    def generate_piecewise_partition(self, quadrature_rule):
1✔
258
        # Get annotated graph of factorisation
259
        F = self.ir.integrand[quadrature_rule]["factorization"]
1✔
260

261
        arraysymbol = L.Symbol(f"sp_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
262
        pre_definitions, parts = self.generate_partition(arraysymbol, F, "piecewise", None)
1✔
263
        assert len(pre_definitions) == 0, "Quadrature independent code should have not pre-definitions"
1✔
264
        parts = L.commented_code_list(
1✔
265
            parts, f"Quadrature loop independent computations for quadrature rule {quadrature_rule.id()}")
266

267
        return parts
1✔
268

269
    def generate_varying_partition(self, quadrature_rule):
1✔
270

271
        # Get annotated graph of factorisation
272
        F = self.ir.integrand[quadrature_rule]["factorization"]
1✔
273

274
        arraysymbol = L.Symbol(f"sv_{quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
275
        pre_definitions, parts = self.generate_partition(arraysymbol, F, "varying", quadrature_rule)
1✔
276
        parts = L.commented_code_list(parts, f"Varying computations for quadrature rule {quadrature_rule.id()}")
1✔
277

278
        return pre_definitions, parts
1✔
279

280
    def generate_partition(self, symbol, F, mode, quadrature_rule):
1✔
281

282
        definitions = dict()
1✔
283
        pre_definitions = dict()
1✔
284
        intermediates = []
1✔
285

286
        use_symbol_array = True
1✔
287

288
        for i, attr in F.nodes.items():
1✔
289
            if attr['status'] != mode:
1✔
290
                continue
1✔
291
            v = attr['expression']
1✔
292
            mt = attr.get('mt')
1✔
293

294
            # Generate code only if the expression is not already in
295
            # cache
296
            if not self.get_var(quadrature_rule, v):
1✔
297
                if v._ufl_is_literal_:
1✔
298
                    vaccess = L.ufl_to_lnodes(v)
×
299
                elif mt is not None:
1✔
300
                    # All finite element based terminals have table
301
                    # data, as well as some, but not all, of the
302
                    # symbolic geometric terminals
303
                    tabledata = attr.get('tr')
1✔
304

305
                    # Backend specific modified terminal translation
306
                    vaccess = self.backend.access.get(mt.terminal, mt, tabledata, quadrature_rule)
1✔
307
                    predef, vdef = self.backend.definitions.get(mt.terminal, mt, tabledata, quadrature_rule, vaccess)
1✔
308
                    if predef:
1✔
309
                        access = predef[0].symbol.name
×
310
                        pre_definitions[str(access)] = predef
×
311

312
                    # Store definitions of terminals in list
313
                    assert isinstance(vdef, list)
1✔
314
                    definitions[str(vaccess)] = vdef
1✔
315
                else:
316
                    # Get previously visited operands
317
                    vops = [self.get_var(quadrature_rule, op) for op in v.ufl_operands]
1✔
318

319
                    # get parent operand
320
                    pid = F.in_edges[i][0] if F.in_edges[i] else -1
1✔
321
                    if pid and pid > i:
1✔
322
                        parent_exp = F.nodes.get(pid)['expression']
1✔
323
                    else:
324
                        parent_exp = None
1✔
325

326
                    # Mapping UFL operator to target language
327
                    self._ufl_names.add(v._ufl_handler_name_)
1✔
328
                    vexpr = L.ufl_to_lnodes(v, *vops)
1✔
329

330
                    # Create a new intermediate for each subexpression
331
                    # except boolean conditions and its childs
332
                    if isinstance(parent_exp, ufl.classes.Condition):
1✔
333
                        # Skip intermediates for 'x' and 'y' in x<y
334
                        # Avoid the creation of complex valued intermediates
335
                        vaccess = vexpr
1✔
336
                    elif isinstance(v, ufl.classes.Condition):
1✔
337
                        # Inline the conditions x < y, condition values
338
                        # This removes the need to handle boolean
339
                        # intermediate variables. With tensor-valued
340
                        # conditionals it may not be optimal but we let
341
                        # the compiler take responsibility for
342
                        # optimizing those cases.
343
                        vaccess = vexpr
1✔
344
                    elif any(op._ufl_is_literal_ for op in v.ufl_operands):
1✔
345
                        # Skip intermediates for e.g. -2.0*x,
346
                        # resulting in lines like z = y + -2.0*x
347
                        vaccess = vexpr
1✔
348
                    else:
349
                        # Record assignment of vexpr to intermediate variable
350
                        j = len(intermediates)
1✔
351
                        if use_symbol_array:
1✔
352
                            vaccess = symbol[j]
1✔
353
                            intermediates.append(L.Assign(vaccess, vexpr))
1✔
354
                        else:
355
                            vaccess = L.Symbol("%s_%d" % (symbol.name, j))
×
356
                            intermediates.append(L.VariableDecl(vaccess, vexpr))
×
357

358
                # Store access node for future reference
359
                self.set_var(quadrature_rule, v, vaccess)
1✔
360

361
        # Join terminal computation, array of intermediate expressions,
362
        # and intermediate computations
363
        parts = []
1✔
364
        parts += self.fuse_loops(definitions)
1✔
365

366
        if intermediates:
1✔
367
            if use_symbol_array:
1✔
368
                parts += [L.ArrayDecl(symbol, sizes=len(intermediates))]
1✔
369
            parts += intermediates
1✔
370
        return pre_definitions, parts
1✔
371

372
    def generate_dofblock_partition(self, quadrature_rule: QuadratureRule):
1✔
373
        block_contributions = self.ir.integrand[quadrature_rule]["block_contributions"]
1✔
374
        preparts = []
1✔
375
        quadparts = []
1✔
376
        blocks = [(blockmap, blockdata)
1✔
377
                  for blockmap, contributions in sorted(block_contributions.items())
378
                  for blockdata in contributions]
379

380
        block_groups = collections.defaultdict(list)
1✔
381

382
        # Group loops by blockmap, in Vector elements each component has
383
        # a different blockmap
384
        for blockmap, blockdata in blocks:
1✔
385
            scalar_blockmap = []
1✔
386
            assert len(blockdata.ma_data) == len(blockmap)
1✔
387
            for i, b in enumerate(blockmap):
1✔
388
                bs = blockdata.ma_data[i].tabledata.block_size
1✔
389
                offset = blockdata.ma_data[i].tabledata.offset
1✔
390
                b = tuple([(idx - offset) // bs for idx in b])
1✔
391
                scalar_blockmap.append(b)
1✔
392
            block_groups[tuple(scalar_blockmap)].append(blockdata)
1✔
393

394
        for blockmap in block_groups:
1✔
395
            block_preparts, block_quadparts = self.generate_block_parts(
1✔
396
                quadrature_rule, blockmap, block_groups[blockmap])
397

398
            # Add definitions
399
            preparts.extend(block_preparts)
1✔
400

401
            # Add computations
402
            quadparts.extend(block_quadparts)
1✔
403

404
        return preparts, quadparts
1✔
405

406
    def get_arg_factors(self, blockdata, block_rank, quadrature_rule, iq, indices):
1✔
407
        arg_factors = []
1✔
408
        for i in range(block_rank):
1✔
409
            mad = blockdata.ma_data[i]
1✔
410
            td = mad.tabledata
1✔
411
            scope = self.ir.integrand[quadrature_rule]["modified_arguments"]
1✔
412
            mt = scope[mad.ma_index]
1✔
413

414
            # Translate modified terminal to code
415
            # TODO: Move element table access out of backend?
416
            #       Not using self.backend.access.argument() here
417
            #       now because it assumes too much about indices.
418

419
            assert td.ttype != "zeros"
1✔
420

421
            if td.ttype == "ones":
1✔
422
                arg_factor = 1
×
423
            else:
424
                # Assuming B sparsity follows element table sparsity
425
                arg_factor = self.backend.symbols.table_access(
1✔
426
                    td, self.ir.entitytype, mt.restriction, iq, indices[i])
427
            arg_factors.append(arg_factor)
1✔
428
        return arg_factors
1✔
429

430
    def generate_block_parts(self, quadrature_rule: QuadratureRule, blockmap: Tuple, blocklist: List[BlockDataT]):
1✔
431
        """Generate and return code parts for a given block.
432

433
        Returns parts occurring before, inside, and after the quadrature
434
        loop identified by the quadrature rule.
435

436
        Should be called with quadrature_rule=None for
437
        quadloop-independent blocks.
438
        """
439
        # The parts to return
440
        preparts: List[LNode] = []
1✔
441
        quadparts: List[LNode] = []
1✔
442

443
        # RHS expressions grouped by LHS "dofmap"
444
        rhs_expressions = collections.defaultdict(list)
1✔
445

446
        block_rank = len(blockmap)
1✔
447
        blockdims = tuple(len(dofmap) for dofmap in blockmap)
1✔
448

449
        iq_symbol = self.backend.symbols.quadrature_loop_index
1✔
450
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
451

452
        for blockdata in blocklist:
1✔
453

454
            B_indices = []
1✔
455
            for i in range(block_rank):
1✔
456
                table_ref = blockdata.ma_data[i].tabledata
1✔
457
                symbol = self.backend.symbols.argument_loop_index(i)
1✔
458
                index = create_dof_index(table_ref, symbol)
1✔
459
                B_indices.append(index)
1✔
460

461
            ttypes = blockdata.ttypes
1✔
462
            if "zeros" in ttypes:
1✔
463
                raise RuntimeError("Not expecting zero arguments to be left in dofblock generation.")
×
464

465
            if len(blockdata.factor_indices_comp_indices) > 1:
1✔
466
                raise RuntimeError("Code generation for non-scalar integrals unsupported")
×
467

468
            # We have scalar integrand here, take just the factor index
469
            factor_index = blockdata.factor_indices_comp_indices[0][0]
1✔
470

471
            # Get factor expression
472
            F = self.ir.integrand[quadrature_rule]["factorization"]
1✔
473

474
            v = F.nodes[factor_index]['expression']
1✔
475
            f = self.get_var(quadrature_rule, v)
1✔
476

477
            # Quadrature weight was removed in representation, add it back now
478
            if self.ir.integral_type in ufl.custom_integral_types:
1✔
479
                weights = self.backend.symbols.custom_weights_table
×
NEW
480
                weight = weights[iq.global_index]
×
481
            else:
482
                weights = self.backend.symbols.weights_table(quadrature_rule)
1✔
483
                weight = weights[iq.global_index]
1✔
484

485
            # Define fw = f * weight
486
            fw_rhs = L.float_product([f, weight])
1✔
487
            if not isinstance(fw_rhs, L.Product):
1✔
488
                fw = fw_rhs
×
489
            else:
490
                # Define and cache scalar temp variable
491
                key = (quadrature_rule, factor_index, blockdata.all_factors_piecewise)
1✔
492
                fw, defined = self.get_temp_symbol("fw", key)
1✔
493
                if not defined:
1✔
494
                    quadparts.append(L.VariableDecl(fw, fw_rhs))
1✔
495

496
            assert not blockdata.transposed, "Not handled yet"
1✔
497

498
            # Fetch code to access modified arguments
499
            arg_factors = self.get_arg_factors(blockdata, block_rank, quadrature_rule, iq, B_indices)
1✔
500

501
            # Define B_rhs = fw * arg_factors
502
            B_rhs = L.float_product([fw] + arg_factors)
1✔
503

504
            A_indices = []
1✔
505
            for i in range(block_rank):
1✔
506
                index = B_indices[i]
1✔
507
                tabledata = blockdata.ma_data[i].tabledata
1✔
508
                offset = tabledata.offset
1✔
509
                if len(blockmap[i]) == 1:
1✔
NEW
510
                    A_indices.append(index.global_index + offset)
×
511
                else:
512
                    block_size = blockdata.ma_data[i].tabledata.block_size
1✔
513
                    A_indices.append(block_size * index.global_index + offset)
1✔
514
            rhs_expressions[tuple(A_indices)].append(B_rhs)
1✔
515

516
        # List of statements to keep in the inner loop
517
        keep = collections.defaultdict(list)
1✔
518
        # List of temporary array declarations
519
        pre_loop: List[LNode] = []
1✔
520
        # List of loop invariant expressions to hoist
521
        hoist: List[BinOp] = []
1✔
522

523
        for indices in rhs_expressions:
1✔
524
            hoist_rhs = collections.defaultdict(list)
1✔
525
            # Hoist loop invariant code and group array access (each
526
            # table should only be read one time in the inner loop)
527
            if block_rank == 2:
1✔
528
                ind = B_indices[-1]
1✔
529
                for rhs in rhs_expressions[indices]:
1✔
530
                    if len(rhs.args) <= 2:
1✔
531
                        keep[indices].append(rhs)
×
532
                    else:
533
                        varying = next((x for x in rhs.args if hasattr(x, 'indices')
1✔
534
                                       and (ind.global_index in x.indices)), None)
535
                        if varying:
1✔
536
                            invariant = [x for x in rhs.args if x is not varying]
1✔
537
                            hoist_rhs[varying].append(invariant)
1✔
538
                        else:
539
                            keep[indices].append(rhs)
1✔
540

541
                # Perform algebraic manipulations to reduce number of
542
                # floating point operations (factorize expressions by
543
                # grouping)
544
                for statement in hoist_rhs:
1✔
545
                    sum = L.Sum([L.float_product(rhs) for rhs in hoist_rhs[statement]])
1✔
546

547
                    lhs = None
1✔
548
                    for h in hoist:
1✔
549
                        if h.rhs == sum:
1✔
550
                            lhs = h.lhs
×
551
                            break
×
552
                    if lhs:
1✔
553
                        keep[indices].append(L.float_product([statement, lhs]))
×
554
                    else:
555
                        t = self.new_temp_symbol("t")
1✔
556
                        pre_loop.append(L.ArrayDecl(t, sizes=blockdims[0]))
1✔
557
                        keep[indices].append(L.float_product([statement, t[B_indices[0].global_index]]))
1✔
558
                        hoist.append(L.Assign(t[B_indices[i - 1].global_index], sum))
1✔
559
            else:
560
                keep[indices] = rhs_expressions[indices]
1✔
561

562
        hoist_code: List[LNode] = [L.ForRange(B_indices[0], 0, blockdims[0], body=hoist)] if hoist else []
1✔
563

564
        body: List[LNode] = []
1✔
565

566
        A = self.backend.symbols.element_tensor
1✔
567
        A_shape = self.ir.tensor_shape
1✔
568
        for indices in keep:
1✔
569
            multi_index = L.MultiIndex(list(indices), A_shape)
1✔
570
            body.append(L.AssignAdd(A[multi_index], L.Sum(keep[indices])))
1✔
571

572
        for i in reversed(range(block_rank)):
1✔
573
            body = [create_nested_for_loops([B_indices[i]], body)]
1✔
574

575
        quadparts += pre_loop
1✔
576
        quadparts += hoist_code
1✔
577
        quadparts += body
1✔
578

579
        return preparts, quadparts
1✔
580

581
    def fuse_loops(self, definitions):
1✔
582
        """Merge a sequence of loops with the same iteration space into a single loop.
583

584
        Loop fusion improves data locality, cache reuse and decreases
585
        the loop control overhead.
586

587
        NOTE: Loop fusion might increase the pressure on register
588
        allocation. Ideally, we should define a cost function to
589
        determine how many loops should fuse at a time.
590

591
        """
592
        loops = collections.defaultdict(list)
1✔
593
        pre_loop = []
1✔
594
        for access, definition in definitions.items():
1✔
595
            for d in definition:
1✔
596
                if isinstance(d, L.ForRange):
1✔
597
                    loops[(d.index, d.begin, d.end)] += [d.body]
1✔
598
                else:
599
                    pre_loop += [d]
1✔
600
        fused = []
1✔
601

602
        for info, body in loops.items():
1✔
603
            index, begin, end = info
1✔
604
            fused += [L.ForRange(index, begin, end, body)]
1✔
605

606
        code = []
1✔
607
        code += pre_loop
1✔
608
        code += fused
1✔
609
        return code
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc