• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 18695412111

21 Oct 2025 07:31PM UTC coverage: 83.0%. Remained the same
18695412111

push

github

web-flow
Typing fixes (#794)

6 of 6 new or added lines in 3 files covered. (100.0%)

27 existing lines in 3 files now uncovered.

3730 of 4494 relevant lines covered (83.0%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.33
/ffcx/codegeneration/definitions.py
1
# Copyright (C) 2011-2023 Martin Sandve Alnæs, Igor A. Baratta
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFC specific variable definitions."""
7

8
import logging
1✔
9

10
import ufl
1✔
11

12
import ffcx.codegeneration.lnodes as L
1✔
13
from ffcx.definitions import entity_types
1✔
14
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
15
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
16
from ffcx.ir.representationutils import QuadratureRule
1✔
17

18
logger = logging.getLogger("ffcx")
1✔
19

20

21
def create_quadrature_index(quadrature_rule, quadrature_index_symbol):
1✔
22
    """Create a multi index for the quadrature loop."""
23
    ranges = [0]
1✔
24
    name = quadrature_index_symbol.name
1✔
25
    indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
26
    if quadrature_rule:
1✔
27
        ranges = [quadrature_rule.weights.size]
1✔
28
        if quadrature_rule.has_tensor_factors:
1✔
29
            dim = len(quadrature_rule.tensor_factors)
1✔
30
            ranges = [factor[1].size for factor in quadrature_rule.tensor_factors]
1✔
31
            indices = [L.Symbol(name + f"{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
32

33
    return L.MultiIndex(indices, ranges)
1✔
34

35

36
def create_dof_index(tabledata, dof_index_symbol):
1✔
37
    """Create a multi index for the coefficient dofs."""
38
    name = dof_index_symbol.name
1✔
39
    if tabledata.has_tensor_factorisation:
1✔
40
        dim = len(tabledata.tensor_factors)
1✔
41
        ranges = [factor.values.shape[-1] for factor in tabledata.tensor_factors]
1✔
42
        indices = [L.Symbol(f"{name}{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
43
    else:
44
        ranges = [tabledata.values.shape[-1]]
1✔
45
        indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
46

47
    return L.MultiIndex(indices, ranges)
1✔
48

49

50
class FFCXBackendDefinitions:
1✔
51
    """FFCx specific code definitions."""
52

53
    entity_type: entity_types
1✔
54

55
    def __init__(self, entity_type: entity_types, integral_type: str, access, options):
1✔
56
        """Initialise."""
57
        # Store ir and options
58
        self.integral_type = integral_type
1✔
59
        self.entity_type = entity_type
1✔
60
        self.access = access
1✔
61
        self.options = options
1✔
62

63
        # called, depending on the first argument type.
64
        self.handler_lookup = {
1✔
65
            ufl.coefficient.Coefficient: self.coefficient,
66
            ufl.geometry.Jacobian: self._define_coordinate_dofs_lincomb,
67
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
68
            ufl.constant.Constant: self.pass_through,
69
            ufl.geometry.CellVertices: self.pass_through,
70
            ufl.geometry.FacetEdgeVectors: self.pass_through,
71
            ufl.geometry.CellEdgeVectors: self.pass_through,
72
            ufl.geometry.CellFacetJacobian: self.pass_through,
73
            ufl.geometry.CellRidgeJacobian: self.pass_through,
74
            ufl.geometry.ReferenceCellVolume: self.pass_through,
75
            ufl.geometry.ReferenceFacetVolume: self.pass_through,
76
            ufl.geometry.ReferenceCellEdgeVectors: self.pass_through,
77
            ufl.geometry.ReferenceFacetEdgeVectors: self.pass_through,
78
            ufl.geometry.ReferenceNormal: self.pass_through,
79
            ufl.geometry.CellOrientation: self.pass_through,
80
            ufl.geometry.FacetOrientation: self.pass_through,
81
        }
82

83
    @property
1✔
84
    def symbols(self):
1✔
85
        """Return formatter."""
86
        return self.access.symbols
1✔
87

88
    def get(
1✔
89
        self,
90
        mt: ModifiedTerminal,
91
        tabledata: UniqueTableReferenceT,
92
        quadrature_rule: QuadratureRule,
93
        access: L.Symbol,
94
    ) -> L.Section | list:
95
        """Return definition code for a terminal."""
96
        # Call appropriate handler, depending on the type of terminal
97
        terminal = mt.terminal
1✔
98
        ttype = type(terminal)
1✔
99

100
        # Look for parent class of ttype or direct handler
101
        while ttype not in self.handler_lookup and ttype.__bases__:
1✔
102
            ttype = ttype.__bases__[0]
×
103

104
        # Get the handler from the lookup, or None if not found
105
        handler = self.handler_lookup.get(ttype)  # type: ignore
1✔
106

107
        if handler is None:
1✔
108
            raise NotImplementedError(f"No handler for terminal type: {ttype}")
×
109

110
        # Call the handler
111
        return handler(mt, tabledata, quadrature_rule, access)  # type: ignore
1✔
112

113
    def coefficient(
1✔
114
        self,
115
        mt: ModifiedTerminal,
116
        tabledata: UniqueTableReferenceT,
117
        quadrature_rule: QuadratureRule,
118
        access: L.Symbol,
119
    ) -> L.Section | list:
120
        """Return definition code for coefficients."""
121
        # For applying tensor product to coefficients, we need to know if the coefficient
122
        # has a tensor factorisation and if the quadrature rule has a tensor factorisation.
123
        # If both are true, we can apply the tensor product to the coefficient.
124

125
        iq_symbol = self.symbols.quadrature_loop_index
1✔
126
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
127

128
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
129
        ic = create_dof_index(tabledata, ic_symbol)
1✔
130

131
        # Get properties of tables
132
        ttype = tabledata.ttype
1✔
133
        num_dofs = tabledata.values.shape[3]
1✔
134
        bs = tabledata.block_size
1✔
135
        begin = tabledata.offset
1✔
136
        assert bs is not None
1✔
137
        assert begin is not None
1✔
138
        end = begin + bs * (num_dofs - 1) + 1
1✔
139

140
        if ttype == "zeros":
1✔
141
            logger.debug("Not expecting zero coefficients to get this far.")
×
142
            return []
×
143

144
        # For a constant coefficient we reference the dofs directly, so no definition needed
145
        if ttype == "ones" and end - begin == 1:
1✔
146
            return []
1✔
147

148
        assert begin < end
1✔
149

150
        # Get access to element table
151
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
152
        dof_access: L.ArrayAccess = self.symbols.coefficient_dof_access(
1✔
153
            mt.terminal, (ic.global_index) * bs + begin
154
        )
155

156
        declaration: list[L.Declaration] = [L.VariableDecl(access, 0.0)]
1✔
157
        body = [L.AssignAdd(access, dof_access * FE)]
1✔
158
        code = [L.create_nested_for_loops([ic], body)]
1✔
159

160
        name = type(mt.terminal).__name__
1✔
161
        input = [dof_access.array, *tables]
1✔
162
        output = [access]
1✔
163
        annotations = [L.Annotation.fuse]
1✔
164

165
        # assert input and output are Symbol objects
166
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
167
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
168

169
        return L.Section(name, code, declaration, input, output, annotations)
1✔
170

171
    def _define_coordinate_dofs_lincomb(
1✔
172
        self,
173
        mt: ModifiedTerminal,
174
        tabledata: UniqueTableReferenceT,
175
        quadrature_rule: QuadratureRule,
176
        access: L.Symbol,
177
    ) -> L.Section | list:
178
        """Define x or J as a linear combination of coordinate dofs with given table data."""
179
        # Get properties of domain
180
        domain = ufl.domain.extract_unique_domain(mt.terminal)
1✔
181
        assert isinstance(domain, ufl.Mesh)
1✔
182
        coordinate_element = domain.ufl_coordinate_element()
1✔
183
        num_scalar_dofs = coordinate_element._sub_element.dim
1✔
184

185
        num_dofs = tabledata.values.shape[3]
1✔
186
        begin = tabledata.offset
1✔
187

188
        assert num_scalar_dofs == num_dofs
1✔
189

190
        # Find table name
191
        ttype = tabledata.ttype
1✔
192

193
        assert ttype != "zeros"
1✔
194
        assert ttype != "ones"
1✔
195

196
        # Get access to element table
197
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
198
        iq_symbol = self.symbols.quadrature_loop_index
1✔
199
        ic = create_dof_index(tabledata, ic_symbol)
1✔
200
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
201
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
202

203
        dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL)
1✔
204

205
        # coordinate dofs is always 3d
206
        dim = 3
1✔
207
        offset = 0
1✔
208
        if mt.restriction == "-":
1✔
209
            offset = num_scalar_dofs * dim
1✔
210

211
        code = []
1✔
212
        declaration = [L.VariableDecl(access, 0.0)]
1✔
213
        body = [L.AssignAdd(access, dof_access[ic.global_index * dim + begin + offset] * FE)]
1✔
214
        code = [L.create_nested_for_loops([ic], body)]
1✔
215

216
        name = type(mt.terminal).__name__
1✔
217
        output = [access]
1✔
218
        input = [dof_access, *tables]
1✔
219
        annotations = [L.Annotation.fuse]
1✔
220

221
        # assert input and output are Symbol objects
222
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
223
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
224

225
        return L.Section(name, code, declaration, input, output, annotations)
1✔
226

227
    def spatial_coordinate(
1✔
228
        self,
229
        mt: ModifiedTerminal,
230
        tabledata: UniqueTableReferenceT,
231
        quadrature_rule: QuadratureRule,
232
        access: L.Symbol,
233
    ) -> L.Section | list:
234
        """Return definition code for the physical spatial coordinates.
235

236
        If physical coordinates are given:
237
          No definition needed.
238

239
        If reference coordinates are given:
240
          x = sum_k xdof_k xphi_k(X)
241

242
        If reference facet coordinates are given:
243
          x = sum_k xdof_k xphi_k(Xf)
244
        """
245
        if self.integral_type in ufl.custom_integral_types:
1✔
246
            # FIXME: Jacobian may need adjustment for custom_integral_types
UNCOV
247
            if mt.local_derivatives:
×
248
                logger.exception("FIXME: Jacobian in custom integrals is not implemented.")
×
249
            return []
×
250
        else:
251
            return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
1✔
252

253
    def jacobian(
1✔
254
        self,
255
        mt: ModifiedTerminal,
256
        tabledata: UniqueTableReferenceT,
257
        quadrature_rule: QuadratureRule,
258
        access: L.Symbol,
259
    ) -> L.Section | list:
260
        """Return definition code for the Jacobian of x(X)."""
UNCOV
261
        return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
×
262

263
    def pass_through(
1✔
264
        self,
265
        mt: ModifiedTerminal,
266
        tabledata: UniqueTableReferenceT,
267
        quadrature_rule: QuadratureRule,
268
        access: L.Symbol,
269
    ) -> L.Section | list:
270
        """Return definition code for pass through terminals."""
271
        return []
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc