• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 19770323370

28 Nov 2025 05:22PM UTC coverage: 77.933% (-5.1%) from 83.044%
19770323370

Pull #801

github

schnellerhase
Try with Path
Pull Request #801: Add `numba` backend

55 of 359 new or added lines in 21 files covered. (15.32%)

85 existing lines in 4 files now uncovered.

3740 of 4799 relevant lines covered (77.93%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/ffcx/codegeneration/numba/implementation.py
1
# Copyright (C) 2025 Chris Richardson
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Numba implementation for output."""
7

NEW
8
import ffcx.codegeneration.lnodes as L
×
9

10

NEW
11
def build_initializer_lists(values):
×
12
    """Build list of values."""
NEW
13
    arr = "["
×
NEW
14
    if len(values.shape) == 1:
×
NEW
15
        return "[" + ", ".join(str(v) for v in values) + "]"
×
NEW
16
    elif len(values.shape) > 1:
×
NEW
17
        arr += ",\n".join(build_initializer_lists(v) for v in values)
×
NEW
18
    arr += "]"
×
NEW
19
    return arr
×
20

21

NEW
22
class Formatter:
×
23
    """Implementation for numba output backend."""
24

NEW
25
    def __init__(self, scalar) -> None:
×
26
        """Initialise."""
NEW
27
        self.scalar_type = scalar
×
28

NEW
29
    def format_section(self, section):
×
30
        """Format a section."""
31
        # add new line before section
NEW
32
        comments = "# ------------------------ \n"
×
NEW
33
        comments += "# Section: " + section.name + "\n"
×
NEW
34
        comments += "# Inputs: " + ", ".join(w.name for w in section.input) + "\n"
×
NEW
35
        comments += "# Outputs: " + ", ".join(w.name for w in section.output) + "\n"
×
NEW
36
        declarations = "".join(self.format(s) for s in section.declarations)
×
37

NEW
38
        body = ""
×
NEW
39
        if len(section.statements) > 0:
×
NEW
40
            body = "".join(self.format(s) for s in section.statements)
×
41

NEW
42
        body += "# ------------------------ \n"
×
NEW
43
        return comments + declarations + body
×
44

NEW
45
    def format_statement_list(self, slist):
×
46
        """Format a list of statements."""
NEW
47
        output = ""
×
NEW
48
        for s in slist.statements:
×
NEW
49
            output += self.format(s)
×
NEW
50
        return output
×
51

NEW
52
    def format_comment(self, c):
×
53
        """Format a comment."""
NEW
54
        return "# " + c.comment + "\n"
×
55

NEW
56
    def format_array_decl(self, arr):
×
57
        """Format an array declaration."""
NEW
58
        if arr.symbol.dtype == L.DataType.SCALAR:
×
NEW
59
            dtype = "A.dtype"
×
NEW
60
        elif arr.symbol.dtype == L.DataType.REAL:
×
NEW
61
            dtype = "coordinate_dofs.dtype"
×
NEW
62
        elif arr.symbol.dtype == L.DataType.INT:
×
NEW
63
            dtype = "np.int32"
×
NEW
64
        symbol = self.format(arr.symbol)
×
NEW
65
        if arr.values is None:
×
NEW
66
            return f"{symbol} = np.empty({arr.sizes}, dtype={dtype})\n"
×
NEW
67
        elif arr.values.size == 1:
×
NEW
68
            return f"{symbol} = np.full({arr.sizes}, {arr.values[0]}, dtype={dtype})\n"
×
NEW
69
        av = build_initializer_lists(arr.values)
×
NEW
70
        av = "np.array(" + av + f", dtype={dtype})"
×
NEW
71
        return f"{symbol} = {av}\n"
×
72

NEW
73
    def format_array_access(self, arr):
×
74
        """Format array access."""
NEW
75
        array = self.format(arr.array)
×
NEW
76
        idx = ", ".join(self.format(ix) for ix in arr.indices)
×
NEW
77
        return f"{array}[{idx}]"
×
78

NEW
79
    def format_multi_index(self, index):
×
80
        """Format a multi-index."""
NEW
81
        return self.format(index.global_index)
×
82

NEW
83
    def format_variable_decl(self, v):
×
84
        """Format a variable declaration."""
NEW
85
        sym = self.format(v.symbol)
×
NEW
86
        val = self.format(v.value)
×
NEW
87
        return f"{sym} = {val}\n"
×
88

NEW
89
    def format_nary_op(self, oper):
×
90
        """Format a n argument operation."""
91
        # Format children
NEW
92
        args = [self.format(arg) for arg in oper.args]
×
93

94
        # Apply parentheses
NEW
95
        for i in range(len(args)):
×
NEW
96
            if oper.args[i].precedence >= oper.precedence:
×
NEW
97
                args[i] = "(" + args[i] + ")"
×
98

99
        # Return combined string
NEW
100
        return f" {oper.op} ".join(args)
×
101

NEW
102
    def format_binary_op(self, oper):
×
103
        """Format a binary operation."""
104
        # Format children
NEW
105
        lhs = self.format(oper.lhs)
×
NEW
106
        rhs = self.format(oper.rhs)
×
107

108
        # Apply parentheses
NEW
109
        if oper.lhs.precedence >= oper.precedence:
×
NEW
110
            lhs = f"({lhs})"
×
NEW
111
        if oper.rhs.precedence >= oper.precedence:
×
NEW
112
            rhs = f"({rhs})"
×
113

114
        # Return combined string
NEW
115
        return f"{lhs} {oper.op} {rhs}"
×
116

NEW
117
    def format_neg(self, val):
×
118
        """Format unary negation."""
NEW
119
        arg = self.format(val.arg)
×
NEW
120
        return f"-{arg}"
×
121

NEW
122
    def format_not(self, val):
×
123
        """Format not operation."""
NEW
124
        arg = self.format(val.arg)
×
NEW
125
        return f"not({arg})"
×
126

NEW
127
    def format_andor(self, oper):
×
128
        """Format and or or operation."""
129
        # Format children
NEW
130
        lhs = self.format(oper.lhs)
×
NEW
131
        rhs = self.format(oper.rhs)
×
132

133
        # Apply parentheses
NEW
134
        if oper.lhs.precedence >= oper.precedence:
×
NEW
135
            lhs = f"({lhs})"
×
NEW
136
        if oper.rhs.precedence >= oper.precedence:
×
NEW
137
            rhs = f"({rhs})"
×
138

NEW
139
        opstr = {"||": "or", "&&": "and"}[oper.op]
×
140

141
        # Return combined string
NEW
142
        return f"{lhs} {opstr} {rhs}"
×
143

NEW
144
    def format_literal_float(self, val):
×
145
        """Format a literal float."""
NEW
146
        return f"{val.value}"
×
147

NEW
148
    def format_literal_int(self, val):
×
149
        """Format a literal int."""
NEW
150
        return f"{val.value}"
×
151

NEW
152
    def format_for_range(self, r):
×
153
        """Format a loop over a range."""
NEW
154
        begin = self.format(r.begin)
×
NEW
155
        end = self.format(r.end)
×
NEW
156
        index = self.format(r.index)
×
NEW
157
        output = f"for {index} in range({begin}, {end}):\n"
×
NEW
158
        b = self.format(r.body).split("\n")
×
NEW
159
        for line in b:
×
NEW
160
            output += f"    {line}\n"
×
NEW
161
        return output
×
162

NEW
163
    def format_statement(self, s):
×
164
        """Format a statement."""
NEW
165
        return self.format(s.expr)
×
166

NEW
167
    def format_assign(self, expr):
×
168
        """Format assignment."""
NEW
169
        rhs = self.format(expr.rhs)
×
NEW
170
        lhs = self.format(expr.lhs)
×
NEW
171
        return f"{lhs} {expr.op} {rhs}\n"
×
172

NEW
173
    def format_conditional(self, s):
×
174
        """Format a conditional."""
175
        # Format children
NEW
176
        c = self.format(s.condition)
×
NEW
177
        t = self.format(s.true)
×
NEW
178
        f = self.format(s.false)
×
179

180
        # Apply parentheses
NEW
181
        if s.condition.precedence >= s.precedence:
×
NEW
182
            c = "(" + c + ")"
×
NEW
183
        if s.true.precedence >= s.precedence:
×
NEW
184
            t = "(" + t + ")"
×
NEW
185
        if s.false.precedence >= s.precedence:
×
NEW
186
            f = "(" + f + ")"
×
187

188
        # Return combined string
NEW
189
        return f"({t} if {c} else {f})"
×
190

NEW
191
    def format_symbol(self, s):
×
192
        """Format a symbol."""
NEW
193
        return f"{s.name}"
×
194

NEW
195
    def format_mathfunction(self, f):
×
196
        """Format a math function."""
NEW
197
        function_map = {
×
198
            "ln": "log",
199
            "acos": "arccos",
200
            "asin": "arcsin",
201
            "atan": "arctan",
202
            "atan2": "arctan2",
203
            "acosh": "arccosh",
204
            "asinh": "arcsinh",
205
            "atanh": "arctanh",
206
        }
NEW
207
        function = function_map.get(f.function, f.function)
×
NEW
208
        args = [self.format(arg) for arg in f.args]
×
NEW
209
        if "bessel" in function:
×
NEW
210
            return "0"
×
NEW
211
        if function == "erf":
×
NEW
212
            return f"math.erf({args[0]})"
×
NEW
213
        argstr = ", ".join(args)
×
NEW
214
        return f"np.{function}({argstr})"
×
215

NEW
216
    impl = {
×
217
        "StatementList": format_statement_list,
218
        "Comment": format_comment,
219
        "Section": format_section,
220
        "ArrayDecl": format_array_decl,
221
        "ArrayAccess": format_array_access,
222
        "MultiIndex": format_multi_index,
223
        "VariableDecl": format_variable_decl,
224
        "ForRange": format_for_range,
225
        "Statement": format_statement,
226
        "Assign": format_assign,
227
        "AssignAdd": format_assign,
228
        "Product": format_nary_op,
229
        "Sum": format_nary_op,
230
        "Add": format_binary_op,
231
        "Sub": format_binary_op,
232
        "Mul": format_binary_op,
233
        "Div": format_binary_op,
234
        "Neg": format_neg,
235
        "Not": format_not,
236
        "LiteralFloat": format_literal_float,
237
        "LiteralInt": format_literal_int,
238
        "Symbol": format_symbol,
239
        "Conditional": format_conditional,
240
        "MathFunction": format_mathfunction,
241
        "And": format_andor,
242
        "Or": format_andor,
243
        "NE": format_binary_op,
244
        "EQ": format_binary_op,
245
        "GE": format_binary_op,
246
        "LE": format_binary_op,
247
        "GT": format_binary_op,
248
        "LT": format_binary_op,
249
    }
250

NEW
251
    def format(self, s):
×
252
        """Format output."""
NEW
253
        name = s.__class__.__name__
×
NEW
254
        try:
×
NEW
255
            return self.impl[name](self, s)
×
NEW
256
        except KeyError:
×
NEW
257
            raise RuntimeError("Unknown statement: ", name)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc