• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 7299320904

22 Dec 2023 11:18AM UTC coverage: 80.35% (+0.4%) from 79.915%
7299320904

push

github

web-flow
[Sum Factorization - 1] Transition to 1D basis evaluation matrices for hexes and quads (#642)

* add multi index and tensor tables

* update parameters

* Updates to MultiIndex

* Add a simple formatting rule in C

* Replace FlattenedArray

* Minor fixes

* minor fix to pass the tests

* hopefully fix some tests

* update quadrature rule

* Fix MultiIndex when zero size

* Minor tweak

* minor update to documentaion

* update quadrature

* update entity dofs

* fix tensor factors

* update code

* add mass action

* update mass action

* create quadrature rule with tensor factors

* fix tablse

* user tensor structure for matrices as well

* improve tests

* minor improvements

* Fix error

* Simplify and fix mypy

* Fix options parser for bool types

* update quadrature generation

* fixes for peicewise tabledata

* add precendece to MultiIndex

* fix quadrature permutation

* update loop hoisting

* Try with np.prod

* add test

* update test

* update test

* update test for hexes

* use property for dim

* use argument_loop_index

* remove comments

* add doc strings

* add license

* fix documentation

* remove extra if

---------

Co-authored-by: Chris Richardson <chris@bpi.cam.ac.uk>

157 of 170 new or added lines in 9 files covered. (92.35%)

4 existing lines in 1 file now uncovered.

3676 of 4575 relevant lines covered (80.35%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.86
/ffcx/codegeneration/lnodes.py
1
# Copyright (C) 2013-2023 Martin Sandve Alnæs, Chris Richardson
2
#
3
# This file is part of FFCx.(https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6

7
import numbers
1✔
8
import ufl
1✔
9
import numpy as np
1✔
10
from enum import Enum
1✔
11

12

13
class PRECEDENCE:
1✔
14
    """An enum-like class for operator precedence levels."""
15

16
    HIGHEST = 0
1✔
17
    LITERAL = 0
1✔
18
    SYMBOL = 0
1✔
19
    SUBSCRIPT = 2
1✔
20

21
    NOT = 3
1✔
22
    NEG = 3
1✔
23

24
    MUL = 4
1✔
25
    DIV = 4
1✔
26

27
    ADD = 5
1✔
28
    SUB = 5
1✔
29

30
    LT = 7
1✔
31
    LE = 7
1✔
32
    GT = 7
1✔
33
    GE = 7
1✔
34
    EQ = 8
1✔
35
    NE = 8
1✔
36
    AND = 11
1✔
37
    OR = 12
1✔
38
    CONDITIONAL = 13
1✔
39
    ASSIGN = 13
1✔
40
    LOWEST = 15
1✔
41

42

43
"""LNodes is intended as a minimal generic language description.
44
Formatting is done later, depending on the target language.
45

46
Supported:
47
 Floating point (and complex) and integer variables and multidimensional arrays
48
 Range loops
49
 Simple arithmetic, +-*/
50
 Math operations
51
 Logic conditions
52
 Comments
53
Not supported:
54
 Pointers
55
 Function Calls
56
 Flow control (if, switch, while)
57
 Booleans
58
 Strings
59
"""
60

61

62
def is_zero_lexpr(lexpr):
1✔
63
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 0.0) or (
1✔
64
        isinstance(lexpr, LiteralInt) and lexpr.value == 0
65
    )
66

67

68
def is_one_lexpr(lexpr):
1✔
69
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == 1.0) or (
1✔
70
        isinstance(lexpr, LiteralInt) and lexpr.value == 1
71
    )
72

73

74
def is_negative_one_lexpr(lexpr):
1✔
75
    return (isinstance(lexpr, LiteralFloat) and lexpr.value == -1.0) or (
1✔
76
        isinstance(lexpr, LiteralInt) and lexpr.value == -1
77
    )
78

79

80
def float_product(factors):
1✔
81
    """Build product of float factors, simplifying ones and zeros and returning 1.0 if empty sequence."""
82
    factors = [f for f in factors if not is_one_lexpr(f)]
1✔
83
    if len(factors) == 0:
1✔
84
        return LiteralFloat(1.0)
×
85
    elif len(factors) == 1:
1✔
86
        return factors[0]
1✔
87
    else:
88
        for f in factors:
1✔
89
            if is_zero_lexpr(f):
1✔
90
                return f
×
91
        return Product(factors)
1✔
92

93

94
class DataType(Enum):
1✔
95
    """Representation of data types for variables in LNodes.
96

97
    These can be REAL (same type as geometry),
98
    SCALAR (same type as tensor), or INT (for entity indices etc.)
99
    """
100

101
    REAL = 0
1✔
102
    SCALAR = 1
1✔
103
    INT = 2
1✔
104
    NONE = 3
1✔
105

106

107
def merge_dtypes(dtype0, dtype1):
1✔
108
    # Promote dtype to SCALAR or REAL if either argument matches
109
    if DataType.NONE in (dtype0, dtype1):
1✔
110
        raise ValueError(f"Invalid DataType in LNodes {dtype0, dtype1}")
×
111
    if DataType.SCALAR in (dtype0, dtype1):
1✔
112
        return DataType.SCALAR
1✔
113
    elif DataType.REAL in (dtype0, dtype1):
1✔
114
        return DataType.REAL
1✔
115
    elif (dtype0 == DataType.INT and dtype1 == DataType.INT):
1✔
116
        return DataType.INT
1✔
117
    else:
118
        raise ValueError(f"Can't get dtype for binary operation with {dtype0, dtype1}")
×
119

120

121
class LNode(object):
1✔
122
    """Base class for all AST nodes."""
123

124
    def __eq__(self, other):
1✔
125
        return NotImplemented
×
126

127
    def __ne__(self, other):
1✔
128
        return NotImplemented
×
129

130

131
class LExpr(LNode):
1✔
132
    """Base class for all expressions.
133

134
    All subtypes should define a 'precedence' class attribute.
135
    """
136

137
    dtype = DataType.NONE
1✔
138

139
    def __getitem__(self, indices):
1✔
140
        return ArrayAccess(self, indices)
1✔
141

142
    def __neg__(self):
1✔
143
        if isinstance(self, LiteralFloat):
×
144
            return LiteralFloat(-self.value)
×
145
        if isinstance(self, LiteralInt):
×
146
            return LiteralInt(-self.value)
×
147
        return Neg(self)
×
148

149
    def __add__(self, other):
1✔
150
        other = as_lexpr(other)
1✔
151
        if is_zero_lexpr(self):
1✔
152
            return other
×
153
        if is_zero_lexpr(other):
1✔
154
            return self
1✔
155
        if isinstance(other, Neg):
1✔
156
            return Sub(self, other.arg)
1✔
157
        return Add(self, other)
1✔
158

159
    def __radd__(self, other):
1✔
160
        other = as_lexpr(other)
1✔
161
        if is_zero_lexpr(self):
1✔
162
            return other
×
163
        if is_zero_lexpr(other):
1✔
164
            return self
1✔
165
        if isinstance(self, Neg):
×
166
            return Sub(other, self.arg)
×
167
        return Add(other, self)
×
168

169
    def __sub__(self, other):
1✔
170
        other = as_lexpr(other)
×
171
        if is_zero_lexpr(self):
×
172
            return -other
×
173
        if is_zero_lexpr(other):
×
174
            return self
×
175
        if isinstance(other, Neg):
×
176
            return Add(self, other.arg)
×
177
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
×
178
            return LiteralInt(self.value - other.value)
×
179
        return Sub(self, other)
×
180

181
    def __rsub__(self, other):
1✔
182
        other = as_lexpr(other)
×
183
        if is_zero_lexpr(self):
×
184
            return other
×
185
        if is_zero_lexpr(other):
×
186
            return -self
×
187
        if isinstance(self, Neg):
×
188
            return Add(other, self.arg)
×
189
        return Sub(other, self)
×
190

191
    def __mul__(self, other):
1✔
192
        other = as_lexpr(other)
1✔
193
        if is_zero_lexpr(self):
1✔
194
            return self
×
195
        if is_zero_lexpr(other):
1✔
196
            return other
1✔
197
        if is_one_lexpr(self):
1✔
198
            return other
1✔
199
        if is_one_lexpr(other):
1✔
200
            return self
1✔
201
        if is_negative_one_lexpr(other):
1✔
202
            return Neg(self)
×
203
        if is_negative_one_lexpr(self):
1✔
204
            return Neg(other)
1✔
205
        if isinstance(self, LiteralInt) and isinstance(other, LiteralInt):
1✔
206
            return LiteralInt(self.value * other.value)
×
207
        return Mul(self, other)
1✔
208

209
    def __rmul__(self, other):
1✔
210
        other = as_lexpr(other)
1✔
211
        if is_zero_lexpr(self):
1✔
212
            return self
×
213
        if is_zero_lexpr(other):
1✔
214
            return other
×
215
        if is_one_lexpr(self):
1✔
216
            return other
×
217
        if is_one_lexpr(other):
1✔
218
            return self
1✔
219
        if is_negative_one_lexpr(other):
1✔
220
            return Neg(self)
×
221
        if is_negative_one_lexpr(self):
1✔
222
            return Neg(other)
×
223
        return Mul(other, self)
1✔
224

225
    def __div__(self, other):
1✔
226
        other = as_lexpr(other)
1✔
227
        if is_zero_lexpr(other):
1✔
228
            raise ValueError("Division by zero!")
×
229
        if is_zero_lexpr(self):
1✔
230
            return self
×
231
        return Div(self, other)
1✔
232

233
    def __rdiv__(self, other):
1✔
234
        other = as_lexpr(other)
×
235
        if is_zero_lexpr(self):
×
236
            raise ValueError("Division by zero!")
×
237
        if is_zero_lexpr(other):
×
238
            return other
×
239
        return Div(other, self)
×
240

241
    # TODO: Error check types?
242
    __truediv__ = __div__
1✔
243
    __rtruediv__ = __rdiv__
1✔
244
    __floordiv__ = __div__
1✔
245
    __rfloordiv__ = __rdiv__
1✔
246

247

248
class LExprOperator(LExpr):
1✔
249
    """Base class for all expression operators."""
250

251
    sideeffect = False
1✔
252

253

254
class LExprTerminal(LExpr):
1✔
255
    """Base class for all  expression terminals."""
256

257
    sideeffect = False
1✔
258

259

260
# LExprTerminal types
261

262

263
class LiteralFloat(LExprTerminal):
1✔
264
    """A floating point literal value."""
265

266
    precedence = PRECEDENCE.LITERAL
1✔
267

268
    def __init__(self, value):
1✔
269
        assert isinstance(value, (float, complex))
1✔
270
        self.value = value
1✔
271
        if isinstance(value, complex):
1✔
272
            self.dtype = DataType.SCALAR
1✔
273
        else:
274
            self.dtype = DataType.REAL
1✔
275

276
    def __eq__(self, other):
1✔
277
        return isinstance(other, LiteralFloat) and self.value == other.value
×
278

279
    def __float__(self):
1✔
280
        return float(self.value)
×
281

282
    def __repr__(self):
1✔
283
        return str(self.value)
×
284

285

286
class LiteralInt(LExprTerminal):
1✔
287
    """An integer literal value."""
288

289
    precedence = PRECEDENCE.LITERAL
1✔
290

291
    def __init__(self, value):
1✔
292
        assert isinstance(value, (int, np.number))
1✔
293
        self.value = value
1✔
294
        self.dtype = DataType.INT
1✔
295

296
    def __eq__(self, other):
1✔
297
        return isinstance(other, LiteralInt) and self.value == other.value
1✔
298

299
    def __hash__(self):
1✔
300
        return hash(self.value)
1✔
301

302
    def __repr__(self):
1✔
303
        return str(self.value)
1✔
304

305

306
class Symbol(LExprTerminal):
1✔
307
    """A named symbol."""
308

309
    precedence = PRECEDENCE.SYMBOL
1✔
310

311
    def __init__(self, name: str, dtype):
1✔
312
        assert isinstance(name, str)
1✔
313
        assert name.replace("_", "").isalnum()
1✔
314
        self.name = name
1✔
315
        self.dtype = dtype
1✔
316

317
    def __eq__(self, other):
1✔
318
        return isinstance(other, Symbol) and self.name == other.name
1✔
319

320
    def __hash__(self):
1✔
321
        return hash(self.name)
1✔
322

323
    def __repr__(self):
1✔
324
        return self.name
1✔
325

326

327
class MultiIndex(LExpr):
1✔
328
    """A multi-index for accessing tensors flattened in memory."""
329

330
    precedence = PRECEDENCE.SYMBOL
1✔
331

332
    def __init__(self, symbols: list, sizes: list):
1✔
333
        self.dtype = DataType.INT
1✔
334
        self.sizes = sizes
1✔
335
        self.symbols = [as_lexpr(sym) for sym in symbols]
1✔
336
        for sym in self.symbols:
1✔
337
            assert sym.dtype == DataType.INT
1✔
338

339
        dim = len(sizes)
1✔
340
        if dim == 0:
1✔
341
            self.global_index: LExpr = LiteralInt(0)
1✔
342
        else:
343
            stride = [np.prod(sizes[i:]) for i in range(dim)] + [LiteralInt(1)]
1✔
344
            self.global_index = Sum(n * sym for n, sym in zip(stride[1:], symbols))
1✔
345

346
    @property
1✔
347
    def dim(self):
1✔
348
        return len(self.sizes)
1✔
349

350
    def size(self):
1✔
351
        return np.prod(self.sizes)
×
352

353
    def local_index(self, idx):
1✔
354
        assert idx < len(self.symbols)
1✔
355
        return self.symbols[idx]
1✔
356

357
    def intersection(self, other):
1✔
358
        symbols = []
×
359
        sizes = []
×
360
        for (sym, size) in zip(self.symbols, self.sizes):
×
361
            if sym in other.symbols:
×
362
                i = other.symbols.index(sym)
×
363
                assert other.sizes[i] == size
×
364
                symbols.append(sym)
×
365
                sizes.append(size)
×
366
        return MultiIndex(symbols, sizes)
×
367

368
    def union(self, other):
1✔
369
        # NB result may depend on order a.union(b) != b.union(a)
370
        symbols = self.symbols.copy()
×
371
        sizes = self.sizes.copy()
×
372
        for (sym, size) in zip(other.symbols, other.sizes):
×
373
            if sym in symbols:
×
374
                i = symbols.index(sym)
×
375
                assert sizes[i] == size
×
376
            else:
377
                symbols.append(sym)
×
378
                sizes.append(size)
×
379
        return MultiIndex(symbols, sizes)
×
380

381
    def difference(self, other):
1✔
382
        symbols = []
×
383
        sizes = []
×
384
        for (idx, size) in zip(self.symbols, self.sizes):
×
385
            if idx not in other.symbols:
×
386
                symbols.append(idx)
×
387
                sizes.append(size)
×
388
        return MultiIndex(symbols, sizes)
×
389

390
    def __hash__(self):
1✔
NEW
391
        return hash(self.global_index.__repr__)
×
392

393

394
class PrefixUnaryOp(LExprOperator):
1✔
395
    """Base class for unary operators."""
396

397
    def __init__(self, arg):
1✔
398
        self.arg = as_lexpr(arg)
×
399

400
    def __eq__(self, other):
1✔
401
        return isinstance(other, type(self)) and self.arg == other.arg
×
402

403

404
class BinOp(LExprOperator):
1✔
405
    def __init__(self, lhs, rhs):
1✔
406
        self.lhs = as_lexpr(lhs)
1✔
407
        self.rhs = as_lexpr(rhs)
1✔
408

409
    def __eq__(self, other):
1✔
410
        return (
1✔
411
            isinstance(other, type(self))
412
            and self.lhs == other.lhs
413
            and self.rhs == other.rhs
414
        )
415

416
    def __hash__(self):
1✔
417
        return hash(self.lhs) + hash(self.rhs)
1✔
418

419
    def __repr__(self):
1✔
420
        return f"({self.lhs} {self.op} {self.rhs})"
×
421

422

423
class ArithmeticBinOp(BinOp):
1✔
424
    def __init__(self, lhs, rhs):
1✔
425
        self.lhs = as_lexpr(lhs)
1✔
426
        self.rhs = as_lexpr(rhs)
1✔
427
        self.dtype = merge_dtypes(self.lhs.dtype, self.rhs.dtype)
1✔
428

429

430
class NaryOp(LExprOperator):
1✔
431
    """Base class for special n-ary operators."""
432

433
    op = ""
1✔
434

435
    def __init__(self, args):
1✔
436
        self.args = [as_lexpr(arg) for arg in args]
1✔
437
        self.dtype = self.args[0].dtype
1✔
438
        for arg in self.args:
1✔
439
            self.dtype = merge_dtypes(self.dtype, arg.dtype)
1✔
440

441
    def __eq__(self, other):
1✔
442
        return (
1✔
443
            isinstance(other, type(self))
444
            and len(self.args) == len(other.args)
445
            and all(a == b for a, b in zip(self.args, other.args))
446
        )
447

448
    def __repr__(self) -> str:
1✔
449
        return f"{self.op} ".join(f"{i} " for i in self.args)
×
450

451
    def __hash__(self):
1✔
452
        return hash(tuple(self.args))
1✔
453

454

455
class Neg(PrefixUnaryOp):
1✔
456
    precedence = PRECEDENCE.NEG
1✔
457
    op = "-"
1✔
458

459
    def __init__(self, arg):
1✔
460
        self.arg = as_lexpr(arg)
1✔
461
        self.dtype = self.arg.dtype
1✔
462

463

464
class Not(PrefixUnaryOp):
1✔
465
    precedence = PRECEDENCE.NOT
1✔
466
    op = "!"
1✔
467

468

469
# Binary operators
470
# Arithmetic operators preserve the dtype of their operands
471
# The other operations (logical) do not need a dtype
472

473
class Add(ArithmeticBinOp):
1✔
474
    precedence = PRECEDENCE.ADD
1✔
475
    op = "+"
1✔
476

477

478
class Sub(ArithmeticBinOp):
1✔
479
    precedence = PRECEDENCE.SUB
1✔
480
    op = "-"
1✔
481

482

483
class Mul(ArithmeticBinOp):
1✔
484
    precedence = PRECEDENCE.MUL
1✔
485
    op = "*"
1✔
486

487

488
class Div(ArithmeticBinOp):
1✔
489
    precedence = PRECEDENCE.DIV
1✔
490
    op = "/"
1✔
491

492

493
class EQ(BinOp):
1✔
494
    precedence = PRECEDENCE.EQ
1✔
495
    op = "=="
1✔
496

497

498
class NE(BinOp):
1✔
499
    precedence = PRECEDENCE.NE
1✔
500
    op = "!="
1✔
501

502

503
class LT(BinOp):
1✔
504
    precedence = PRECEDENCE.LT
1✔
505
    op = "<"
1✔
506

507

508
class GT(BinOp):
1✔
509
    precedence = PRECEDENCE.GT
1✔
510
    op = ">"
1✔
511

512

513
class LE(BinOp):
1✔
514
    precedence = PRECEDENCE.LE
1✔
515
    op = "<="
1✔
516

517

518
class GE(BinOp):
1✔
519
    precedence = PRECEDENCE.GE
1✔
520
    op = ">="
1✔
521

522

523
class And(BinOp):
1✔
524
    precedence = PRECEDENCE.AND
1✔
525
    op = "&&"
1✔
526

527

528
class Or(BinOp):
1✔
529
    precedence = PRECEDENCE.OR
1✔
530
    op = "||"
1✔
531

532

533
class Sum(NaryOp):
1✔
534
    """Sum of any number of operands."""
535

536
    precedence = PRECEDENCE.ADD
1✔
537
    op = "+"
1✔
538

539

540
class Product(NaryOp):
1✔
541
    """Product of any number of operands."""
542

543
    precedence = PRECEDENCE.MUL
1✔
544
    op = "*"
1✔
545

546

547
class MathFunction(LExprOperator):
1✔
548
    """A Math Function, with any arguments."""
549

550
    precedence = PRECEDENCE.HIGHEST
1✔
551

552
    def __init__(self, func, args):
1✔
553
        self.function = func
1✔
554
        self.args = [as_lexpr(arg) for arg in args]
1✔
555
        self.dtype = self.args[0].dtype
1✔
556

557
    def __eq__(self, other):
1✔
558
        return (
×
559
            isinstance(other, type(self))
560
            and self.function == other.function
561
            and len(self.args) == len(other.args)
562
            and all(a == b for a, b in zip(self.args, other.args))
563
        )
564

565

566
class AssignOp(BinOp):
1✔
567
    """Base class for assignment operators."""
568

569
    precedence = PRECEDENCE.ASSIGN
1✔
570
    sideeffect = True
1✔
571

572
    def __init__(self, lhs, rhs):
1✔
573
        assert isinstance(lhs, LNode)
1✔
574
        BinOp.__init__(self, lhs, rhs)
1✔
575

576

577
class Assign(AssignOp):
1✔
578
    op = "="
1✔
579

580

581
class AssignAdd(AssignOp):
1✔
582
    op = "+="
1✔
583

584

585
class AssignSub(AssignOp):
1✔
586
    op = "-="
1✔
587

588

589
class AssignMul(AssignOp):
1✔
590
    op = "*="
1✔
591

592

593
class AssignDiv(AssignOp):
1✔
594
    op = "/="
1✔
595

596

597
class ArrayAccess(LExprOperator):
1✔
598
    precedence = PRECEDENCE.SUBSCRIPT
1✔
599

600
    def __init__(self, array, indices):
1✔
601
        # Typecheck array argument
602
        if isinstance(array, Symbol):
1✔
603
            self.array = array
1✔
604
            self.dtype = array.dtype
1✔
605
        elif isinstance(array, ArrayDecl):
×
606
            self.array = array.symbol
×
607
            self.dtype = array.symbol.dtype
×
608
        else:
609
            raise ValueError("Unexpected array type %s." % (type(array).__name__,))
×
610

611
        # Allow expressions or literals as indices
612
        if not isinstance(indices, (list, tuple)):
1✔
613
            indices = (indices,)
1✔
614
        self.indices = tuple(as_lexpr(i) for i in indices)
1✔
615

616
        # Early error checking for negative array dimensions
617
        if any(isinstance(i, int) and i < 0 for i in self.indices):
1✔
618
            raise ValueError("Index value < 0.")
×
619

620
        # Additional dimension checks possible if we get an ArrayDecl instead of just a name
621
        if isinstance(array, ArrayDecl):
1✔
622
            if len(self.indices) != len(array.sizes):
×
623
                raise ValueError("Invalid number of indices.")
×
624
            ints = (int, LiteralInt)
×
625
            if any(
×
626
                (isinstance(i, ints) and isinstance(d, ints) and int(i) >= int(d))
627
                for i, d in zip(self.indices, array.sizes)
628
            ):
629
                raise ValueError("Index value >= array dimension.")
×
630

631
    def __getitem__(self, indices):
1✔
632
        """Handle nested expr[i][j]."""
633
        if isinstance(indices, list):
1✔
634
            indices = tuple(indices)
×
635
        elif not isinstance(indices, tuple):
1✔
636
            indices = (indices,)
1✔
637
        return ArrayAccess(self.array, self.indices + indices)
1✔
638

639
    def __eq__(self, other):
1✔
640
        return (
1✔
641
            isinstance(other, type(self))
642
            and self.array == other.array
643
            and self.indices == other.indices
644
        )
645

646
    def __hash__(self):
1✔
647
        return hash(self.array)
1✔
648

649
    def __repr__(self):
1✔
650
        return str(self.array) + "[" + ", ".join(str(i) for i in self.indices) + "]"
1✔
651

652

653
class Conditional(LExprOperator):
1✔
654
    precedence = PRECEDENCE.CONDITIONAL
1✔
655

656
    def __init__(self, condition, true, false):
1✔
657
        self.condition = as_lexpr(condition)
1✔
658
        self.true = as_lexpr(true)
1✔
659
        self.false = as_lexpr(false)
1✔
660
        self.dtype = merge_dtypes(self.true.dtype, self.false.dtype)
1✔
661

662
    def __eq__(self, other):
1✔
663
        return (
×
664
            isinstance(other, type(self))
665
            and self.condition == other.condition
666
            and self.true == other.true
667
            and self.false == other.false
668
        )
669

670

671
def as_lexpr(node):
1✔
672
    """Typechecks and wraps an object as a valid LExpr.
673

674
    Accepts LExpr nodes, treats int and float as literals.
675

676
    """
677
    if isinstance(node, LExpr):
1✔
678
        return node
1✔
679
    elif isinstance(node, numbers.Integral):
1✔
680
        return LiteralInt(node)
1✔
681
    elif isinstance(node, numbers.Real):
1✔
682
        return LiteralFloat(node)
1✔
683
    else:
684
        raise RuntimeError("Unexpected LExpr type %s:\n%s" % (type(node), str(node)))
×
685

686

687
class Statement(LNode):
1✔
688
    """Make an expression into a statement."""
689

690
    is_scoped = False
1✔
691

692
    def __init__(self, expr):
1✔
693
        self.expr = as_lexpr(expr)
1✔
694

695
    def __eq__(self, other):
1✔
696
        return isinstance(other, type(self)) and self.expr == other.expr
×
697

698

699
def as_statement(node):
1✔
700
    """Perform type checking on node and wrap in a suitable statement type if necessary."""
701
    if isinstance(node, StatementList) and len(node.statements) == 1:
1✔
702
        # Cleans up the expression tree a bit
703
        return node.statements[0]
1✔
704
    elif isinstance(node, Statement):
1✔
705
        # No-op
706
        return node
1✔
707
    elif isinstance(node, LExprOperator):
1✔
708
        if node.sideeffect:
1✔
709
            # Special case for using assignment expressions as statements
710
            return Statement(node)
1✔
711
        else:
712
            raise RuntimeError(
×
713
                "Trying to create a statement of lexprOperator type %s:\n%s"
714
                % (type(node), str(node))
715
            )
716
    elif isinstance(node, list):
1✔
717
        # Convenience case for list of statements
718
        if len(node) == 1:
1✔
719
            # Cleans up the expression tree a bit
720
            return as_statement(node[0])
1✔
721
        else:
722
            return StatementList(node)
1✔
723
    else:
724
        raise RuntimeError(
×
725
            "Unexpected Statement type %s:\n%s" % (type(node), str(node))
726
        )
727

728

729
class StatementList(LNode):
1✔
730
    """A simple sequence of statements. No new scopes are introduced."""
731

732
    def __init__(self, statements):
1✔
733
        self.statements = [as_statement(st) for st in statements]
1✔
734

735
    @property
1✔
736
    def is_scoped(self):
1✔
737
        return all(st.is_scoped for st in self.statements)
×
738

739
    def __eq__(self, other):
1✔
740
        return isinstance(other, type(self)) and self.statements == other.statements
×
741

742

743
class Comment(Statement):
1✔
744
    """Line comment(s) used for annotating the generated code with human readable remarks."""
745

746
    is_scoped = True
1✔
747

748
    def __init__(self, comment):
1✔
749
        assert isinstance(comment, str)
1✔
750
        self.comment = comment
1✔
751

752
    def __eq__(self, other):
1✔
753
        return isinstance(other, type(self)) and self.comment == other.comment
×
754

755

756
def commented_code_list(code, comments):
1✔
757
    """Add comment to code list if the list is not empty."""
758
    if isinstance(code, LNode):
1✔
759
        code = [code]
×
760
    assert isinstance(code, list)
1✔
761
    if code:
1✔
762
        if not isinstance(comments, (list, tuple)):
1✔
763
            comments = [comments]
1✔
764
        comments = [Comment(c) for c in comments]
1✔
765
        code = comments + code
1✔
766
    return code
1✔
767

768

769
# Type and variable declarations
770

771

772
class VariableDecl(Statement):
1✔
773
    """Declare a variable, optionally define initial value."""
774

775
    is_scoped = False
1✔
776

777
    def __init__(self, symbol, value=None):
1✔
778

779
        assert isinstance(symbol, Symbol)
1✔
780
        assert symbol.dtype is not None
1✔
781
        self.symbol = symbol
1✔
782

783
        if value is not None:
1✔
784
            value = as_lexpr(value)
1✔
785
        self.value = value
1✔
786

787
    def __eq__(self, other):
1✔
788
        return (
×
789
            isinstance(other, type(self))
790
            and self.typename == other.typename
791
            and self.symbol == other.symbol
792
            and self.value == other.value
793
        )
794

795

796
class ArrayDecl(Statement):
1✔
797
    """A declaration or definition of an array.
798

799
    Note that just setting values=0 is sufficient to initialize the
800
    entire array to zero.
801

802
    Otherwise use nested lists of lists to represent multidimensional
803
    array values to initialize to.
804

805
    """
806

807
    is_scoped = False
1✔
808

809
    def __init__(self, symbol, sizes=None, values=None, const=False):
1✔
810
        assert isinstance(symbol, Symbol)
1✔
811
        self.symbol = symbol
1✔
812
        assert symbol.dtype
1✔
813

814
        if sizes is None:
1✔
815
            assert values is not None
1✔
816
            sizes = values.shape
1✔
817
        if isinstance(sizes, int):
1✔
818
            sizes = (sizes,)
1✔
819
        self.sizes = tuple(sizes)
1✔
820

821
        if values is None:
1✔
822
            assert sizes is not None
1✔
823

824
        # NB! No type checking, assuming nested lists of literal values. Not applying as_lexpr.
825
        if isinstance(values, (list, tuple)):
1✔
826
            self.values = np.asarray(values)
×
827
        else:
828
            self.values = values
1✔
829

830
        self.const = const
1✔
831

832
    def __eq__(self, other):
1✔
833
        attributes = ("typename", "symbol", "sizes", "values")
×
834
        return isinstance(other, type(self)) and all(
×
835
            getattr(self, name) == getattr(self, name) for name in attributes
836
        )
837

838

839
def is_simple_inner_loop(code):
1✔
840
    if isinstance(code, ForRange) and is_simple_inner_loop(code.body):
×
841
        return True
×
842
    if isinstance(code, Statement) and isinstance(code.expr, AssignOp):
×
843
        return True
×
844
    return False
×
845

846

847
class ForRange(Statement):
1✔
848
    """Slightly higher-level for loop assuming incrementing an index over a range."""
849

850
    is_scoped = True
1✔
851

852
    def __init__(self, index, begin, end, body):
1✔
853
        assert isinstance(index, Symbol) or isinstance(index, MultiIndex)
1✔
854
        self.index = index
1✔
855
        self.begin = as_lexpr(begin)
1✔
856
        self.end = as_lexpr(end)
1✔
857
        assert isinstance(body, list)
1✔
858
        self.body = StatementList(body)
1✔
859

860
    def __eq__(self, other):
1✔
861
        attributes = ("index", "begin", "end", "body")
×
862
        return isinstance(other, type(self)) and all(
×
863
            getattr(self, name) == getattr(self, name) for name in attributes
864
        )
865

866

867
def _math_function(op, *args):
1✔
868
    name = op._ufl_handler_name_
1✔
869
    dtype = args[0].dtype
1✔
870
    if name in ("conj", "real") and dtype == DataType.REAL:
1✔
871
        assert len(args) == 1
×
872
        return args[0]
×
873
    if name == "imag" and dtype == DataType.REAL:
1✔
874
        assert len(args) == 1
×
875
        return LiteralFloat(0.0)
×
876
    return MathFunction(name, args)
1✔
877

878

879
# Lookup table for handler to call when the ufl_to_lnodes method (below) is
880
# called, depending on the first argument type.
881
_ufl_call_lookup = {
1✔
882
    ufl.constantvalue.IntValue: lambda x: LiteralInt(int(x)),
883
    ufl.constantvalue.FloatValue: lambda x: LiteralFloat(float(x)),
884
    ufl.constantvalue.ComplexValue: lambda x: LiteralFloat(x.value()),
885
    ufl.constantvalue.Zero: lambda x: LiteralFloat(0.0),
886
    ufl.algebra.Product: lambda x, a, b: a * b,
887
    ufl.algebra.Sum: lambda x, a, b: a + b,
888
    ufl.algebra.Division: lambda x, a, b: a / b,
889
    ufl.algebra.Abs: _math_function,
890
    ufl.algebra.Power: _math_function,
891
    ufl.algebra.Real: _math_function,
892
    ufl.algebra.Imag: _math_function,
893
    ufl.algebra.Conj: _math_function,
894
    ufl.classes.GT: lambda x, a, b: GT(a, b),
895
    ufl.classes.GE: lambda x, a, b: GE(a, b),
896
    ufl.classes.EQ: lambda x, a, b: EQ(a, b),
897
    ufl.classes.NE: lambda x, a, b: NE(a, b),
898
    ufl.classes.LT: lambda x, a, b: LT(a, b),
899
    ufl.classes.LE: lambda x, a, b: LE(a, b),
900
    ufl.classes.AndCondition: lambda x, a, b: And(a, b),
901
    ufl.classes.OrCondition: lambda x, a, b: Or(a, b),
902
    ufl.classes.NotCondition: lambda x, a: Not(a),
903
    ufl.classes.Conditional: lambda x, c, t, f: Conditional(c, t, f),
904
    ufl.classes.MinValue: _math_function,
905
    ufl.classes.MaxValue: _math_function,
906
    ufl.mathfunctions.Sqrt: _math_function,
907
    ufl.mathfunctions.Ln: _math_function,
908
    ufl.mathfunctions.Exp: _math_function,
909
    ufl.mathfunctions.Cos: _math_function,
910
    ufl.mathfunctions.Sin: _math_function,
911
    ufl.mathfunctions.Tan: _math_function,
912
    ufl.mathfunctions.Cosh: _math_function,
913
    ufl.mathfunctions.Sinh: _math_function,
914
    ufl.mathfunctions.Tanh: _math_function,
915
    ufl.mathfunctions.Acos: _math_function,
916
    ufl.mathfunctions.Asin: _math_function,
917
    ufl.mathfunctions.Atan: _math_function,
918
    ufl.mathfunctions.Erf: _math_function,
919
    ufl.mathfunctions.Atan2: _math_function,
920
    ufl.mathfunctions.MathFunction: _math_function,
921
    ufl.mathfunctions.BesselJ: _math_function,
922
    ufl.mathfunctions.BesselY: _math_function}
923

924

925
def ufl_to_lnodes(operator, *args):
1✔
926
    # Call appropriate handler, depending on the type of operator
927
    optype = type(operator)
1✔
928
    if optype in _ufl_call_lookup:
1✔
929
        return _ufl_call_lookup[optype](operator, *args)
1✔
930
    else:
931
        raise RuntimeError(f"Missing lookup for expr type {optype}.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc