• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 13486553783

23 Feb 2025 08:13PM UTC coverage: 82.188% (-0.05%) from 82.235%
13486553783

push

github

web-flow
Resolve jacobians from duplicate meshes (#733)

* Fix multiple jacobians

* add test that currently fails on main

* ruff

---------

Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>

1 of 1 new or added line in 1 file covered. (100.0%)

3 existing lines in 1 file now uncovered.

3516 of 4278 relevant lines covered (82.19%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.45
/ffcx/codegeneration/lnodes.py
1
# Copyright (C) 2013-2023 Martin Sandve Alnæs, Chris Richardson
2
#
3
# This file is part of FFCx.(https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""LNodes.
7

8
LNodes is intended as a minimal generic language description.
9
Formatting is done later, depending on the target language.
10

11
Supported:
12
 Floating point (and complex) and integer variables and multidimensional arrays
13
 Range loops
14
 Simple arithmetic, +-*/
15
 Math operations
16
 Logic conditions
17
 Comments
18
Not supported:
19
 Pointers
20
 Function Calls
21
 Flow control (if, switch, while)
22
 Booleans
23
 Strings
24
"""
25

26
import numbers
1✔
27
from collections.abc import Sequence
1✔
28
from enum import Enum
1✔
29
from typing import Optional
1✔
30

31
import numpy as np
1✔
32
import ufl
1✔
33

34

35
class PRECEDENCE:
1✔
36
    """An enum-like class for operator precedence levels."""
37

38
    HIGHEST = 0
1✔
39
    LITERAL = 0
1✔
40
    SYMBOL = 0
1✔
41
    SUBSCRIPT = 2
1✔
42

43
    NOT = 3
1✔
44
    NEG = 3
1✔
45

46
    MUL = 4
1✔
47
    DIV = 4
1✔
48

49
    ADD = 5
1✔
50
    SUB = 5
1✔
51

52
    LT = 7
1✔
53
    LE = 7
1✔
54
    GT = 7
1✔
55
    GE = 7
1✔
56
    EQ = 8
1✔
57
    NE = 8
1✔
58
    AND = 11
1✔
59
    OR = 12
1✔
60
    CONDITIONAL = 13
1✔
61
    ASSIGN = 13
1✔
62
    LOWEST = 15
1✔
63

64

65
def is_zero_lexpr(lexpr):
1✔
66
    """Check if an expression is zero."""
67
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 0.0) or (
1✔
68
        isinstance(lexpr, LiteralInt) and lexpr.value == 0
69
    )
70

71

72
def is_one_lexpr(lexpr):
1✔
73
    """Check if an expression is one."""
74
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 1.0) or (
1✔
75
        isinstance(lexpr, LiteralInt) and lexpr.value == 1
76
    )
77

78

79
def is_negative_one_lexpr(lexpr):
1✔
80
    """Check if an expression is negative one."""
81
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == -1.0) or (
1✔
82
        isinstance(lexpr, LiteralInt) and lexpr.value == -1
83
    )
84

85

86
def float_product(factors):
1✔
87
    """Build product of float factors.
88

89
    Simplify ones and returning 1.0 if empty sequence.
90
    """
91
    factors = [f for f in factors if not is_one_lexpr(f)]
1✔
92
    if len(factors) == 0:
1✔
93
        return LiteralFloat(1.0)
×
94
    elif len(factors) == 1:
1✔
95
        return factors[0]
1✔
96
    else:
97
        return Product(factors)
1✔
98

99

100
class DataType(Enum):
1✔
101
    """Representation of data types for variables in LNodes.
102

103
    These can be REAL (same type as geometry),
104
    SCALAR (same type as tensor), or INT (for entity indices etc.)
105
    """
106

107
    REAL = 0
1✔
108
    SCALAR = 1
1✔
109
    INT = 2
1✔
110
    BOOL = 3
1✔
111
    NONE = 4
1✔
112

113

114
def merge_dtypes(dtypes: list[DataType]):
1✔
115
    """Promote dtype to SCALAR or REAL if either argument matches."""
116
    if DataType.NONE in dtypes:
1✔
117
        raise ValueError(f"Invalid DataType in LNodes {dtypes}")
×
118
    if DataType.SCALAR in dtypes:
1✔
119
        return DataType.SCALAR
1✔
120
    elif DataType.REAL in dtypes:
1✔
121
        return DataType.REAL
1✔
122
    elif DataType.INT in dtypes:
1✔
123
        return DataType.INT
1✔
124
    elif DataType.BOOL in dtypes:
×
125
        return DataType.BOOL
×
126
    else:
127
        raise ValueError(f"Can't get dtype for operation with {dtypes}")
×
128

129

130
class LNode:
1✔
131
    """Base class for all AST nodes."""
132

133
    def __eq__(self, other):
1✔
134
        """Check for equality."""
135
        return NotImplemented
1✔
136

137
    def __ne__(self, other):
1✔
138
        """Check for inequality."""
139
        return NotImplemented
×
140

141

142
class LExpr(LNode):
1✔
143
    """Base class for all expressions.
144

145
    All subtypes should define a 'precedence' class attribute.
146
    """
147

148
    dtype = DataType.NONE
1✔
149

150
    def __getitem__(self, indices):
1✔
151
        """Get an item."""
152
        return ArrayAccess(self, indices)
1✔
153

154
    def __neg__(self):
1✔
155
        """Negate."""
156
        if isinstance(self, LiteralFloat):
×
157
            return LiteralFloat(-self.value)
×
158
        if isinstance(self, LiteralInt):
×
159
            return LiteralInt(-self.value)
×
160
        return Neg(self)
×
161

162
    def __add__(self, other):
1✔
163
        """Add."""
164
        other = as_lexpr(other)
1✔
165
        if is_zero_lexpr(self):
1✔
166
            return other
×
167
        if is_zero_lexpr(other):
1✔
168
            return self
1✔
169
        if isinstance(other, Neg):
1✔
170
            return Sub(self, other.arg)
×
171
        return Add(self, other)
1✔
172

173
    def __radd__(self, other):
1✔
174
        """Add."""
175
        other = as_lexpr(other)
1✔
176
        if is_zero_lexpr(self):
1✔
177
            return other
×
178
        if is_zero_lexpr(other):
1✔
179
            return self
1✔
180
        if isinstance(self, Neg):
1✔
181
            return Sub(other, self.arg)
×
182
        return Add(other, self)
1✔
183

184
    def __sub__(self, other):
1✔
185
        """Subtract."""
186
        other = as_lexpr(other)
×
187
        if is_zero_lexpr(self):
×
188
            return -other
×
189
        if is_zero_lexpr(other):
×
190
            return self
×
191
        if isinstance(other, Neg):
×
192
            return Add(self, other.arg)
×
193
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
×
194
            return LiteralInt(self.value - other.value)
×
195
        return Sub(self, other)
×
196

197
    def __rsub__(self, other):
1✔
198
        """Subtract."""
199
        other = as_lexpr(other)
×
200
        if is_zero_lexpr(self):
×
201
            return other
×
202
        if is_zero_lexpr(other):
×
203
            return -self
×
204
        if isinstance(self, Neg):
×
205
            return Add(other, self.arg)
×
206
        return Sub(other, self)
×
207

208
    def __mul__(self, other):
1✔
209
        """Multiply."""
210
        other = as_lexpr(other)
1✔
211
        if is_zero_lexpr(self):
1✔
212
            return self
×
213
        if is_zero_lexpr(other):
1✔
214
            return other
1✔
215
        if is_one_lexpr(self):
1✔
216
            return other
1✔
217
        if is_one_lexpr(other):
1✔
218
            return self
1✔
219
        if is_negative_one_lexpr(other):
1✔
220
            return Neg(self)
×
221
        if is_negative_one_lexpr(self):
1✔
222
            return Neg(other)
1✔
223
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
1✔
224
            return LiteralInt(self.value * other.value)
×
225
        return Mul(self, other)
1✔
226

227
    def __rmul__(self, other):
1✔
228
        """Multiply."""
229
        other = as_lexpr(other)
1✔
230
        if is_zero_lexpr(self):
1✔
231
            return self
×
232
        if is_zero_lexpr(other):
1✔
233
            return other
×
234
        if is_one_lexpr(self):
1✔
235
            return other
×
236
        if is_one_lexpr(other):
1✔
237
            return self
1✔
238
        if is_negative_one_lexpr(other):
1✔
239
            return Neg(self)
×
240
        if is_negative_one_lexpr(self):
1✔
241
            return Neg(other)
×
242
        return Mul(other, self)
1✔
243

244
    def __div__(self, other):
1✔
245
        """Divide."""
246
        other = as_lexpr(other)
1✔
247
        if is_zero_lexpr(other):
1✔
248
            raise ValueError("Division by zero!")
×
249
        if is_zero_lexpr(self):
1✔
250
            return self
×
251
        return Div(self, other)
1✔
252

253
    def __rdiv__(self, other):
1✔
254
        """Divide."""
255
        other = as_lexpr(other)
×
256
        if is_zero_lexpr(self):
×
257
            raise ValueError("Division by zero!")
×
258
        if is_zero_lexpr(other):
×
259
            return other
×
260
        return Div(other, self)
×
261

262
    # TODO: Error check types?
263
    __truediv__ = __div__
1✔
264
    __rtruediv__ = __rdiv__
1✔
265
    __floordiv__ = __div__
1✔
266
    __rfloordiv__ = __rdiv__
1✔
267

268

269
class LExprOperator(LExpr):
1✔
270
    """Base class for all expression operators."""
271

272
    sideeffect = False
1✔
273

274

275
class LExprTerminal(LExpr):
1✔
276
    """Base class for all  expression terminals."""
277

278
    sideeffect = False
1✔
279

280

281
class LiteralFloat(LExprTerminal):
1✔
282
    """A floating point literal value."""
283

284
    precedence = PRECEDENCE.LITERAL
1✔
285

286
    def __init__(self, value):
1✔
287
        """Initialise."""
288
        assert isinstance(value, (float, complex))
1✔
289
        self.value = value
1✔
290
        if isinstance(value, complex):
1✔
291
            self.dtype = DataType.SCALAR
1✔
292
        else:
293
            self.dtype = DataType.REAL
1✔
294

295
    def __eq__(self, other):
1✔
296
        """Check equality."""
297
        return isinstance(other, LiteralFloat) and self.value == other.value
×
298

299
    def __float__(self):
1✔
300
        """Convert to float."""
301
        return float(self.value)
×
302

303
    def __repr__(self):
1✔
304
        """Representation."""
305
        return str(self.value)
×
306

307

308
class LiteralInt(LExprTerminal):
1✔
309
    """An integer literal value."""
310

311
    precedence = PRECEDENCE.LITERAL
1✔
312

313
    def __init__(self, value):
1✔
314
        """Initialise."""
315
        assert isinstance(value, (int, np.number))
1✔
316
        self.value = value
1✔
317
        self.dtype = DataType.INT
1✔
318

319
    def __eq__(self, other):
1✔
320
        """Check equality."""
321
        return isinstance(other, LiteralInt) and self.value == other.value
1✔
322

323
    def __hash__(self):
1✔
324
        """Hash."""
325
        return hash(self.value)
1✔
326

327
    def __repr__(self):
1✔
328
        """Representation."""
329
        return str(self.value)
×
330

331

332
class Symbol(LExprTerminal):
1✔
333
    """A named symbol."""
334

335
    precedence = PRECEDENCE.SYMBOL
1✔
336

337
    def __init__(self, name: str, dtype):
1✔
338
        """Initialise."""
339
        assert isinstance(name, str)
1✔
340
        assert name.replace("_", "").isalnum()
1✔
341
        self.name = name
1✔
342
        self.dtype = dtype
1✔
343

344
    def __eq__(self, other):
1✔
345
        """Check equality."""
346
        return isinstance(other, Symbol) and self.name == other.name
1✔
347

348
    def __hash__(self):
1✔
349
        """Hash."""
350
        return hash(self.name)
1✔
351

352
    def __repr__(self):
1✔
353
        """Representation."""
354
        return self.name
×
355

356

357
class MultiIndex(LExpr):
1✔
358
    """A multi-index for accessing tensors flattened in memory."""
359

360
    precedence = PRECEDENCE.SYMBOL
1✔
361

362
    def __init__(self, symbols: list, sizes: list):
1✔
363
        """Initialise."""
364
        self.dtype = DataType.INT
1✔
365
        self.sizes = sizes
1✔
366
        self.symbols = [as_lexpr(sym) for sym in symbols]
1✔
367
        for sym in self.symbols:
1✔
368
            assert sym.dtype == DataType.INT
1✔
369

370
        dim = len(sizes)
1✔
371
        if dim == 0:
1✔
372
            self.global_index: LExpr = LiteralInt(0)
1✔
373
        else:
374
            stride = [np.prod(sizes[i:]) for i in range(dim)] + [LiteralInt(1)]
1✔
375
            self.global_index = Sum(n * sym for n, sym in zip(stride[1:], symbols))
1✔
376

377
    @property
1✔
378
    def dim(self):
1✔
379
        """Dimension of the multi-index."""
380
        return len(self.sizes)
1✔
381

382
    def size(self):
1✔
383
        """Size of the multi-index."""
384
        return np.prod(self.sizes)
×
385

386
    def local_index(self, idx):
1✔
387
        """Get the local index."""
388
        assert idx < len(self.symbols)
1✔
389
        return self.symbols[idx]
1✔
390

391
    def intersection(self, other):
1✔
392
        """Get the intersection."""
393
        symbols = []
×
394
        sizes = []
×
395
        for sym, size in zip(self.symbols, self.sizes):
×
396
            if sym in other.symbols:
×
397
                i = other.symbols.index(sym)
×
398
                assert other.sizes[i] == size
×
399
                symbols.append(sym)
×
400
                sizes.append(size)
×
401
        return MultiIndex(symbols, sizes)
×
402

403
    def union(self, other):
1✔
404
        """Get the union.
405

406
        Note:
407
            Result may depend on order a.union(b) != b.union(a)
408
        """
409
        symbols = self.symbols.copy()
×
410
        sizes = self.sizes.copy()
×
411
        for sym, size in zip(other.symbols, other.sizes):
×
412
            if sym in symbols:
×
413
                i = symbols.index(sym)
×
414
                assert sizes[i] == size
×
415
            else:
416
                symbols.append(sym)
×
417
                sizes.append(size)
×
418
        return MultiIndex(symbols, sizes)
×
419

420
    def difference(self, other):
1✔
421
        """Get the difference."""
422
        symbols = []
×
423
        sizes = []
×
424
        for idx, size in zip(self.symbols, self.sizes):
×
425
            if idx not in other.symbols:
×
426
                symbols.append(idx)
×
427
                sizes.append(size)
×
428
        return MultiIndex(symbols, sizes)
×
429

430
    def __hash__(self):
1✔
431
        """Hash."""
432
        return hash(self.global_index.__repr__)
×
433

434

435
class PrefixUnaryOp(LExprOperator):
1✔
436
    """Base class for unary operators."""
437

438
    def __init__(self, arg):
1✔
439
        """Initialise."""
440
        self.arg = as_lexpr(arg)
×
441

442
    def __eq__(self, other):
1✔
443
        """Check equality."""
444
        return isinstance(other, type(self)) and self.arg == other.arg
×
445

446

447
class BinOp(LExprOperator):
1✔
448
    """A binary operator."""
449

450
    def __init__(self, lhs, rhs):
1✔
451
        """Initialise."""
452
        self.lhs = as_lexpr(lhs)
1✔
453
        self.rhs = as_lexpr(rhs)
1✔
454

455
    def __eq__(self, other):
1✔
456
        """Check equality."""
457
        return isinstance(other, type(self)) and self.lhs == other.lhs and self.rhs == other.rhs
1✔
458

459
    def __hash__(self):
1✔
460
        """Hash."""
461
        return hash(self.lhs) + hash(self.rhs)
1✔
462

463
    def __repr__(self):
1✔
464
        """Representation."""
465
        return f"({self.lhs} {self.op} {self.rhs})"
×
466

467

468
class ArithmeticBinOp(BinOp):
1✔
469
    """An artithmetic binary operator."""
470

471
    def __init__(self, lhs, rhs):
1✔
472
        """Initialise."""
473
        self.lhs = as_lexpr(lhs)
1✔
474
        self.rhs = as_lexpr(rhs)
1✔
475
        self.dtype = merge_dtypes([self.lhs.dtype, self.rhs.dtype])
1✔
476

477

478
class NaryOp(LExprOperator):
1✔
479
    """Base class for special n-ary operators."""
480

481
    op = ""
1✔
482

483
    def __init__(self, args):
1✔
484
        """Initialise."""
485
        self.args = [as_lexpr(arg) for arg in args]
1✔
486
        self.dtype = self.args[0].dtype
1✔
487
        for arg in self.args:
1✔
488
            self.dtype = merge_dtypes([self.dtype, arg.dtype])
1✔
489

490
    def __eq__(self, other):
1✔
491
        """Check equality."""
492
        return (
1✔
493
            isinstance(other, type(self))
494
            and len(self.args) == len(other.args)
495
            and all(a == b for a, b in zip(self.args, other.args))
496
        )
497

498
    def __repr__(self) -> str:
1✔
499
        """Representation."""
500
        return f"{self.op} ".join(f"{i} " for i in self.args)
×
501

502
    def __hash__(self):
1✔
503
        """Hash."""
504
        return hash(tuple(self.args))
1✔
505

506

507
class Neg(PrefixUnaryOp):
1✔
508
    """Negation operator."""
509

510
    precedence = PRECEDENCE.NEG
1✔
511
    op = "-"
1✔
512

513
    def __init__(self, arg):
1✔
514
        """Initialise."""
515
        self.arg = as_lexpr(arg)
1✔
516
        self.dtype = self.arg.dtype
1✔
517

518

519
class Not(PrefixUnaryOp):
1✔
520
    """Not operator."""
521

522
    precedence = PRECEDENCE.NOT
1✔
523
    op = "!"
1✔
524

525

526
class Add(ArithmeticBinOp):
1✔
527
    """Add operator."""
528

529
    precedence = PRECEDENCE.ADD
1✔
530
    op = "+"
1✔
531

532

533
class Sub(ArithmeticBinOp):
1✔
534
    """Subtract operator."""
535

536
    precedence = PRECEDENCE.SUB
1✔
537
    op = "-"
1✔
538

539

540
class Mul(ArithmeticBinOp):
1✔
541
    """Multiply operator."""
542

543
    precedence = PRECEDENCE.MUL
1✔
544
    op = "*"
1✔
545

546

547
class Div(ArithmeticBinOp):
1✔
548
    """Division operator."""
549

550
    precedence = PRECEDENCE.DIV
1✔
551
    op = "/"
1✔
552

553

554
class EQ(BinOp):
1✔
555
    """Equality operator."""
556

557
    precedence = PRECEDENCE.EQ
1✔
558
    op = "=="
1✔
559

560

561
class NE(BinOp):
1✔
562
    """Inequality operator."""
563

564
    precedence = PRECEDENCE.NE
1✔
565
    op = "!="
1✔
566

567

568
class LT(BinOp):
1✔
569
    """Less than operator."""
570

571
    precedence = PRECEDENCE.LT
1✔
572
    op = "<"
1✔
573

574

575
class GT(BinOp):
1✔
576
    """Greater than operator."""
577

578
    precedence = PRECEDENCE.GT
1✔
579
    op = ">"
1✔
580

581

582
class LE(BinOp):
1✔
583
    """Less than or equal to operator."""
584

585
    precedence = PRECEDENCE.LE
1✔
586
    op = "<="
1✔
587

588

589
class GE(BinOp):
1✔
590
    """Greater than or equal to operator."""
591

592
    precedence = PRECEDENCE.GE
1✔
593
    op = ">="
1✔
594

595

596
class And(BinOp):
1✔
597
    """And operator."""
598

599
    precedence = PRECEDENCE.AND
1✔
600
    op = "&&"
1✔
601

602

603
class Or(BinOp):
1✔
604
    """Or operator."""
605

606
    precedence = PRECEDENCE.OR
1✔
607
    op = "||"
1✔
608

609

610
class Sum(NaryOp):
1✔
611
    """Sum of any number of operands."""
612

613
    precedence = PRECEDENCE.ADD
1✔
614
    op = "+"
1✔
615

616

617
class Product(NaryOp):
1✔
618
    """Product of any number of operands."""
619

620
    precedence = PRECEDENCE.MUL
1✔
621
    op = "*"
1✔
622

623

624
class MathFunction(LExprOperator):
1✔
625
    """A Math Function, with any arguments."""
626

627
    precedence = PRECEDENCE.HIGHEST
1✔
628

629
    def __init__(self, func, args):
1✔
630
        """Initialise."""
631
        self.function = func
1✔
632
        self.args = [as_lexpr(arg) for arg in args]
1✔
633
        self.dtype = self.args[0].dtype
1✔
634

635
    def __eq__(self, other):
1✔
636
        """Check equality."""
637
        return (
×
638
            isinstance(other, type(self))
639
            and self.function == other.function
640
            and len(self.args) == len(other.args)
641
            and all(a == b for a, b in zip(self.args, other.args))
642
        )
643

644

645
class AssignOp(BinOp):
1✔
646
    """Base class for assignment operators."""
647

648
    precedence = PRECEDENCE.ASSIGN
1✔
649
    sideeffect = True
1✔
650

651
    def __init__(self, lhs, rhs):
1✔
652
        """Initialise."""
653
        assert isinstance(lhs, LNode)
1✔
654
        BinOp.__init__(self, lhs, rhs)
1✔
655

656

657
class Assign(AssignOp):
1✔
658
    """Assign operator."""
659

660
    op = "="
1✔
661

662

663
class AssignAdd(AssignOp):
1✔
664
    """Assign add operator."""
665

666
    op = "+="
1✔
667

668

669
class AssignSub(AssignOp):
1✔
670
    """Assign subtract operator."""
671

672
    op = "-="
1✔
673

674

675
class AssignMul(AssignOp):
1✔
676
    """Assign multiply operator."""
677

678
    op = "*="
1✔
679

680

681
class AssignDiv(AssignOp):
1✔
682
    """Assign division operator."""
683

684
    op = "/="
1✔
685

686

687
class ArrayAccess(LExprOperator):
1✔
688
    """Array access."""
689

690
    precedence = PRECEDENCE.SUBSCRIPT
1✔
691

692
    def __init__(self, array, indices):
1✔
693
        """Initialise."""
694
        # Typecheck array argument
695
        if isinstance(array, Symbol):
1✔
696
            self.array = array
1✔
697
            self.dtype = array.dtype
1✔
698
        elif isinstance(array, ArrayDecl):
×
699
            self.array = array.symbol
×
700
            self.dtype = array.symbol.dtype
×
701
        else:
702
            raise ValueError(f"Unexpected array type {type(array).__name__}")
×
703

704
        # Allow expressions or literals as indices
705
        if not isinstance(indices, (list, tuple)):
1✔
706
            indices = (indices,)
1✔
707
        self.indices = tuple(as_lexpr(i) for i in indices)
1✔
708

709
        # Early error checking for negative array dimensions
710
        if any(isinstance(i, int) and i < 0 for i in self.indices):
1✔
711
            raise ValueError("Index value < 0.")
×
712

713
        # Additional dimension checks possible if we get an ArrayDecl instead of just a name
714
        if isinstance(array, ArrayDecl):
1✔
715
            if len(self.indices) != len(array.sizes):
×
716
                raise ValueError("Invalid number of indices.")
×
717
            ints = (int, LiteralInt)
×
718
            if any(
×
719
                (isinstance(i, ints) and isinstance(d, ints) and int(i) >= int(d))
720
                for i, d in zip(self.indices, array.sizes)
721
            ):
722
                raise ValueError("Index value >= array dimension.")
×
723

724
    def __getitem__(self, indices):
1✔
725
        """Handle nested expr[i][j]."""
726
        if isinstance(indices, list):
1✔
727
            indices = tuple(indices)
×
728
        elif not isinstance(indices, tuple):
1✔
729
            indices = (indices,)
1✔
730
        return ArrayAccess(self.array, self.indices + indices)
1✔
731

732
    def __eq__(self, other):
1✔
733
        """Check equality."""
734
        return (
1✔
735
            isinstance(other, type(self))
736
            and self.array == other.array
737
            and self.indices == other.indices
738
        )
739

740
    def __hash__(self):
1✔
741
        """Hash."""
742
        return hash(self.array)
1✔
743

744
    def __repr__(self):
1✔
745
        """Representation."""
746
        return str(self.array) + "[" + ", ".join(str(i) for i in self.indices) + "]"
×
747

748

749
class Conditional(LExprOperator):
1✔
750
    """Conditional."""
751

752
    precedence = PRECEDENCE.CONDITIONAL
1✔
753

754
    def __init__(self, condition, true, false):
1✔
755
        """Initialise."""
756
        self.condition = as_lexpr(condition)
1✔
757
        self.true = as_lexpr(true)
1✔
758
        self.false = as_lexpr(false)
1✔
759
        self.dtype = merge_dtypes([self.true.dtype, self.false.dtype])
1✔
760

761
    def __eq__(self, other):
1✔
762
        """Check equality."""
763
        return (
×
764
            isinstance(other, type(self))
765
            and self.condition == other.condition
766
            and self.true == other.true
767
            and self.false == other.false
768
        )
769

770

771
def as_lexpr(node):
1✔
772
    """Typechecks and wraps an object as a valid LExpr.
773

774
    Accepts LExpr nodes, treats int and float as literals.
775

776
    """
777
    if isinstance(node, LExpr):
1✔
778
        return node
1✔
779
    elif isinstance(node, numbers.Integral):
1✔
780
        return LiteralInt(node)
1✔
781
    elif isinstance(node, numbers.Real):
1✔
782
        return LiteralFloat(node)
1✔
783
    else:
784
        raise RuntimeError(f"Unexpected LExpr type {type(node)}:\n{node}")
×
785

786

787
class Statement(LNode):
1✔
788
    """Make an expression into a statement."""
789

790
    def __init__(self, expr):
1✔
791
        """Initialise."""
792
        self.expr = as_lexpr(expr)
1✔
793

794
    def __eq__(self, other):
1✔
795
        """Check equality."""
796
        return isinstance(other, type(self)) and self.expr == other.expr
×
797

798
    def __hash__(self) -> int:
1✔
799
        """Hash."""
800
        return hash(self.expr)
×
801

802

803
def as_statement(node):
1✔
804
    """Perform type checking on node and wrap in a suitable statement type if necessary."""
805
    if isinstance(node, StatementList) and len(node.statements) == 1:
1✔
806
        # Cleans up the expression tree a bit
807
        return node.statements[0]
1✔
808
    elif isinstance(node, Statement):
1✔
809
        # No-op
810
        return node
1✔
811
    elif isinstance(node, LExprOperator):
1✔
812
        if node.sideeffect:
1✔
813
            # Special case for using assignment expressions as statements
814
            return Statement(node)
1✔
815
        else:
816
            raise RuntimeError(
×
817
                f"Trying to create a statement of lexprOperator type {type(node)}:\n{node}"
818
            )
819

820
    elif isinstance(node, list):
1✔
821
        # Convenience case for list of statements
822
        if len(node) == 1:
1✔
823
            # Cleans up the expression tree a bit
824
            return as_statement(node[0])
1✔
825
        else:
826
            return StatementList(node)
1✔
827
    elif isinstance(node, Section):
1✔
828
        return node
1✔
829
    else:
830
        raise RuntimeError(f"Unexpected Statement type {type(node)}:\n{node}")
×
831

832

833
class Annotation(Enum):
1✔
834
    """Annotation."""
835

836
    fuse = 1  # fuse loops in section
1✔
837
    unroll = 2  # unroll loop in section
1✔
838
    licm = 3  # loop invariant code motion
1✔
839
    factorize = 4  # apply sum factorization
1✔
840

841

842
class Declaration(Statement):
1✔
843
    """Base class for all declarations."""
844

845
    def __init__(self, symbol):
1✔
846
        """Initialise."""
847
        self.symbol = symbol
×
848

849
    def __eq__(self, other):
1✔
850
        """Check equality."""
851
        return isinstance(other, type(self)) and self.symbol == other.symbol
×
852

853

854
def is_declaration(node) -> bool:
1✔
855
    """Check if a node is a declaration."""
856
    return isinstance(node, VariableDecl) or isinstance(node, ArrayDecl)
1✔
857

858

859
class Section(LNode):
1✔
860
    """A section of code with a name and a list of statements."""
861

862
    def __init__(
1✔
863
        self,
864
        name: str,
865
        statements: list[LNode],
866
        declarations: Sequence[Declaration],
867
        input: Optional[list[Symbol]] = None,
868
        output: Optional[list[Symbol]] = None,
869
        annotations: Optional[list[Annotation]] = None,
870
    ):
871
        """Initialise."""
872
        self.name = name
1✔
873
        self.statements = [as_statement(st) for st in statements]
1✔
874
        self.annotations = annotations or []
1✔
875
        self.input = input or []
1✔
876
        self.declarations = declarations or []
1✔
877
        self.output = output or []
1✔
878

879
        for decl in self.declarations:
1✔
880
            assert is_declaration(decl)
1✔
881
            if decl.symbol not in self.output:
1✔
882
                self.output.append(decl.symbol)
×
883

884
    def __eq__(self, other):
1✔
885
        """Check equality."""
886
        attributes = ("name", "input", "output", "annotations", "statements")
1✔
887
        return isinstance(other, type(self)) and all(
1✔
888
            getattr(self, name) == getattr(other, name) for name in attributes
889
        )
890

891

892
class StatementList(LNode):
1✔
893
    """A simple sequence of statements."""
894

895
    def __init__(self, statements):
1✔
896
        """Initialise."""
897
        self.statements = [as_statement(st) for st in statements]
1✔
898

899
    def __eq__(self, other):
1✔
900
        """Check equality."""
UNCOV
901
        return isinstance(other, type(self)) and self.statements == other.statements
×
902

903
    def __hash__(self) -> int:
1✔
904
        """Hash."""
905
        return hash(tuple(self.statements))
×
906

907
    def __repr__(self):
1✔
908
        """Representation."""
909
        return f"StatementList({self.statements})"
×
910

911

912
class Comment(Statement):
1✔
913
    """Line comment(s) used for annotating the generated code with human readable remarks."""
914

915
    def __init__(self, comment):
1✔
916
        """Initialise."""
917
        assert isinstance(comment, str)
1✔
918
        self.comment = comment
1✔
919

920
    def __eq__(self, other):
1✔
921
        """Check equality."""
922
        return isinstance(other, type(self)) and self.comment == other.comment
×
923

924

925
def commented_code_list(code, comments):
1✔
926
    """Add comment to code list if the list is not empty."""
927
    if isinstance(code, LNode):
1✔
928
        code = [code]
×
929
    assert isinstance(code, list)
1✔
930
    if code:
1✔
931
        if not isinstance(comments, (list, tuple)):
1✔
932
            comments = [comments]
1✔
933
        comments = [Comment(c) for c in comments]
1✔
934
        code = comments + code
1✔
935
    return code
1✔
936

937

938
# Type and variable declarations
939

940

941
class VariableDecl(Declaration):
1✔
942
    """Declare a variable, optionally define initial value."""
943

944
    def __init__(self, symbol, value=None):
1✔
945
        """Initialise."""
946
        assert isinstance(symbol, Symbol)
1✔
947
        assert symbol.dtype is not None
1✔
948
        self.symbol = symbol
1✔
949

950
        if value is not None:
1✔
951
            value = as_lexpr(value)
1✔
952
        self.value = value
1✔
953

954
    def __eq__(self, other):
1✔
955
        """Check equality."""
956
        return (
×
957
            isinstance(other, type(self))
958
            and self.typename == other.typename
959
            and self.symbol == other.symbol
960
            and self.value == other.value
961
        )
962

963

964
class ArrayDecl(Declaration):
1✔
965
    """A declaration or definition of an array.
966

967
    Note that just setting values=0 is sufficient to initialize the
968
    entire array to zero.
969

970
    Otherwise use nested lists of lists to represent multidimensional
971
    array values to initialize to.
972

973
    """
974

975
    def __init__(self, symbol, sizes=None, values=None, const=False):
1✔
976
        """Initialise."""
977
        assert isinstance(symbol, Symbol)
1✔
978
        self.symbol = symbol
1✔
979
        assert symbol.dtype
1✔
980

981
        if sizes is None:
1✔
982
            assert values is not None
1✔
983
            sizes = values.shape
1✔
984
        if isinstance(sizes, int):
1✔
985
            sizes = (sizes,)
1✔
986
        self.sizes = tuple(sizes)
1✔
987

988
        if values is None:
1✔
989
            assert sizes is not None
×
990

991
        # NB! No type checking, assuming nested lists of literal values. Not applying as_lexpr.
992
        if isinstance(values, (list, tuple)):
1✔
993
            self.values = np.asarray(values)
1✔
994
        else:
995
            self.values = values
1✔
996

997
        self.const = const
1✔
998
        self.dtype = symbol.dtype
1✔
999

1000
    def __eq__(self, other):
1✔
1001
        """Check equality."""
1002
        attributes = ("dtype", "symbol", "sizes", "values")
×
1003
        return isinstance(other, type(self)) and all(
×
1004
            getattr(self, name) == getattr(self, name) for name in attributes
1005
        )
1006

1007
    def __hash__(self) -> int:
1✔
1008
        """Hash."""
1009
        return hash(self.symbol)
×
1010

1011

1012
def is_simple_inner_loop(code):
1✔
1013
    """Check if code is a simple inner loop."""
1014
    if isinstance(code, ForRange) and is_simple_inner_loop(code.body):
×
1015
        return True
×
1016
    if isinstance(code, Statement) and isinstance(code.expr, AssignOp):
×
1017
        return True
×
1018
    return False
×
1019

1020

1021
def depth(code) -> int:
1✔
1022
    """Get depth of code."""
1023
    if isinstance(code, ForRange):
1✔
1024
        return 1 + depth(code.body)
1✔
1025
    if isinstance(code, StatementList):
1✔
1026
        return max([depth(c) for c in code.statements])
1✔
1027
    return 0
1✔
1028

1029

1030
class ForRange(Statement):
1✔
1031
    """Slightly higher-level for loop assuming incrementing an index over a range."""
1032

1033
    def __init__(self, index, begin, end, body):
1✔
1034
        """Initialise."""
1035
        assert isinstance(index, Symbol) or isinstance(index, MultiIndex)
1✔
1036
        self.index = index
1✔
1037
        self.begin = as_lexpr(begin)
1✔
1038
        self.end = as_lexpr(end)
1✔
1039
        assert isinstance(body, list)
1✔
1040
        self.body = StatementList(body)
1✔
1041

1042
    def as_tuple(self):
1✔
1043
        """Convert to a tuple."""
1044
        return (self.index, self.begin, self.end, self.body)
×
1045

1046
    def __eq__(self, other):
1✔
1047
        """Check equality."""
UNCOV
1048
        attributes = ("index", "begin", "end", "body")
×
UNCOV
1049
        return isinstance(other, type(self)) and all(
×
1050
            getattr(self, name) == getattr(self, name) for name in attributes
1051
        )
1052

1053
    def __hash__(self) -> int:
1✔
1054
        """Hash."""
1055
        return hash(self.as_tuple())
×
1056

1057

1058
def _math_function(op, *args):
1✔
1059
    """Get a math function."""
1060
    name = op._ufl_handler_name_
1✔
1061
    dtype = args[0].dtype
1✔
1062
    if name in ("conj", "real") and dtype == DataType.REAL:
1✔
1063
        assert len(args) == 1
1✔
1064
        return args[0]
1✔
1065
    if name == "imag" and dtype == DataType.REAL:
1✔
1066
        assert len(args) == 1
×
1067
        return LiteralFloat(0.0)
×
1068
    return MathFunction(name, args)
1✔
1069

1070

1071
# Lookup table for handler to call when the ufl_to_lnodes method (below) is
1072
# called, depending on the first argument type.
1073
_ufl_call_lookup = {
1✔
1074
    ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)),
1075
    ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)),
1076
    ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()),
1077
    ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0),
1078
    ufl.algebra.Product: lambda x, a, b: a * b,
1079
    ufl.algebra.Sum: lambda x, a, b: a + b,
1080
    ufl.algebra.Division: lambda x, a, b: a / b,
1081
    ufl.algebra.Abs: _math_function,
1082
    ufl.algebra.Power: _math_function,
1083
    ufl.algebra.Real: _math_function,
1084
    ufl.algebra.Imag: _math_function,
1085
    ufl.algebra.Conj: _math_function,
1086
    ufl.classes.GT: lambda x, a, b: GT(a, b),
1087
    ufl.classes.GE: lambda x, a, b: GE(a, b),
1088
    ufl.classes.EQ: lambda x, a, b: EQ(a, b),
1089
    ufl.classes.NE: lambda x, a, b: NE(a, b),
1090
    ufl.classes.LT: lambda x, a, b: LT(a, b),
1091
    ufl.classes.LE: lambda x, a, b: LE(a, b),
1092
    ufl.classes.AndCondition: lambda x, a, b: And(a, b),
1093
    ufl.classes.OrCondition: lambda x, a, b: Or(a, b),
1094
    ufl.classes.NotCondition: lambda x, a: Not(a),
1095
    ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f),
1096
    ufl.classes.MinValue: _math_function,
1097
    ufl.classes.MaxValue: _math_function,
1098
    ufl.mathfunctions.Sqrt: _math_function,
1099
    ufl.mathfunctions.Ln: _math_function,
1100
    ufl.mathfunctions.Exp: _math_function,
1101
    ufl.mathfunctions.Cos: _math_function,
1102
    ufl.mathfunctions.Sin: _math_function,
1103
    ufl.mathfunctions.Tan: _math_function,
1104
    ufl.mathfunctions.Cosh: _math_function,
1105
    ufl.mathfunctions.Sinh: _math_function,
1106
    ufl.mathfunctions.Tanh: _math_function,
1107
    ufl.mathfunctions.Acos: _math_function,
1108
    ufl.mathfunctions.Asin: _math_function,
1109
    ufl.mathfunctions.Atan: _math_function,
1110
    ufl.mathfunctions.Erf: _math_function,
1111
    ufl.mathfunctions.Atan2: _math_function,
1112
    ufl.mathfunctions.MathFunction: _math_function,
1113
    ufl.mathfunctions.BesselJ: _math_function,
1114
    ufl.mathfunctions.BesselY: _math_function,
1115
}
1116

1117

1118
def ufl_to_lnodes(operator, *args):
1✔
1119
    """Call appropriate handler, depending on the type of operator."""
1120
    optype = type(operator)
1✔
1121
    if optype in _ufl_call_lookup:
1✔
1122
        return _ufl_call_lookup[optype](operator, *args)
1✔
1123
    else:
1124
        raise RuntimeError(f"Missing lookup for expr type {optype}.")
×
1125

1126

1127
def create_nested_for_loops(indices: list[MultiIndex], body):
1✔
1128
    """Create nested for loops over list of indices.
1129

1130
    The depth of the nested for loops is equal to the sub-indices for all
1131
    MultiIndex combined.
1132
    """
1133
    ranges = [r for idx in indices for r in idx.sizes]
1✔
1134
    indices = [idx.local_index(i) for idx in indices for i in range(len(idx.sizes))]
1✔
1135
    depth = len(ranges)
1✔
1136
    for i in reversed(range(depth)):
1✔
1137
        body = ForRange(indices[i], 0, ranges[i], body=[body])
1✔
1138
    return body
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc