• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 20112426700

10 Dec 2025 08:31PM UTC coverage: 84.612% (+0.2%) from 84.44%
20112426700

Pull #806

github

schnellerhase
Missed one
Pull Request #806: Backend formatter interface

140 of 148 new or added lines in 6 files covered. (94.59%)

33 existing lines in 2 files now uncovered.

4157 of 4913 relevant lines covered (84.61%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.63
/ffcx/codegeneration/numba/implementation.py
1
# Copyright (C) 2025 Chris Richardson and Paul T. Kühner
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Numba implementation for output."""
7

8
from functools import singledispatchmethod
1✔
9

10
import numpy as np
1✔
11
from numpy import typing as npt
1✔
12

13
import ffcx.codegeneration.lnodes as L
1✔
14
from ffcx.codegeneration.utils import dtype_to_scalar_dtype
1✔
15

16

17
def build_initializer_lists(values: npt.NDArray) -> str:
1✔
18
    """Build list of values."""
19
    arr = "["
1✔
20
    if len(values.shape) == 1:
1✔
21
        return "[" + ", ".join(str(v) for v in values) + "]"
1✔
22
    elif len(values.shape) > 1:
1✔
23
        arr += ",\n".join(build_initializer_lists(v) for v in values)
1✔
24
    arr += "]"
1✔
25
    return arr
1✔
26

27

28
class Formatter:
1✔
29
    """Implementation for numba output backend."""
30

31
    scalar_type: np.dtype
1✔
32
    real_type: np.dtype
1✔
33

34
    def __init__(self, dtype: npt.DTypeLike) -> None:
1✔
35
        """Initialise."""
36
        self.scalar_type = np.dtype(dtype)
1✔
37
        self.real_type = dtype_to_scalar_dtype(dtype)
1✔
38

39
    def _dtype_to_name(self, dtype: L.DataType) -> str:
1✔
40
        """Convert dtype to Python name."""
41
        if dtype == L.DataType.SCALAR:
1✔
42
            return f"np.{self.scalar_type}"
1✔
43
        if dtype == L.DataType.REAL:
1✔
44
            return f"np.{self.real_type}"
1✔
45
        if dtype == L.DataType.INT:
×
46
            return f"np.{np.int32}"
×
47
        if dtype == L.DataType.BOOL:
×
UNCOV
48
            return f"np.{np.bool}"
×
UNCOV
49
        raise ValueError(f"Invalid dtype: {dtype}")
×
50

51
    @singledispatchmethod
1✔
52
    def __call__(self, obj: L.LNode) -> str:
1✔
53
        """Format an L Node."""
NEW
UNCOV
54
        raise NotImplementedError(f"Can not format object to type {type(obj)}")
×
55

56
    @__call__.register
1✔
57
    def _(self, section: L.Section) -> str:
1✔
58
        """Format a section."""
59
        # add new line before section
60
        comments = self._format_comment_str("------------------------")
1✔
61
        comments += self._format_comment_str(f"Section: {section.name}")
1✔
62
        comments += self._format_comment_str(f"Inputs: {', '.join(w.name for w in section.input)}")
1✔
63
        comments += self._format_comment_str(
1✔
64
            f"Outputs: {', '.join(w.name for w in section.output)}"
65
        )
66
        declarations = "".join(self(s) for s in section.declarations)
1✔
67

68
        body = ""
1✔
69
        if len(section.statements) > 0:
1✔
70
            body = "".join(self(s) for s in section.statements)
1✔
71

72
        body += self._format_comment_str("------------------------")
1✔
73
        return comments + declarations + body
1✔
74

75
    @__call__.register
1✔
76
    def _(self, slist: L.StatementList) -> str:
1✔
77
        """Format a list of statements."""
78
        output = ""
1✔
79
        for s in slist.statements:
1✔
80
            output += self(s)
1✔
81
        return output
1✔
82

83
    def _format_comment_str(self, comment: str) -> str:
1✔
84
        """Format str to comment string."""
85
        return f"# {comment} \n"
1✔
86

87
    @__call__.register
1✔
88
    def _(self, c: L.Comment) -> str:
1✔
89
        """Format a comment."""
90
        return self._format_comment_str(c.comment)
1✔
91

92
    @__call__.register
1✔
93
    def _(self, arr: L.ArrayDecl) -> str:
1✔
94
        """Format an array declaration."""
95
        dtype = arr.symbol.dtype
1✔
96
        typename = self._dtype_to_name(dtype)
1✔
97

98
        symbol = self(arr.symbol)
1✔
99
        if arr.values is None:
1✔
UNCOV
100
            return f"{symbol} = np.empty({arr.sizes}, dtype={typename})\n"
×
101
        elif arr.values.size == 1:
1✔
102
            return f"{symbol} = np.full({arr.sizes}, {arr.values[0]}, dtype={typename})\n"
1✔
103
        av = build_initializer_lists(arr.values)
1✔
104
        av = f"np.array({av}, dtype={typename})"
1✔
105
        return f"{symbol} = {av}\n"
1✔
106

107
    @__call__.register
1✔
108
    def _(self, arr: L.ArrayAccess) -> str:
1✔
109
        """Format array access."""
110
        array = self(arr.array)
1✔
111
        idx = ", ".join(self(ix) for ix in arr.indices)
1✔
112
        return f"{array}[{idx}]"
1✔
113

114
    @__call__.register
1✔
115
    def _(self, index: L.MultiIndex) -> str:
1✔
116
        """Format a multi-index."""
117
        return self(index.global_index)
1✔
118

119
    @__call__.register
1✔
120
    def _(self, v: L.VariableDecl) -> str:
1✔
121
        """Format a variable declaration."""
122
        sym = self(v.symbol)
1✔
123
        val = self(v.value)
1✔
124
        return f"{sym} = {val}\n"
1✔
125

126
    @__call__.register
1✔
127
    def _(self, oper: L.NaryOp) -> str:
1✔
128
        """Format a n argument operation."""
129
        # Format children
130
        args = [self(arg) for arg in oper.args]
1✔
131

132
        # Apply parentheses
133
        for i in range(len(args)):
1✔
134
            if oper.args[i].precedence >= oper.precedence:
1✔
135
                args[i] = f"({args[i]})"
1✔
136

137
        # Return combined string
138
        return f" {oper.op} ".join(args)
1✔
139

140
    @__call__.register
1✔
141
    def _(self, oper: L.BinOp) -> str:
1✔
142
        """Format a binary operation."""
143
        # Format children
144
        lhs = self(oper.lhs)
1✔
145
        rhs = self(oper.rhs)
1✔
146

147
        # Apply parentheses
148
        if oper.lhs.precedence >= oper.precedence:
1✔
149
            lhs = f"({lhs})"
1✔
150
        if oper.rhs.precedence >= oper.precedence:
1✔
151
            rhs = f"({rhs})"
1✔
152

153
        # Return combined string
154
        return f"{lhs} {oper.op} {rhs}"
1✔
155

156
    @__call__.register(L.Neg)
1✔
157
    @__call__.register(L.Not)
1✔
158
    def _(self, oper: L.Not | L.Neg) -> str:
1✔
159
        """Format a unary operation."""
160
        arg = self(oper.arg)
1✔
161
        if oper.arg.precedence >= oper.precedence:
1✔
NEW
162
            return f"{oper.op}({arg})"
×
163
        return f"{oper.op}{arg}"
1✔
164

165
    @__call__.register(L.And)
1✔
166
    @__call__.register(L.Or)
1✔
167
    def _(self, oper: L.And | L.Or) -> str:
1✔
168
        """Format and or or operation."""
169
        # Format children
NEW
UNCOV
170
        lhs = self(oper.lhs)
×
NEW
UNCOV
171
        rhs = self(oper.rhs)
×
172

173
        # Apply parentheses
UNCOV
174
        if oper.lhs.precedence >= oper.precedence:
×
UNCOV
175
            lhs = f"({lhs})"
×
UNCOV
176
        if oper.rhs.precedence >= oper.precedence:
×
UNCOV
177
            rhs = f"({rhs})"
×
178

UNCOV
179
        opstr = {"||": "or", "&&": "and"}[oper.op]
×
180

181
        # Return combined string
UNCOV
182
        return f"{lhs} {opstr} {rhs}"
×
183

184
    @__call__.register
1✔
185
    def _(self, val: L.LiteralFloat) -> str:
1✔
186
        """Format a literal float."""
187
        return f"{val.value}"
1✔
188

189
    @__call__.register
1✔
190
    def _(self, val: L.LiteralInt) -> str:
1✔
191
        """Format a literal int."""
192
        return f"{val.value}"
1✔
193

194
    @__call__.register
1✔
195
    def _(self, r: L.ForRange) -> str:
1✔
196
        """Format a loop over a range."""
197
        begin = self(r.begin)
1✔
198
        end = self(r.end)
1✔
199
        index = self(r.index)
1✔
200
        output = f"for {index} in range({begin}, {end}):\n"
1✔
201
        b = self(r.body).split("\n")
1✔
202
        for line in b:
1✔
203
            output += f"    {line}\n"
1✔
204
        return output
1✔
205

206
    @__call__.register
1✔
207
    def _(self, s: L.Statement) -> str:
1✔
208
        """Format a statement."""
209
        return self(s.expr)
1✔
210

211
    @__call__.register(L.Assign)
1✔
212
    @__call__.register(L.AssignAdd)
1✔
213
    def _(self, expr: L.Assign | L.AssignAdd) -> str:
1✔
214
        """Format an assignment."""
215
        rhs = self(expr.rhs)
1✔
216
        lhs = self(expr.lhs)
1✔
217
        return f"{lhs} {expr.op} {rhs}\n"
1✔
218

219
    @__call__.register
1✔
220
    def _(self, s: L.Conditional) -> str:
1✔
221
        """Format a conditional."""
222
        # Format children
NEW
UNCOV
223
        c = self(s.condition)
×
NEW
UNCOV
224
        t = self(s.true)
×
NEW
UNCOV
225
        f = self(s.false)
×
226

227
        # Apply parentheses
UNCOV
228
        if s.condition.precedence >= s.precedence:
×
UNCOV
229
            c = f"({c})"
×
UNCOV
230
        if s.true.precedence >= s.precedence:
×
UNCOV
231
            t = f"({t})"
×
UNCOV
232
        if s.false.precedence >= s.precedence:
×
233
            f = f"({f})"
×
234

235
        # Return combined string
UNCOV
236
        return f"({t} if {c} else {f})"
×
237

238
    @__call__.register
1✔
239
    def _(self, s: L.Symbol) -> str:
1✔
240
        """Format a symbol."""
241
        return f"{s.name}"
1✔
242

243
    @__call__.register
1✔
244
    def _(self, f: L.MathFunction) -> str:
1✔
245
        """Format a math function."""
246
        function_map = {
1✔
247
            "ln": "log",
248
            "acos": "arccos",
249
            "asin": "arcsin",
250
            "atan": "arctan",
251
            "atan2": "arctan2",
252
            "acosh": "arccosh",
253
            "asinh": "arcsinh",
254
            "atanh": "arctanh",
255
        }
256
        function = function_map.get(f.function, f.function)
1✔
257
        args = [self(arg) for arg in f.args]
1✔
258
        if "bessel_y" in function:
1✔
UNCOV
259
            return "scipy.special.yn"
×
260
        if "bessel_j" in function:
1✔
UNCOV
261
            return "scipy.special.jn"
×
262
        if function == "erf":
1✔
UNCOV
263
            return f"math.erf({args[0]})"
×
264
        argstr = ", ".join(args)
1✔
265
        return f"np.{function}({argstr})"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc