• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 18834018400

23 Oct 2025 12:21PM UTC coverage: 83.0%. Remained the same
18834018400

push

github

web-flow
`AbstractCell` members are now properties (#789)

* Adapt to cell members now properties

* Change ufl branch

* One more

* Change UFL branch for dolfinx CI

* Adapt demos

* Last?

* Change ref branch

* Apply suggestions from code review

26 of 35 new or added lines in 7 files covered. (74.29%)

2 existing lines in 1 file now uncovered.

3730 of 4494 relevant lines covered (83.0%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

43.68
/ffcx/codegeneration/access.py
1
# Copyright (C) 2011-2017 Martin Sandve Alnæs
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFC specific variable access."""
7

8
import logging
1✔
9
import warnings
1✔
10

11
import basix.ufl
1✔
12
import ufl
1✔
13

14
import ffcx.codegeneration.lnodes as L
1✔
15
from ffcx.definitions import entity_types
1✔
16
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
17
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
18
from ffcx.ir.representationutils import QuadratureRule
1✔
19

20
logger = logging.getLogger("ffcx")
1✔
21

22

23
class FFCXBackendAccess:
1✔
24
    """FFCx specific formatter class."""
25

26
    entity_type: entity_types
1✔
27

28
    def __init__(self, entity_type: entity_types, integral_type: str, symbols, options):
1✔
29
        """Initialise."""
30
        # Store ir and options
31
        self.entity_type = entity_type
1✔
32
        self.integral_type = integral_type
1✔
33
        self.symbols = symbols
1✔
34
        self.options = options
1✔
35

36
        # Lookup table for handler to call when the "get" method (below) is
37
        # called, depending on the first argument type.
38
        self.call_lookup = {
1✔
39
            ufl.coefficient.Coefficient: self.coefficient,
40
            ufl.constant.Constant: self.constant,
41
            ufl.geometry.Jacobian: self.jacobian,
42
            ufl.geometry.CellCoordinate: self.cell_coordinate,
43
            ufl.geometry.FacetCoordinate: self.facet_coordinate,
44
            ufl.geometry.CellVertices: self.cell_vertices,
45
            ufl.geometry.FacetEdgeVectors: self.facet_edge_vectors,
46
            ufl.geometry.CellEdgeVectors: self.cell_edge_vectors,
47
            ufl.geometry.CellFacetJacobian: self.cell_facet_jacobian,
48
            ufl.geometry.CellRidgeJacobian: self.cell_ridge_jacobian,
49
            ufl.geometry.ReferenceCellVolume: self.reference_cell_volume,
50
            ufl.geometry.ReferenceFacetVolume: self.reference_facet_volume,
51
            ufl.geometry.ReferenceCellEdgeVectors: self.reference_cell_edge_vectors,
52
            ufl.geometry.ReferenceFacetEdgeVectors: self.reference_facet_edge_vectors,
53
            ufl.geometry.ReferenceNormal: self.reference_normal,
54
            ufl.geometry.CellOrientation: self._pass,
55
            ufl.geometry.FacetOrientation: self.facet_orientation,
56
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
57
        }
58

59
    def get(
1✔
60
        self,
61
        mt: ModifiedTerminal,
62
        tabledata: UniqueTableReferenceT,
63
        quadrature_rule: QuadratureRule,
64
    ):
65
        """Format a terminal."""
66
        e = mt.terminal
1✔
67
        # Call appropriate handler, depending on the type of e
68
        handler = self.call_lookup.get(type(e), False)  # type: ignore
1✔
69

70
        if not handler:
1✔
71
            # Look for parent class types instead
72
            for k in self.call_lookup.keys():
×
73
                if isinstance(e, k):
×
74
                    handler = self.call_lookup[k]
×
75
                    break
×
76
        if handler:
1✔
77
            return handler(mt, tabledata, quadrature_rule)
1✔
78
        else:
79
            raise RuntimeError(f"Not handled: {type(e)}")
×
80

81
    def coefficient(
1✔
82
        self,
83
        mt: ModifiedTerminal,
84
        tabledata: UniqueTableReferenceT,
85
        quadrature_rule: QuadratureRule,
86
    ):
87
        """Access a coefficient."""
88
        ttype = tabledata.ttype
1✔
89
        assert ttype != "zeros"
1✔
90

91
        num_dofs = tabledata.values.shape[3]
1✔
92
        begin = tabledata.offset
1✔
93
        assert begin is not None
1✔
94
        assert tabledata.block_size is not None
1✔
95
        end = begin + tabledata.block_size * (num_dofs - 1) + 1
1✔
96

97
        if ttype == "ones" and (end - begin) == 1:
1✔
98
            # f = 1.0 * f_{begin}, just return direct reference to dof
99
            # array at dof begin (if mt is restricted, begin contains
100
            # cell offset)
101
            return self.symbols.coefficient_dof_access(mt.terminal, begin)
1✔
102
        else:
103
            # Return symbol, see definitions for computation
104
            return self.symbols.coefficient_value(mt)
1✔
105

106
    def constant(
1✔
107
        self,
108
        mt: ModifiedTerminal,
109
        tabledata: UniqueTableReferenceT | None,
110
        quadrature_rule: QuadratureRule | None,
111
    ):
112
        """Access a constant."""
113
        # Access to a constant is handled trivially, directly through constants symbol
114
        return self.symbols.constant_index_access(mt.terminal, mt.flat_component)
1✔
115

116
    def spatial_coordinate(
1✔
117
        self, mt: ModifiedTerminal, tabledata: UniqueTableReferenceT, num_points: QuadratureRule
118
    ):
119
        """Access a spatial coordinate."""
120
        if mt.global_derivatives:
1✔
121
            raise RuntimeError("Not expecting global derivatives of SpatialCoordinate.")
×
122
        if mt.averaged is not None:
1✔
123
            raise RuntimeError("Not expecting average of SpatialCoordinates.")
×
124

125
        if self.integral_type in ufl.custom_integral_types:
1✔
126
            if mt.local_derivatives:
×
127
                raise RuntimeError("FIXME: Jacobian in custom integrals is not implemented.")
×
128

129
            # Access predefined quadrature points table
130
            x = self.symbols.custom_points_table
×
131
            iq = self.symbols.quadrature_loop_index
×
132
            (gdim,) = mt.terminal.ufl_shape
×
133
            if gdim == 1:
×
134
                index = iq
×
135
            else:
136
                index = iq * gdim + mt.flat_component
×
137
            return x[index]
×
138
        elif self.integral_type == "expression":
1✔
139
            # Physical coordinates are computed by code generated in
140
            # definitions
141
            return self.symbols.x_component(mt)
1✔
142
        else:
143
            # Physical coordinates are computed by code generated in
144
            # definitions
145
            return self.symbols.x_component(mt)
1✔
146

147
    def cell_coordinate(self, mt, tabledata, num_points):
1✔
148
        """Access a cell coordinate."""
149
        if mt.global_derivatives:
×
150
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
151
        if mt.local_derivatives:
×
152
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
153
        if mt.averaged is not None:
×
154
            raise RuntimeError("Not expecting average of CellCoordinate.")
×
155

156
        if self.integral_type == "cell" and not mt.restriction:
×
157
            # Access predefined quadrature points table
158
            X = self.symbols.points_table(num_points)
×
159
            (tdim,) = mt.terminal.ufl_shape
×
160
            iq = self.symbols.quadrature_loop_index()
×
161
            if num_points == 1:
×
162
                index = mt.flat_component
×
163
            elif tdim == 1:
×
164
                index = iq
×
165
            else:
166
                index = iq * tdim + mt.flat_component
×
167
            return X[index]
×
168
        else:
169
            # X should be computed from x or Xf symbolically instead of
170
            # getting here
171
            raise RuntimeError("Expecting reference cell coordinate to be symbolically rewritten.")
×
172

173
    def facet_coordinate(self, mt, tabledata, num_points):
1✔
174
        """Access a facet coordinate."""
175
        if mt.global_derivatives:
×
176
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
177
        if mt.local_derivatives:
×
178
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
179
        if mt.averaged is not None:
×
180
            raise RuntimeError("Not expecting average of FacetCoordinate.")
×
181
        if mt.restriction:
×
182
            raise RuntimeError("Not expecting restriction of FacetCoordinate.")
×
183

184
        if self.integral_type in ("interior_facet", "exterior_facet"):
×
185
            (tdim,) = mt.terminal.ufl_shape
×
186
            if tdim == 0:
×
187
                raise RuntimeError("Vertices have no facet coordinates.")
×
188
            elif tdim == 1:
×
189
                warnings.warn(
×
190
                    "Vertex coordinate is always 0, should get rid of this in UFL "
191
                    "geometry lowering."
192
                )
193
                return L.LiteralFloat(0.0)
×
194
            Xf = self.points_table(num_points)
×
195
            iq = self.symbols.quadrature_loop_index()
×
196
            assert 0 <= mt.flat_component < (tdim - 1)
×
197
            if num_points == 1:
×
198
                index = mt.flat_component
×
199
            elif tdim == 2:
×
200
                index = iq
×
201
            else:
202
                index = iq * (tdim - 1) + mt.flat_component
×
203
            return Xf[index]
×
204
        else:
205
            # Xf should be computed from X or x symbolically instead of
206
            # getting here
207
            raise RuntimeError("Expecting reference facet coordinate to be symbolically rewritten.")
×
208

209
    def jacobian(self, mt, tabledata, num_points):
1✔
210
        """Access a jacobian."""
211
        if mt.averaged is not None:
1✔
212
            raise RuntimeError("Not expecting average of Jacobian.")
×
213
        return self.symbols.J_component(mt)
1✔
214

215
    def reference_cell_volume(self, mt, tabledata, access):
1✔
216
        """Access a reference cell volume."""
NEW
217
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
×
218
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
219
            return L.Symbol(f"{cellname}_reference_cell_volume", dtype=L.DataType.REAL)
×
220
        else:
221
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
222

223
    def reference_facet_volume(self, mt, tabledata, access):
1✔
224
        """Access a reference facet volume."""
225
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
226
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
227
            return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL)
1✔
228
        else:
229
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
230

231
    def reference_normal(self, mt, tabledata, access):
1✔
232
        """Access a reference normal."""
233
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
234
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
235
            table = L.Symbol(f"{cellname}_reference_normals", dtype=L.DataType.REAL)
1✔
236
            facet = self.symbols.entity("facet", mt.restriction)
1✔
237
            return table[facet][mt.component[0]]
1✔
238
        else:
239
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
240

241
    def cell_facet_jacobian(self, mt, tabledata, num_points):
1✔
242
        """Access a cell facet jacobian."""
243
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
244
        if cellname in (
1✔
245
            "triangle",
246
            "tetrahedron",
247
            "quadrilateral",
248
            "hexahedron",
249
            "prism",
250
            "pyramid",
251
        ):
252
            table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
1✔
253
            facet = self.symbols.entity("facet", mt.restriction)
1✔
254
            return table[facet][mt.component[0]][mt.component[1]]
1✔
255
        elif cellname == "interval":
×
256
            raise RuntimeError("The reference facet jacobian doesn't make sense for interval cell.")
×
257
        else:
258
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
259

260
    def cell_ridge_jacobian(self, mt, tabledata, num_points):
1✔
261
        """Access a cell ridge jacobian."""
262
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
263
        if cellname in ("tetrahedron", "prism", "hexahedron"):
1✔
264
            table = L.Symbol(f"{cellname}_cell_ridge_jacobian", dtype=L.DataType.REAL)
1✔
265
            ridge = self.symbols.entity("ridge", mt.restriction)
1✔
266
            return table[ridge][mt.component[0]][mt.component[1]]
1✔
267
        elif cellname in ["triangle", "quadrilateral"]:
×
268
            raise RuntimeError("The ridge jacobian doesn't make sense for 2D cells.")
×
269
        else:
270
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
271

272
    def reference_cell_edge_vectors(self, mt, tabledata, num_points):
1✔
273
        """Access a reference cell edge vector."""
274
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
275
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
276
            table = L.Symbol(f"{cellname}_reference_cell_edge_vectors", dtype=L.DataType.REAL)
1✔
277
            return table[mt.component[0]][mt.component[1]]
1✔
278
        elif cellname == "interval":
×
279
            raise RuntimeError(
×
280
                "The reference cell edge vectors doesn't make sense for interval cell."
281
            )
282
        else:
283
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
284

285
    def reference_facet_edge_vectors(self, mt, tabledata, num_points):
1✔
286
        """Access a reference facet edge vector."""
287
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
1✔
288
        if cellname in ("tetrahedron", "hexahedron"):
1✔
289
            table = L.Symbol(f"{cellname}_reference_facet_edge_vectors", dtype=L.DataType.REAL)
1✔
290
            return table[mt.component[0]][mt.component[1]]
1✔
291
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
292
            raise RuntimeError(
×
293
                "The reference cell facet edge vectors doesn't make sense for interval "
294
                "or triangle cell."
295
            )
296
        else:
297
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
298

299
    def facet_orientation(self, mt, tabledata, num_points):
1✔
300
        """Access a facet orientation."""
NEW
301
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname
×
302
        if cellname not in ("interval", "triangle", "tetrahedron"):
×
303
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
304

305
        table = L.Symbol(f"{cellname}_facet_orientation", dtype=L.DataType.INT)
×
306
        facet = self.symbols.entity("facet", mt.restriction)
×
307
        return table[facet]
×
308

309
    def cell_vertices(self, mt, tabledata, num_points):
1✔
310
        """Access a cell vertex."""
311
        # Get properties of domain
312
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
NEW
313
        gdim = domain.geometric_dimension
×
314
        coordinate_element = domain.ufl_coordinate_element()
×
315

316
        # Get dimension and dofmap of scalar element
317
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
318
        assert coordinate_element.reference_value_shape == (gdim,)
×
319
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
320
        scalar_element = ufl_scalar_element
×
321
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
322

323
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
324
        num_scalar_dofs = scalar_element.dim
×
325

326
        # Get dof and component
327
        (dof,) = vertex_scalar_dofs[mt.component[0]]
×
328
        component = mt.component[1]
×
329

330
        expr = self.symbols.domain_dof_access(dof, component, gdim, num_scalar_dofs, mt.restriction)
×
331
        return expr
×
332

333
    def cell_edge_vectors(self, mt, tabledata, num_points):
1✔
334
        """Access a cell edge vector."""
335
        # Get properties of domain
336
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
NEW
337
        cellname = domain.ufl_cell().cellname
×
NEW
338
        gdim = domain.geometric_dimension
×
UNCOV
339
        coordinate_element = domain.ufl_coordinate_element()
×
340

341
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
342
            pass
×
343
        elif cellname == "interval":
×
344
            raise RuntimeError(
×
345
                "The physical cell edge vectors doesn't make sense for interval cell."
346
            )
347
        else:
348
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
349

350
        # Get dimension and dofmap of scalar element
351
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
352
        assert coordinate_element.reference_value_shape == (gdim,)
×
353
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
354
        scalar_element = ufl_scalar_element
×
355
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
356

357
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
358
        num_scalar_dofs = scalar_element.dim
×
359

360
        # Get edge vertices
361
        edge = mt.component[0]
×
362
        vertex0, vertex1 = scalar_element.reference_topology[1][edge]
×
363

364
        # Get dofs and component
365
        (dof0,) = vertex_scalar_dofs[vertex0]
×
366
        (dof1,) = vertex_scalar_dofs[vertex1]
×
367
        component = mt.component[1]
×
368

369
        return self.symbols.domain_dof_access(
×
370
            dof0, component, gdim, num_scalar_dofs, mt.restriction
371
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
372

373
    def facet_edge_vectors(self, mt, tabledata, num_points):
1✔
374
        """Access a facet edge vector."""
375
        # Get properties of domain
376
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
NEW
377
        cellname = domain.ufl_cell().cellname
×
NEW
378
        gdim = domain.geometric_dimension
×
UNCOV
379
        coordinate_element = domain.ufl_coordinate_element()
×
380

381
        if cellname in ("tetrahedron", "hexahedron"):
×
382
            pass
×
383
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
384
            raise RuntimeError(
×
385
                f"The physical facet edge vectors doesn't make sense for {cellname} cell."
386
            )
387
        else:
388
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
389

390
        # Get dimension and dofmap of scalar element
391
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
392
        assert coordinate_element.reference_value_shape == (gdim,)
×
393
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
394
        scalar_element = ufl_scalar_element
×
395
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
396

397
        scalar_element = ufl_scalar_element
×
398
        num_scalar_dofs = scalar_element.dim
×
399

400
        # Get edge vertices
401
        facet = self.symbols.entity("facet", mt.restriction)
×
402
        facet_edge = mt.component[0]
×
403
        facet_edge_vertices = L.Symbol(f"{cellname}_facet_edge_vertices", dtype=L.DataType.INT)
×
404
        vertex0 = facet_edge_vertices[facet][facet_edge][0]
×
405
        vertex1 = facet_edge_vertices[facet][facet_edge][1]
×
406

407
        # Get dofs and component
408
        component = mt.component[1]
×
409
        assert coordinate_element.embedded_superdegree == 1, "Assuming degree 1 element"
×
410
        dof0 = vertex0
×
411
        dof1 = vertex1
×
412
        expr = self.symbols.domain_dof_access(
×
413
            dof0, component, gdim, num_scalar_dofs, mt.restriction
414
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
415

416
        return expr
×
417

418
    def _pass(self, *args, **kwargs):
1✔
419
        """Return one."""
420
        return 1
1✔
421

422
    def table_access(
1✔
423
        self,
424
        tabledata: UniqueTableReferenceT,
425
        entity_type: entity_types,
426
        restriction: str,
427
        quadrature_index: L.MultiIndex,
428
        dof_index: L.MultiIndex,
429
    ):
430
        """Access element table for given entity, quadrature point, and dof index.
431

432
        Args:
433
            tabledata: Table data object
434
            entity_type: Entity type
435
            restriction: Restriction ("+", "-")
436
            quadrature_index: Quadrature index
437
            dof_index: Dof index
438
        """
439
        entity = self.symbols.entity(entity_type, restriction)
1✔
440
        iq_global_index = quadrature_index.global_index
1✔
441
        ic_global_index = dof_index.global_index
1✔
442
        qp = 0  # quadrature permutation
1✔
443

444
        symbols = []
1✔
445
        if tabledata.is_uniform:
1✔
446
            entity = L.LiteralInt(0)
1✔
447

448
        if tabledata.is_piecewise:
1✔
449
            iq_global_index = L.LiteralInt(0)
1✔
450

451
        # FIXME: Hopefully tabledata is not permuted when applying sum
452
        # factorization
453
        if tabledata.is_permuted:
1✔
454
            qp = self.symbols.quadrature_permutation[0]
1✔
455
            if restriction == "-":
1✔
456
                qp = self.symbols.quadrature_permutation[1]
×
457

458
        if dof_index.dim == 1 and quadrature_index.dim == 1:
1✔
459
            symbols += [L.Symbol(tabledata.name, dtype=L.DataType.REAL)]
1✔
460
            return self.symbols.element_tables[tabledata.name][qp][entity][iq_global_index][
1✔
461
                ic_global_index
462
            ], symbols
463
        else:
464
            FE = []
1✔
465
            assert tabledata.tensor_factors is not None
1✔
466
            for i in range(dof_index.dim):
1✔
467
                factor = tabledata.tensor_factors[i]
1✔
468
                iq_i = quadrature_index.local_index(i)
1✔
469
                ic_i = dof_index.local_index(i)
1✔
470
                table = self.symbols.element_tables[factor.name][qp][entity][iq_i][ic_i]
1✔
471
                symbols += [L.Symbol(factor.name, dtype=L.DataType.REAL)]
1✔
472
                FE.append(table)
1✔
473
            return L.Product(FE), symbols
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc