• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 19801120260

30 Nov 2025 03:36PM UTC coverage: 73.753% (-4.2%) from 77.933%
19801120260

Pull #803

github

schnellerhase
passing compiles?
Pull Request #803: Add `C++` backend

0 of 272 new or added lines in 10 files covered. (0.0%)

3740 of 5071 relevant lines covered (73.75%)

0.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/ffcx/codegeneration/cpp/cpp_implementation.py
1
# Copyright (C) 2023 Chris Richardson
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""C++ implementation."""
7

NEW
8
import ffcx.codegeneration.lnodes as L
×
9

NEW
10
math_table = {
×
11
    "sqrt": "std::sqrt",
12
    "abs": "std::abs",
13
    "cos": "std::cos",
14
    "sin": "std::sin",
15
    "tan": "std::tan",
16
    "acos": "std::acos",
17
    "asin": "std::asin",
18
    "atan": "std::atan",
19
    "cosh": "std::cosh",
20
    "sinh": "std::sinh",
21
    "tanh": "std::tanh",
22
    "acosh": "std::acosh",
23
    "asinh": "std::asinh",
24
    "atanh": "std::atanh",
25
    "power": "std::pow",
26
    "exp": "std::exp",
27
    "ln": "std::log",
28
    "erf": "std::erf",
29
    "atan_2": "std::atan2",
30
    "min_value": "std::fmin",
31
    "max_value": "std::fmax",
32
    "bessel_y": "std::cyl_bessel_i",
33
    "bessel_j": "std::cyl_bessel_j",
34
    "conj": "std::conj",
35
    "real": "std::real",
36
    "imag": "std::imag",
37
}
38

39

NEW
40
def build_initializer_lists(values):
×
41
    """Build initializer lists."""
NEW
42
    arr = "{"
×
NEW
43
    if len(values.shape) == 1:
×
NEW
44
        return "{" + ", ".join(str(v) for v in values) + "}"
×
NEW
45
    elif len(values.shape) > 1:
×
NEW
46
        arr += ",\n".join(build_initializer_lists(v) for v in values)
×
NEW
47
    arr += "}"
×
NEW
48
    return arr
×
49

50

NEW
51
class CppFormatter:
×
52
    """C++ formatter."""
53

NEW
54
    def __init__(self, scalar) -> None:
×
55
        """Initialise."""
NEW
56
        self.scalar_type = "T"
×
NEW
57
        self.real_type = "U"
×
58

NEW
59
    def format_statement_list(self, slist) -> str:
×
60
        """Format statement list."""
NEW
61
        return "".join(self.c_format(s) for s in slist.statements)
×
62

NEW
63
    def format_section(self, section) -> str:
×
64
        """Format a section."""
65
        # add new line before section
NEW
66
        comments = "// ------------------------ \n"
×
NEW
67
        comments += "// Section: " + section.name + "\n"
×
NEW
68
        comments += "// Inputs: " + ", ".join(w.name for w in section.input) + "\n"
×
NEW
69
        comments += "// Outputs: " + ", ".join(w.name for w in section.output) + "\n"
×
NEW
70
        declarations = "".join(self.c_format(s) for s in section.declarations)
×
71

NEW
72
        body = ""
×
NEW
73
        if len(section.statements) > 0:
×
NEW
74
            declarations += "{\n  "
×
NEW
75
            body = "".join(self.c_format(s) for s in section.statements)
×
NEW
76
            body = body.replace("\n", "\n  ")
×
NEW
77
            body = body[:-2] + "}\n"
×
78

NEW
79
        body += "// ------------------------ \n"
×
NEW
80
        return comments + declarations + body
×
81

NEW
82
    def format_comment(self, c) -> str:
×
83
        """Format a comment."""
NEW
84
        return "// " + c.comment + "\n"
×
85

NEW
86
    def format_array_decl(self, arr) -> str:
×
87
        """Format an array declaration."""
NEW
88
        dtype = arr.symbol.dtype
×
NEW
89
        assert dtype is not None
×
90

NEW
91
        if dtype == L.DataType.SCALAR:
×
NEW
92
            typename = self.scalar_type
×
NEW
93
        elif dtype == L.DataType.REAL:
×
NEW
94
            typename = self.real_type
×
NEW
95
        elif dtype == L.DataType.INT:
×
NEW
96
            typename = "int"
×
97
        else:
NEW
98
            raise ValueError("Invalid datatype")
×
99

NEW
100
        symbol = self.c_format(arr.symbol)
×
NEW
101
        dims = "".join([f"[{i}]" for i in arr.sizes])
×
NEW
102
        if arr.values is None:
×
NEW
103
            assert arr.const is False
×
NEW
104
            return f"{typename} {symbol}{dims};\n"
×
105

NEW
106
        vals = build_initializer_lists(arr.values)
×
NEW
107
        cstr = "static const " if arr.const else ""
×
NEW
108
        return f"{cstr}{typename} {symbol}{dims} = {vals};\n"
×
109

NEW
110
    def format_array_access(self, arr) -> str:
×
111
        """Format array access."""
NEW
112
        name = self.c_format(arr.array)
×
NEW
113
        indices = f"[{']['.join(self.c_format(i) for i in arr.indices)}]"
×
NEW
114
        return f"{name}{indices}"
×
115

NEW
116
    def format_multi_index(self, index) -> str:
×
117
        """Format a multi-index."""
NEW
118
        return self.c_format(index.global_index)
×
119

NEW
120
    def format_variable_decl(self, v) -> str:
×
121
        """Format a variable declaration."""
NEW
122
        val = self.c_format(v.value)
×
NEW
123
        symbol = self.c_format(v.symbol)
×
NEW
124
        assert v.symbol.dtype
×
125
        # TODO: move to _dtype_to_name
NEW
126
        typename = ""  # tmp fix!!
×
NEW
127
        if v.symbol.dtype == L.DataType.SCALAR:
×
NEW
128
            typename = self.scalar_type
×
NEW
129
        elif v.symbol.dtype == L.DataType.REAL:
×
NEW
130
            typename = self.real_type
×
NEW
131
        elif v.symbol.dtype == L.DataType.INT:
×
NEW
132
            typename = "std::int32_t"
×
NEW
133
        elif v.symbol.dtype == L.DataType.BOOL:
×
NEW
134
            typename = "bool"
×
NEW
135
        return f"{typename} {symbol} = {val};\n"
×
136

NEW
137
    def format_nary_op(self, oper) -> str:
×
138
        """Format an n-argument operation."""
139
        # Format children
NEW
140
        args = [self.c_format(arg) for arg in oper.args]
×
141

142
        # Apply parentheses
NEW
143
        for i in range(len(args)):
×
NEW
144
            if oper.args[i].precedence >= oper.precedence:
×
NEW
145
                args[i] = "(" + args[i] + ")"
×
146

147
        # Return combined string
NEW
148
        return f" {oper.op} ".join(args)
×
149

NEW
150
    def format_binary_op(self, oper) -> str:
×
151
        """Format a binary operation."""
152
        # Format children
NEW
153
        lhs = self.c_format(oper.lhs)
×
NEW
154
        rhs = self.c_format(oper.rhs)
×
155

156
        # Apply parentheses
NEW
157
        if oper.lhs.precedence >= oper.precedence:
×
NEW
158
            lhs = f"({lhs})"
×
NEW
159
        if oper.rhs.precedence >= oper.precedence:
×
NEW
160
            rhs = f"({rhs})"
×
161

162
        # Return combined string
NEW
163
        return f"{lhs} {oper.op} {rhs}"
×
164

NEW
165
    def format_neg(self, val) -> str:
×
166
        """Format negation."""
NEW
167
        arg = self.c_format(val.arg)
×
NEW
168
        return f"-{arg}"
×
169

NEW
170
    def format_not(self, val) -> str:
×
171
        """Format 'not' statement."""
NEW
172
        arg = self.c_format(val.arg)
×
NEW
173
        return f"{val.op}({arg})"
×
174

NEW
175
    def format_literal_float(self, val) -> str:
×
176
        """Format a literal float number."""
NEW
177
        return f"{val.value}"
×
178

NEW
179
    def format_literal_int(self, val) -> str:
×
180
        """Format a literal int number."""
NEW
181
        return f"{val.value}"
×
182

NEW
183
    def format_for_range(self, r) -> str:
×
184
        """Format a loop over a range."""
NEW
185
        begin = self.c_format(r.begin)
×
NEW
186
        end = self.c_format(r.end)
×
NEW
187
        index = self.c_format(r.index)
×
NEW
188
        output = f"for (int {index} = {begin}; {index} < {end}; ++{index})\n"
×
NEW
189
        output += "{\n"
×
NEW
190
        body = self.c_format(r.body)
×
NEW
191
        for line in body.split("\n"):
×
NEW
192
            if len(line) > 0:
×
NEW
193
                output += f"  {line}\n"
×
NEW
194
        output += "}\n"
×
NEW
195
        return output
×
196

NEW
197
    def format_statement(self, s) -> str:
×
198
        """Format a statement."""
NEW
199
        return self.c_format(s.expr)
×
200

NEW
201
    def format_assign(self, expr) -> str:
×
202
        """Format an assignment statement."""
NEW
203
        rhs = self.c_format(expr.rhs)
×
NEW
204
        lhs = self.c_format(expr.lhs)
×
NEW
205
        return f"{lhs} {expr.op} {rhs};\n"
×
206

NEW
207
    def format_conditional(self, s) -> str:
×
208
        """Format a conditional."""
209
        # Format children
NEW
210
        c = self.c_format(s.condition)
×
NEW
211
        t = self.c_format(s.true)
×
NEW
212
        f = self.c_format(s.false)
×
213

214
        # Apply parentheses
NEW
215
        if s.condition.precedence >= s.precedence:
×
NEW
216
            c = "(" + c + ")"
×
NEW
217
        if s.true.precedence >= s.precedence:
×
NEW
218
            t = "(" + t + ")"
×
NEW
219
        if s.false.precedence >= s.precedence:
×
NEW
220
            f = "(" + f + ")"
×
221

222
        # Return combined string
NEW
223
        return c + " ? " + t + " : " + f
×
224

NEW
225
    def format_symbol(self, s) -> str:
×
226
        """Format a symbol."""
NEW
227
        return f"{s.name}"
×
228

NEW
229
    def format_math_function(self, c) -> str:
×
230
        """Format a math function."""
231
        # Get a function from the table, if available, else just use bare name
NEW
232
        func = math_table.get(c.function, c.function)
×
NEW
233
        args = ", ".join(self.c_format(arg) for arg in c.args)
×
NEW
234
        return f"{func}({args})"
×
235

NEW
236
    c_impl = {
×
237
        "Section": format_section,
238
        "StatementList": format_statement_list,
239
        "Comment": format_comment,
240
        "ArrayDecl": format_array_decl,
241
        "ArrayAccess": format_array_access,
242
        "MultiIndex": format_multi_index,
243
        "VariableDecl": format_variable_decl,
244
        "ForRange": format_for_range,
245
        "Statement": format_statement,
246
        "Assign": format_assign,
247
        "AssignAdd": format_assign,
248
        "Product": format_nary_op,
249
        "Neg": format_neg,
250
        "Sum": format_nary_op,
251
        "Add": format_binary_op,
252
        "Sub": format_binary_op,
253
        "Mul": format_binary_op,
254
        "Div": format_binary_op,
255
        "Not": format_not,
256
        "LiteralFloat": format_literal_float,
257
        "LiteralInt": format_literal_int,
258
        "Symbol": format_symbol,
259
        "Conditional": format_conditional,
260
        "MathFunction": format_math_function,
261
        "And": format_binary_op,
262
        "Or": format_binary_op,
263
        "NE": format_binary_op,
264
        "EQ": format_binary_op,
265
        "GE": format_binary_op,
266
        "LE": format_binary_op,
267
        "GT": format_binary_op,
268
        "LT": format_binary_op,
269
    }
270

NEW
271
    def c_format(self, s) -> str:
×
272
        """Formatting function."""
NEW
273
        name = s.__class__.__name__
×
NEW
274
        try:
×
NEW
275
            return self.c_impl[name](self, s)
×
NEW
276
        except KeyError:
×
NEW
277
            raise RuntimeError("Unknown statement: ", name)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc