• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 19770323370

28 Nov 2025 05:22PM UTC coverage: 77.933% (-5.1%) from 83.044%
19770323370

Pull #801

github

schnellerhase
Try with Path
Pull Request #801: Add `numba` backend

55 of 359 new or added lines in 21 files covered. (15.32%)

85 existing lines in 4 files now uncovered.

3740 of 4799 relevant lines covered (77.93%)

0.78 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.58
/ffcx/codegeneration/lnodes.py
1
# Copyright (C) 2013-2023 Martin Sandve Alnæs, Chris Richardson
2
#
3
# This file is part of FFCx.(https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""LNodes.
7

8
LNodes is intended as a minimal generic language description.
9
Formatting is done later, depending on the target language.
10

11
Supported:
12
 Floating point (and complex) and integer variables and multidimensional arrays
13
 Range loops
14
 Simple arithmetic, +-*/
15
 Math operations
16
 Logic conditions
17
 Comments
18
Not supported:
19
 Pointers
20
 Function Calls
21
 Flow control (if, switch, while)
22
 Booleans
23
 Strings
24
"""
25

26
import numbers
1✔
27
from collections.abc import Sequence
1✔
28
from enum import Enum
1✔
29

30
import numpy as np
1✔
31
import ufl
1✔
32

33

34
class PRECEDENCE:
1✔
35
    """An enum-like class for operator precedence levels."""
36

37
    HIGHEST = 0
1✔
38
    LITERAL = 0
1✔
39
    SYMBOL = 0
1✔
40
    SUBSCRIPT = 2
1✔
41

42
    NOT = 3
1✔
43
    NEG = 3
1✔
44

45
    MUL = 4
1✔
46
    DIV = 4
1✔
47

48
    ADD = 5
1✔
49
    SUB = 5
1✔
50

51
    LT = 7
1✔
52
    LE = 7
1✔
53
    GT = 7
1✔
54
    GE = 7
1✔
55
    EQ = 8
1✔
56
    NE = 8
1✔
57
    AND = 11
1✔
58
    OR = 12
1✔
59
    CONDITIONAL = 13
1✔
60
    ASSIGN = 13
1✔
61
    LOWEST = 15
1✔
62

63

64
def is_zero_lexpr(lexpr):
1✔
65
    """Check if an expression is zero."""
66
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 0.0) or (
1✔
67
        isinstance(lexpr, LiteralInt) and lexpr.value == 0
68
    )
69

70

71
def is_one_lexpr(lexpr):
1✔
72
    """Check if an expression is one."""
73
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 1.0) or (
1✔
74
        isinstance(lexpr, LiteralInt) and lexpr.value == 1
75
    )
76

77

78
def is_negative_one_lexpr(lexpr):
1✔
79
    """Check if an expression is negative one."""
80
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == -1.0) or (
1✔
81
        isinstance(lexpr, LiteralInt) and lexpr.value == -1
82
    )
83

84

85
def float_product(factors):
1✔
86
    """Build product of float factors.
87

88
    Simplify ones and returning 1.0 if empty sequence.
89
    """
90
    factors = [f for f in factors if not is_one_lexpr(f)]
1✔
91
    if len(factors) == 0:
1✔
92
        return LiteralFloat(1.0)
1✔
93
    elif len(factors) == 1:
1✔
94
        return factors[0]
1✔
95
    else:
96
        return Product(factors)
1✔
97

98

99
class DataType(Enum):
1✔
100
    """Representation of data types for variables in LNodes.
101

102
    These can be REAL (same type as geometry),
103
    SCALAR (same type as tensor), or INT (for entity indices etc.)
104
    """
105

106
    REAL = 0
1✔
107
    SCALAR = 1
1✔
108
    INT = 2
1✔
109
    BOOL = 3
1✔
110
    NONE = 4
1✔
111

112

113
def merge_dtypes(dtypes: list[DataType]):
1✔
114
    """Promote dtype to SCALAR or REAL if either argument matches."""
115
    if DataType.NONE in dtypes:
1✔
116
        raise ValueError(f"Invalid DataType in LNodes {dtypes}")
×
117
    if DataType.SCALAR in dtypes:
1✔
118
        return DataType.SCALAR
1✔
119
    elif DataType.REAL in dtypes:
1✔
120
        return DataType.REAL
1✔
121
    elif DataType.INT in dtypes:
1✔
122
        return DataType.INT
1✔
123
    elif DataType.BOOL in dtypes:
×
124
        return DataType.BOOL
×
125
    else:
126
        raise ValueError(f"Can't get dtype for operation with {dtypes}")
×
127

128

129
class LNode:
1✔
130
    """Base class for all AST nodes."""
131

132
    def __eq__(self, other):
1✔
133
        """Check for equality."""
134
        return NotImplemented
1✔
135

136
    def __ne__(self, other):
1✔
137
        """Check for inequality."""
138
        return NotImplemented
×
139

140

141
class LExpr(LNode):
1✔
142
    """Base class for all expressions.
143

144
    All subtypes should define a 'precedence' class attribute.
145
    """
146

147
    dtype = DataType.NONE
1✔
148

149
    def __getitem__(self, indices):
1✔
150
        """Get an item."""
151
        return ArrayAccess(self, indices)
1✔
152

153
    def __neg__(self):
1✔
154
        """Negate."""
155
        if isinstance(self, LiteralFloat):
×
156
            return LiteralFloat(-self.value)
×
157
        if isinstance(self, LiteralInt):
×
158
            return LiteralInt(-self.value)
×
159
        return Neg(self)
×
160

161
    def __add__(self, other):
1✔
162
        """Add."""
163
        other = as_lexpr(other)
1✔
164
        if is_zero_lexpr(self):
1✔
165
            return other
×
166
        if is_zero_lexpr(other):
1✔
167
            return self
1✔
168
        if isinstance(other, Neg):
1✔
169
            return Sub(self, other.arg)
×
170
        return Add(self, other)
1✔
171

172
    def __radd__(self, other):
1✔
173
        """Add."""
174
        other = as_lexpr(other)
1✔
175
        if is_zero_lexpr(self):
1✔
176
            return other
×
177
        if is_zero_lexpr(other):
1✔
178
            return self
1✔
179
        if isinstance(self, Neg):
1✔
180
            return Sub(other, self.arg)
×
181
        return Add(other, self)
1✔
182

183
    def __sub__(self, other):
1✔
184
        """Subtract."""
185
        other = as_lexpr(other)
×
186
        if is_zero_lexpr(self):
×
187
            return -other
×
188
        if is_zero_lexpr(other):
×
189
            return self
×
190
        if isinstance(other, Neg):
×
191
            return Add(self, other.arg)
×
192
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
×
193
            return LiteralInt(self.value - other.value)
×
194
        return Sub(self, other)
×
195

196
    def __rsub__(self, other):
1✔
197
        """Subtract."""
198
        other = as_lexpr(other)
×
199
        if is_zero_lexpr(self):
×
200
            return other
×
201
        if is_zero_lexpr(other):
×
202
            return -self
×
203
        if isinstance(self, Neg):
×
204
            return Add(other, self.arg)
×
205
        return Sub(other, self)
×
206

207
    def __mul__(self, other):
1✔
208
        """Multiply."""
209
        other = as_lexpr(other)
1✔
210
        if is_zero_lexpr(self):
1✔
211
            return self
×
212
        if is_zero_lexpr(other):
1✔
213
            return other
1✔
214
        if is_one_lexpr(self):
1✔
215
            return other
1✔
216
        if is_one_lexpr(other):
1✔
217
            return self
1✔
218
        if is_negative_one_lexpr(other):
1✔
219
            return Neg(self)
×
220
        if is_negative_one_lexpr(self):
1✔
221
            return Neg(other)
1✔
222
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
1✔
223
            return LiteralInt(self.value * other.value)
×
224
        return Mul(self, other)
1✔
225

226
    def __rmul__(self, other):
1✔
227
        """Multiply."""
228
        other = as_lexpr(other)
1✔
229
        if is_zero_lexpr(self):
1✔
230
            return self
×
231
        if is_zero_lexpr(other):
1✔
232
            return other
×
233
        if is_one_lexpr(self):
1✔
234
            return other
×
235
        if is_one_lexpr(other):
1✔
236
            return self
1✔
237
        if is_negative_one_lexpr(other):
1✔
238
            return Neg(self)
×
239
        if is_negative_one_lexpr(self):
1✔
240
            return Neg(other)
×
241
        return Mul(other, self)
1✔
242

243
    def __div__(self, other):
1✔
244
        """Divide."""
245
        other = as_lexpr(other)
1✔
246
        if is_zero_lexpr(other):
1✔
247
            raise ValueError("Division by zero!")
×
248
        if is_zero_lexpr(self):
1✔
249
            return self
×
250
        return Div(self, other)
1✔
251

252
    def __rdiv__(self, other):
1✔
253
        """Divide."""
254
        other = as_lexpr(other)
×
255
        if is_zero_lexpr(self):
×
256
            raise ValueError("Division by zero!")
×
257
        if is_zero_lexpr(other):
×
258
            return other
×
259
        return Div(other, self)
×
260

261
    # TODO: Error check types?
262
    __truediv__ = __div__
1✔
263
    __rtruediv__ = __rdiv__
1✔
264
    __floordiv__ = __div__
1✔
265
    __rfloordiv__ = __rdiv__
1✔
266

267

268
class LExprOperator(LExpr):
1✔
269
    """Base class for all expression operators."""
270

271
    sideeffect = False
1✔
272

273

274
class LExprTerminal(LExpr):
1✔
275
    """Base class for all  expression terminals."""
276

277
    sideeffect = False
1✔
278

279

280
class LiteralFloat(LExprTerminal):
1✔
281
    """A floating point literal value."""
282

283
    precedence = PRECEDENCE.LITERAL
1✔
284

285
    def __init__(self, value):
1✔
286
        """Initialise."""
287
        assert isinstance(value, float | complex)
1✔
288
        self.value = value
1✔
289
        if isinstance(value, complex):
1✔
290
            self.dtype = DataType.SCALAR
1✔
291
        else:
292
            self.dtype = DataType.REAL
1✔
293

294
    def __eq__(self, other):
1✔
295
        """Check equality."""
UNCOV
296
        return isinstance(other, LiteralFloat) and self.value == other.value
×
297

298
    def __float__(self):
1✔
299
        """Convert to float."""
UNCOV
300
        return float(self.value)
×
301

302
    def __repr__(self):
1✔
303
        """Representation."""
UNCOV
304
        return str(self.value)
×
305

306

307
class LiteralInt(LExprTerminal):
1✔
308
    """An integer literal value."""
309

310
    precedence = PRECEDENCE.LITERAL
1✔
311

312
    def __init__(self, value):
1✔
313
        """Initialise."""
314
        assert isinstance(value, int | np.number)
1✔
315
        self.value = value
1✔
316
        self.dtype = DataType.INT
1✔
317

318
    def __eq__(self, other):
1✔
319
        """Check equality."""
320
        return isinstance(other, LiteralInt) and self.value == other.value
1✔
321

322
    def __hash__(self):
1✔
323
        """Hash."""
324
        return hash(self.value)
1✔
325

326
    def __repr__(self):
1✔
327
        """Representation."""
UNCOV
328
        return str(self.value)
×
329

330

331
class Symbol(LExprTerminal):
1✔
332
    """A named symbol."""
333

334
    precedence = PRECEDENCE.SYMBOL
1✔
335

336
    def __init__(self, name: str, dtype):
1✔
337
        """Initialise."""
338
        assert isinstance(name, str)
1✔
339
        assert name.replace("_", "").isalnum()
1✔
340
        self.name = name
1✔
341
        self.dtype = dtype
1✔
342

343
    def __eq__(self, other):
1✔
344
        """Check equality."""
345
        return isinstance(other, Symbol) and self.name == other.name
1✔
346

347
    def __hash__(self):
1✔
348
        """Hash."""
349
        return hash(self.name)
1✔
350

351
    def __repr__(self):
1✔
352
        """Representation."""
UNCOV
353
        return self.name
×
354

355

356
class MultiIndex(LExpr):
1✔
357
    """A multi-index for accessing tensors flattened in memory."""
358

359
    precedence = PRECEDENCE.SYMBOL
1✔
360

361
    def __init__(self, symbols: list, sizes: list):
1✔
362
        """Initialise."""
363
        self.dtype = DataType.INT
1✔
364
        self.sizes = sizes
1✔
365
        self.symbols = [as_lexpr(sym) for sym in symbols]
1✔
366
        for sym in self.symbols:
1✔
367
            assert sym.dtype == DataType.INT
1✔
368

369
        dim = len(sizes)
1✔
370
        if dim == 0:
1✔
371
            self.global_index: LExpr = LiteralInt(0)
1✔
372
        else:
373
            stride = [np.prod(sizes[i:]) for i in range(dim)] + [LiteralInt(1)]
1✔
374
            self.global_index = Sum(n * sym for n, sym in zip(stride[1:], symbols))
1✔
375

376
    @property
1✔
377
    def dim(self):
1✔
378
        """Dimension of the multi-index."""
379
        return len(self.sizes)
1✔
380

381
    def size(self):
1✔
382
        """Size of the multi-index."""
UNCOV
383
        return np.prod(self.sizes)
×
384

385
    def local_index(self, idx):
1✔
386
        """Get the local index."""
387
        assert idx < len(self.symbols)
1✔
388
        return self.symbols[idx]
1✔
389

390
    def intersection(self, other):
1✔
391
        """Get the intersection."""
UNCOV
392
        symbols = []
×
UNCOV
393
        sizes = []
×
UNCOV
394
        for sym, size in zip(self.symbols, self.sizes):
×
UNCOV
395
            if sym in other.symbols:
×
UNCOV
396
                i = other.symbols.index(sym)
×
UNCOV
397
                assert other.sizes[i] == size
×
398
                symbols.append(sym)
×
399
                sizes.append(size)
×
400
        return MultiIndex(symbols, sizes)
×
401

402
    def union(self, other):
1✔
403
        """Get the union.
404

405
        Note:
406
            Result may depend on order a.union(b) != b.union(a)
407
        """
UNCOV
408
        symbols = self.symbols.copy()
×
UNCOV
409
        sizes = self.sizes.copy()
×
UNCOV
410
        for sym, size in zip(other.symbols, other.sizes):
×
UNCOV
411
            if sym in symbols:
×
UNCOV
412
                i = symbols.index(sym)
×
UNCOV
413
                assert sizes[i] == size
×
414
            else:
415
                symbols.append(sym)
×
416
                sizes.append(size)
×
417
        return MultiIndex(symbols, sizes)
×
418

419
    def difference(self, other):
1✔
420
        """Get the difference."""
421
        symbols = []
×
422
        sizes = []
×
423
        for idx, size in zip(self.symbols, self.sizes):
×
UNCOV
424
            if idx not in other.symbols:
×
UNCOV
425
                symbols.append(idx)
×
UNCOV
426
                sizes.append(size)
×
427
        return MultiIndex(symbols, sizes)
×
428

429
    def __hash__(self):
1✔
430
        """Hash."""
431
        return hash(self.global_index.__repr__)
×
432

433

434
class PrefixUnaryOp(LExprOperator):
1✔
435
    """Base class for unary operators."""
436

437
    def __init__(self, arg):
1✔
438
        """Initialise."""
UNCOV
439
        self.arg = as_lexpr(arg)
×
440

441
    def __eq__(self, other):
1✔
442
        """Check equality."""
UNCOV
443
        return isinstance(other, type(self)) and self.arg == other.arg
×
444

445

446
class BinOp(LExprOperator):
1✔
447
    """A binary operator."""
448

449
    def __init__(self, lhs, rhs):
1✔
450
        """Initialise."""
451
        self.lhs = as_lexpr(lhs)
1✔
452
        self.rhs = as_lexpr(rhs)
1✔
453

454
    def __eq__(self, other):
1✔
455
        """Check equality."""
456
        return isinstance(other, type(self)) and self.lhs == other.lhs and self.rhs == other.rhs
1✔
457

458
    def __hash__(self):
1✔
459
        """Hash."""
460
        return hash(self.lhs) + hash(self.rhs)
1✔
461

462
    def __repr__(self):
1✔
463
        """Representation."""
UNCOV
464
        return f"({self.lhs} {self.op} {self.rhs})"
×
465

466

467
class ArithmeticBinOp(BinOp):
1✔
468
    """An artithmetic binary operator."""
469

470
    def __init__(self, lhs, rhs):
1✔
471
        """Initialise."""
472
        self.lhs = as_lexpr(lhs)
1✔
473
        self.rhs = as_lexpr(rhs)
1✔
474
        self.dtype = merge_dtypes([self.lhs.dtype, self.rhs.dtype])
1✔
475

476

477
class NaryOp(LExprOperator):
1✔
478
    """Base class for special n-ary operators."""
479

480
    op = ""
1✔
481

482
    def __init__(self, args):
1✔
483
        """Initialise."""
484
        self.args = [as_lexpr(arg) for arg in args]
1✔
485
        self.dtype = self.args[0].dtype
1✔
486
        for arg in self.args:
1✔
487
            self.dtype = merge_dtypes([self.dtype, arg.dtype])
1✔
488

489
    def __eq__(self, other):
1✔
490
        """Check equality."""
491
        return (
1✔
492
            isinstance(other, type(self))
493
            and len(self.args) == len(other.args)
494
            and all(a == b for a, b in zip(self.args, other.args))
495
        )
496

497
    def __repr__(self) -> str:
1✔
498
        """Representation."""
UNCOV
499
        return f"{self.op} ".join(f"{i} " for i in self.args)
×
500

501
    def __hash__(self):
1✔
502
        """Hash."""
503
        return hash(tuple(self.args))
1✔
504

505

506
class Neg(PrefixUnaryOp):
1✔
507
    """Negation operator."""
508

509
    precedence = PRECEDENCE.NEG
1✔
510
    op = "-"
1✔
511

512
    def __init__(self, arg):
1✔
513
        """Initialise."""
514
        self.arg = as_lexpr(arg)
1✔
515
        self.dtype = self.arg.dtype
1✔
516

517

518
class Not(PrefixUnaryOp):
1✔
519
    """Not operator."""
520

521
    precedence = PRECEDENCE.NOT
1✔
522
    op = "!"
1✔
523

524

525
class Add(ArithmeticBinOp):
1✔
526
    """Add operator."""
527

528
    precedence = PRECEDENCE.ADD
1✔
529
    op = "+"
1✔
530

531

532
class Sub(ArithmeticBinOp):
1✔
533
    """Subtract operator."""
534

535
    precedence = PRECEDENCE.SUB
1✔
536
    op = "-"
1✔
537

538

539
class Mul(ArithmeticBinOp):
1✔
540
    """Multiply operator."""
541

542
    precedence = PRECEDENCE.MUL
1✔
543
    op = "*"
1✔
544

545

546
class Div(ArithmeticBinOp):
1✔
547
    """Division operator."""
548

549
    precedence = PRECEDENCE.DIV
1✔
550
    op = "/"
1✔
551

552

553
class EQ(BinOp):
1✔
554
    """Equality operator."""
555

556
    precedence = PRECEDENCE.EQ
1✔
557
    op = "=="
1✔
558

559

560
class NE(BinOp):
1✔
561
    """Inequality operator."""
562

563
    precedence = PRECEDENCE.NE
1✔
564
    op = "!="
1✔
565

566

567
class LT(BinOp):
1✔
568
    """Less than operator."""
569

570
    precedence = PRECEDENCE.LT
1✔
571
    op = "<"
1✔
572

573

574
class GT(BinOp):
1✔
575
    """Greater than operator."""
576

577
    precedence = PRECEDENCE.GT
1✔
578
    op = ">"
1✔
579

580

581
class LE(BinOp):
1✔
582
    """Less than or equal to operator."""
583

584
    precedence = PRECEDENCE.LE
1✔
585
    op = "<="
1✔
586

587

588
class GE(BinOp):
1✔
589
    """Greater than or equal to operator."""
590

591
    precedence = PRECEDENCE.GE
1✔
592
    op = ">="
1✔
593

594

595
class And(BinOp):
1✔
596
    """And operator."""
597

598
    precedence = PRECEDENCE.AND
1✔
599
    op = "&&"
1✔
600

601

602
class Or(BinOp):
1✔
603
    """Or operator."""
604

605
    precedence = PRECEDENCE.OR
1✔
606
    op = "||"
1✔
607

608

609
class Sum(NaryOp):
1✔
610
    """Sum of any number of operands."""
611

612
    precedence = PRECEDENCE.ADD
1✔
613
    op = "+"
1✔
614

615

616
class Product(NaryOp):
1✔
617
    """Product of any number of operands."""
618

619
    precedence = PRECEDENCE.MUL
1✔
620
    op = "*"
1✔
621

622

623
class MathFunction(LExprOperator):
1✔
624
    """A Math Function, with any arguments."""
625

626
    precedence = PRECEDENCE.HIGHEST
1✔
627

628
    def __init__(self, func, args):
1✔
629
        """Initialise."""
630
        self.function = func
1✔
631
        self.args = [as_lexpr(arg) for arg in args]
1✔
632
        self.dtype = self.args[0].dtype
1✔
633

634
    def __eq__(self, other):
1✔
635
        """Check equality."""
UNCOV
636
        return (
×
637
            isinstance(other, type(self))
638
            and self.function == other.function
639
            and len(self.args) == len(other.args)
640
            and all(a == b for a, b in zip(self.args, other.args))
641
        )
642

643

644
class AssignOp(BinOp):
1✔
645
    """Base class for assignment operators."""
646

647
    precedence = PRECEDENCE.ASSIGN
1✔
648
    sideeffect = True
1✔
649

650
    def __init__(self, lhs, rhs):
1✔
651
        """Initialise."""
652
        assert isinstance(lhs, LNode)
1✔
653
        BinOp.__init__(self, lhs, rhs)
1✔
654

655

656
class Assign(AssignOp):
1✔
657
    """Assign operator."""
658

659
    op = "="
1✔
660

661

662
class AssignAdd(AssignOp):
1✔
663
    """Assign add operator."""
664

665
    op = "+="
1✔
666

667

668
class AssignSub(AssignOp):
1✔
669
    """Assign subtract operator."""
670

671
    op = "-="
1✔
672

673

674
class AssignMul(AssignOp):
1✔
675
    """Assign multiply operator."""
676

677
    op = "*="
1✔
678

679

680
class AssignDiv(AssignOp):
1✔
681
    """Assign division operator."""
682

683
    op = "/="
1✔
684

685

686
class ArrayAccess(LExprOperator):
1✔
687
    """Array access."""
688

689
    precedence = PRECEDENCE.SUBSCRIPT
1✔
690

691
    def __init__(self, array, indices):
1✔
692
        """Initialise."""
693
        # Typecheck array argument
694
        if isinstance(array, Symbol):
1✔
695
            self.array = array
1✔
696
            self.dtype = array.dtype
1✔
UNCOV
697
        elif isinstance(array, ArrayDecl):
×
UNCOV
698
            self.array = array.symbol
×
UNCOV
699
            self.dtype = array.symbol.dtype
×
700
        else:
UNCOV
701
            raise ValueError(f"Unexpected array type {type(array).__name__}")
×
702

703
        # Allow expressions or literals as indices
704
        if not isinstance(indices, list | tuple):
1✔
705
            indices = (indices,)
1✔
706
        self.indices = tuple(as_lexpr(i) for i in indices)
1✔
707

708
        # Early error checking for negative array dimensions
709
        if any(isinstance(i, int) and i < 0 for i in self.indices):
1✔
UNCOV
710
            raise ValueError("Index value < 0.")
×
711

712
        # Additional dimension checks possible if we get an ArrayDecl instead of just a name
713
        if isinstance(array, ArrayDecl):
1✔
UNCOV
714
            if len(self.indices) != len(array.sizes):
×
UNCOV
715
                raise ValueError("Invalid number of indices.")
×
UNCOV
716
            ints = (int, LiteralInt)
×
UNCOV
717
            if any(
×
718
                (isinstance(i, ints) and isinstance(d, ints) and int(i) >= int(d))
719
                for i, d in zip(self.indices, array.sizes)
720
            ):
UNCOV
721
                raise ValueError("Index value >= array dimension.")
×
722

723
    def __getitem__(self, indices):
1✔
724
        """Handle nested expr[i][j]."""
725
        if isinstance(indices, list):
1✔
UNCOV
726
            indices = tuple(indices)
×
727
        elif not isinstance(indices, tuple):
1✔
728
            indices = (indices,)
1✔
729
        return ArrayAccess(self.array, self.indices + indices)
1✔
730

731
    def __eq__(self, other):
1✔
732
        """Check equality."""
733
        return (
1✔
734
            isinstance(other, type(self))
735
            and self.array == other.array
736
            and self.indices == other.indices
737
        )
738

739
    def __hash__(self):
1✔
740
        """Hash."""
741
        return hash(self.array)
1✔
742

743
    def __repr__(self):
1✔
744
        """Representation."""
UNCOV
745
        return str(self.array) + "[" + ", ".join(str(i) for i in self.indices) + "]"
×
746

747

748
class Conditional(LExprOperator):
1✔
749
    """Conditional."""
750

751
    precedence = PRECEDENCE.CONDITIONAL
1✔
752

753
    def __init__(self, condition, true, false):
1✔
754
        """Initialise."""
755
        self.condition = as_lexpr(condition)
1✔
756
        self.true = as_lexpr(true)
1✔
757
        self.false = as_lexpr(false)
1✔
758
        self.dtype = merge_dtypes([self.true.dtype, self.false.dtype])
1✔
759

760
    def __eq__(self, other):
1✔
761
        """Check equality."""
UNCOV
762
        return (
×
763
            isinstance(other, type(self))
764
            and self.condition == other.condition
765
            and self.true == other.true
766
            and self.false == other.false
767
        )
768

769

770
def as_lexpr(node):
1✔
771
    """Typechecks and wraps an object as a valid LExpr.
772

773
    Accepts LExpr nodes, treats int and float as literals.
774

775
    """
776
    if isinstance(node, LExpr):
1✔
777
        return node
1✔
778
    elif isinstance(node, numbers.Integral):
1✔
779
        return LiteralInt(node)
1✔
780
    elif isinstance(node, numbers.Real):
1✔
781
        return LiteralFloat(node)
1✔
782
    else:
UNCOV
783
        raise RuntimeError(f"Unexpected LExpr type {type(node)}:\n{node}")
×
784

785

786
class Statement(LNode):
1✔
787
    """Make an expression into a statement."""
788

789
    def __init__(self, expr):
1✔
790
        """Initialise."""
791
        self.expr = as_lexpr(expr)
1✔
792

793
    def __eq__(self, other):
1✔
794
        """Check equality."""
UNCOV
795
        return isinstance(other, type(self)) and self.expr == other.expr
×
796

797
    def __hash__(self) -> int:
1✔
798
        """Hash."""
UNCOV
799
        return hash(self.expr)
×
800

801

802
def as_statement(node):
1✔
803
    """Perform type checking on node and wrap in a suitable statement type if necessary."""
804
    if isinstance(node, StatementList) and len(node.statements) == 1:
1✔
805
        # Cleans up the expression tree a bit
806
        return node.statements[0]
1✔
807
    elif isinstance(node, Statement):
1✔
808
        # No-op
809
        return node
1✔
810
    elif isinstance(node, LExprOperator):
1✔
811
        if node.sideeffect:
1✔
812
            # Special case for using assignment expressions as statements
813
            return Statement(node)
1✔
814
        else:
UNCOV
815
            raise RuntimeError(
×
816
                f"Trying to create a statement of lexprOperator type {type(node)}:\n{node}"
817
            )
818

819
    elif isinstance(node, list):
1✔
820
        # Convenience case for list of statements
821
        if len(node) == 1:
1✔
822
            # Cleans up the expression tree a bit
823
            return as_statement(node[0])
1✔
824
        else:
825
            return StatementList(node)
1✔
826
    elif isinstance(node, Section):
1✔
827
        return node
1✔
828
    else:
UNCOV
829
        raise RuntimeError(f"Unexpected Statement type {type(node)}:\n{node}")
×
830

831

832
class Annotation(Enum):
1✔
833
    """Annotation."""
834

835
    fuse = 1  # fuse loops in section
1✔
836
    unroll = 2  # unroll loop in section
1✔
837
    licm = 3  # loop invariant code motion
1✔
838
    factorize = 4  # apply sum factorization
1✔
839

840

841
class Declaration(Statement):
1✔
842
    """Base class for all declarations."""
843

844
    def __init__(self, symbol):
1✔
845
        """Initialise."""
UNCOV
846
        self.symbol = symbol
×
847

848
    def __eq__(self, other):
1✔
849
        """Check equality."""
UNCOV
850
        return isinstance(other, type(self)) and self.symbol == other.symbol
×
851

852

853
def is_declaration(node) -> bool:
1✔
854
    """Check if a node is a declaration."""
855
    return isinstance(node, VariableDecl) or isinstance(node, ArrayDecl)
1✔
856

857

858
class Section(LNode):
1✔
859
    """A section of code with a name and a list of statements."""
860

861
    def __init__(
1✔
862
        self,
863
        name: str,
864
        statements: list[LNode],
865
        declarations: Sequence[Declaration],
866
        input: list[Symbol] | None = None,
867
        output: list[Symbol] | None = None,
868
        annotations: list[Annotation] | None = None,
869
    ):
870
        """Initialise."""
871
        self.name = name
1✔
872
        self.statements = [as_statement(st) for st in statements]
1✔
873
        self.annotations = annotations or []
1✔
874
        self.input = input or []
1✔
875
        self.declarations = declarations or []
1✔
876
        self.output = output or []
1✔
877

878
        for decl in self.declarations:
1✔
879
            assert is_declaration(decl)
1✔
880
            if decl.symbol not in self.output:
1✔
UNCOV
881
                self.output.append(decl.symbol)
×
882

883
    def __eq__(self, other):
1✔
884
        """Check equality."""
885
        attributes = ("name", "input", "output", "annotations", "statements")
1✔
886
        return isinstance(other, type(self)) and all(
1✔
887
            getattr(self, name) == getattr(other, name) for name in attributes
888
        )
889

890

891
class StatementList(LNode):
1✔
892
    """A simple sequence of statements."""
893

894
    def __init__(self, statements):
1✔
895
        """Initialise."""
896
        self.statements = [as_statement(st) for st in statements]
1✔
897

898
    def __eq__(self, other):
1✔
899
        """Check equality."""
UNCOV
900
        return isinstance(other, type(self)) and self.statements == other.statements
×
901

902
    def __hash__(self) -> int:
1✔
903
        """Hash."""
UNCOV
904
        return hash(tuple(self.statements))
×
905

906
    def __repr__(self):
1✔
907
        """Representation."""
908
        return f"StatementList({self.statements})"
×
909

910

911
class Comment(Statement):
1✔
912
    """Line comment(s) used for annotating the generated code with human readable remarks."""
913

914
    def __init__(self, comment: str):
1✔
915
        """Initialise."""
916
        assert isinstance(comment, str)
1✔
917
        self.comment = comment
1✔
918

919
    def __eq__(self, other):
1✔
920
        """Check equality."""
UNCOV
921
        return isinstance(other, type(self)) and self.comment == other.comment
×
922

923

924
def commented_code_list(code, comments):
1✔
925
    """Add comment to code list if the list is not empty."""
926
    if isinstance(code, LNode):
1✔
UNCOV
927
        code = [code]
×
928
    assert isinstance(code, list)
1✔
929
    if code:
1✔
930
        if not isinstance(comments, list | tuple):
1✔
931
            comments = [comments]
1✔
932
        comments = [Comment(c) for c in comments]
1✔
933
        code = comments + code
1✔
934
    return code
1✔
935

936

937
# Type and variable declarations
938

939

940
class VariableDecl(Declaration):
1✔
941
    """Declare a variable, optionally define initial value."""
942

943
    def __init__(self, symbol, value=None):
1✔
944
        """Initialise."""
945
        assert isinstance(symbol, Symbol)
1✔
946
        assert symbol.dtype is not None
1✔
947
        self.symbol = symbol
1✔
948

949
        if value is not None:
1✔
950
            value = as_lexpr(value)
1✔
951
        self.value = value
1✔
952

953
    def __eq__(self, other):
1✔
954
        """Check equality."""
UNCOV
955
        return (
×
956
            isinstance(other, type(self))
957
            and self.typename == other.typename
958
            and self.symbol == other.symbol
959
            and self.value == other.value
960
        )
961

962

963
class ArrayDecl(Declaration):
1✔
964
    """A declaration or definition of an array.
965

966
    Note that just setting values=0 is sufficient to initialize the
967
    entire array to zero.
968

969
    Otherwise use nested lists of lists to represent multidimensional
970
    array values to initialize to.
971

972
    """
973

974
    def __init__(self, symbol, sizes=None, values=None, const=False):
1✔
975
        """Initialise."""
976
        assert isinstance(symbol, Symbol)
1✔
977
        self.symbol = symbol
1✔
978
        assert symbol.dtype
1✔
979

980
        if sizes is None:
1✔
981
            assert values is not None
1✔
982
            sizes = values.shape
1✔
983
        if isinstance(sizes, int):
1✔
984
            sizes = (sizes,)
1✔
985
        self.sizes = tuple(sizes)
1✔
986

987
        if values is None:
1✔
UNCOV
988
            assert sizes is not None
×
989

990
        # NB! No type checking, assuming nested lists of literal values. Not applying as_lexpr.
991
        if isinstance(values, list | tuple):
1✔
992
            self.values = np.asarray(values)
1✔
993
        else:
994
            self.values = values
1✔
995

996
        self.const = const
1✔
997
        self.dtype = symbol.dtype
1✔
998

999
    def __eq__(self, other):
1✔
1000
        """Check equality."""
UNCOV
1001
        attributes = ("dtype", "symbol", "sizes", "values")
×
UNCOV
1002
        return isinstance(other, type(self)) and all(
×
1003
            getattr(self, name) == getattr(self, name) for name in attributes
1004
        )
1005

1006
    def __hash__(self) -> int:
1✔
1007
        """Hash."""
UNCOV
1008
        return hash(self.symbol)
×
1009

1010

1011
def is_simple_inner_loop(code):
1✔
1012
    """Check if code is a simple inner loop."""
UNCOV
1013
    if isinstance(code, ForRange) and is_simple_inner_loop(code.body):
×
UNCOV
1014
        return True
×
UNCOV
1015
    if isinstance(code, Statement) and isinstance(code.expr, AssignOp):
×
1016
        return True
×
UNCOV
1017
    return False
×
1018

1019

1020
def depth(code) -> int:
1✔
1021
    """Get depth of code."""
1022
    if isinstance(code, ForRange):
1✔
1023
        return 1 + depth(code.body)
1✔
1024
    if isinstance(code, StatementList):
1✔
1025
        return max([depth(c) for c in code.statements])
1✔
1026
    return 0
1✔
1027

1028

1029
class ForRange(Statement):
1✔
1030
    """Slightly higher-level for loop assuming incrementing an index over a range."""
1031

1032
    def __init__(self, index, begin, end, body):
1✔
1033
        """Initialise."""
1034
        assert isinstance(index, Symbol) or isinstance(index, MultiIndex)
1✔
1035
        self.index = index
1✔
1036
        self.begin = as_lexpr(begin)
1✔
1037
        self.end = as_lexpr(end)
1✔
1038
        assert isinstance(body, list)
1✔
1039
        self.body = StatementList(body)
1✔
1040

1041
    def as_tuple(self):
1✔
1042
        """Convert to a tuple."""
UNCOV
1043
        return (self.index, self.begin, self.end, self.body)
×
1044

1045
    def __eq__(self, other):
1✔
1046
        """Check equality."""
UNCOV
1047
        attributes = ("index", "begin", "end", "body")
×
UNCOV
1048
        return isinstance(other, type(self)) and all(
×
1049
            getattr(self, name) == getattr(self, name) for name in attributes
1050
        )
1051

1052
    def __hash__(self) -> int:
1✔
1053
        """Hash."""
UNCOV
1054
        return hash(self.as_tuple())
×
1055

1056

1057
def _math_function(op, *args):
1✔
1058
    """Get a math function."""
1059
    name = op._ufl_handler_name_
1✔
1060
    dtype = args[0].dtype
1✔
1061
    if name in ("conj", "real") and dtype == DataType.REAL:
1✔
1062
        assert len(args) == 1
1✔
1063
        return args[0]
1✔
1064
    if name == "imag" and dtype == DataType.REAL:
1✔
UNCOV
1065
        assert len(args) == 1
×
UNCOV
1066
        return LiteralFloat(0.0)
×
1067
    return MathFunction(name, args)
1✔
1068

1069

1070
# Lookup table for handler to call when the ufl_to_lnodes method (below) is
1071
# called, depending on the first argument type.
1072
_ufl_call_lookup = {
1✔
1073
    ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)),
1074
    ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)),
1075
    ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()),
1076
    ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0),
1077
    ufl.algebra.Product: lambda x, a, b: a * b,
1078
    ufl.algebra.Sum: lambda x, a, b: a + b,
1079
    ufl.algebra.Division: lambda x, a, b: a / b,
1080
    ufl.algebra.Abs: _math_function,
1081
    ufl.algebra.Power: _math_function,
1082
    ufl.algebra.Real: _math_function,
1083
    ufl.algebra.Imag: _math_function,
1084
    ufl.algebra.Conj: _math_function,
1085
    ufl.classes.GT: lambda x, a, b: GT(a, b),
1086
    ufl.classes.GE: lambda x, a, b: GE(a, b),
1087
    ufl.classes.EQ: lambda x, a, b: EQ(a, b),
1088
    ufl.classes.NE: lambda x, a, b: NE(a, b),
1089
    ufl.classes.LT: lambda x, a, b: LT(a, b),
1090
    ufl.classes.LE: lambda x, a, b: LE(a, b),
1091
    ufl.classes.AndCondition: lambda x, a, b: And(a, b),
1092
    ufl.classes.OrCondition: lambda x, a, b: Or(a, b),
1093
    ufl.classes.NotCondition: lambda x, a: Not(a),
1094
    ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f),
1095
    ufl.classes.MinValue: _math_function,
1096
    ufl.classes.MaxValue: _math_function,
1097
    ufl.mathfunctions.Sqrt: _math_function,
1098
    ufl.mathfunctions.Ln: _math_function,
1099
    ufl.mathfunctions.Exp: _math_function,
1100
    ufl.mathfunctions.Cos: _math_function,
1101
    ufl.mathfunctions.Sin: _math_function,
1102
    ufl.mathfunctions.Tan: _math_function,
1103
    ufl.mathfunctions.Cosh: _math_function,
1104
    ufl.mathfunctions.Sinh: _math_function,
1105
    ufl.mathfunctions.Tanh: _math_function,
1106
    ufl.mathfunctions.Acos: _math_function,
1107
    ufl.mathfunctions.Asin: _math_function,
1108
    ufl.mathfunctions.Atan: _math_function,
1109
    ufl.mathfunctions.Erf: _math_function,
1110
    ufl.mathfunctions.Atan2: _math_function,
1111
    ufl.mathfunctions.MathFunction: _math_function,
1112
    ufl.mathfunctions.BesselJ: _math_function,
1113
    ufl.mathfunctions.BesselY: _math_function,
1114
}
1115

1116

1117
def ufl_to_lnodes(operator, *args):
1✔
1118
    """Call appropriate handler, depending on the type of operator."""
1119
    optype = type(operator)
1✔
1120
    if optype in _ufl_call_lookup:
1✔
1121
        return _ufl_call_lookup[optype](operator, *args)
1✔
1122
    else:
UNCOV
1123
        raise RuntimeError(f"Missing lookup for expr type {optype}.")
×
1124

1125

1126
def create_nested_for_loops(indices: list[MultiIndex], body):
1✔
1127
    """Create nested for loops over list of indices.
1128

1129
    The depth of the nested for loops is equal to the sub-indices for all
1130
    MultiIndex combined.
1131
    """
1132
    ranges = [r for idx in indices for r in idx.sizes]
1✔
1133
    indices = [idx.local_index(i) for idx in indices for i in range(len(idx.sizes))]
1✔
1134
    depth = len(ranges)
1✔
1135
    for i in reversed(range(depth)):
1✔
1136
        body = ForRange(indices[i], 0, ranges[i], body=[body])
1✔
1137
    return body
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc