• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 19989253606

06 Dec 2025 01:35PM UTC coverage: 84.457% (+1.4%) from 83.044%
19989253606

Pull #801

github

schnellerhase
Revert for demos, cwd more important than coverage
Pull Request #801: Add `numba` backend

340 of 368 new or added lines in 22 files covered. (92.39%)

6 existing lines in 3 files now uncovered.

4059 of 4806 relevant lines covered (84.46%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.7
/ffcx/codegeneration/numba/implementation.py
1
# Copyright (C) 2025 Chris Richardson
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Numba implementation for output."""
7

8
import ffcx.codegeneration.lnodes as L
1✔
9

10

11
def build_initializer_lists(values):
1✔
12
    """Build list of values."""
13
    arr = "["
1✔
14
    if len(values.shape) == 1:
1✔
15
        return "[" + ", ".join(str(v) for v in values) + "]"
1✔
16
    elif len(values.shape) > 1:
1✔
17
        arr += ",\n".join(build_initializer_lists(v) for v in values)
1✔
18
    arr += "]"
1✔
19
    return arr
1✔
20

21

22
class Formatter:
1✔
23
    """Implementation for numba output backend."""
24

25
    def __init__(self, scalar) -> None:
1✔
26
        """Initialise."""
27
        self.scalar_type = scalar
1✔
28

29
    def format_section(self, section: L.Section) -> str:
1✔
30
        """Format a section."""
31
        # add new line before section
32
        comments = "# ------------------------ \n"
1✔
33
        comments += "# Section: " + section.name + "\n"
1✔
34
        comments += "# Inputs: " + ", ".join(w.name for w in section.input) + "\n"
1✔
35
        comments += "# Outputs: " + ", ".join(w.name for w in section.output) + "\n"
1✔
36
        declarations = "".join(self.format(s) for s in section.declarations)
1✔
37

38
        body = ""
1✔
39
        if len(section.statements) > 0:
1✔
40
            body = "".join(self.format(s) for s in section.statements)
1✔
41

42
        body += "# ------------------------ \n"
1✔
43
        return comments + declarations + body
1✔
44

45
    def format_statement_list(self, slist: L.StatementList) -> str:
1✔
46
        """Format a list of statements."""
47
        output = ""
1✔
48
        for s in slist.statements:
1✔
49
            output += self.format(s)
1✔
50
        return output
1✔
51

52
    def format_comment(self, c: L.Comment) -> str:
1✔
53
        """Format a comment."""
54
        return "# " + c.comment + "\n"
1✔
55

56
    def format_array_decl(self, arr: L.ArrayDecl) -> str:
1✔
57
        """Format an array declaration."""
58
        if arr.symbol.dtype == L.DataType.SCALAR:
1✔
59
            dtype = "A.dtype"
1✔
60
        elif arr.symbol.dtype == L.DataType.REAL:
1✔
61
            dtype = "coordinate_dofs.dtype"
1✔
NEW
62
        elif arr.symbol.dtype == L.DataType.INT:
×
NEW
63
            dtype = "np.int32"
×
64
        symbol = self.format(arr.symbol)
1✔
65
        if arr.values is None:
1✔
NEW
66
            return f"{symbol} = np.empty({arr.sizes}, dtype={dtype})\n"
×
67
        elif arr.values.size == 1:
1✔
68
            return f"{symbol} = np.full({arr.sizes}, {arr.values[0]}, dtype={dtype})\n"
1✔
69
        av = build_initializer_lists(arr.values)
1✔
70
        av = "np.array(" + av + f", dtype={dtype})"
1✔
71
        return f"{symbol} = {av}\n"
1✔
72

73
    def format_array_access(self, arr: L.ArrayAccess) -> str:
1✔
74
        """Format array access."""
75
        array = self.format(arr.array)
1✔
76
        idx = ", ".join(self.format(ix) for ix in arr.indices)
1✔
77
        return f"{array}[{idx}]"
1✔
78

79
    def format_multi_index(self, index: L.MultiIndex) -> str:
1✔
80
        """Format a multi-index."""
81
        return self.format(index.global_index)
1✔
82

83
    def format_variable_decl(self, v: L.VariableDecl) -> str:
1✔
84
        """Format a variable declaration."""
85
        sym = self.format(v.symbol)
1✔
86
        val = self.format(v.value)
1✔
87
        return f"{sym} = {val}\n"
1✔
88

89
    def format_nary_op(self, oper: L.NaryOp) -> str:
1✔
90
        """Format a n argument operation."""
91
        # Format children
92
        args = [self.format(arg) for arg in oper.args]
1✔
93

94
        # Apply parentheses
95
        for i in range(len(args)):
1✔
96
            if oper.args[i].precedence >= oper.precedence:
1✔
97
                args[i] = "(" + args[i] + ")"
1✔
98

99
        # Return combined string
100
        return f" {oper.op} ".join(args)
1✔
101

102
    def format_binary_op(self, oper: L.BinOp) -> str:
1✔
103
        """Format a binary operation."""
104
        # Format children
105
        lhs = self.format(oper.lhs)
1✔
106
        rhs = self.format(oper.rhs)
1✔
107

108
        # Apply parentheses
109
        if oper.lhs.precedence >= oper.precedence:
1✔
110
            lhs = f"({lhs})"
1✔
111
        if oper.rhs.precedence >= oper.precedence:
1✔
112
            rhs = f"({rhs})"
1✔
113

114
        # Return combined string
115
        return f"{lhs} {oper.op} {rhs}"
1✔
116

117
    def format_neg(self, val: L.Neg) -> str:
1✔
118
        """Format unary negation."""
119
        arg = self.format(val.arg)
1✔
120
        return f"-{arg}"
1✔
121

122
    def format_not(self, val: L.Not) -> str:
1✔
123
        """Format not operation."""
NEW
124
        arg = self.format(val.arg)
×
NEW
125
        return f"not({arg})"
×
126

127
    def format_andor(self, oper: L.And | L.Or) -> str:
1✔
128
        """Format and or or operation."""
129
        # Format children
NEW
130
        lhs = self.format(oper.lhs)
×
NEW
131
        rhs = self.format(oper.rhs)
×
132

133
        # Apply parentheses
NEW
134
        if oper.lhs.precedence >= oper.precedence:
×
NEW
135
            lhs = f"({lhs})"
×
NEW
136
        if oper.rhs.precedence >= oper.precedence:
×
NEW
137
            rhs = f"({rhs})"
×
138

NEW
139
        opstr = {"||": "or", "&&": "and"}[oper.op]
×
140

141
        # Return combined string
NEW
142
        return f"{lhs} {opstr} {rhs}"
×
143

144
    def format_literal_float(self, val: L.LiteralFloat) -> str:
1✔
145
        """Format a literal float."""
146
        return f"{val.value}"
1✔
147

148
    def format_literal_int(self, val: L.LiteralInt) -> str:
1✔
149
        """Format a literal int."""
150
        return f"{val.value}"
1✔
151

152
    def format_for_range(self, r: L.ForRange) -> str:
1✔
153
        """Format a loop over a range."""
154
        begin = self.format(r.begin)
1✔
155
        end = self.format(r.end)
1✔
156
        index = self.format(r.index)
1✔
157
        output = f"for {index} in range({begin}, {end}):\n"
1✔
158
        b = self.format(r.body).split("\n")
1✔
159
        for line in b:
1✔
160
            output += f"    {line}\n"
1✔
161
        return output
1✔
162

163
    def format_statement(self, s: L.Statement) -> str:
1✔
164
        """Format a statement."""
165
        return self.format(s.expr)
1✔
166

167
    def format_assign(self, expr: L.Assign) -> str:
1✔
168
        """Format assignment."""
169
        rhs = self.format(expr.rhs)
1✔
170
        lhs = self.format(expr.lhs)
1✔
171
        return f"{lhs} {expr.op} {rhs}\n"
1✔
172

173
    def format_conditional(self, s: L.Conditional) -> str:
1✔
174
        """Format a conditional."""
175
        # Format children
NEW
176
        c = self.format(s.condition)
×
NEW
177
        t = self.format(s.true)
×
NEW
178
        f = self.format(s.false)
×
179

180
        # Apply parentheses
NEW
181
        if s.condition.precedence >= s.precedence:
×
NEW
182
            c = "(" + c + ")"
×
NEW
183
        if s.true.precedence >= s.precedence:
×
NEW
184
            t = "(" + t + ")"
×
NEW
185
        if s.false.precedence >= s.precedence:
×
NEW
186
            f = "(" + f + ")"
×
187

188
        # Return combined string
NEW
189
        return f"({t} if {c} else {f})"
×
190

191
    def format_symbol(self, s: L.Symbol) -> str:
1✔
192
        """Format a symbol."""
193
        return f"{s.name}"
1✔
194

195
    def format_mathfunction(self, f: L.MathFunction) -> str:
1✔
196
        """Format a math function."""
197
        function_map = {
1✔
198
            "ln": "log",
199
            "acos": "arccos",
200
            "asin": "arcsin",
201
            "atan": "arctan",
202
            "atan2": "arctan2",
203
            "acosh": "arccosh",
204
            "asinh": "arcsinh",
205
            "atanh": "arctanh",
206
        }
207
        function = function_map.get(f.function, f.function)
1✔
208
        args = [self.format(arg) for arg in f.args]
1✔
209
        if "bessel" in function:
1✔
NEW
210
            return "0"
×
211
        if function == "erf":
1✔
NEW
212
            return f"math.erf({args[0]})"
×
213
        argstr = ", ".join(args)
1✔
214
        return f"np.{function}({argstr})"
1✔
215

216
    impl = {
1✔
217
        "StatementList": format_statement_list,
218
        "Comment": format_comment,
219
        "Section": format_section,
220
        "ArrayDecl": format_array_decl,
221
        "ArrayAccess": format_array_access,
222
        "MultiIndex": format_multi_index,
223
        "VariableDecl": format_variable_decl,
224
        "ForRange": format_for_range,
225
        "Statement": format_statement,
226
        "Assign": format_assign,
227
        "AssignAdd": format_assign,
228
        "Product": format_nary_op,
229
        "Sum": format_nary_op,
230
        "Add": format_binary_op,
231
        "Sub": format_binary_op,
232
        "Mul": format_binary_op,
233
        "Div": format_binary_op,
234
        "Neg": format_neg,
235
        "Not": format_not,
236
        "LiteralFloat": format_literal_float,
237
        "LiteralInt": format_literal_int,
238
        "Symbol": format_symbol,
239
        "Conditional": format_conditional,
240
        "MathFunction": format_mathfunction,
241
        "And": format_andor,
242
        "Or": format_andor,
243
        "NE": format_binary_op,
244
        "EQ": format_binary_op,
245
        "GE": format_binary_op,
246
        "LE": format_binary_op,
247
        "GT": format_binary_op,
248
        "LT": format_binary_op,
249
    }
250

251
    def format(self, s: L.LNode) -> str:
1✔
252
        """Format output."""
253
        name = s.__class__.__name__
1✔
254
        try:
1✔
255
            return self.impl[name](self, s)  # type: ignore
1✔
NEW
256
        except KeyError:
×
NEW
257
            raise RuntimeError("Unknown statement: ", name)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc