• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 14179953154

31 Mar 2025 07:38PM UTC coverage: 82.568% (+0.5%) from 82.058%
14179953154

Pull #731

github

jorgensd
Ruff + mypy
Pull Request #731: Codim 2 coupling

79 of 103 new or added lines in 10 files covered. (76.7%)

3 existing lines in 1 file now uncovered.

3633 of 4400 relevant lines covered (82.57%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

43.89
/ffcx/codegeneration/access.py
1
# Copyright (C) 2011-2017 Martin Sandve Alnæs
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFC specific variable access."""
7

8
import logging
1✔
9
import warnings
1✔
10
from typing import Optional
1✔
11

12
import basix.ufl
1✔
13
import ufl
1✔
14

15
import ffcx.codegeneration.lnodes as L
1✔
16
from ffcx.definitions import entity_types
1✔
17
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
18
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
19
from ffcx.ir.representationutils import QuadratureRule
1✔
20

21
logger = logging.getLogger("ffcx")
1✔
22

23

24
class FFCXBackendAccess:
1✔
25
    """FFCx specific formatter class."""
26

27
    entity_type: entity_types
1✔
28

29
    def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
1✔
30
        """Initialise."""
31
        # Store ir and options
32
        self.entity_type = entity_type
1✔
33
        self.integral_type = integral_type
1✔
34
        self.symbols = symbols
1✔
35
        self.options = options
1✔
36

37
        # Lookup table for handler to call when the "get" method (below) is
38
        # called, depending on the first argument type.
39
        self.call_lookup = {
1✔
40
            ufl.coefficient.Coefficient: self.coefficient,
41
            ufl.constant.Constant: self.constant,
42
            ufl.geometry.Jacobian: self.jacobian,
43
            ufl.geometry.CellCoordinate: self.cell_coordinate,
44
            ufl.geometry.FacetCoordinate: self.facet_coordinate,
45
            ufl.geometry.CellVertices: self.cell_vertices,
46
            ufl.geometry.FacetEdgeVectors: self.facet_edge_vectors,
47
            ufl.geometry.CellEdgeVectors: self.cell_edge_vectors,
48
            ufl.geometry.CellFacetJacobian: self.cell_facet_jacobian,
49
            ufl.geometry.CellRidgeJacobian: self.cell_ridge_jacobian,
50
            ufl.geometry.ReferenceCellVolume: self.reference_cell_volume,
51
            ufl.geometry.ReferenceFacetVolume: self.reference_facet_volume,
52
            ufl.geometry.ReferenceCellEdgeVectors: self.reference_cell_edge_vectors,
53
            ufl.geometry.ReferenceFacetEdgeVectors: self.reference_facet_edge_vectors,
54
            ufl.geometry.ReferenceNormal: self.reference_normal,
55
            ufl.geometry.CellOrientation: self._pass,
56
            ufl.geometry.FacetOrientation: self.facet_orientation,
57
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
58
        }
59

60
    def get(
1✔
61
        self,
62
        mt: ModifiedTerminal,
63
        tabledata: UniqueTableReferenceT,
64
        quadrature_rule: QuadratureRule,
65
    ):
66
        """Format a terminal."""
67
        e = mt.terminal
1✔
68
        # Call appropriate handler, depending on the type of e
69
        handler = self.call_lookup.get(type(e), False)
1✔
70

71
        if not handler:
1✔
72
            # Look for parent class types instead
73
            for k in self.call_lookup.keys():
×
74
                if isinstance(e, k):
×
75
                    handler = self.call_lookup[k]
×
76
                    break
×
77
        if handler:
1✔
78
            return handler(mt, tabledata, quadrature_rule)  # type: ignore
1✔
79
        else:
80
            raise RuntimeError(f"Not handled: {type(e)}")
×
81

82
    def coefficient(
1✔
83
        self,
84
        mt: ModifiedTerminal,
85
        tabledata: UniqueTableReferenceT,
86
        quadrature_rule: QuadratureRule,
87
    ):
88
        """Access a coefficient."""
89
        ttype = tabledata.ttype
1✔
90
        assert ttype != "zeros"
1✔
91

92
        num_dofs = tabledata.values.shape[3]
1✔
93
        begin = tabledata.offset
1✔
94
        assert begin is not None
1✔
95
        assert tabledata.block_size is not None
1✔
96
        end = begin + tabledata.block_size * (num_dofs - 1) + 1
1✔
97

98
        if ttype == "ones" and (end - begin) == 1:
1✔
99
            # f = 1.0 * f_{begin}, just return direct reference to dof
100
            # array at dof begin (if mt is restricted, begin contains
101
            # cell offset)
102
            return self.symbols.coefficient_dof_access(mt.terminal, begin)
1✔
103
        else:
104
            # Return symbol, see definitions for computation
105
            return self.symbols.coefficient_value(mt)
1✔
106

107
    def constant(
1✔
108
        self,
109
        mt: ModifiedTerminal,
110
        tabledata: Optional[UniqueTableReferenceT],
111
        quadrature_rule: Optional[QuadratureRule],
112
    ):
113
        """Access a constant."""
114
        # Access to a constant is handled trivially, directly through constants symbol
115
        return self.symbols.constant_index_access(mt.terminal, mt.flat_component)
1✔
116

117
    def spatial_coordinate(
1✔
118
        self, mt: ModifiedTerminal, tabledata: UniqueTableReferenceT, num_points: QuadratureRule
119
    ):
120
        """Access a spatial coordinate."""
121
        if mt.global_derivatives:
1✔
122
            raise RuntimeError("Not expecting global derivatives of SpatialCoordinate.")
×
123
        if mt.averaged is not None:
1✔
124
            raise RuntimeError("Not expecting average of SpatialCoordinates.")
×
125

126
        if self.integral_type in ufl.custom_integral_types:
1✔
127
            if mt.local_derivatives:
×
128
                raise RuntimeError("FIXME: Jacobian in custom integrals is not implemented.")
×
129

130
            # Access predefined quadrature points table
131
            x = self.symbols.custom_points_table
×
132
            iq = self.symbols.quadrature_loop_index
×
133
            (gdim,) = mt.terminal.ufl_shape
×
134
            if gdim == 1:
×
135
                index = iq
×
136
            else:
137
                index = iq * gdim + mt.flat_component
×
138
            return x[index]
×
139
        elif self.integral_type == "expression":
1✔
140
            # Physical coordinates are computed by code generated in
141
            # definitions
142
            return self.symbols.x_component(mt)
1✔
143
        else:
144
            # Physical coordinates are computed by code generated in
145
            # definitions
146
            return self.symbols.x_component(mt)
1✔
147

148
    def cell_coordinate(self, mt, tabledata, num_points):
1✔
149
        """Access a cell coordinate."""
150
        if mt.global_derivatives:
×
151
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
152
        if mt.local_derivatives:
×
153
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
154
        if mt.averaged is not None:
×
155
            raise RuntimeError("Not expecting average of CellCoordinate.")
×
156

157
        if self.integral_type == "cell" and not mt.restriction:
×
158
            # Access predefined quadrature points table
159
            X = self.symbols.points_table(num_points)
×
160
            (tdim,) = mt.terminal.ufl_shape
×
161
            iq = self.symbols.quadrature_loop_index()
×
162
            if num_points == 1:
×
163
                index = mt.flat_component
×
164
            elif tdim == 1:
×
165
                index = iq
×
166
            else:
167
                index = iq * tdim + mt.flat_component
×
168
            return X[index]
×
169
        else:
170
            # X should be computed from x or Xf symbolically instead of
171
            # getting here
172
            raise RuntimeError("Expecting reference cell coordinate to be symbolically rewritten.")
×
173

174
    def facet_coordinate(self, mt, tabledata, num_points):
1✔
175
        """Access a facet coordinate."""
176
        if mt.global_derivatives:
×
177
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
178
        if mt.local_derivatives:
×
179
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
180
        if mt.averaged is not None:
×
181
            raise RuntimeError("Not expecting average of FacetCoordinate.")
×
182
        if mt.restriction:
×
183
            raise RuntimeError("Not expecting restriction of FacetCoordinate.")
×
184

185
        if self.integral_type in ("interior_facet", "exterior_facet"):
×
186
            (tdim,) = mt.terminal.ufl_shape
×
187
            if tdim == 0:
×
188
                raise RuntimeError("Vertices have no facet coordinates.")
×
189
            elif tdim == 1:
×
190
                warnings.warn(
×
191
                    "Vertex coordinate is always 0, should get rid of this in UFL "
192
                    "geometry lowering."
193
                )
194
                return L.LiteralFloat(0.0)
×
195
            Xf = self.points_table(num_points)
×
196
            iq = self.symbols.quadrature_loop_index()
×
197
            assert 0 <= mt.flat_component < (tdim - 1)
×
198
            if num_points == 1:
×
199
                index = mt.flat_component
×
200
            elif tdim == 2:
×
201
                index = iq
×
202
            else:
203
                index = iq * (tdim - 1) + mt.flat_component
×
204
            return Xf[index]
×
205
        else:
206
            # Xf should be computed from X or x symbolically instead of
207
            # getting here
208
            raise RuntimeError("Expecting reference facet coordinate to be symbolically rewritten.")
×
209

210
    def jacobian(self, mt, tabledata, num_points):
1✔
211
        """Access a jacobian."""
212
        if mt.averaged is not None:
1✔
213
            raise RuntimeError("Not expecting average of Jacobian.")
×
214
        return self.symbols.J_component(mt)
1✔
215

216
    def reference_cell_volume(self, mt, tabledata, access):
1✔
217
        """Access a reference cell volume."""
218
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
×
219
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
220
            return L.Symbol(f"{cellname}_reference_cell_volume", dtype=L.DataType.REAL)
×
221
        else:
222
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
223

224
    def reference_facet_volume(self, mt, tabledata, access):
1✔
225
        """Access a reference facet volume."""
226
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
227
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
228
            return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL)
1✔
229
        else:
230
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
231

232
    def reference_normal(self, mt, tabledata, access):
1✔
233
        """Access a reference normal."""
234
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
235
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
236
            table = L.Symbol(f"{cellname}_reference_normals", dtype=L.DataType.REAL)
1✔
237
            facet = self.symbols.entity("facet", mt.restriction)
1✔
238
            return table[facet][mt.component[0]]
1✔
239
        else:
240
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
241

242
    def cell_facet_jacobian(self, mt, tabledata, num_points):
1✔
243
        """Access a cell facet jacobian."""
244
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
245
        if cellname in (
1✔
246
            "triangle",
247
            "tetrahedron",
248
            "quadrilateral",
249
            "hexahedron",
250
            "prism",
251
            "pyramid",
252
        ):
253
            table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
1✔
254
            facet = self.symbols.entity("facet", mt.restriction)
1✔
255
            return table[facet][mt.component[0]][mt.component[1]]
1✔
256
        elif cellname == "interval":
×
257
            raise RuntimeError("The reference facet jacobian doesn't make sense for interval cell.")
×
258
        else:
259
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
260

261
    def cell_ridge_jacobian(self, mt, tabledata, num_points):
1✔
262
        """Access a cell ridge jacobian."""
263
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
264
        if cellname in ("tetrahedron", "prism", "hexahedron"):
1✔
265
            table = L.Symbol(f"{cellname}_cell_ridge_jacobian", dtype=L.DataType.REAL)
1✔
266
            ridge = self.symbols.entity("ridge", mt.restriction)
1✔
267
            return table[ridge][mt.component[0]][mt.component[1]]
1✔
NEW
268
        elif cellname in ["triangle", "quadrilateral"]:
×
NEW
269
            raise RuntimeError("The ridge jacobian doesn't make sense for 2D cells.")
×
270
        else:
NEW
271
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
272

273
    def reference_cell_edge_vectors(self, mt, tabledata, num_points):
1✔
274
        """Access a reference cell edge vector."""
275
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
276
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
277
            table = L.Symbol(f"{cellname}_reference_cell_edge_vectors", dtype=L.DataType.REAL)
1✔
278
            return table[mt.component[0]][mt.component[1]]
1✔
279
        elif cellname == "interval":
×
280
            raise RuntimeError(
×
281
                "The reference cell edge vectors doesn't make sense for interval cell."
282
            )
283
        else:
284
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
285

286
    def reference_facet_edge_vectors(self, mt, tabledata, num_points):
1✔
287
        """Access a reference facet edge vector."""
288
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
289
        if cellname in ("tetrahedron", "hexahedron"):
1✔
290
            table = L.Symbol(f"{cellname}_reference_facet_edge_vectors", dtype=L.DataType.REAL)
1✔
291
            return table[mt.component[0]][mt.component[1]]
1✔
292
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
293
            raise RuntimeError(
×
294
                "The reference cell facet edge vectors doesn't make sense for interval "
295
                "or triangle cell."
296
            )
297
        else:
298
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
299

300
    def facet_orientation(self, mt, tabledata, num_points):
1✔
301
        """Access a facet orientation."""
302
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
×
303
        if cellname not in ("interval", "triangle", "tetrahedron"):
×
304
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
305

306
        table = L.Symbol(f"{cellname}_facet_orientation", dtype=L.DataType.INT)
×
307
        facet = self.symbols.entity("facet", mt.restriction)
×
308
        return table[facet]
×
309

310
    def cell_vertices(self, mt, tabledata, num_points):
1✔
311
        """Access a cell vertex."""
312
        # Get properties of domain
313
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
314
        gdim = domain.geometric_dimension()
×
315
        coordinate_element = domain.ufl_coordinate_element()
×
316

317
        # Get dimension and dofmap of scalar element
318
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
319
        assert coordinate_element.reference_value_shape == (gdim,)
×
320
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
321
        scalar_element = ufl_scalar_element
×
322
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
323

324
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
325
        num_scalar_dofs = scalar_element.dim
×
326

327
        # Get dof and component
328
        (dof,) = vertex_scalar_dofs[mt.component[0]]
×
329
        component = mt.component[1]
×
330

331
        expr = self.symbols.domain_dof_access(dof, component, gdim, num_scalar_dofs, mt.restriction)
×
332
        return expr
×
333

334
    def cell_edge_vectors(self, mt, tabledata, num_points):
1✔
335
        """Access a cell edge vector."""
336
        # Get properties of domain
337
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
338
        cellname = domain.ufl_cell().cellname()
×
339
        gdim = domain.geometric_dimension()
×
340
        coordinate_element = domain.ufl_coordinate_element()
×
341

342
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
343
            pass
×
344
        elif cellname == "interval":
×
345
            raise RuntimeError(
×
346
                "The physical cell edge vectors doesn't make sense for interval cell."
347
            )
348
        else:
349
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
350

351
        # Get dimension and dofmap of scalar element
352
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
353
        assert coordinate_element.reference_value_shape == (gdim,)
×
354
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
355
        scalar_element = ufl_scalar_element
×
356
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
357

358
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
359
        num_scalar_dofs = scalar_element.dim
×
360

361
        # Get edge vertices
362
        edge = mt.component[0]
×
363
        vertex0, vertex1 = scalar_element.reference_topology[1][edge]
×
364

365
        # Get dofs and component
366
        (dof0,) = vertex_scalar_dofs[vertex0]
×
367
        (dof1,) = vertex_scalar_dofs[vertex1]
×
368
        component = mt.component[1]
×
369

370
        return self.symbols.domain_dof_access(
×
371
            dof0, component, gdim, num_scalar_dofs, mt.restriction
372
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
373

374
    def facet_edge_vectors(self, mt, tabledata, num_points):
1✔
375
        """Access a facet edge vector."""
376
        # Get properties of domain
377
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
378
        cellname = domain.ufl_cell().cellname()
×
379
        gdim = domain.geometric_dimension()
×
380
        coordinate_element = domain.ufl_coordinate_element()
×
381

382
        if cellname in ("tetrahedron", "hexahedron"):
×
383
            pass
×
384
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
385
            raise RuntimeError(
×
386
                f"The physical facet edge vectors doesn't make sense for {cellname} cell."
387
            )
388
        else:
389
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
390

391
        # Get dimension and dofmap of scalar element
392
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
393
        assert coordinate_element.reference_value_shape == (gdim,)
×
394
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
395
        scalar_element = ufl_scalar_element
×
396
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
397

398
        scalar_element = ufl_scalar_element
×
399
        num_scalar_dofs = scalar_element.dim
×
400

401
        # Get edge vertices
402
        facet = self.symbols.entity("facet", mt.restriction)
×
403
        facet_edge = mt.component[0]
×
404
        facet_edge_vertices = L.Symbol(f"{cellname}_facet_edge_vertices", dtype=L.DataType.INT)
×
405
        vertex0 = facet_edge_vertices[facet][facet_edge][0]
×
406
        vertex1 = facet_edge_vertices[facet][facet_edge][1]
×
407

408
        # Get dofs and component
409
        component = mt.component[1]
×
410
        assert coordinate_element.embedded_superdegree == 1, "Assuming degree 1 element"
×
411
        dof0 = vertex0
×
412
        dof1 = vertex1
×
413
        expr = self.symbols.domain_dof_access(
×
414
            dof0, component, gdim, num_scalar_dofs, mt.restriction
415
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
416

417
        return expr
×
418

419
    def _pass(self, *args, **kwargs):
1✔
420
        """Return one."""
421
        return 1
1✔
422

423
    def table_access(
1✔
424
        self,
425
        tabledata: UniqueTableReferenceT,
426
        entity_type: entity_types,
427
        restriction: str,
428
        quadrature_index: L.MultiIndex,
429
        dof_index: L.MultiIndex,
430
    ):
431
        """Access element table for given entity, quadrature point, and dof index.
432

433
        Args:
434
            tabledata: Table data object
435
            entity_type: Entity type
436
            restriction: Restriction ("+", "-")
437
            quadrature_index: Quadrature index
438
            dof_index: Dof index
439
        """
440
        entity = self.symbols.entity(entity_type, restriction)
1✔
441
        iq_global_index = quadrature_index.global_index
1✔
442
        ic_global_index = dof_index.global_index
1✔
443
        qp = 0  # quadrature permutation
1✔
444

445
        symbols = []
1✔
446
        if tabledata.is_uniform:
1✔
447
            entity = L.LiteralInt(0)
1✔
448

449
        if tabledata.is_piecewise:
1✔
450
            iq_global_index = L.LiteralInt(0)
1✔
451

452
        # FIXME: Hopefully tabledata is not permuted when applying sum
453
        # factorization
454
        if tabledata.is_permuted:
1✔
455
            qp = self.symbols.quadrature_permutation[0]
1✔
456
            if restriction == "-":
1✔
457
                qp = self.symbols.quadrature_permutation[1]
×
458

459
        if dof_index.dim == 1 and quadrature_index.dim == 1:
1✔
460
            symbols += [L.Symbol(tabledata.name, dtype=L.DataType.REAL)]
1✔
461
            return self.symbols.element_tables[tabledata.name][qp][entity][iq_global_index][
1✔
462
                ic_global_index
463
            ], symbols
464
        else:
465
            FE = []
1✔
466
            assert tabledata.tensor_factors is not None
1✔
467
            for i in range(dof_index.dim):
1✔
468
                factor = tabledata.tensor_factors[i]
1✔
469
                iq_i = quadrature_index.local_index(i)
1✔
470
                ic_i = dof_index.local_index(i)
1✔
471
                table = self.symbols.element_tables[factor.name][qp][entity][iq_i][ic_i]
1✔
472
                symbols += [L.Symbol(factor.name, dtype=L.DataType.REAL)]
1✔
473
                FE.append(table)
1✔
474
            return L.Product(FE), symbols
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc