• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ffcx / 18834018400

23 Oct 2025 12:21PM UTC coverage: 83.0%. Remained the same
18834018400

push

github

web-flow
`AbstractCell` members are now properties (#789)

* Adapt to cell members now properties

* Change ufl branch

* One more

* Change UFL branch for dolfinx CI

* Adapt demos

* Last?

* Change ref branch

* Apply suggestions from code review

26 of 35 new or added lines in 7 files covered. (74.29%)

2 existing lines in 1 file now uncovered.

3730 of 4494 relevant lines covered (83.0%)

0.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.39
/ffcx/ir/elementtables.py
1
# Copyright (C) 2013-2017 Martin Sandve Alnæs
2
#
3
# This file is part of FFCx. (https://www.fenicsproject.org)
4
#
5
# SPDX-License-Identifier:    LGPL-3.0-or-later
6
"""Tools for precomputed tables of terminal values."""
7

8
import logging
1✔
9
import typing
1✔
10

11
import basix.ufl
1✔
12
import numpy as np
1✔
13
import numpy.typing as npt
1✔
14
import ufl
1✔
15

16
from ffcx.definitions import entity_types
1✔
17
from ffcx.element_interface import basix_index
1✔
18
from ffcx.ir.analysis.modified_terminals import ModifiedTerminal
1✔
19
from ffcx.ir.representationutils import (
1✔
20
    QuadratureRule,
21
    create_quadrature_points_and_weights,
22
    integral_type_to_entity_dim,
23
    map_integral_points,
24
)
25

26
logger = logging.getLogger("ffcx")
1✔
27

28
# Using same defaults as np.allclose
29
default_rtol = 1e-6
1✔
30
default_atol = 1e-9
1✔
31

32
piecewise_ttypes = ("piecewise", "fixed", "ones", "zeros")
1✔
33
uniform_ttypes = ("fixed", "ones", "zeros", "uniform")
1✔
34

35

36
class ModifiedTerminalElement(typing.NamedTuple):
1✔
37
    """Modified terminal element."""
38

39
    element: basix.ufl._ElementBase
1✔
40
    averaged: str
1✔
41
    local_derivatives: tuple[int, ...]
1✔
42
    fc: int
1✔
43

44

45
class UniqueTableReferenceT(typing.NamedTuple):
1✔
46
    """Unique table reference."""
47

48
    name: str
1✔
49
    values: npt.NDArray[np.float64]
1✔
50
    offset: int | None
1✔
51
    block_size: int | None
1✔
52
    ttype: str | None
1✔
53
    is_piecewise: bool
1✔
54
    is_uniform: bool
1✔
55
    is_permuted: bool
1✔
56
    has_tensor_factorisation: bool
1✔
57
    tensor_factors: list[typing.Any] | None
1✔
58
    tensor_permutation: np.typing.NDArray[np.int32] | None
1✔
59

60

61
def equal_tables(a, b, rtol=default_rtol, atol=default_atol):
1✔
62
    """Check if two tables are equal."""
63
    a = np.asarray(a)
1✔
64
    b = np.asarray(b)
1✔
65
    if a.shape != b.shape:
1✔
66
        return False
1✔
67
    else:
68
        return np.allclose(a, b, rtol=rtol, atol=atol)
1✔
69

70

71
def clamp_table_small_numbers(
1✔
72
    table, rtol=default_rtol, atol=default_atol, numbers=(-1.0, 0.0, 1.0)
73
):
74
    """Clamp almost 0,1,-1 values to integers. Returns new table."""
75
    # Get shape of table and number of columns, defined as the last axis
76
    table = np.asarray(table)
1✔
77
    for n in numbers:
1✔
78
        table[np.where(np.isclose(table, n, rtol=rtol, atol=atol))] = n
1✔
79
    return table
1✔
80

81

82
def get_ffcx_table_values(
1✔
83
    points,
84
    cell,
85
    integral_type,
86
    element,
87
    avg,
88
    entity_type: entity_types,
89
    derivative_counts,
90
    flat_component,
91
    codim,
92
):
93
    """Extract values from FFCx element table.
94

95
    Returns a 3D numpy array with axes
96
    (entity number, quadrature point number, dof number)
97
    """
98
    deriv_order = sum(derivative_counts)
1✔
99

100
    if integral_type in ufl.custom_integral_types:
1✔
101
        # Use quadrature points on cell for analysis in custom integral types
102
        integral_type = "cell"
×
103
        assert not avg
×
104

105
    if integral_type == "expression":
1✔
106
        # FFCx tables for expression are generated as either interior cell points
107
        # or points on a facet
108
        if entity_type == "cell":
1✔
109
            integral_type = "cell"
1✔
110
        else:
111
            integral_type = "exterior_facet"
1✔
112

113
    if avg in ("cell", "facet"):
1✔
114
        # Redefine points to compute average tables
115

116
        # Make sure this is not called with points, that doesn't make sense
117
        # assert points is None
118

119
        # Not expecting derivatives of averages
120
        assert not any(derivative_counts)
×
121
        assert deriv_order == 0
×
122

123
        # Doesn't matter if it's exterior or interior facet integral,
124
        # just need a valid integral type to create quadrature rule
125
        if avg == "cell":
×
126
            integral_type = "cell"
×
127
        elif avg == "facet":
×
128
            integral_type = "exterior_facet"
×
129

130
        if isinstance(element, basix.ufl._QuadratureElement):
×
131
            points = element._points
×
132
            weights = element._weights
×
133
        else:
134
            # Make quadrature rule and get points and weights
135
            points, weights = create_quadrature_points_and_weights(
×
136
                integral_type, cell, element.embedded_superdegree(), "default", [element]
137
            )
138

139
    # Tabulate table of basis functions and derivatives in points for each entity
140
    tdim = cell.topological_dimension
1✔
141
    entity_dim = integral_type_to_entity_dim(integral_type, tdim)
1✔
142
    num_entities = cell.num_sub_entities(entity_dim)
1✔
143

144
    # Extract arrays for the right scalar component
145
    component_tables = []
1✔
146
    component_element, offset, stride = element.get_component_element(flat_component)
1✔
147
    for entity in range(num_entities):
1✔
148
        if codim == 0:
1✔
149
            entity_points = map_integral_points(points, integral_type, cell, entity)
1✔
150
        elif codim == 1 or codim == 2:
1✔
151
            entity_points = points
1✔
152
        else:
153
            raise RuntimeError("Codimension > 1 isn't supported.")
×
154
        tbl = component_element.tabulate(deriv_order, entity_points)
1✔
155
        tbl = tbl[basix_index(derivative_counts)]
1✔
156
        component_tables.append(tbl)
1✔
157

158
    if avg in ("cell", "facet"):
1✔
159
        # Compute numeric integral of the each component table
160
        wsum = sum(weights)
×
161
        for entity, tbl in enumerate(component_tables):
×
162
            num_dofs = tbl.shape[1]
×
163
            tbl = np.dot(tbl, weights) / wsum
×
164
            tbl = np.reshape(tbl, (1, num_dofs))
×
165
            component_tables[entity] = tbl
×
166

167
    # Loop over entities and fill table blockwise (each block = points x dofs)
168
    # Reorder axes as (points, dofs) instead of (dofs, points)
169
    assert len(component_tables) == num_entities
1✔
170
    num_points, num_dofs = component_tables[0].shape
1✔
171
    shape = (1, num_entities, num_points, num_dofs)
1✔
172
    res = np.zeros(shape)
1✔
173
    for entity in range(num_entities):
1✔
174
        res[:, entity, :, :] = component_tables[entity]
1✔
175

176
    return {"array": res, "offset": offset, "stride": stride}
1✔
177

178

179
def generate_psi_table_name(
1✔
180
    quadrature_rule: QuadratureRule,
181
    element_counter,
182
    averaged: str,
183
    entity_type: entity_types,
184
    derivative_counts,
185
    flat_component,
186
):
187
    """Generate a name for the psi table.
188

189
    Format:
190
        FE#_C#_D###[_AC|_AF|][_F|V][_Q#], where '#' will be an integer value and:
191
        - FE is a simple counter to distinguish the various bases, it will be
192
          assigned in an arbitrary fashion.
193
        - C is the component number if any (this does not yet take into account
194
          tensor valued functions)
195
        - D is the number of derivatives in each spatial direction if any.
196
          If the element is defined in 3D, then D012 means d^3(*)/dydz^2.
197
        - AC marks that the element values are averaged over the cell
198
        - AF marks that the element values are averaged over the facet
199
        - F marks that the first array dimension enumerates facets on the cell
200
        - V marks that the first array dimension enumerates vertices on the cell
201
        - Q unique ID of quadrature rule, to distinguish between tables
202
          in a mixed quadrature rule setting
203
    """
204
    name = f"FE{element_counter:d}"
1✔
205
    if flat_component is not None:
1✔
206
        name += f"_C{flat_component:d}"
1✔
207
    if any(derivative_counts):
1✔
208
        name += "_D" + "".join(str(d) for d in derivative_counts)
1✔
209
    name += {None: "", "cell": "_AC", "facet": "_AF"}[averaged]
1✔
210
    name += {"cell": "", "facet": "_F", "vertex": "_V", "ridge": "_R"}[entity_type]
1✔
211
    name += f"_Q{quadrature_rule.id()}"
1✔
212
    return name
1✔
213

214

215
def get_modified_terminal_element(mt) -> ModifiedTerminalElement | None:
1✔
216
    """Get modified terminal element."""
217
    gd = mt.global_derivatives
1✔
218
    ld = mt.local_derivatives
1✔
219
    domain = ufl.domain.extract_unique_domain(mt.terminal)
1✔
220
    # Extract element from FormArguments and relevant GeometricQuantities
221
    if isinstance(mt.terminal, ufl.classes.FormArgument):
1✔
222
        if gd and mt.reference_value:
1✔
223
            raise RuntimeError("Global derivatives of reference values not defined.")
×
224
        elif ld and not mt.reference_value:
1✔
225
            raise RuntimeError("Local derivatives of global values not defined.")
×
226
        element = mt.terminal.ufl_function_space().ufl_element()  # type: ignore
1✔
227
        fc = mt.flat_component
1✔
228
    elif isinstance(mt.terminal, ufl.classes.SpatialCoordinate):
1✔
229
        if mt.reference_value:
1✔
230
            raise RuntimeError("Not expecting reference value of x.")
×
231
        if gd:
1✔
232
            raise RuntimeError("Not expecting global derivatives of x.")
×
233
        assert isinstance(domain, ufl.Mesh)
1✔
234
        element = domain.ufl_coordinate_element()
1✔
235
        if not ld:
1✔
236
            fc = mt.flat_component
1✔
237
        else:
238
            # Actually the Jacobian expressed as reference_grad(x)
239
            fc = mt.flat_component  # x-component
×
240
            assert len(mt.component) == 1
×
241
            assert mt.component[0] == mt.flat_component
×
242
    elif isinstance(mt.terminal, ufl.classes.Jacobian):
1✔
243
        if mt.reference_value:
1✔
244
            raise RuntimeError("Not expecting reference value of J.")
×
245
        if gd:
1✔
246
            raise RuntimeError("Not expecting global derivatives of J.")
×
247
        assert isinstance(domain, ufl.Mesh)
1✔
248
        element = domain.ufl_coordinate_element()
1✔
249
        assert len(mt.component) == 2
1✔
250
        # Translate component J[i,d] to x element context rgrad(x[i])[d]
251
        fc, d = mt.component  # x-component, derivative
1✔
252
        ld = tuple(sorted((d,) + ld))
1✔
253
    else:
254
        return None
1✔
255

256
    assert (mt.averaged is None) or not (ld or gd)
1✔
257
    assert isinstance(domain, ufl.Mesh)
1✔
258

259
    # Change derivatives format for table lookup
260
    tdim = domain.topological_dimension
1✔
261
    # The input `ld` is a tuple containing the index access of a recursive application of
262
    # reference gradient, e.g. [0, 1, 2] means that the modified terminal is
263
    # a reference_grad(reference_grad(reference_grad(expr)))[0][1][2],
264
    # we have a derivative in each direction (x, y, z).
265
    # This is converted into a tuple indicating the counts of derivatives in each direction.
266
    # This means that if we have a reference_value as a modified terminal, the
267
    # local_derivatives that we store in the `ModifiedTerminalElement` should be a tuple of
268
    # length topological dimension with only zeros. This is later used to access the correct
269
    # table values from `basix.tabulate`.
270
    # To access the correct table values for a 0D domains, we need this index to be `(0, )`,
271
    # as `basix.index` does not exist for 0D domains.
272
    num_derivatives_per_ref_component = 1 if tdim == 0 else tdim
1✔
273
    local_derivatives: tuple[int, ...] = tuple(
1✔
274
        ld.count(i) for i in range(num_derivatives_per_ref_component)
275
    )
276
    return ModifiedTerminalElement(element, mt.averaged, local_derivatives, fc)
1✔
277

278

279
def permute_quadrature_interval(points, reflections=0):
1✔
280
    """Permute quadrature points for an interval."""
281
    output = points.copy()
1✔
282
    for p in output:
1✔
283
        assert len(p) < 2 or np.isclose(p[1], 0)
1✔
284
        assert len(p) < 3 or np.isclose(p[2], 0)
1✔
285
    for _ in range(reflections):
1✔
286
        for n, p in enumerate(output):
1✔
287
            output[n] = [1 - p[0]]
1✔
288
    return output
1✔
289

290

291
def permute_quadrature_triangle(points, reflections=0, rotations=0):
1✔
292
    """Permute quadrature points for a triangle."""
293
    output = points.copy()
×
294
    for p in output:
×
295
        assert len(p) < 3 or np.isclose(p[2], 0)
×
296
    for _ in range(rotations):
×
297
        for n, p in enumerate(output):
×
298
            output[n] = [p[1], 1 - p[0] - p[1]]
×
299
    for _ in range(reflections):
×
300
        for n, p in enumerate(output):
×
301
            output[n] = [p[1], p[0]]
×
302
    return output
×
303

304

305
def permute_quadrature_quadrilateral(points, reflections=0, rotations=0):
1✔
306
    """Permute quadrature points for a quadrilateral."""
307
    output = points.copy()
×
308
    for p in output:
×
309
        assert len(p) < 3 or np.isclose(p[2], 0)
×
310
    for _ in range(rotations):
×
311
        for n, p in enumerate(output):
×
312
            output[n] = [p[1], 1 - p[0]]
×
313
    for _ in range(reflections):
×
314
        for n, p in enumerate(output):
×
315
            output[n] = [p[1], p[0]]
×
316
    return output
×
317

318

319
def build_optimized_tables(
1✔
320
    quadrature_rule: QuadratureRule,
321
    cell: ufl.Cell,
322
    integral_type: typing.Literal["interior_facet", "exterior_facet", "ridge", "cell", "vertex"],
323
    entity_type: entity_types,
324
    modified_terminals: typing.Iterable[ModifiedTerminal],
325
    existing_tables: dict[str, npt.NDArray[np.float64]],
326
    use_sum_factorization: bool,
327
    is_mixed_dim: bool,
328
    rtol: float = default_rtol,
329
    atol: float = default_atol,
330
) -> dict[str | ModifiedTerminal, UniqueTableReferenceT]:
331
    """Build the element tables needed for a list of modified terminals.
332

333
    Args:
334
        quadrature_rule: The quadrature rule used for the tables.
335
        cell: The cell type of the domain the tables will be used with.
336
        entity_type: The entity type (vertex,edge,facet,cell) that the tables are evaluated for.
337
        integral_type: The type of integral the tables are used for.
338
        modified_terminals: Ordered sequence of unique modified terminals
339
        existing_tables: Register of tables that already exist and reused.
340
        use_sum_factorization: Use sum factorization for tensor product elements.
341
        is_mixed_dim: Mixed dimensionality of the domain.
342
        rtol: Relative tolerance for clamping tables to -1,0 or 1
343
        atol: Absolute tolerance for clamping tables to -1,0 or 1
344

345
    Returns:
346
        Dictionary mapping each modified terminal to the a unique table reference.
347
        If ``use_sum_factorization`` is turned on, the map also contains the map
348
        from the unique table reference for the tensor product factorization
349
        to the name of the modified terminal.
350
    """
351
    # Add to element tables
352
    analysis = {}
1✔
353
    for mt in modified_terminals:
1✔
354
        res = get_modified_terminal_element(mt)
1✔
355
        if res:
1✔
356
            analysis[mt] = res
1✔
357

358
    # Build element numbering using topological ordering so subelements
359
    # get priority
360
    all_elements = [res[0] for res in analysis.values()]
1✔
361
    unique_elements = ufl.algorithms.sort_elements(
1✔
362
        set(ufl.algorithms.analysis.extract_sub_elements(all_elements))
363
    )
364
    element_numbers = {element: i for i, element in enumerate(unique_elements)}
1✔
365
    mt_tables: dict[str | ModifiedTerminal, UniqueTableReferenceT] = {}
1✔
366

367
    _existing_tables = existing_tables.copy()
1✔
368

369
    all_tensor_factors: list[UniqueTableReferenceT] = []
1✔
370
    tensor_n = 0
1✔
371

372
    for mt in modified_terminals:
1✔
373
        res = analysis.get(mt)
1✔
374
        if not res:
1✔
375
            continue
1✔
376
        element, avg, local_derivatives, flat_component = res
1✔
377

378
        # Generate table and store table name with modified terminal
379

380
        # Build name for this particular table
381
        element_number = element_numbers[element]
1✔
382
        name = generate_psi_table_name(
1✔
383
            quadrature_rule, element_number, avg, entity_type, local_derivatives, flat_component
384
        )
385

386
        # FIXME - currently just recalculate the tables every time,
387
        # only reusing them if they match numerically.
388
        # It should be possible to reuse the cached tables by name, but
389
        # the dofmap offset may differ due to restriction.
390

391
        tdim = cell.topological_dimension
1✔
392
        codim = tdim - element.cell.topological_dimension
1✔
393
        assert codim >= 0
1✔
394
        if codim > 2:
1✔
395
            raise RuntimeError("Codimension > 2 isn't supported.")
×
396

397
        # Only permute quadrature rules for interior facets integrals and for
398
        # the codim zero element in mixed-dimensional integrals. The latter is
399
        # needed because a cell may see its sub-entities as being oriented
400
        # differently to their global orientation
401
        if (
1✔
402
            integral_type == "interior_facet"
403
            or integral_type == "ridge"
404
            or (is_mixed_dim and codim == 0)
405
        ):
406
            if entity_type == "facet":
1✔
407
                if tdim == 1 or codim == 1:
1✔
408
                    # Do not add permutations if codim-1 as facets have already gotten a global
409
                    # orientation in DOLFINx
410
                    t = get_ffcx_table_values(
×
411
                        quadrature_rule.points,
412
                        cell,
413
                        integral_type,
414
                        element,
415
                        avg,
416
                        entity_type,
417
                        local_derivatives,
418
                        flat_component,
419
                        codim,
420
                    )
421
                elif tdim == 2:
1✔
422
                    new_table = []
1✔
423
                    for ref in range(2):
1✔
424
                        new_table.append(
1✔
425
                            get_ffcx_table_values(
426
                                permute_quadrature_interval(quadrature_rule.points, ref),
427
                                cell,
428
                                integral_type,
429
                                element,
430
                                avg,
431
                                entity_type,
432
                                local_derivatives,
433
                                flat_component,
434
                                codim,
435
                            )
436
                        )
437

438
                    t = new_table[0]
1✔
439
                    t["array"] = np.vstack([td["array"] for td in new_table])
1✔
440
                elif tdim == 3:
×
NEW
441
                    cell_type = cell.cellname
×
442
                    if cell_type == "tetrahedron":
×
443
                        new_table = []
×
444
                        for rot in range(3):
×
445
                            for ref in range(2):
×
446
                                new_table.append(
×
447
                                    get_ffcx_table_values(
448
                                        permute_quadrature_triangle(
449
                                            quadrature_rule.points, ref, rot
450
                                        ),
451
                                        cell,
452
                                        integral_type,
453
                                        element,
454
                                        avg,
455
                                        entity_type,
456
                                        local_derivatives,
457
                                        flat_component,
458
                                        codim,
459
                                    )
460
                                )
461
                        t = new_table[0]
×
462
                        t["array"] = np.vstack([td["array"] for td in new_table])
×
463
                    elif cell_type == "hexahedron":
×
464
                        new_table = []
×
465
                        for rot in range(4):
×
466
                            for ref in range(2):
×
467
                                new_table.append(
×
468
                                    get_ffcx_table_values(
469
                                        permute_quadrature_quadrilateral(
470
                                            quadrature_rule.points, ref, rot
471
                                        ),
472
                                        cell,
473
                                        integral_type,
474
                                        element,
475
                                        avg,
476
                                        entity_type,
477
                                        local_derivatives,
478
                                        flat_component,
479
                                        codim,
480
                                    )
481
                                )
482
                        t = new_table[0]
×
483
                        t["array"] = np.vstack([td["array"] for td in new_table])
×
484
            elif entity_type == "ridge":
1✔
485
                if tdim < 3 or codim == 2:
1✔
486
                    # If ridge integral over vertex no permutation is needed,
487
                    # or if it is a single domain ridge integral,
488
                    # as ridges has a global orientation in DOLFINx.
489
                    t = get_ffcx_table_values(
1✔
490
                        quadrature_rule.points,
491
                        cell,
492
                        integral_type,
493
                        element,
494
                        avg,
495
                        entity_type,
496
                        local_derivatives,
497
                        flat_component,
498
                        codim,
499
                    )
500
                else:
501
                    new_table = []
1✔
502
                    for ref in range(2):
1✔
503
                        new_table.append(
1✔
504
                            get_ffcx_table_values(
505
                                permute_quadrature_interval(quadrature_rule.points, ref),
506
                                cell,
507
                                integral_type,
508
                                element,
509
                                avg,
510
                                entity_type,
511
                                local_derivatives,
512
                                flat_component,
513
                                codim,
514
                            )
515
                        )
516
                    t = new_table[0]
1✔
517
                    t["array"] = np.vstack([td["array"] for td in new_table])
1✔
518
        else:
519
            t = get_ffcx_table_values(
1✔
520
                quadrature_rule.points,
521
                cell,
522
                integral_type,
523
                element,
524
                avg,
525
                entity_type,
526
                local_derivatives,
527
                flat_component,
528
                codim,
529
            )
530
        # Clean up table
531
        tbl = clamp_table_small_numbers(t["array"], rtol=rtol, atol=atol)
1✔
532
        tabletype = analyse_table_type(tbl)
1✔
533

534
        if tabletype in piecewise_ttypes:
1✔
535
            # Reduce table to dimension 1 along num_points axis in generated code
536
            tbl = tbl[:, :, :1, :]
1✔
537
        if tabletype in uniform_ttypes:
1✔
538
            # Reduce table to dimension 1 along num_entities axis in generated code
539
            tbl = tbl[:, :1, :, :]
1✔
540
        is_permuted = is_permuted_table(tbl)
1✔
541
        if not is_permuted:
1✔
542
            # Reduce table along num_perms axis
543
            tbl = tbl[:1, :, :, :]
1✔
544

545
        # Check for existing identical table
546
        is_new_table = True
1✔
547
        for table_name in _existing_tables:
1✔
548
            # FIXME: should we pass in atol and rtol here?
549
            if equal_tables(tbl, _existing_tables[table_name]):
1✔
550
                name = table_name
1✔
551
                tbl = _existing_tables[name]
1✔
552
                is_new_table = False
1✔
553
                break
1✔
554

555
        if is_new_table:
1✔
556
            _existing_tables[name] = tbl
1✔
557

558
        cell_offset = 0
1✔
559

560
        if use_sum_factorization and (not quadrature_rule.has_tensor_factors):
1✔
561
            raise RuntimeError("Sum factorization not available for this quadrature rule.")
×
562

563
        tensor_factors: list[UniqueTableReferenceT] | None = None
1✔
564
        tensor_perm = None
1✔
565
        if (
1✔
566
            use_sum_factorization
567
            and element.has_tensor_product_factorisation
568
            and len(element.get_tensor_product_representation()) == 1
569
            and quadrature_rule.has_tensor_factors
570
        ):
571
            factors = element.get_tensor_product_representation()
1✔
572

573
            tensor_factors = []
1✔
574
            for i, j in enumerate(factors[0]):
1✔
575
                pts = quadrature_rule.tensor_factors[i][0]
1✔
576
                d = local_derivatives[i]
1✔
577
                sub_tbl = j.tabulate(d, pts)[d]
1✔
578
                sub_tbl = sub_tbl.reshape(1, 1, sub_tbl.shape[0], sub_tbl.shape[1])
1✔
579
                for tensor_factor in all_tensor_factors:
1✔
580
                    if tensor_factor.values.shape == sub_tbl.shape and np.allclose(
1✔
581
                        tensor_factor.values, sub_tbl
582
                    ):
583
                        tensor_factors.append(tensor_factor)
1✔
584
                        break
1✔
585
                else:
586
                    # FIXME: The inputs here does not match the type-hints of
587
                    # unique_table_reference
588
                    ut = UniqueTableReferenceT(
1✔
589
                        f"FE_TF{tensor_n}",
590
                        sub_tbl,
591
                        None,
592
                        None,
593
                        None,
594
                        False,
595
                        False,
596
                        False,
597
                        False,
598
                        None,
599
                        None,
600
                    )
601
                    all_tensor_factors.append(ut)
1✔
602
                    tensor_factors.append(ut)
1✔
603
                    mt_tables[ut.name] = ut
1✔
604
                    tensor_n += 1
1✔
605

606
            tensor_perm = factors[0][1]
1✔
607

608
        if mt.restriction == "-" and isinstance(mt.terminal, ufl.classes.FormArgument):
1✔
609
            # offset = 0 or number of element dofs, if restricted to "-"
610
            cell_offset = element.dim
1✔
611

612
        offset = cell_offset + t["offset"]
1✔
613
        block_size = t["stride"]
1✔
614
        # tables is just np.arrays, mt_tables hold metadata too
615
        # FIXME: type-hinting of tensor factors is not correct
616
        mt_tables[mt] = UniqueTableReferenceT(
1✔
617
            name,
618
            tbl,
619
            offset,
620
            block_size,
621
            tabletype,
622
            tabletype in piecewise_ttypes,
623
            tabletype in uniform_ttypes,
624
            is_permuted,
625
            tensor_factors is not None,
626
            tensor_factors,
627
            tensor_perm,
628
        )
629

630
    return mt_tables
1✔
631

632

633
def is_zeros_table(table, rtol=default_rtol, atol=default_atol):
1✔
634
    """Check if table values are all zero."""
635
    return np.prod(table.shape) == 0 or np.allclose(
1✔
636
        table, np.zeros(table.shape), rtol=rtol, atol=atol
637
    )
638

639

640
def is_ones_table(table, rtol=default_rtol, atol=default_atol):
1✔
641
    """Check if table values are all one."""
642
    return np.allclose(table, np.ones(table.shape), rtol=rtol, atol=atol)
1✔
643

644

645
def is_quadrature_table(table, rtol=default_rtol, atol=default_atol):
1✔
646
    """Check if table is a quadrature table."""
647
    _, num_entities, num_points, num_dofs = table.shape
1✔
648
    Id = np.eye(num_points)
1✔
649
    return num_points == num_dofs and all(
1✔
650
        np.allclose(table[0, i, :, :], Id, rtol=rtol, atol=atol) for i in range(num_entities)
651
    )
652

653

654
def is_permuted_table(table, rtol=default_rtol, atol=default_atol):
1✔
655
    """Check if table is permuted."""
656
    return not all(
1✔
657
        np.allclose(table[0, :, :, :], table[i, :, :, :], rtol=rtol, atol=atol)
658
        for i in range(1, table.shape[0])
659
    )
660

661

662
def is_piecewise_table(table, rtol=default_rtol, atol=default_atol):
1✔
663
    """Check if table is piecewise."""
664
    return all(
1✔
665
        np.allclose(table[0, :, 0, :], table[0, :, i, :], rtol=rtol, atol=atol)
666
        for i in range(1, table.shape[2])
667
    )
668

669

670
def is_uniform_table(table, rtol=default_rtol, atol=default_atol):
1✔
671
    """Check if table is uniform."""
672
    return all(
1✔
673
        np.allclose(table[0, 0, :, :], table[0, i, :, :], rtol=rtol, atol=atol)
674
        for i in range(1, table.shape[1])
675
    )
676

677

678
def analyse_table_type(table, rtol=default_rtol, atol=default_atol):
1✔
679
    """Analyse table type."""
680
    if is_zeros_table(table, rtol=rtol, atol=atol):
1✔
681
        # Table is empty or all values are 0.0
682
        ttype = "zeros"
1✔
683
    elif is_ones_table(table, rtol=rtol, atol=atol):
1✔
684
        # All values are 1.0
685
        ttype = "ones"
1✔
686
    elif is_quadrature_table(table, rtol=rtol, atol=atol):
1✔
687
        # Identity matrix mapping points to dofs (separately on each entity)
688
        ttype = "quadrature"
1✔
689
    else:
690
        # Equal for all points on a given entity
691
        piecewise = is_piecewise_table(table, rtol=rtol, atol=atol)
1✔
692
        uniform = is_uniform_table(table, rtol=rtol, atol=atol)
1✔
693

694
        if piecewise and uniform:
1✔
695
            # Constant for all points and all entities
696
            ttype = "fixed"
1✔
697
        elif piecewise:
1✔
698
            # Constant for all points on each entity separately
699
            ttype = "piecewise"
1✔
700
        elif uniform:
1✔
701
            # Equal on all entities
702
            ttype = "uniform"
1✔
703
        else:
704
            # Varying over points and entities
705
            ttype = "varying"
1✔
706
    return ttype
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc