• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 14172337523

31 Mar 2025 01:09PM UTC coverage: 82.473% (+0.5%) from 81.955%
14172337523

Pull #731

github

jorgensd
Ruff formatting
Pull Request #731: Codim 2 coupling

88 of 112 new or added lines in 9 files covered. (78.57%)

21 existing lines in 5 files now uncovered.

3609 of 4376 relevant lines covered (82.47%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

42.8
/ffcx/codegeneration/access.py
1
# Copyright (C) 2011-2017 Martin Sandve Alnæs
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFC specific variable access."""
7

8
import logging
1✔
9
import warnings
1✔
10
from typing import Optional
1✔
11

12
import basix.ufl
1✔
13
import ufl
1✔
14

15
import ffcx.codegeneration.lnodes as L
1✔
16
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
17
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
18
from ffcx.ir.representationutils import QuadratureRule
1✔
19

20
logger = logging.getLogger("ffcx")
1✔
21

22

23
class FFCXBackendAccess:
1✔
24
    """FFCx specific formatter class."""
25

26
    def __init__(self, entity_type: str, integral_type: str, symbols, options):
1✔
27
        """Initialise."""
28
        # Store ir and options
29
        self.entity_type = entity_type
1✔
30
        self.integral_type = integral_type
1✔
31
        self.symbols = symbols
1✔
32
        self.options = options
1✔
33

34
        # Lookup table for handler to call when the "get" method (below) is
35
        # called, depending on the first argument type.
36
        self.call_lookup = {
1✔
37
            ufl.coefficient.Coefficient: self.coefficient,
38
            ufl.constant.Constant: self.constant,
39
            ufl.geometry.Jacobian: self.jacobian,
40
            ufl.geometry.CellCoordinate: self.cell_coordinate,
41
            ufl.geometry.FacetCoordinate: self.facet_coordinate,
42
            ufl.geometry.CellVertices: self.cell_vertices,
43
            ufl.geometry.FacetEdgeVectors: self.facet_edge_vectors,
44
            ufl.geometry.CellEdgeVectors: self.cell_edge_vectors,
45
            ufl.geometry.CellFacetJacobian: self.cell_facet_jacobian,
46
            ufl.geometry.CellRidgeJacobian: self.cell_ridge_jacobian,
47
            ufl.geometry.ReferenceCellVolume: self.reference_cell_volume,
48
            ufl.geometry.ReferenceFacetVolume: self.reference_facet_volume,
49
            ufl.geometry.ReferenceCellEdgeVectors: self.reference_cell_edge_vectors,
50
            ufl.geometry.ReferenceFacetEdgeVectors: self.reference_facet_edge_vectors,
51
            ufl.geometry.ReferenceNormal: self.reference_normal,
52
            ufl.geometry.CellOrientation: self._pass,
53
            ufl.geometry.FacetOrientation: self.facet_orientation,
54
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
55
        }
56

57
    def get(
1✔
58
        self,
59
        mt: ModifiedTerminal,
60
        tabledata: UniqueTableReferenceT,
61
        quadrature_rule: QuadratureRule,
62
    ):
63
        """Format a terminal."""
64
        e = mt.terminal
1✔
65
        # Call appropriate handler, depending on the type of e
66
        handler = self.call_lookup.get(type(e), False)
1✔
67

68
        if not handler:
1✔
69
            # Look for parent class types instead
70
            for k in self.call_lookup.keys():
×
71
                if isinstance(e, k):
×
72
                    handler = self.call_lookup[k]
×
73
                    break
×
74
        if handler:
1✔
75
            return handler(mt, tabledata, quadrature_rule)  # type: ignore
1✔
76
        else:
UNCOV
77
            raise RuntimeError(f"Not handled: {type(e)}")
×
78

79
    def coefficient(
1✔
80
        self,
81
        mt: ModifiedTerminal,
82
        tabledata: UniqueTableReferenceT,
83
        quadrature_rule: QuadratureRule,
84
    ):
85
        """Access a coefficient."""
86
        ttype = tabledata.ttype
1✔
87
        assert ttype != "zeros"
1✔
88

89
        num_dofs = tabledata.values.shape[3]
1✔
90
        begin = tabledata.offset
1✔
91
        end = begin + tabledata.block_size * (num_dofs - 1) + 1
1✔
92

93
        if ttype == "ones" and (end - begin) == 1:
1✔
94
            # f = 1.0 * f_{begin}, just return direct reference to dof
95
            # array at dof begin (if mt is restricted, begin contains
96
            # cell offset)
97
            return self.symbols.coefficient_dof_access(mt.terminal, begin)
1✔
98
        else:
99
            # Return symbol, see definitions for computation
100
            return self.symbols.coefficient_value(mt)
1✔
101

102
    def constant(
1✔
103
        self,
104
        mt: ModifiedTerminal,
105
        tabledata: Optional[UniqueTableReferenceT],
106
        quadrature_rule: Optional[QuadratureRule],
107
    ):
108
        """Access a constant."""
109
        # Access to a constant is handled trivially, directly through constants symbol
110
        return self.symbols.constant_index_access(mt.terminal, mt.flat_component)
1✔
111

112
    def spatial_coordinate(
1✔
113
        self, mt: ModifiedTerminal, tabledata: UniqueTableReferenceT, num_points: QuadratureRule
114
    ):
115
        """Access a spatial coordinate."""
116
        if mt.global_derivatives:
1✔
117
            raise RuntimeError("Not expecting global derivatives of SpatialCoordinate.")
×
118
        if mt.averaged is not None:
1✔
119
            raise RuntimeError("Not expecting average of SpatialCoordinates.")
×
120

121
        if self.integral_type in ufl.custom_integral_types:
1✔
122
            if mt.local_derivatives:
×
123
                raise RuntimeError("FIXME: Jacobian in custom integrals is not implemented.")
×
124

125
            # Access predefined quadrature points table
126
            x = self.symbols.custom_points_table
×
127
            iq = self.symbols.quadrature_loop_index
×
128
            (gdim,) = mt.terminal.ufl_shape
×
129
            if gdim == 1:
×
130
                index = iq
×
131
            else:
132
                index = iq * gdim + mt.flat_component
×
133
            return x[index]
×
134
        elif self.integral_type == "expression":
1✔
135
            # Physical coordinates are computed by code generated in
136
            # definitions
137
            return self.symbols.x_component(mt)
1✔
138
        else:
139
            # Physical coordinates are computed by code generated in
140
            # definitions
141
            return self.symbols.x_component(mt)
1✔
142

143
    def cell_coordinate(self, mt, tabledata, num_points):
1✔
144
        """Access a cell coordinate."""
145
        if mt.global_derivatives:
×
146
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
147
        if mt.local_derivatives:
×
148
            raise RuntimeError("Not expecting derivatives of CellCoordinate.")
×
149
        if mt.averaged is not None:
×
150
            raise RuntimeError("Not expecting average of CellCoordinate.")
×
151

152
        if self.integral_type == "cell" and not mt.restriction:
×
153
            # Access predefined quadrature points table
154
            X = self.symbols.points_table(num_points)
×
155
            (tdim,) = mt.terminal.ufl_shape
×
156
            iq = self.symbols.quadrature_loop_index()
×
157
            if num_points == 1:
×
158
                index = mt.flat_component
×
159
            elif tdim == 1:
×
160
                index = iq
×
161
            else:
162
                index = iq * tdim + mt.flat_component
×
163
            return X[index]
×
164
        else:
165
            # X should be computed from x or Xf symbolically instead of
166
            # getting here
167
            raise RuntimeError("Expecting reference cell coordinate to be symbolically rewritten.")
×
168

169
    def facet_coordinate(self, mt, tabledata, num_points):
1✔
170
        """Access a facet coordinate."""
171
        if mt.global_derivatives:
×
172
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
173
        if mt.local_derivatives:
×
174
            raise RuntimeError("Not expecting derivatives of FacetCoordinate.")
×
175
        if mt.averaged is not None:
×
176
            raise RuntimeError("Not expecting average of FacetCoordinate.")
×
177
        if mt.restriction:
×
178
            raise RuntimeError("Not expecting restriction of FacetCoordinate.")
×
179

180
        if self.integral_type in ("interior_facet", "exterior_facet"):
×
181
            (tdim,) = mt.terminal.ufl_shape
×
182
            if tdim == 0:
×
183
                raise RuntimeError("Vertices have no facet coordinates.")
×
184
            elif tdim == 1:
×
185
                warnings.warn(
×
186
                    "Vertex coordinate is always 0, should get rid of this in UFL "
187
                    "geometry lowering."
188
                )
189
                return L.LiteralFloat(0.0)
×
190
            Xf = self.points_table(num_points)
×
191
            iq = self.symbols.quadrature_loop_index()
×
192
            assert 0 <= mt.flat_component < (tdim - 1)
×
193
            if num_points == 1:
×
194
                index = mt.flat_component
×
195
            elif tdim == 2:
×
196
                index = iq
×
197
            else:
198
                index = iq * (tdim - 1) + mt.flat_component
×
199
            return Xf[index]
×
200
        else:
201
            # Xf should be computed from X or x symbolically instead of
202
            # getting here
203
            raise RuntimeError("Expecting reference facet coordinate to be symbolically rewritten.")
×
204

205
    def jacobian(self, mt, tabledata, num_points):
1✔
206
        """Access a jacobian."""
207
        if mt.averaged is not None:
1✔
208
            raise RuntimeError("Not expecting average of Jacobian.")
×
209
        return self.symbols.J_component(mt)
1✔
210

211
    def reference_cell_volume(self, mt, tabledata, access):
1✔
212
        """Access a reference cell volume."""
213
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
×
214
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
215
            return L.Symbol(f"{cellname}_reference_cell_volume", dtype=L.DataType.REAL)
×
216
        else:
217
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
218

219
    def reference_facet_volume(self, mt, tabledata, access):
1✔
220
        """Access a reference facet volume."""
221
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
222
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
223
            return L.Symbol(f"{cellname}_reference_facet_volume", dtype=L.DataType.REAL)
1✔
224
        else:
225
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
226

227
    def reference_normal(self, mt, tabledata, access):
1✔
228
        """Access a reference normal."""
229
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
230
        if cellname in ("interval", "triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
231
            table = L.Symbol(f"{cellname}_reference_normals", dtype=L.DataType.REAL)
1✔
232
            facet = self.symbols.entity("facet", mt.restriction)
1✔
233
            return table[facet][mt.component[0]]
1✔
234
        else:
235
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
236

237
    def cell_facet_jacobian(self, mt, tabledata, num_points):
1✔
238
        """Access a cell facet jacobian."""
239
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
240
        if cellname in (
1✔
241
            "triangle",
242
            "tetrahedron",
243
            "quadrilateral",
244
            "hexahedron",
245
            "prism",
246
            "pyramid",
247
        ):
248
            table = L.Symbol(f"{cellname}_cell_facet_jacobian", dtype=L.DataType.REAL)
1✔
249
            facet = self.symbols.entity("facet", mt.restriction)
1✔
250
            return table[facet][mt.component[0]][mt.component[1]]
1✔
251
        elif cellname == "interval":
×
252
            raise RuntimeError("The reference facet jacobian doesn't make sense for interval cell.")
×
253
        else:
254
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
255

256
    def cell_ridge_jacobian(self, mt, tabledata, num_points):
1✔
257
        """Access a cell ridge jacobian."""
258
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
259
        if cellname in ("tetrahedron", "prism", "hexahedron"):
1✔
260
            table = L.Symbol(f"{cellname}_cell_ridge_jacobian", dtype=L.DataType.REAL)
1✔
261
            ridge = self.symbols.entity("ridge", mt.restriction)
1✔
262
            return table[ridge][mt.component[0]][mt.component[1]]
1✔
NEW
263
        elif cellname in ["triangle", "quadrilateral"]:
×
NEW
264
            raise RuntimeError("The ridge jacobian doesn't make sense for 2D cells.")
×
265
        else:
NEW
266
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
267

268
    def reference_cell_edge_vectors(self, mt, tabledata, num_points):
1✔
269
        """Access a reference cell edge vector."""
270
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
271
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
1✔
272
            table = L.Symbol(f"{cellname}_reference_cell_edge_vectors", dtype=L.DataType.REAL)
1✔
273
            return table[mt.component[0]][mt.component[1]]
1✔
274
        elif cellname == "interval":
×
275
            raise RuntimeError(
×
276
                "The reference cell edge vectors doesn't make sense for interval cell."
277
            )
278
        else:
279
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
280

281
    def reference_facet_edge_vectors(self, mt, tabledata, num_points):
1✔
282
        """Access a reference facet edge vector."""
283
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
1✔
284
        if cellname in ("tetrahedron", "hexahedron"):
1✔
285
            table = L.Symbol(f"{cellname}_reference_facet_edge_vectors", dtype=L.DataType.REAL)
1✔
286
            return table[mt.component[0]][mt.component[1]]
1✔
287
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
288
            raise RuntimeError(
×
289
                "The reference cell facet edge vectors doesn't make sense for interval "
290
                "or triangle cell."
291
            )
292
        else:
293
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
294

295
    def facet_orientation(self, mt, tabledata, num_points):
1✔
296
        """Access a facet orientation."""
297
        cellname = ufl.domain.extract_unique_domain(mt.terminal).ufl_cell().cellname()
×
298
        if cellname not in ("interval", "triangle", "tetrahedron"):
×
299
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
300

301
        table = L.Symbol(f"{cellname}_facet_orientation", dtype=L.DataType.INT)
×
302
        facet = self.symbols.entity("facet", mt.restriction)
×
303
        return table[facet]
×
304

305
    def cell_vertices(self, mt, tabledata, num_points):
1✔
306
        """Access a cell vertex."""
307
        # Get properties of domain
308
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
309
        gdim = domain.geometric_dimension()
×
310
        coordinate_element = domain.ufl_coordinate_element()
×
311

312
        # Get dimension and dofmap of scalar element
313
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
314
        assert coordinate_element.reference_value_shape == (gdim,)
×
315
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
316
        scalar_element = ufl_scalar_element
×
317
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
318

319
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
320
        num_scalar_dofs = scalar_element.dim
×
321

322
        # Get dof and component
323
        (dof,) = vertex_scalar_dofs[mt.component[0]]
×
324
        component = mt.component[1]
×
325

326
        expr = self.symbols.domain_dof_access(dof, component, gdim, num_scalar_dofs, mt.restriction)
×
327
        return expr
×
328

329
    def cell_edge_vectors(self, mt, tabledata, num_points):
1✔
330
        """Access a cell edge vector."""
331
        # Get properties of domain
332
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
333
        cellname = domain.ufl_cell().cellname()
×
334
        gdim = domain.geometric_dimension()
×
335
        coordinate_element = domain.ufl_coordinate_element()
×
336

337
        if cellname in ("triangle", "tetrahedron", "quadrilateral", "hexahedron"):
×
338
            pass
×
339
        elif cellname == "interval":
×
340
            raise RuntimeError(
×
341
                "The physical cell edge vectors doesn't make sense for interval cell."
342
            )
343
        else:
344
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
345

346
        # Get dimension and dofmap of scalar element
347
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
348
        assert coordinate_element.reference_value_shape == (gdim,)
×
349
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
350
        scalar_element = ufl_scalar_element
×
351
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
352

353
        vertex_scalar_dofs = scalar_element.entity_dofs[0]
×
354
        num_scalar_dofs = scalar_element.dim
×
355

356
        # Get edge vertices
357
        edge = mt.component[0]
×
358
        vertex0, vertex1 = scalar_element.reference_topology[1][edge]
×
359

360
        # Get dofs and component
361
        (dof0,) = vertex_scalar_dofs[vertex0]
×
362
        (dof1,) = vertex_scalar_dofs[vertex1]
×
363
        component = mt.component[1]
×
364

365
        return self.symbols.domain_dof_access(
×
366
            dof0, component, gdim, num_scalar_dofs, mt.restriction
367
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
368

369
    def facet_edge_vectors(self, mt, tabledata, num_points):
1✔
370
        """Access a facet edge vector."""
371
        # Get properties of domain
372
        domain = ufl.domain.extract_unique_domain(mt.terminal)
×
373
        cellname = domain.ufl_cell().cellname()
×
374
        gdim = domain.geometric_dimension()
×
375
        coordinate_element = domain.ufl_coordinate_element()
×
376

377
        if cellname in ("tetrahedron", "hexahedron"):
×
378
            pass
×
379
        elif cellname in ("interval", "triangle", "quadrilateral"):
×
380
            raise RuntimeError(
×
381
                f"The physical facet edge vectors doesn't make sense for {cellname} cell."
382
            )
383
        else:
384
            raise RuntimeError(f"Unhandled cell types {cellname}.")
×
385

386
        # Get dimension and dofmap of scalar element
387
        assert isinstance(coordinate_element, basix.ufl._BlockedElement)
×
388
        assert coordinate_element.reference_value_shape == (gdim,)
×
389
        (ufl_scalar_element,) = set(coordinate_element.sub_elements)
×
390
        scalar_element = ufl_scalar_element
×
391
        assert scalar_element.reference_value_size == 1 and scalar_element.block_size == 1
×
392

393
        scalar_element = ufl_scalar_element
×
394
        num_scalar_dofs = scalar_element.dim
×
395

396
        # Get edge vertices
397
        facet = self.symbols.entity("facet", mt.restriction)
×
398
        facet_edge = mt.component[0]
×
399
        facet_edge_vertices = L.Symbol(f"{cellname}_facet_edge_vertices", dtype=L.DataType.INT)
×
400
        vertex0 = facet_edge_vertices[facet][facet_edge][0]
×
401
        vertex1 = facet_edge_vertices[facet][facet_edge][1]
×
402

403
        # Get dofs and component
404
        component = mt.component[1]
×
405
        assert coordinate_element.embedded_superdegree == 1, "Assuming degree 1 element"
×
406
        dof0 = vertex0
×
407
        dof1 = vertex1
×
408
        expr = self.symbols.domain_dof_access(
×
409
            dof0, component, gdim, num_scalar_dofs, mt.restriction
410
        ) - self.symbols.domain_dof_access(dof1, component, gdim, num_scalar_dofs, mt.restriction)
411

412
        return expr
×
413

414
    def _pass(self, *args, **kwargs):
1✔
415
        """Return one."""
416
        return 1
1✔
417

418
    def table_access(
1✔
419
        self,
420
        tabledata: UniqueTableReferenceT,
421
        entity_type: str,
422
        restriction: str,
423
        quadrature_index: L.MultiIndex,
424
        dof_index: L.MultiIndex,
425
    ):
426
        """Access element table for given entity, quadrature point, and dof index.
427

428
        Args:
429
            tabledata: Table data object
430
            entity_type: Entity type ("cell", "facet", "vertex")
431
            restriction: Restriction ("+", "-")
432
            quadrature_index: Quadrature index
433
            dof_index: Dof index
434
        """
435
        entity = self.symbols.entity(entity_type, restriction)
1✔
436
        iq_global_index = quadrature_index.global_index
1✔
437
        ic_global_index = dof_index.global_index
1✔
438
        qp = 0  # quadrature permutation
1✔
439

440
        symbols = []
1✔
441
        if tabledata.is_uniform:
1✔
442
            entity = L.LiteralInt(0)
1✔
443

444
        if tabledata.is_piecewise:
1✔
445
            iq_global_index = L.LiteralInt(0)
1✔
446

447
        # FIXME: Hopefully tabledata is not permuted when applying sum
448
        # factorization
449
        if tabledata.is_permuted:
1✔
450
            qp = self.symbols.quadrature_permutation[0]
1✔
451
            if restriction == "-":
1✔
452
                qp = self.symbols.quadrature_permutation[1]
×
453

454
        if dof_index.dim == 1 and quadrature_index.dim == 1:
1✔
455
            symbols += [L.Symbol(tabledata.name, dtype=L.DataType.REAL)]
1✔
456
            return self.symbols.element_tables[tabledata.name][qp][entity][iq_global_index][
1✔
457
                ic_global_index
458
            ], symbols
459
        else:
460
            FE = []
1✔
461
            for i in range(dof_index.dim):
1✔
462
                factor = tabledata.tensor_factors[i]
1✔
463
                iq_i = quadrature_index.local_index(i)
1✔
464
                ic_i = dof_index.local_index(i)
1✔
465
                table = self.symbols.element_tables[factor.name][qp][entity][iq_i][ic_i]
1✔
466
                symbols += [L.Symbol(factor.name, dtype=L.DataType.REAL)]
1✔
467
                FE.append(table)
1✔
468
            return L.Product(FE), symbols
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc