• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 11642291501

02 Nov 2024 11:16AM UTC coverage: 81.168% (+0.5%) from 80.657%
11642291501

push

github

web-flow
Upload to coveralls and docs from CI job running against python 3.12 (#726)

3474 of 4280 relevant lines covered (81.17%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.1
/ffcx/codegeneration/definitions.py
1
# Copyright (C) 2011-2023 Martin Sandve Alnæs, Igor A. Baratta
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""FFCx/UFC specific variable definitions."""
7

8
import logging
1✔
9
from typing import Union
1✔
10

11
import ufl
1✔
12

13
import ffcx.codegeneration.lnodes as L
1✔
14
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
15
from ffcx.ir.elementtables import UniqueTableReferenceT
1✔
16
from ffcx.ir.representationutils import QuadratureRule
1✔
17

18
logger = logging.getLogger("ffcx")
1✔
19

20

21
def create_quadrature_index(quadrature_rule, quadrature_index_symbol):
1✔
22
    """Create a multi index for the quadrature loop."""
23
    ranges = [0]
1✔
24
    name = quadrature_index_symbol.name
1✔
25
    indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
26
    if quadrature_rule:
1✔
27
        ranges = [quadrature_rule.weights.size]
1✔
28
        if quadrature_rule.has_tensor_factors:
1✔
29
            dim = len(quadrature_rule.tensor_factors)
1✔
30
            ranges = [factor[1].size for factor in quadrature_rule.tensor_factors]
1✔
31
            indices = [L.Symbol(name + f"{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
32

33
    return L.MultiIndex(indices, ranges)
1✔
34

35

36
def create_dof_index(tabledata, dof_index_symbol):
1✔
37
    """Create a multi index for the coefficient dofs."""
38
    name = dof_index_symbol.name
1✔
39
    if tabledata.has_tensor_factorisation:
1✔
40
        dim = len(tabledata.tensor_factors)
1✔
41
        ranges = [factor.values.shape[-1] for factor in tabledata.tensor_factors]
1✔
42
        indices = [L.Symbol(f"{name}{i}", dtype=L.DataType.INT) for i in range(dim)]
1✔
43
    else:
44
        ranges = [tabledata.values.shape[-1]]
1✔
45
        indices = [L.Symbol(name, dtype=L.DataType.INT)]
1✔
46

47
    return L.MultiIndex(indices, ranges)
1✔
48

49

50
class FFCXBackendDefinitions:
1✔
51
    """FFCx specific code definitions."""
52

53
    def __init__(self, entity_type: str, integral_type: str, access, options):
1✔
54
        """Initialise."""
55
        # Store ir and options
56
        self.integral_type = integral_type
1✔
57
        self.entity_type = entity_type
1✔
58
        self.access = access
1✔
59
        self.options = options
1✔
60

61
        # called, depending on the first argument type.
62
        self.handler_lookup = {
1✔
63
            ufl.coefficient.Coefficient: self.coefficient,
64
            ufl.geometry.Jacobian: self._define_coordinate_dofs_lincomb,
65
            ufl.geometry.SpatialCoordinate: self.spatial_coordinate,
66
            ufl.constant.Constant: self.pass_through,
67
            ufl.geometry.CellVertices: self.pass_through,
68
            ufl.geometry.FacetEdgeVectors: self.pass_through,
69
            ufl.geometry.CellEdgeVectors: self.pass_through,
70
            ufl.geometry.CellFacetJacobian: self.pass_through,
71
            ufl.geometry.ReferenceCellVolume: self.pass_through,
72
            ufl.geometry.ReferenceFacetVolume: self.pass_through,
73
            ufl.geometry.ReferenceCellEdgeVectors: self.pass_through,
74
            ufl.geometry.ReferenceFacetEdgeVectors: self.pass_through,
75
            ufl.geometry.ReferenceNormal: self.pass_through,
76
            ufl.geometry.CellOrientation: self.pass_through,
77
            ufl.geometry.FacetOrientation: self.pass_through,
78
        }
79

80
    @property
1✔
81
    def symbols(self):
1✔
82
        """Return formatter."""
83
        return self.access.symbols
1✔
84

85
    def get(
1✔
86
        self,
87
        mt: ModifiedTerminal,
88
        tabledata: UniqueTableReferenceT,
89
        quadrature_rule: QuadratureRule,
90
        access: L.Symbol,
91
    ) -> Union[L.Section, list]:
92
        """Return definition code for a terminal."""
93
        # Call appropriate handler, depending on the type of terminal
94
        terminal = mt.terminal
1✔
95
        ttype = type(terminal)
1✔
96

97
        # Look for parent class of ttype or direct handler
98
        while ttype not in self.handler_lookup and ttype.__bases__:
1✔
99
            ttype = ttype.__bases__[0]
×
100

101
        # Get the handler from the lookup, or None if not found
102
        handler = self.handler_lookup.get(ttype)
1✔
103

104
        if handler is None:
1✔
105
            raise NotImplementedError(f"No handler for terminal type: {ttype}")
×
106

107
        # Call the handler
108
        return handler(mt, tabledata, quadrature_rule, access)
1✔
109

110
    def coefficient(
1✔
111
        self,
112
        mt: ModifiedTerminal,
113
        tabledata: UniqueTableReferenceT,
114
        quadrature_rule: QuadratureRule,
115
        access: L.Symbol,
116
    ) -> Union[L.Section, list]:
117
        """Return definition code for coefficients."""
118
        # For applying tensor product to coefficients, we need to know if the coefficient
119
        # has a tensor factorisation and if the quadrature rule has a tensor factorisation.
120
        # If both are true, we can apply the tensor product to the coefficient.
121

122
        iq_symbol = self.symbols.quadrature_loop_index
1✔
123
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
124

125
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
126
        ic = create_dof_index(tabledata, ic_symbol)
1✔
127

128
        # Get properties of tables
129
        ttype = tabledata.ttype
1✔
130
        num_dofs = tabledata.values.shape[3]
1✔
131
        bs = tabledata.block_size
1✔
132
        begin = tabledata.offset
1✔
133
        end = begin + bs * (num_dofs - 1) + 1
1✔
134

135
        if ttype == "zeros":
1✔
136
            logging.debug("Not expecting zero coefficients to get this far.")
×
137
            return []
×
138

139
        # For a constant coefficient we reference the dofs directly, so no definition needed
140
        if ttype == "ones" and end - begin == 1:
1✔
141
            return []
1✔
142

143
        assert begin < end
1✔
144

145
        # Get access to element table
146
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
147
        dof_access: L.ArrayAccess = self.symbols.coefficient_dof_access(
1✔
148
            mt.terminal, (ic.global_index) * bs + begin
149
        )
150

151
        declaration: list[L.Declaration] = [L.VariableDecl(access, 0.0)]
1✔
152
        body = [L.AssignAdd(access, dof_access * FE)]
1✔
153
        code = [L.create_nested_for_loops([ic], body)]
1✔
154

155
        name = type(mt.terminal).__name__
1✔
156
        input = [dof_access.array, *tables]
1✔
157
        output = [access]
1✔
158
        annotations = [L.Annotation.fuse]
1✔
159

160
        # assert input and output are Symbol objects
161
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
162
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
163

164
        return L.Section(name, code, declaration, input, output, annotations)
1✔
165

166
    def _define_coordinate_dofs_lincomb(
1✔
167
        self,
168
        mt: ModifiedTerminal,
169
        tabledata: UniqueTableReferenceT,
170
        quadrature_rule: QuadratureRule,
171
        access: L.Symbol,
172
    ) -> Union[L.Section, list]:
173
        """Define x or J as a linear combination of coordinate dofs with given table data."""
174
        # Get properties of domain
175
        domain = ufl.domain.extract_unique_domain(mt.terminal)
1✔
176
        coordinate_element = domain.ufl_coordinate_element()
1✔
177
        num_scalar_dofs = coordinate_element._sub_element.dim
1✔
178

179
        num_dofs = tabledata.values.shape[3]
1✔
180
        begin = tabledata.offset
1✔
181

182
        assert num_scalar_dofs == num_dofs
1✔
183

184
        # Find table name
185
        ttype = tabledata.ttype
1✔
186

187
        assert ttype != "zeros"
1✔
188
        assert ttype != "ones"
1✔
189

190
        # Get access to element table
191
        ic_symbol = self.symbols.coefficient_dof_sum_index
1✔
192
        iq_symbol = self.symbols.quadrature_loop_index
1✔
193
        ic = create_dof_index(tabledata, ic_symbol)
1✔
194
        iq = create_quadrature_index(quadrature_rule, iq_symbol)
1✔
195
        FE, tables = self.access.table_access(tabledata, self.entity_type, mt.restriction, iq, ic)
1✔
196

197
        dof_access = L.Symbol("coordinate_dofs", dtype=L.DataType.REAL)
1✔
198

199
        # coordinate dofs is always 3d
200
        dim = 3
1✔
201
        offset = 0
1✔
202
        if mt.restriction == "-":
1✔
203
            offset = num_scalar_dofs * dim
1✔
204

205
        code = []
1✔
206
        declaration = [L.VariableDecl(access, 0.0)]
1✔
207
        body = [L.AssignAdd(access, dof_access[ic.global_index * dim + begin + offset] * FE)]
1✔
208
        code = [L.create_nested_for_loops([ic], body)]
1✔
209

210
        name = type(mt.terminal).__name__
1✔
211
        output = [access]
1✔
212
        input = [dof_access, *tables]
1✔
213
        annotations = [L.Annotation.fuse]
1✔
214

215
        # assert input and output are Symbol objects
216
        assert all(isinstance(i, L.Symbol) for i in input)
1✔
217
        assert all(isinstance(o, L.Symbol) for o in output)
1✔
218

219
        return L.Section(name, code, declaration, input, output, annotations)
1✔
220

221
    def spatial_coordinate(
1✔
222
        self,
223
        mt: ModifiedTerminal,
224
        tabledata: UniqueTableReferenceT,
225
        quadrature_rule: QuadratureRule,
226
        access: L.Symbol,
227
    ) -> Union[L.Section, list]:
228
        """Return definition code for the physical spatial coordinates.
229

230
        If physical coordinates are given:
231
          No definition needed.
232

233
        If reference coordinates are given:
234
          x = sum_k xdof_k xphi_k(X)
235

236
        If reference facet coordinates are given:
237
          x = sum_k xdof_k xphi_k(Xf)
238
        """
239
        if self.integral_type in ufl.custom_integral_types:
1✔
240
            # FIXME: Jacobian may need adjustment for custom_integral_types
241
            if mt.local_derivatives:
×
242
                logging.exception("FIXME: Jacobian in custom integrals is not implemented.")
×
243
            return []
×
244
        else:
245
            return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
1✔
246

247
    def jacobian(
1✔
248
        self,
249
        mt: ModifiedTerminal,
250
        tabledata: UniqueTableReferenceT,
251
        quadrature_rule: QuadratureRule,
252
        access: L.Symbol,
253
    ) -> Union[L.Section, list]:
254
        """Return definition code for the Jacobian of x(X)."""
255
        return self._define_coordinate_dofs_lincomb(mt, tabledata, quadrature_rule, access)
×
256

257
    def pass_through(
1✔
258
        self,
259
        mt: ModifiedTerminal,
260
        tabledata: UniqueTableReferenceT,
261
        quadrature_rule: QuadratureRule,
262
        access: L.Symbol,
263
    ) -> Union[L.Section, list]:
264
        """Return definition code for pass through terminals."""
265
        return []
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc