• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 11642291501

02 Nov 2024 11:16AM UTC coverage: 81.168% (+0.5%) from 80.657%
11642291501

push

github

web-flow
Upload to coveralls and docs from CI job running against python 3.12 (#726)

3474 of 4280 relevant lines covered (81.17%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.81
/ffcx/codegeneration/expression_generator.py
1
# Copyright (C) 2019 Michal Habera
2
#
3
# This file is part of FFCx.(https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Expression generator."""
7

8
import collections
1✔
9
import logging
1✔
10
from itertools import product
1✔
11
from typing import Any
1✔
12

13
import ufl
1✔
14

15
import ffcx.codegeneration.lnodes as L
1✔
16
from ffcx.codegeneration import geometry
1✔
17
from ffcx.codegeneration.backend import FFCXBackend
1✔
18
from ffcx.codegeneration.lnodes import LNode
1✔
19
from ffcx.ir.representation import ExpressionIR
1✔
20

21
logger = logging.getLogger("ffcx")
1✔
22

23

24
class ExpressionGenerator:
1✔
25
    """Expression generator."""
26

27
    def __init__(self, ir: ExpressionIR, backend: FFCXBackend):
1✔
28
        """Initialise."""
29
        if len(list(ir.expression.integrand.keys())) != 1:
1✔
30
            raise RuntimeError("Only one set of points allowed for expression evaluation")
×
31

32
        self.ir = ir
1✔
33
        self.backend = backend
1✔
34
        self.scope: dict[Any, LNode] = {}
1✔
35
        self._ufl_names: set[Any] = set()
1✔
36
        self.symbol_counters: collections.defaultdict[Any, int] = collections.defaultdict(int)
1✔
37
        self.shared_symbols: dict[Any, Any] = {}
1✔
38
        self.quadrature_rule = next(iter(self.ir.expression.integrand.keys()))
1✔
39

40
    def generate(self):
1✔
41
        """Generate."""
42
        parts = []
1✔
43
        parts += self.generate_element_tables()
1✔
44

45
        # Generate the tables of geometry data that are needed
46
        parts += self.generate_geometry_tables()
1✔
47
        parts += self.generate_piecewise_partition()
1✔
48

49
        all_preparts = []
1✔
50
        all_quadparts = []
1✔
51

52
        preparts, quadparts = self.generate_quadrature_loop()
1✔
53
        all_preparts += preparts
1✔
54
        all_quadparts += quadparts
1✔
55

56
        # Collect parts before, during, and after quadrature loops
57
        parts += all_preparts
1✔
58
        parts += all_quadparts
1✔
59

60
        return L.StatementList(parts)
1✔
61

62
    def generate_geometry_tables(self):
1✔
63
        """Generate static tables of geometry data."""
64
        # Currently we only support circumradius
65
        ufl_geometry = {
1✔
66
            ufl.geometry.ReferenceCellVolume: "reference_cell_volume",
67
            ufl.geometry.ReferenceNormal: "reference_facet_normals",
68
        }
69

70
        cells: dict[Any, set[Any]] = {t: set() for t in ufl_geometry.keys()}  # type: ignore
1✔
71
        for integrand in self.ir.expression.integrand.values():
1✔
72
            for attr in integrand["factorization"].nodes.values():
1✔
73
                mt = attr.get("mt")
1✔
74
                if mt is not None:
1✔
75
                    t = type(mt.terminal)
1✔
76
                    if self.ir.expression.entity_type == "cell" and issubclass(
1✔
77
                        t, ufl.geometry.GeometricFacetQuantity
78
                    ):
79
                        raise RuntimeError(f"Expressions for cells do not support {t}.")
×
80
                    if t in ufl_geometry:
1✔
81
                        cells[t].add(
1✔
82
                            ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
83
                        )
84

85
        parts = []
1✔
86
        for i, cell_list in cells.items():
1✔
87
            for c in cell_list:
1✔
88
                parts.append(geometry.write_table(ufl_geometry[i], c))
1✔
89

90
        return parts
1✔
91

92
    def generate_element_tables(self):
1✔
93
        """Generate tables of FE basis evaluated at specified points."""
94
        parts = []
1✔
95

96
        tables = self.ir.expression.unique_tables
1✔
97
        table_names = sorted(tables)
1✔
98

99
        for name in table_names:
1✔
100
            table = tables[name]
1✔
101
            symbol = L.Symbol(name, dtype=L.DataType.REAL)
1✔
102
            self.backend.symbols.element_tables[name] = symbol
1✔
103
            decl = L.ArrayDecl(symbol, sizes=table.shape, values=table, const=True)
1✔
104
            parts += [decl]
1✔
105

106
        # Add leading comment if there are any tables
107
        parts = L.commented_code_list(
1✔
108
            parts,
109
            [
110
                "Precomputed values of basis functions",
111
                "FE* dimensions: [entities][points][dofs]",
112
            ],
113
        )
114
        return parts
1✔
115

116
    def generate_quadrature_loop(self):
1✔
117
        """Generate quadrature loop for this quadrature rule.
118

119
        In the context of expressions quadrature loop is not accumulated.
120
        """
121
        # Generate varying partition
122
        body = self.generate_varying_partition()
1✔
123
        body = L.commented_code_list(
1✔
124
            body, f"Points loop body setup quadrature loop {self.quadrature_rule.id()}"
125
        )
126

127
        # Generate dofblock parts, some of this
128
        # will be placed before or after quadloop
129
        preparts, quadparts = self.generate_dofblock_partition()
1✔
130
        body += quadparts
1✔
131

132
        # Wrap body in loop or scope
133
        if not body:
1✔
134
            # Could happen for integral with everything zero and optimized away
135
            quadparts = []
×
136
        else:
137
            iq = self.backend.symbols.quadrature_loop_index
1✔
138
            num_points = self.quadrature_rule.points.shape[0]
1✔
139
            quadparts = [L.ForRange(iq, 0, num_points, body=body)]
1✔
140
        return preparts, quadparts
1✔
141

142
    def generate_varying_partition(self):
1✔
143
        """Generate factors of blocks which are not cellwise constant."""
144
        # Get annotated graph of factorisation
145
        F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]
1✔
146

147
        arraysymbol = L.Symbol(f"sv_{self.quadrature_rule.id()}", dtype=L.DataType.SCALAR)
1✔
148
        parts = self.generate_partition(arraysymbol, F, "varying")
1✔
149
        parts = L.commented_code_list(
1✔
150
            parts,
151
            f"Unstructured varying computations for quadrature rule {self.quadrature_rule.id()}",
152
        )
153
        return parts
1✔
154

155
    def generate_piecewise_partition(self):
1✔
156
        """Generate factors of blocks which are constant.
157

158
        I.e. do not depend on quadrature points).
159
        """
160
        # Get annotated graph of factorisation
161
        F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]
1✔
162

163
        arraysymbol = L.Symbol("sp", dtype=L.DataType.SCALAR)
1✔
164
        parts = self.generate_partition(arraysymbol, F, "piecewise")
1✔
165
        parts = L.commented_code_list(parts, "Unstructured piecewise computations")
1✔
166
        return parts
1✔
167

168
    def generate_dofblock_partition(self):
1✔
169
        """Generate assignments of blocks multiplied with their factors into final tensor A."""
170
        block_contributions = self.ir.expression.integrand[self.quadrature_rule][
1✔
171
            "block_contributions"
172
        ]
173

174
        preparts = []
1✔
175
        quadparts = []
1✔
176

177
        blocks = [
1✔
178
            (blockmap, blockdata)
179
            for blockmap, contributions in sorted(block_contributions.items())
180
            for blockdata in contributions
181
        ]
182

183
        for blockmap, blockdata in blocks:
1✔
184
            # Define code for block depending on mode
185
            block_preparts, block_quadparts = self.generate_block_parts(blockmap, blockdata)
1✔
186

187
            # Add definitions
188
            preparts.extend(block_preparts)
1✔
189

190
            # Add computations
191
            quadparts.extend(block_quadparts)
1✔
192

193
        return preparts, quadparts
1✔
194

195
    def generate_block_parts(self, blockmap, blockdata):
1✔
196
        """Generate and return code parts for a given block."""
197
        # The parts to return
198
        preparts = []
1✔
199
        quadparts = []
1✔
200

201
        block_rank = len(blockmap)
1✔
202
        blockdims = tuple(len(dofmap) for dofmap in blockmap)
1✔
203

204
        ttypes = blockdata.ttypes
1✔
205
        if "zeros" in ttypes:
1✔
206
            raise RuntimeError("Not expecting zero arguments to be left in dofblock generation.")
×
207

208
        arg_indices = tuple(self.backend.symbols.argument_loop_index(i) for i in range(block_rank))
1✔
209

210
        F = self.ir.expression.integrand[self.quadrature_rule]["factorization"]
1✔
211

212
        assert not blockdata.transposed, "Not handled yet"
1✔
213
        components = ufl.product(self.ir.expression.shape)
1✔
214

215
        num_points = self.quadrature_rule.points.shape[0]
1✔
216
        A_shape = [num_points, components] + self.ir.expression.tensor_shape
1✔
217
        A = self.backend.symbols.element_tensor
1✔
218
        iq = self.backend.symbols.quadrature_loop_index
1✔
219

220
        # Check if DOFs in dofrange are equally spaced.
221
        expand_loop = False
1✔
222
        for bm in blockmap:
1✔
223
            for a, b in zip(bm[1:-1], bm[2:]):
1✔
224
                if b - a != bm[1] - bm[0]:
1✔
225
                    expand_loop = True
×
226
                    break
×
227
            else:
228
                continue
1✔
229
            break
×
230

231
        if expand_loop:
1✔
232
            # If DOFs in dofrange are not equally spaced, then expand out the for loop
233
            for A_indices, B_indices in zip(
×
234
                product(*blockmap), product(*[range(len(b)) for b in blockmap])
235
            ):
236
                B_indices = tuple([iq] + list(B_indices))
×
237
                A_indices = tuple([iq] + A_indices)
×
238
                for fi_ci in blockdata.factor_indices_comp_indices:
×
239
                    f = self.get_var(F.nodes[fi_ci[0]]["expression"])
×
240
                    arg_factors = self.get_arg_factors(blockdata, block_rank, B_indices)
×
241
                    Brhs = L.float_product([f] + arg_factors)
×
242
                    multi_index = L.MultiIndex([A_indices[0], fi_ci[1]] + A_indices[1:], A_shape)
×
243
                    quadparts.append(L.AssignAdd(A[multi_index], Brhs))
×
244
        else:
245
            # Prepend dimensions of dofmap block with free index
246
            # for quadrature points and expression components
247
            B_indices = tuple([iq] + list(arg_indices))
1✔
248

249
            # Fetch code to access modified arguments
250
            # An access to FE table data
251
            arg_factors = self.get_arg_factors(blockdata, block_rank, B_indices)
1✔
252

253
            # TODO: handle non-contiguous dof ranges
254

255
            A_indices = []
1✔
256
            for bm, index in zip(blockmap, arg_indices):
1✔
257
                # TODO: switch order here? (optionally)
258
                offset = bm[0]
1✔
259
                if len(bm) == 1:
1✔
260
                    A_indices.append(index + offset)
×
261
                else:
262
                    block_size = bm[1] - bm[0]
1✔
263
                    A_indices.append(block_size * index + offset)
1✔
264
            A_indices = tuple([iq] + A_indices)
1✔
265

266
            # Multiply collected factors
267
            # For each component of the factor expression
268
            # add result inside quadloop
269
            body = []
1✔
270

271
            for fi_ci in blockdata.factor_indices_comp_indices:
1✔
272
                f = self.get_var(F.nodes[fi_ci[0]]["expression"])
1✔
273
                Brhs = L.float_product([f] + arg_factors)
1✔
274
                indices = [A_indices[0], fi_ci[1]] + list(A_indices[1:])
1✔
275
                multi_index = L.MultiIndex(indices, A_shape)
1✔
276
                body.append(L.AssignAdd(A[multi_index], Brhs))
1✔
277

278
            for i in reversed(range(block_rank)):
1✔
279
                body = L.ForRange(B_indices[i + 1], 0, blockdims[i], body=body)
1✔
280
            quadparts += [body]
1✔
281

282
        return preparts, quadparts
1✔
283

284
    def get_arg_factors(self, blockdata, block_rank, indices):
1✔
285
        """Get argument factors (i.e. blocks).
286

287
        Args:
288
            blockdata: block data
289
            block_rank: block rank
290
            indices: Indices used to index element tables
291
        """
292
        arg_factors = []
1✔
293
        for i in range(block_rank):
1✔
294
            mad = blockdata.ma_data[i]
1✔
295
            td = mad.tabledata
1✔
296
            mt = self.ir.expression.integrand[self.quadrature_rule]["modified_arguments"][
1✔
297
                mad.ma_index
298
            ]
299

300
            table = self.backend.symbols.element_table(
1✔
301
                td, self.ir.expression.entity_type, mt.restriction
302
            )
303

304
            assert td.ttype != "zeros"
1✔
305

306
            if td.ttype == "ones":
1✔
307
                arg_factor = L.LiteralFloat(1.0)
×
308
            else:
309
                arg_factor = table[indices[i + 1]]
1✔
310
            arg_factors.append(arg_factor)
1✔
311
        return arg_factors
1✔
312

313
    def new_temp_symbol(self, basename):
1✔
314
        """Create a new code symbol named basename + running counter."""
315
        name = "%s%d" % (basename, self.symbol_counters[basename])
×
316
        self.symbol_counters[basename] += 1
×
317
        return L.Symbol(name, dtype=L.DataType.SCALAR)
×
318

319
    def get_var(self, v):
1✔
320
        """Get a variable."""
321
        if v._ufl_is_literal_:
1✔
322
            return L.ufl_to_lnodes(v)
1✔
323
        f = self.scope.get(v)
1✔
324
        return f
1✔
325

326
    def generate_partition(self, symbol, F, mode):
1✔
327
        """Generate computations of factors of blocks."""
328
        definitions = []
1✔
329
        intermediates = []
1✔
330

331
        for _, attr in F.nodes.items():
1✔
332
            if attr["status"] != mode:
1✔
333
                continue
1✔
334
            v = attr["expression"]
1✔
335
            mt = attr.get("mt")
1✔
336

337
            if v._ufl_is_literal_:
1✔
338
                vaccess = L.ufl_to_lnodes(v)
1✔
339
            elif mt is not None:
1✔
340
                # All finite element based terminals have table data, as well
341
                # as some, but not all, of the symbolic geometric terminals
342
                tabledata = attr.get("tr")
1✔
343

344
                # Backend specific modified terminal translation
345
                vaccess = self.backend.access.get(mt, tabledata, 0)
1✔
346
                vdef = self.backend.definitions.get(mt, tabledata, 0, vaccess)
1✔
347

348
                if vdef:
1✔
349
                    assert isinstance(vdef, L.Section)
1✔
350
                    vdef = vdef.declarations + vdef.statements
1✔
351

352
                # Store definitions of terminals in list
353
                assert isinstance(vdef, list)
1✔
354
                definitions.extend(vdef)
1✔
355
            else:
356
                # Get previously visited operands
357
                vops = [self.get_var(op) for op in v.ufl_operands]
1✔
358

359
                # Mapping UFL operator to target language
360
                self._ufl_names.add(v._ufl_handler_name_)
1✔
361
                vexpr = L.ufl_to_lnodes(v, *vops)
1✔
362

363
                is_cond = isinstance(v, ufl.classes.Condition)
1✔
364
                dtype = L.DataType.BOOL if is_cond else L.DataType.SCALAR
1✔
365

366
                j = len(intermediates)
1✔
367
                vaccess = L.Symbol(f"{symbol.name}_{j}", dtype=dtype)
1✔
368
                intermediates.append(L.VariableDecl(vaccess, vexpr))
1✔
369

370
            # Store access node for future reference
371
            self.scope[v] = vaccess
1✔
372

373
        # Join terminal computation, array of intermediate expressions,
374
        # and intermediate computations
375
        parts = []
1✔
376

377
        parts += definitions
1✔
378
        parts += intermediates
1✔
379

380
        return parts
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc