• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 20023188942

08 Dec 2025 09:28AM UTC coverage: 84.443% (+1.4%) from 83.044%
20023188942

Pull #801

github

schnellerhase
format
Pull Request #801: Add `numba` backend

352 of 383 new or added lines in 22 files covered. (91.91%)

2 existing lines in 1 file now uncovered.

4071 of 4821 relevant lines covered (84.44%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.31
/ffcx/codegeneration/numba/implementation.py
1
# Copyright (C) 2025 Chris Richardson and Paul T. Kühner
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Numba implementation for output."""
7

8
import numpy as np
1✔
9
from numpy import typing as npt
1✔
10

11
import ffcx.codegeneration.lnodes as L
1✔
12
from ffcx.codegeneration.utils import dtype_to_scalar_dtype
1✔
13

14

15
def build_initializer_lists(values):
1✔
16
    """Build list of values."""
17
    arr = "["
1✔
18
    if len(values.shape) == 1:
1✔
19
        return "[" + ", ".join(str(v) for v in values) + "]"
1✔
20
    elif len(values.shape) > 1:
1✔
21
        arr += ",\n".join(build_initializer_lists(v) for v in values)
1✔
22
    arr += "]"
1✔
23
    return arr
1✔
24

25

26
class Formatter:
1✔
27
    """Implementation for numba output backend."""
28

29
    scalar_type: np.dtype
1✔
30
    real_type: np.dtype
1✔
31

32
    def __init__(self, dtype: npt.DTypeLike) -> None:
1✔
33
        """Initialise."""
34
        self.scalar_type = np.dtype(dtype)
1✔
35
        self.real_type = dtype_to_scalar_dtype(dtype)
1✔
36

37
    def _dtype_to_name(self, dtype) -> str:
1✔
38
        """Convert dtype to Python name."""
39
        if dtype == L.DataType.SCALAR:
1✔
40
            return f"np.{self.scalar_type}"
1✔
41
        if dtype == L.DataType.REAL:
1✔
42
            return f"np.{self.real_type}"
1✔
NEW
43
        if dtype == L.DataType.INT:
×
NEW
44
            return f"np.{np.int32}"
×
NEW
45
        if dtype == L.DataType.BOOL:
×
NEW
46
            return f"np.{np.bool}"
×
NEW
47
        raise ValueError(f"Invalid dtype: {dtype}")
×
48

49
    def format_section(self, section: L.Section) -> str:
1✔
50
        """Format a section."""
51
        # add new line before section
52
        comments = "# ------------------------ \n"
1✔
53
        comments += "# Section: " + section.name + "\n"
1✔
54
        comments += "# Inputs: " + ", ".join(w.name for w in section.input) + "\n"
1✔
55
        comments += "# Outputs: " + ", ".join(w.name for w in section.output) + "\n"
1✔
56
        declarations = "".join(self.format(s) for s in section.declarations)
1✔
57

58
        body = ""
1✔
59
        if len(section.statements) > 0:
1✔
60
            body = "".join(self.format(s) for s in section.statements)
1✔
61

62
        body += "# ------------------------ \n"
1✔
63
        return comments + declarations + body
1✔
64

65
    def format_statement_list(self, slist: L.StatementList) -> str:
1✔
66
        """Format a list of statements."""
67
        output = ""
1✔
68
        for s in slist.statements:
1✔
69
            output += self.format(s)
1✔
70
        return output
1✔
71

72
    def format_comment(self, c: L.Comment) -> str:
1✔
73
        """Format a comment."""
74
        return "# " + c.comment + "\n"
1✔
75

76
    def format_array_decl(self, arr: L.ArrayDecl) -> str:
1✔
77
        """Format an array declaration."""
78
        dtype = arr.symbol.dtype
1✔
79
        typename = self._dtype_to_name(dtype)
1✔
80

81
        symbol = self.format(arr.symbol)
1✔
82
        if arr.values is None:
1✔
NEW
83
            return f"{symbol} = np.empty({arr.sizes}, dtype={typename})\n"
×
84
        elif arr.values.size == 1:
1✔
85
            return f"{symbol} = np.full({arr.sizes}, {arr.values[0]}, dtype={typename})\n"
1✔
86
        av = build_initializer_lists(arr.values)
1✔
87
        av = "np.array(" + av + f", dtype={typename})"
1✔
88
        return f"{symbol} = {av}\n"
1✔
89

90
    def format_array_access(self, arr: L.ArrayAccess) -> str:
1✔
91
        """Format array access."""
92
        array = self.format(arr.array)
1✔
93
        idx = ", ".join(self.format(ix) for ix in arr.indices)
1✔
94
        return f"{array}[{idx}]"
1✔
95

96
    def format_multi_index(self, index: L.MultiIndex) -> str:
1✔
97
        """Format a multi-index."""
98
        return self.format(index.global_index)
1✔
99

100
    def format_variable_decl(self, v: L.VariableDecl) -> str:
1✔
101
        """Format a variable declaration."""
102
        sym = self.format(v.symbol)
1✔
103
        val = self.format(v.value)
1✔
104
        return f"{sym} = {val}\n"
1✔
105

106
    def format_nary_op(self, oper: L.NaryOp) -> str:
1✔
107
        """Format a n argument operation."""
108
        # Format children
109
        args = [self.format(arg) for arg in oper.args]
1✔
110

111
        # Apply parentheses
112
        for i in range(len(args)):
1✔
113
            if oper.args[i].precedence >= oper.precedence:
1✔
114
                args[i] = "(" + args[i] + ")"
1✔
115

116
        # Return combined string
117
        return f" {oper.op} ".join(args)
1✔
118

119
    def format_binary_op(self, oper: L.BinOp) -> str:
1✔
120
        """Format a binary operation."""
121
        # Format children
122
        lhs = self.format(oper.lhs)
1✔
123
        rhs = self.format(oper.rhs)
1✔
124

125
        # Apply parentheses
126
        if oper.lhs.precedence >= oper.precedence:
1✔
127
            lhs = f"({lhs})"
1✔
128
        if oper.rhs.precedence >= oper.precedence:
1✔
129
            rhs = f"({rhs})"
1✔
130

131
        # Return combined string
132
        return f"{lhs} {oper.op} {rhs}"
1✔
133

134
    def format_neg(self, val: L.Neg) -> str:
1✔
135
        """Format unary negation."""
136
        arg = self.format(val.arg)
1✔
137
        return f"-{arg}"
1✔
138

139
    def format_not(self, val: L.Not) -> str:
1✔
140
        """Format not operation."""
NEW
141
        arg = self.format(val.arg)
×
NEW
142
        return f"not({arg})"
×
143

144
    def format_andor(self, oper: L.And | L.Or) -> str:
1✔
145
        """Format and or or operation."""
146
        # Format children
NEW
147
        lhs = self.format(oper.lhs)
×
NEW
148
        rhs = self.format(oper.rhs)
×
149

150
        # Apply parentheses
NEW
151
        if oper.lhs.precedence >= oper.precedence:
×
NEW
152
            lhs = f"({lhs})"
×
NEW
153
        if oper.rhs.precedence >= oper.precedence:
×
NEW
154
            rhs = f"({rhs})"
×
155

NEW
156
        opstr = {"||": "or", "&&": "and"}[oper.op]
×
157

158
        # Return combined string
NEW
159
        return f"{lhs} {opstr} {rhs}"
×
160

161
    def format_literal_float(self, val: L.LiteralFloat) -> str:
1✔
162
        """Format a literal float."""
163
        return f"{val.value}"
1✔
164

165
    def format_literal_int(self, val: L.LiteralInt) -> str:
1✔
166
        """Format a literal int."""
167
        return f"{val.value}"
1✔
168

169
    def format_for_range(self, r: L.ForRange) -> str:
1✔
170
        """Format a loop over a range."""
171
        begin = self.format(r.begin)
1✔
172
        end = self.format(r.end)
1✔
173
        index = self.format(r.index)
1✔
174
        output = f"for {index} in range({begin}, {end}):\n"
1✔
175
        b = self.format(r.body).split("\n")
1✔
176
        for line in b:
1✔
177
            output += f"    {line}\n"
1✔
178
        return output
1✔
179

180
    def format_statement(self, s: L.Statement) -> str:
1✔
181
        """Format a statement."""
182
        return self.format(s.expr)
1✔
183

184
    def format_assign(self, expr: L.Assign) -> str:
1✔
185
        """Format assignment."""
186
        rhs = self.format(expr.rhs)
1✔
187
        lhs = self.format(expr.lhs)
1✔
188
        return f"{lhs} {expr.op} {rhs}\n"
1✔
189

190
    def format_conditional(self, s: L.Conditional) -> str:
1✔
191
        """Format a conditional."""
192
        # Format children
NEW
193
        c = self.format(s.condition)
×
NEW
194
        t = self.format(s.true)
×
NEW
195
        f = self.format(s.false)
×
196

197
        # Apply parentheses
NEW
198
        if s.condition.precedence >= s.precedence:
×
NEW
199
            c = "(" + c + ")"
×
NEW
200
        if s.true.precedence >= s.precedence:
×
NEW
201
            t = "(" + t + ")"
×
NEW
202
        if s.false.precedence >= s.precedence:
×
NEW
203
            f = "(" + f + ")"
×
204

205
        # Return combined string
NEW
206
        return f"({t} if {c} else {f})"
×
207

208
    def format_symbol(self, s: L.Symbol) -> str:
1✔
209
        """Format a symbol."""
210
        return f"{s.name}"
1✔
211

212
    def format_mathfunction(self, f: L.MathFunction) -> str:
1✔
213
        """Format a math function."""
214
        function_map = {
1✔
215
            "ln": "log",
216
            "acos": "arccos",
217
            "asin": "arcsin",
218
            "atan": "arctan",
219
            "atan2": "arctan2",
220
            "acosh": "arccosh",
221
            "asinh": "arcsinh",
222
            "atanh": "arctanh",
223
        }
224
        function = function_map.get(f.function, f.function)
1✔
225
        args = [self.format(arg) for arg in f.args]
1✔
226
        if "bessel" in function:
1✔
NEW
227
            return "0"
×
228
        if function == "erf":
1✔
NEW
229
            return f"math.erf({args[0]})"
×
230
        argstr = ", ".join(args)
1✔
231
        return f"np.{function}({argstr})"
1✔
232

233
    impl = {
1✔
234
        "StatementList": format_statement_list,
235
        "Comment": format_comment,
236
        "Section": format_section,
237
        "ArrayDecl": format_array_decl,
238
        "ArrayAccess": format_array_access,
239
        "MultiIndex": format_multi_index,
240
        "VariableDecl": format_variable_decl,
241
        "ForRange": format_for_range,
242
        "Statement": format_statement,
243
        "Assign": format_assign,
244
        "AssignAdd": format_assign,
245
        "Product": format_nary_op,
246
        "Sum": format_nary_op,
247
        "Add": format_binary_op,
248
        "Sub": format_binary_op,
249
        "Mul": format_binary_op,
250
        "Div": format_binary_op,
251
        "Neg": format_neg,
252
        "Not": format_not,
253
        "LiteralFloat": format_literal_float,
254
        "LiteralInt": format_literal_int,
255
        "Symbol": format_symbol,
256
        "Conditional": format_conditional,
257
        "MathFunction": format_mathfunction,
258
        "And": format_andor,
259
        "Or": format_andor,
260
        "NE": format_binary_op,
261
        "EQ": format_binary_op,
262
        "GE": format_binary_op,
263
        "LE": format_binary_op,
264
        "GT": format_binary_op,
265
        "LT": format_binary_op,
266
    }
267

268
    def format(self, s: L.LNode) -> str:
1✔
269
        """Format output."""
270
        name = s.__class__.__name__
1✔
271
        try:
1✔
272
            return self.impl[name](self, s)  # type: ignore
1✔
NEW
273
        except KeyError:
×
NEW
274
            raise RuntimeError("Unknown statement: ", name)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc