• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 24124115635

08 Apr 2026 07:47AM UTC coverage: 85.226% (+0.6%) from 84.674%
24124115635

Pull #829

github

jorgensd
Fix typing
Pull Request #829: Support `ufl.interpolate`

373 of 397 new or added lines in 17 files covered. (93.95%)

1 existing line in 1 file now uncovered.

4511 of 5293 relevant lines covered (85.23%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.81
/ffcx/codegeneration/definitions.py
1
# Copyright (C) 2011-2023 Martin Sandve Alnæs, Igor A. Baratta
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFCx specific variable definitions."""
7

8
import logging
1✔
9
import typing
1✔
10

11
import ufl
1✔
12

13
import ffcx.codegeneration.lnodes as L
1✔
14
from ffcx.analysis import ProxyCoefficient
1✔
15
from ffcx.definitions import entity_types
1✔
16
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
17
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
18
from ffcx.ir.representationutils import QuadratureRule
1✔
19

20
logger = logging.getLogger("ffcx")
1✔
21

22

23
def create_quadrature_index(quadrature_rule, quadrature_index_symbol):
1✔
24
    """Create a multi index for the quadrature loop."""
25
    ranges = [0]
1✔
26
    name = quadrature_index_symbol.name
1✔
27
    indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
28
    if quadrature_rule:
1✔
29
        ranges = [quadrature_rule.weights.size]
1✔
30
        if quadrature_rule.has_tensor_factors:
1✔
31
            dim = len(quadrature_rule.tensor_factors)
1✔
32
            ranges = [factor[1].size for factor in quadrature_rule.tensor_factors]
1✔
33
            indices = [L.Symbol(name + f"{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
34

35
    return L.MultiIndex(indices, ranges)
1✔
36

37

38
def create_dof_index(tabledata, dof_index_symbol):
1✔
39
    """Create a multi index for the coefficient dofs."""
40
    name = dof_index_symbol.name
1✔
41
    if tabledata.has_tensor_factorisation:
1✔
42
        dim = len(tabledata.tensor_factors)
1✔
43
        ranges = [factor.values.shape[-1] for factor in tabledata.tensor_factors]
1✔
44
        indices = [L.Symbol(f"{name}{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
45
    else:
46
        ranges = [tabledata.values.shape[-1]]
1✔
47
        indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
48

49
    return L.MultiIndex(indices, ranges)
1✔
50

51

52
class FFCXBackendDefinitions:
1✔
53
    """FFCx specific code definitions."""
54

55
    entity_type: entity_types
1✔
56
    handler_lookup: dict[ufl.core.ufl_type.UFLType, typing.Callable]
1✔
57

58
    def __init__(self, entity_type: entity_types, integral_type: str, access, options):
1✔
59
        """Initialise."""
60
        # Store ir and options
61
        self.integral_type = integral_type
1✔
62
        self.entity_type = entity_type
1✔
63
        self.access = access
1✔
64
        self.options = options
1✔
65

66
        # called, depending on the first argument type.
67
        self.handler_lookup = {
1✔
68
            ufl.coefficient.Coefficient: self.coefficient,
69
            ProxyCoefficient: self.proxy_coefficient,
70
            ufl.geometry.Jacobian: self._define_coordinate_dofs_lincomb,
71
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
72
            ufl.constant.Constant: self.pass_through,
73
            ufl.geometry.CellVertices: self.pass_through,
74
            ufl.geometry.FacetEdgeVectors: self.pass_through,
75
            ufl.geometry.CellEdgeVectors: self.pass_through,
76
            ufl.geometry.CellFacetJacobian: self.pass_through,
77
            ufl.geometry.CellRidgeJacobian: self.pass_through,
78
            ufl.geometry.ReferenceCellVolume: self.pass_through,
79
            ufl.geometry.ReferenceFacetVolume: self.pass_through,
80
            ufl.geometry.ReferenceCellEdgeVectors: self.pass_through,
81
            ufl.geometry.ReferenceFacetEdgeVectors: self.pass_through,
82
            ufl.geometry.ReferenceNormal: self.pass_through,
83
            ufl.geometry.CellOrientation: self.pass_through,
84
            ufl.geometry.FacetOrientation: self.pass_through,
85
        }
86

87
    @property
1✔
88
    def symbols(self):
1✔
89
        """Return formatter."""
90
        return self.access.symbols
1✔
91

92
    def get(
1✔
93
        self,
94
        mt: ModifiedTerminal,
95
        tabledata: UniqueTableReferenceT,
96
        quadrature_rule: QuadratureRule,
97
        access: L.Symbol,
98
    ) -> L.Section | list:
99
        """Return definition code for a terminal."""
100
        # Call appropriate handler, depending on the type of terminal
101
        terminal = mt.terminal
1✔
102
        ttype = type(terminal)
1✔
103

104
        # Look for parent class of ttype or direct handler
105
        while ttype not in self.handler_lookup and ttype.__bases__:
1✔
106
            ttype = ttype.__bases__[0]
×
107

108
        # Get the handler from the lookup, or None if not found
109
        handler = self.handler_lookup.get(ttype)  # type: ignore
1✔
110

111
        if handler is None:
1✔
112
            raise NotImplementedError(f"No handler for terminal type: {ttype}")
×
113

114
        # Call the handler
115
        return handler(mt, tabledata, quadrature_rule, access)  # type: ignore
1✔
116

117
    def coefficient(
1✔
118
        self,
119
        mt: ModifiedTerminal,
120
        tabledata: UniqueTableReferenceT,
121
        quadrature_rule: QuadratureRule,
122
        access: L.Symbol,
123
    ) -> L.Section | list:
124
        """Return definition code for coefficients."""
125
        # For applying tensor product to coefficients, we need to know
126
        # if the coefficient has a tensor factorisation and if the
127
        # quadrature rule has a tensor factorisation. If both are true,
128
        # we can apply the tensor product to the coefficient.
129

130
        iq_symbol = self.symbols.quadrature_loop_index
1✔
131
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
132

133
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
134
        ic = create_dof_index(tabledata, ic_symbol)
1✔
135

136
        # Get properties of tables
137
        ttype = tabledata.ttype
1✔
138
        num_dofs = tabledata.values.shape[3]
1✔
139
        bs = tabledata.block_size
1✔
140
        begin = tabledata.offset
1✔
141
        assert bs is not None
1✔
142
        assert begin is not None
1✔
143
        end = begin + bs * (num_dofs - 1) + 1
1✔
144

145
        if ttype == "zeros":
1✔
146
            logger.debug("Not expecting zero coefficients to get this far.")
×
147
            return []
×
148

149
        # For a constant coefficient we reference the dofs directly, so
150
        # no definition needed
151
        if ttype == "ones" and end - begin == 1:
1✔
152
            return []
1✔
153

154
        assert begin < end
1✔
155

156
        # Get access to element table
157
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
158
        dof_access: L.ArrayAccess = self.symbols.coefficient_dof_access(
1✔
159
            mt.terminal, (ic.global_index) * bs + begin
160
        )
161

162
        declaration: list[L.Declaration] = [L.VariableDecl(access, 0.0)]
1✔
163
        body = [L.AssignAdd(access, dof_access * FE)]
1✔
164
        code = [L.create_nested_for_loops([ic], body)]
1✔
165

166
        name = type(mt.terminal).__name__
1✔
167
        input = [dof_access.array, *tables]
1✔
168
        output = [access]
1✔
169
        annotations = [L.Annotation.fuse]
1✔
170

171
        # assert input and output are Symbol objects
172
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
173
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
174
        return L.Section(name, code, declaration, input, output, annotations)
1✔
175

176
    def proxy_coefficient(
1✔
177
        self,
178
        mt: ModifiedTerminal,
179
        tabledata: UniqueTableReferenceT,
180
        quadrature_rule: QuadratureRule,
181
        access: L.Symbol,
182
    ) -> L.Section | list:
183
        """Return definition code for coefficients."""
184
        # For applying tensor product to coefficients, we need to know
185
        # if the coefficient has a tensor factorisation and if the
186
        # quadrature rule has a tensor factorisation. If both are true,
187
        # we can apply the tensor product to the coefficient.
188

189
        iq_symbol = self.symbols.quadrature_loop_index
1✔
190
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
191

192
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
193
        ic = create_dof_index(tabledata, ic_symbol)
1✔
194

195
        # Get properties of tables
196
        ttype = tabledata.ttype
1✔
197
        num_dofs = tabledata.values.shape[3]
1✔
198
        bs = tabledata.block_size
1✔
199
        begin = tabledata.offset
1✔
200
        assert bs is not None
1✔
201
        assert begin is not None
1✔
202
        end = begin + bs * (num_dofs - 1) + 1
1✔
203

204
        if ttype == "zeros":
1✔
NEW
205
            logger.debug("Not expecting zero coefficients to get this far.")
×
NEW
206
            return []
×
207

208
        # For a constant coefficient we reference the dofs directly, so
209
        # no definition needed
210
        if ttype == "ones" and end - begin == 1:
1✔
NEW
211
            return []
×
212

213
        assert begin < end
1✔
214

215
        # Get access to element table
216
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
217
        dof_access: L.ArrayAccess = self.symbols.proxy_coefficient_dof_access(
1✔
218
            mt.terminal, (ic.global_index) * bs + begin
219
        )
220

221
        declaration: list[L.Declaration] = [L.VariableDecl(access, 0.0)]
1✔
222
        body = [L.AssignAdd(access, dof_access * FE)]
1✔
223
        code = [L.create_nested_for_loops([ic], body)]
1✔
224

225
        name = type(mt.terminal).__name__
1✔
226
        input = [dof_access.array, *tables]
1✔
227
        output = [access]
1✔
228
        annotations = [L.Annotation.fuse]
1✔
229

230
        # assert input and output are Symbol objects
231
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
232
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
233
        return L.Section(name, code, declaration, input, output, annotations)
1✔
234

235
    def _define_coordinate_dofs_lincomb(
1✔
236
        self,
237
        mt: ModifiedTerminal,
238
        tabledata: UniqueTableReferenceT,
239
        quadrature_rule: QuadratureRule,
240
        access: L.Symbol,
241
    ) -> L.Section | list:
242
        """Define x or J as a linear combination of coordinate dofs with given table data."""
243
        # Get properties of domain
244
        domain = ufl.domain.extract_unique_domain(mt.terminal)
1✔
245
        assert isinstance(domain, ufl.Mesh)
1✔
246
        coordinate_element = domain.ufl_coordinate_element()
1✔
247
        num_scalar_dofs = coordinate_element._sub_element.dim
1✔
248

249
        num_dofs = tabledata.values.shape[3]
1✔
250
        begin = tabledata.offset
1✔
251

252
        assert num_scalar_dofs == num_dofs
1✔
253

254
        # Find table name
255
        ttype = tabledata.ttype
1✔
256

257
        assert ttype != "zeros"
1✔
258
        assert ttype != "ones"
1✔
259

260
        # Get access to element table
261
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
262
        iq_symbol = self.symbols.quadrature_loop_index
1✔
263
        ic = create_dof_index(tabledata, ic_symbol)
1✔
264
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
265
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
266

267
        dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL)
1✔
268

269
        # coordinate dofs is always 3d
270
        dim = 3
1✔
271
        offset = 0
1✔
272
        if mt.restriction == "-":
1✔
273
            offset = num_scalar_dofs * dim
1✔
274

275
        code = []
1✔
276
        declaration = [L.VariableDecl(access, 0.0)]
1✔
277
        body = [L.AssignAdd(access, dof_access[ic.global_index * dim + begin + offset] * FE)]
1✔
278
        code = [L.create_nested_for_loops([ic], body)]
1✔
279

280
        name = type(mt.terminal).__name__
1✔
281
        output = [access]
1✔
282
        input = [dof_access, *tables]
1✔
283
        annotations = [L.Annotation.fuse]
1✔
284

285
        # assert input and output are Symbol objects
286
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
287
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
288

289
        return L.Section(name, code, declaration, input, output, annotations)
1✔
290

291
    def spatial_coordinate(
1✔
292
        self,
293
        mt: ModifiedTerminal,
294
        tabledata: UniqueTableReferenceT,
295
        quadrature_rule: QuadratureRule,
296
        access: L.Symbol,
297
    ) -> L.Section | list:
298
        """Return definition code for the physical spatial coordinates.
299

300
        If physical coordinates are given:
301
          No definition needed.
302

303
        If reference coordinates are given:
304
          x = sum_k xdof_k xphi_k(X)
305

306
        If reference facet coordinates are given:
307
          x = sum_k xdof_k xphi_k(Xf)
308
        """
309
        if self.integral_type in ufl.custom_integral_types:
1✔
310
            # FIXME: Jacobian may need adjustment for custom_integral_types
311
            if mt.local_derivatives:
×
312
                logger.exception("FIXME: Jacobian in custom integrals is not implemented.")
×
313
            return []
×
314
        else:
315
            return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
1✔
316

317
    def jacobian(
1✔
318
        self,
319
        mt: ModifiedTerminal,
320
        tabledata: UniqueTableReferenceT,
321
        quadrature_rule: QuadratureRule,
322
        access: L.Symbol,
323
    ) -> L.Section | list:
324
        """Return definition code for the Jacobian of x(X)."""
325
        return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
×
326

327
    def pass_through(
1✔
328
        self,
329
        mt: ModifiedTerminal,
330
        tabledata: UniqueTableReferenceT,
331
        quadrature_rule: QuadratureRule,
332
        access: L.Symbol,
333
    ) -> L.Section | list:
334
        """Return definition code for pass through terminals."""
335
        return []
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc