• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 20112112948

10 Dec 2025 08:19PM UTC coverage: 84.573% (+0.1%) from 84.44%
20112112948

Pull #806

github

schnellerhase
Dispatch on __call__
Pull Request #806: Backend formatter interface

155 of 167 new or added lines in 6 files covered. (92.81%)

40 existing lines in 2 files now uncovered.

4172 of 4933 relevant lines covered (84.57%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.82
/ffcx/codegeneration/numba/implementation.py
1
# Copyright (C) 2025 Chris Richardson and Paul T. Kühner
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Numba implementation for output."""
7

8
from functools import singledispatchmethod
1✔
9

10
import numpy as np
1✔
11
from numpy import typing as npt
1✔
12

13
import ffcx.codegeneration.lnodes as L
1✔
14
from ffcx.codegeneration.utils import dtype_to_scalar_dtype
1✔
15

16

17
def build_initializer_lists(values: npt.NDArray) -> str:
1✔
18
    """Build list of values."""
19
    arr = "["
1✔
20
    if len(values.shape) == 1:
1✔
21
        return "[" + ", ".join(str(v) for v in values) + "]"
1✔
22
    elif len(values.shape) > 1:
1✔
23
        arr += ",\n".join(build_initializer_lists(v) for v in values)
1✔
24
    arr += "]"
1✔
25
    return arr
1✔
26

27

28
class Formatter:
1✔
29
    """Implementation for numba output backend."""
30

31
    scalar_type: np.dtype
1✔
32
    real_type: np.dtype
1✔
33

34
    def __init__(self, dtype: npt.DTypeLike) -> None:
1✔
35
        """Initialise."""
36
        self.scalar_type = np.dtype(dtype)
1✔
37
        self.real_type = dtype_to_scalar_dtype(dtype)
1✔
38

39
    def _dtype_to_name(self, dtype: L.DataType) -> str:
1✔
40
        """Convert dtype to Python name."""
41
        if dtype == L.DataType.SCALAR:
1✔
42
            return f"np.{self.scalar_type}"
1✔
43
        if dtype == L.DataType.REAL:
1✔
44
            return f"np.{self.real_type}"
1✔
45
        if dtype == L.DataType.INT:
×
46
            return f"np.{np.int32}"
×
47
        if dtype == L.DataType.BOOL:
×
UNCOV
48
            return f"np.{np.bool}"
×
UNCOV
49
        raise ValueError(f"Invalid dtype: {dtype}")
×
50

51
    @singledispatchmethod
1✔
52
    def __call__(self, obj) -> str:
1✔
53
        """Format an L Node."""
NEW
UNCOV
54
        raise NotImplementedError(f"Can not format object to type {type(obj)}")
×
55

56
    @__call__.register
1✔
57
    def _(self, section: L.Section) -> str:
1✔
58
        """Format a section."""
59
        # add new line before section
60
        comments = self._format_comment_str("------------------------")
1✔
61
        comments += self._format_comment_str(f"Section: {section.name}")
1✔
62
        comments += self._format_comment_str(f"Inputs: {', '.join(w.name for w in section.input)}")
1✔
63
        comments += self._format_comment_str(
1✔
64
            f"Outputs: {', '.join(w.name for w in section.output)}"
65
        )
66
        declarations = "".join(self(s) for s in section.declarations)
1✔
67

68
        body = ""
1✔
69
        if len(section.statements) > 0:
1✔
70
            body = "".join(self(s) for s in section.statements)
1✔
71

72
        body += self._format_comment_str("------------------------")
1✔
73
        return comments + declarations + body
1✔
74

75
    @__call__.register
1✔
76
    def _(self, slist: L.StatementList) -> str:
1✔
77
        """Format a list of statements."""
78
        output = ""
1✔
79
        for s in slist.statements:
1✔
80
            output += self(s)
1✔
81
        return output
1✔
82

83
    def _format_comment_str(self, comment: str) -> str:
1✔
84
        """Format str to comment string."""
85
        return f"# {comment} \n"
1✔
86

87
    @__call__.register
1✔
88
    def _(self, c: L.Comment) -> str:
1✔
89
        """Format a comment."""
90
        return self._format_comment_str(c.comment)
1✔
91

92
    @__call__.register
1✔
93
    def _(self, arr: L.ArrayDecl) -> str:
1✔
94
        """Format an array declaration."""
95
        dtype = arr.symbol.dtype
1✔
96
        typename = self._dtype_to_name(dtype)
1✔
97

98
        symbol = self(arr.symbol)
1✔
99
        if arr.values is None:
1✔
UNCOV
100
            return f"{symbol} = np.empty({arr.sizes}, dtype={typename})\n"
×
101
        elif arr.values.size == 1:
1✔
102
            return f"{symbol} = np.full({arr.sizes}, {arr.values[0]}, dtype={typename})\n"
1✔
103
        av = build_initializer_lists(arr.values)
1✔
104
        av = f"np.array({av}, dtype={typename})"
1✔
105
        return f"{symbol} = {av}\n"
1✔
106

107
    @__call__.register
1✔
108
    def _(self, arr: L.ArrayAccess) -> str:
1✔
109
        """Format array access."""
110
        array = self(arr.array)
1✔
111
        idx = ", ".join(self(ix) for ix in arr.indices)
1✔
112
        return f"{array}[{idx}]"
1✔
113

114
    @__call__.register
1✔
115
    def _(self, index: L.MultiIndex) -> str:
1✔
116
        """Format a multi-index."""
117
        return self(index.global_index)
1✔
118

119
    @__call__.register
1✔
120
    def _(self, v: L.VariableDecl) -> str:
1✔
121
        """Format a variable declaration."""
122
        sym = self(v.symbol)
1✔
123
        val = self(v.value)
1✔
124
        return f"{sym} = {val}\n"
1✔
125

126
    @__call__.register
1✔
127
    def _(self, oper: L.NaryOp) -> str:
1✔
128
        """Format a n argument operation."""
129
        # Format children
130
        args = [self(arg) for arg in oper.args]
1✔
131

132
        # Apply parentheses
133
        for i in range(len(args)):
1✔
134
            if oper.args[i].precedence >= oper.precedence:
1✔
135
                args[i] = f"({args[i]})"
1✔
136

137
        # Return combined string
138
        return f" {oper.op} ".join(args)
1✔
139

140
    @__call__.register
1✔
141
    def _(self, oper: L.BinOp) -> str:
1✔
142
        """Format a binary operation."""
143
        # Format children
144
        lhs = self(oper.lhs)
1✔
145
        rhs = self(oper.rhs)
1✔
146

147
        # Apply parentheses
148
        if oper.lhs.precedence >= oper.precedence:
1✔
149
            lhs = f"({lhs})"
1✔
150
        if oper.rhs.precedence >= oper.precedence:
1✔
151
            rhs = f"({rhs})"
1✔
152

153
        # Return combined string
154
        return f"{lhs} {oper.op} {rhs}"
1✔
155

156
    @__call__.register
1✔
157
    def _(self, val: L.Neg) -> str:
1✔
158
        """Format unary negation."""
159
        arg = self(val.arg)
1✔
160
        return f"-{arg}"
1✔
161

162
    @__call__.register
1✔
163
    def _(self, val: L.Not) -> str:
1✔
164
        """Format not operation."""
NEW
165
        arg = self(val.arg)
×
UNCOV
166
        return f"not({arg})"
×
167

168
    def _format_and_or(self, oper: L.And | L.Or) -> str:
1✔
169
        """Format and or or operation."""
170
        # Format children
NEW
UNCOV
171
        lhs = self(oper.lhs)
×
NEW
UNCOV
172
        rhs = self(oper.rhs)
×
173

174
        # Apply parentheses
UNCOV
175
        if oper.lhs.precedence >= oper.precedence:
×
UNCOV
176
            lhs = f"({lhs})"
×
UNCOV
177
        if oper.rhs.precedence >= oper.precedence:
×
UNCOV
178
            rhs = f"({rhs})"
×
179

UNCOV
180
        opstr = {"||": "or", "&&": "and"}[oper.op]
×
181

182
        # Return combined string
UNCOV
183
        return f"{lhs} {opstr} {rhs}"
×
184

185
    @__call__.register
1✔
186
    def _(self, oper: L.And) -> str:
1✔
NEW
UNCOV
187
        return self._format_and_or(oper)
×
188

189
    @__call__.register
1✔
190
    def _(self, oper: L.Or) -> str:
1✔
NEW
UNCOV
191
        return self._format_and_or(oper)
×
192

193
    @__call__.register
1✔
194
    def _(self, val: L.LiteralFloat) -> str:
1✔
195
        """Format a literal float."""
196
        return f"{val.value}"
1✔
197

198
    @__call__.register
1✔
199
    def _(self, val: L.LiteralInt) -> str:
1✔
200
        """Format a literal int."""
201
        return f"{val.value}"
1✔
202

203
    @__call__.register
1✔
204
    def _(self, r: L.ForRange) -> str:
1✔
205
        """Format a loop over a range."""
206
        begin = self(r.begin)
1✔
207
        end = self(r.end)
1✔
208
        index = self(r.index)
1✔
209
        output = f"for {index} in range({begin}, {end}):\n"
1✔
210
        b = self(r.body).split("\n")
1✔
211
        for line in b:
1✔
212
            output += f"    {line}\n"
1✔
213
        return output
1✔
214

215
    @__call__.register
1✔
216
    def _(self, s: L.Statement) -> str:
1✔
217
        """Format a statement."""
218
        return self(s.expr)
1✔
219

220
    def _format_assign(self, expr) -> str:
1✔
221
        """Format an assignment."""
222
        rhs = self(expr.rhs)
1✔
223
        lhs = self(expr.lhs)
1✔
224
        return f"{lhs} {expr.op} {rhs};\n"
1✔
225

226
    @__call__.register
1✔
227
    def _(self, expr: L.Assign) -> str:
1✔
228
        """Format assignment."""
229
        return self._format_assign(expr)
1✔
230

231
    @__call__.register
1✔
232
    def _(self, expr: L.AssignAdd) -> str:
1✔
233
        """Format assignment add."""
234
        return self._format_assign(expr)
1✔
235

236
    @__call__.register
1✔
237
    def _(self, s: L.Conditional) -> str:
1✔
238
        """Format a conditional."""
239
        # Format children
NEW
UNCOV
240
        c = self(s.condition)
×
NEW
UNCOV
241
        t = self(s.true)
×
NEW
UNCOV
242
        f = self(s.false)
×
243

244
        # Apply parentheses
UNCOV
245
        if s.condition.precedence >= s.precedence:
×
UNCOV
246
            c = f"({c})"
×
UNCOV
247
        if s.true.precedence >= s.precedence:
×
UNCOV
248
            t = f"({t})"
×
UNCOV
249
        if s.false.precedence >= s.precedence:
×
UNCOV
250
            f = f"({f})"
×
251

252
        # Return combined string
UNCOV
253
        return f"({t} if {c} else {f})"
×
254

255
    @__call__.register
1✔
256
    def _(self, s: L.Symbol) -> str:
1✔
257
        """Format a symbol."""
258
        return f"{s.name}"
1✔
259

260
    @__call__.register
1✔
261
    def _(self, f: L.MathFunction) -> str:
1✔
262
        """Format a math function."""
263
        function_map = {
1✔
264
            "ln": "log",
265
            "acos": "arccos",
266
            "asin": "arcsin",
267
            "atan": "arctan",
268
            "atan2": "arctan2",
269
            "acosh": "arccosh",
270
            "asinh": "arcsinh",
271
            "atanh": "arctanh",
272
        }
273
        function = function_map.get(f.function, f.function)
1✔
274
        args = [self(arg) for arg in f.args]
1✔
275
        if "bessel_y" in function:
1✔
UNCOV
276
            return "scipy.special.yn"
×
277
        if "bessel_j" in function:
1✔
UNCOV
278
            return "scipy.special.jn"
×
279
        if function == "erf":
1✔
UNCOV
280
            return f"math.erf({args[0]})"
×
281
        argstr = ", ".join(args)
1✔
282
        return f"np.{function}({argstr})"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc