• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 11642291501

02 Nov 2024 11:16AM UTC coverage: 81.168% (+0.5%) from 80.657%
11642291501

push

github

web-flow
Upload to coveralls and docs from CI job running against python 3.12 (#726)

3474 of 4280 relevant lines covered (81.17%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.57
/ffcx/codegeneration/C/c_implementation.py
1
# Copyright (C) 2023 Chris Richardson
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""C implementation."""
7

8
import warnings
1✔
9

10
import numpy as np
1✔
11
import numpy.typing as npt
1✔
12

13
import ffcx.codegeneration.lnodes as L
1✔
14
from ffcx.codegeneration.utils import dtype_to_c_type, dtype_to_scalar_dtype
1✔
15

16
math_table = {
1✔
17
    "float64": {
18
        "sqrt": "sqrt",
19
        "abs": "fabs",
20
        "cos": "cos",
21
        "sin": "sin",
22
        "tan": "tan",
23
        "acos": "acos",
24
        "asin": "asin",
25
        "atan": "atan",
26
        "cosh": "cosh",
27
        "sinh": "sinh",
28
        "tanh": "tanh",
29
        "acosh": "acosh",
30
        "asinh": "asinh",
31
        "atanh": "atanh",
32
        "power": "pow",
33
        "exp": "exp",
34
        "ln": "log",
35
        "erf": "erf",
36
        "atan_2": "atan2",
37
        "min_value": "fmin",
38
        "max_value": "fmax",
39
        "bessel_y": "yn",
40
        "bessel_j": "jn",
41
    },
42
    "float32": {
43
        "sqrt": "sqrtf",
44
        "abs": "fabsf",
45
        "cos": "cosf",
46
        "sin": "sinf",
47
        "tan": "tanf",
48
        "acos": "acosf",
49
        "asin": "asinf",
50
        "atan": "atanf",
51
        "cosh": "coshf",
52
        "sinh": "sinhf",
53
        "tanh": "tanhf",
54
        "acosh": "acoshf",
55
        "asinh": "asinhf",
56
        "atanh": "atanhf",
57
        "power": "powf",
58
        "exp": "expf",
59
        "ln": "logf",
60
        "erf": "erff",
61
        "atan_2": "atan2f",
62
        "min_value": "fminf",
63
        "max_value": "fmaxf",
64
        "bessel_y": "yn",
65
        "bessel_j": "jn",
66
    },
67
    "longdouble": {
68
        "sqrt": "sqrtl",
69
        "abs": "fabsl",
70
        "cos": "cosl",
71
        "sin": "sinl",
72
        "tan": "tanl",
73
        "acos": "acosl",
74
        "asin": "asinl",
75
        "atan": "atanl",
76
        "cosh": "coshl",
77
        "sinh": "sinhl",
78
        "tanh": "tanhl",
79
        "acosh": "acoshl",
80
        "asinh": "asinhl",
81
        "atanh": "atanhl",
82
        "power": "powl",
83
        "exp": "expl",
84
        "ln": "logl",
85
        "erf": "erfl",
86
        "atan_2": "atan2l",
87
        "min_value": "fminl",
88
        "max_value": "fmaxl",
89
    },
90
    "complex128": {
91
        "sqrt": "csqrt",
92
        "abs": "cabs",
93
        "cos": "ccos",
94
        "sin": "csin",
95
        "tan": "ctan",
96
        "acos": "cacos",
97
        "asin": "casin",
98
        "atan": "catan",
99
        "cosh": "ccosh",
100
        "sinh": "csinh",
101
        "tanh": "ctanh",
102
        "acosh": "cacosh",
103
        "asinh": "casinh",
104
        "atanh": "catanh",
105
        "power": "cpow",
106
        "exp": "cexp",
107
        "ln": "clog",
108
        "real": "creal",
109
        "imag": "cimag",
110
        "conj": "conj",
111
        "max_value": "fmax",
112
        "min_value": "fmin",
113
        "bessel_y": "yn",
114
        "bessel_j": "jn",
115
    },
116
    "complex64": {
117
        "sqrt": "csqrtf",
118
        "abs": "cabsf",
119
        "cos": "ccosf",
120
        "sin": "csinf",
121
        "tan": "ctanf",
122
        "acos": "cacosf",
123
        "asin": "casinf",
124
        "atan": "catanf",
125
        "cosh": "ccoshf",
126
        "sinh": "csinhf",
127
        "tanh": "ctanhf",
128
        "acosh": "cacoshf",
129
        "asinh": "casinhf",
130
        "atanh": "catanhf",
131
        "power": "cpowf",
132
        "exp": "cexpf",
133
        "ln": "clogf",
134
        "real": "crealf",
135
        "imag": "cimagf",
136
        "conj": "conjf",
137
        "max_value": "fmaxf",
138
        "min_value": "fminf",
139
        "bessel_y": "yn",
140
        "bessel_j": "jn",
141
    },
142
}
143

144

145
class CFormatter:
1✔
146
    """C formatter."""
147

148
    scalar_type: np.dtype
1✔
149
    real_type: np.dtype
1✔
150

151
    def __init__(self, dtype: npt.DTypeLike) -> None:
1✔
152
        """Initialise."""
153
        self.scalar_type = np.dtype(dtype)
1✔
154
        self.real_type = dtype_to_scalar_dtype(dtype)
1✔
155

156
    def _dtype_to_name(self, dtype) -> str:
1✔
157
        """Convert dtype to C name."""
158
        if dtype == L.DataType.SCALAR:
1✔
159
            return dtype_to_c_type(self.scalar_type)
1✔
160
        if dtype == L.DataType.REAL:
1✔
161
            return dtype_to_c_type(self.real_type)
1✔
162
        if dtype == L.DataType.INT:
1✔
163
            return "int"
×
164
        if dtype == L.DataType.BOOL:
1✔
165
            return "bool"
1✔
166
        raise ValueError(f"Invalid dtype: {dtype}")
×
167

168
    def _format_number(self, x):
1✔
169
        """Format a number."""
170
        # Use 16sf for precision (good for float64 or less)
171
        if isinstance(x, complex):
1✔
172
            return f"({x.real:.16}+I*{x.imag:.16})"
1✔
173
        elif isinstance(x, float):
1✔
174
            return f"{x:.16}"
1✔
175
        return str(x)
1✔
176

177
    def _build_initializer_lists(self, values):
1✔
178
        """Build initializer lists."""
179
        arr = "{"
1✔
180
        if len(values.shape) == 1:
1✔
181
            arr += ", ".join(self._format_number(v) for v in values)
1✔
182
        elif len(values.shape) > 1:
1✔
183
            arr += ",\n  ".join(self._build_initializer_lists(v) for v in values)
1✔
184
        arr += "}"
1✔
185
        return arr
1✔
186

187
    def format_statement_list(self, slist) -> str:
1✔
188
        """Format a statement list."""
189
        return "".join(self.c_format(s) for s in slist.statements)
1✔
190

191
    def format_section(self, section) -> str:
1✔
192
        """Format a section."""
193
        # add new line before section
194
        comments = "// ------------------------ \n"
1✔
195
        comments += "// Section: " + section.name + "\n"
1✔
196
        comments += "// Inputs: " + ", ".join(w.name for w in section.input) + "\n"
1✔
197
        comments += "// Outputs: " + ", ".join(w.name for w in section.output) + "\n"
1✔
198
        declarations = "".join(self.c_format(s) for s in section.declarations)
1✔
199

200
        body = ""
1✔
201
        if len(section.statements) > 0:
1✔
202
            declarations += "{\n  "
1✔
203
            body = "".join(self.c_format(s) for s in section.statements)
1✔
204
            body = body.replace("\n", "\n  ")
1✔
205
            body = body[:-2] + "}\n"
1✔
206

207
        body += "// ------------------------ \n"
1✔
208
        return comments + declarations + body
1✔
209

210
    def format_comment(self, c) -> str:
1✔
211
        """Format a comment."""
212
        return "// " + c.comment + "\n"
1✔
213

214
    def format_array_decl(self, arr) -> str:
1✔
215
        """Format an array declaration."""
216
        dtype = arr.symbol.dtype
1✔
217
        typename = self._dtype_to_name(dtype)
1✔
218

219
        symbol = self.c_format(arr.symbol)
1✔
220
        dims = "".join([f"[{i}]" for i in arr.sizes])
1✔
221
        if arr.values is None:
1✔
222
            assert arr.const is False
×
223
            return f"{typename} {symbol}{dims};\n"
×
224

225
        vals = self._build_initializer_lists(arr.values)
1✔
226
        cstr = "static const " if arr.const else ""
1✔
227
        return f"{cstr}{typename} {symbol}{dims} = {vals};\n"
1✔
228

229
    def format_array_access(self, arr) -> str:
1✔
230
        """Format an array access."""
231
        name = self.c_format(arr.array)
1✔
232
        indices = f"[{']['.join(self.c_format(i) for i in arr.indices)}]"
1✔
233
        return f"{name}{indices}"
1✔
234

235
    def format_variable_decl(self, v) -> str:
1✔
236
        """Format a variable declaration."""
237
        val = self.c_format(v.value)
1✔
238
        symbol = self.c_format(v.symbol)
1✔
239
        typename = self._dtype_to_name(v.symbol.dtype)
1✔
240
        return f"{typename} {symbol} = {val};\n"
1✔
241

242
    def format_nary_op(self, oper) -> str:
1✔
243
        """Format an n-ary operation."""
244
        # Format children
245
        args = [self.c_format(arg) for arg in oper.args]
1✔
246

247
        # Apply parentheses
248
        for i in range(len(args)):
1✔
249
            if oper.args[i].precedence >= oper.precedence:
1✔
250
                args[i] = "(" + args[i] + ")"
1✔
251

252
        # Return combined string
253
        return f" {oper.op} ".join(args)
1✔
254

255
    def format_binary_op(self, oper) -> str:
1✔
256
        """Format a binary operation."""
257
        # Format children
258
        lhs = self.c_format(oper.lhs)
1✔
259
        rhs = self.c_format(oper.rhs)
1✔
260

261
        # Apply parentheses
262
        if oper.lhs.precedence >= oper.precedence:
1✔
263
            lhs = f"({lhs})"
1✔
264
        if oper.rhs.precedence >= oper.precedence:
1✔
265
            rhs = f"({rhs})"
1✔
266

267
        # Return combined string
268
        return f"{lhs} {oper.op} {rhs}"
1✔
269

270
    def format_unary_op(self, oper) -> str:
1✔
271
        """Format a unary operation."""
272
        arg = self.c_format(oper.arg)
1✔
273
        if oper.arg.precedence >= oper.precedence:
1✔
274
            return f"{oper.op}({arg})"
×
275
        return f"{oper.op}{arg}"
1✔
276

277
    def format_literal_float(self, val) -> str:
1✔
278
        """Format a literal float."""
279
        value = self._format_number(val.value)
1✔
280
        return f"{value}"
1✔
281

282
    def format_literal_int(self, val) -> str:
1✔
283
        """Format a literal int."""
284
        return f"{val.value}"
1✔
285

286
    def format_for_range(self, r) -> str:
1✔
287
        """Format a for loop over a range."""
288
        begin = self.c_format(r.begin)
1✔
289
        end = self.c_format(r.end)
1✔
290
        index = self.c_format(r.index)
1✔
291
        output = f"for (int {index} = {begin}; {index} < {end}; ++{index})\n"
1✔
292
        output += "{\n"
1✔
293
        body = self.c_format(r.body)
1✔
294
        for line in body.split("\n"):
1✔
295
            if len(line) > 0:
1✔
296
                output += f"  {line}\n"
1✔
297
        output += "}\n"
1✔
298
        return output
1✔
299

300
    def format_statement(self, s) -> str:
1✔
301
        """Format a statement."""
302
        return self.c_format(s.expr)
1✔
303

304
    def format_assign(self, expr) -> str:
1✔
305
        """Format an assignment."""
306
        rhs = self.c_format(expr.rhs)
1✔
307
        lhs = self.c_format(expr.lhs)
1✔
308
        return f"{lhs} {expr.op} {rhs};\n"
1✔
309

310
    def format_conditional(self, s) -> str:
1✔
311
        """Format a conditional."""
312
        # Format children
313
        c = self.c_format(s.condition)
1✔
314
        t = self.c_format(s.true)
1✔
315
        f = self.c_format(s.false)
1✔
316

317
        # Apply parentheses
318
        if s.condition.precedence >= s.precedence:
1✔
319
            c = "(" + c + ")"
×
320
        if s.true.precedence >= s.precedence:
1✔
321
            t = "(" + t + ")"
×
322
        if s.false.precedence >= s.precedence:
1✔
323
            f = "(" + f + ")"
×
324

325
        # Return combined string
326
        return c + " ? " + t + " : " + f
1✔
327

328
    def format_symbol(self, s) -> str:
1✔
329
        """Format a symbol."""
330
        return f"{s.name}"
1✔
331

332
    def format_multi_index(self, mi) -> str:
1✔
333
        """Format a multi-index."""
334
        return self.c_format(mi.global_index)
1✔
335

336
    def format_math_function(self, c) -> str:
1✔
337
        """Format a mathematical function."""
338
        # Get a table of functions for this type, if available
339
        arg_type = self.scalar_type
1✔
340
        if hasattr(c.args[0], "dtype"):
1✔
341
            if c.args[0].dtype == L.DataType.REAL:
1✔
342
                arg_type = self.real_type
1✔
343
        else:
344
            warnings.warn(f"Syntax item without dtype {c.args[0]}")
×
345

346
        dtype_math_table = math_table[arg_type.name]
1✔
347

348
        # Get a function from the table, if available, else just use bare name
349
        func = dtype_math_table.get(c.function, c.function)
1✔
350
        args = ", ".join(self.c_format(arg) for arg in c.args)
1✔
351
        return f"{func}({args})"
1✔
352

353
    c_impl = {
1✔
354
        "Section": format_section,
355
        "StatementList": format_statement_list,
356
        "Comment": format_comment,
357
        "ArrayDecl": format_array_decl,
358
        "ArrayAccess": format_array_access,
359
        "MultiIndex": format_multi_index,
360
        "VariableDecl": format_variable_decl,
361
        "ForRange": format_for_range,
362
        "Statement": format_statement,
363
        "Assign": format_assign,
364
        "AssignAdd": format_assign,
365
        "Product": format_nary_op,
366
        "Neg": format_unary_op,
367
        "Sum": format_nary_op,
368
        "Add": format_binary_op,
369
        "Sub": format_binary_op,
370
        "Mul": format_binary_op,
371
        "Div": format_binary_op,
372
        "Not": format_unary_op,
373
        "LiteralFloat": format_literal_float,
374
        "LiteralInt": format_literal_int,
375
        "Symbol": format_symbol,
376
        "Conditional": format_conditional,
377
        "MathFunction": format_math_function,
378
        "And": format_binary_op,
379
        "Or": format_binary_op,
380
        "NE": format_binary_op,
381
        "EQ": format_binary_op,
382
        "GE": format_binary_op,
383
        "LE": format_binary_op,
384
        "GT": format_binary_op,
385
        "LT": format_binary_op,
386
    }
387

388
    def c_format(self, s) -> str:
1✔
389
        """Format as C."""
390
        name = s.__class__.__name__
1✔
391
        try:
1✔
392
            return self.c_impl[name](self, s)
1✔
393
        except KeyError:
×
394
            raise RuntimeError("Unknown statement: ", name)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc