• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ufl / 17525683651

07 Sep 2025 07:41AM UTC coverage: 75.929% (+0.01%) from 75.917%
17525683651

Pull #416

github

web-flow
Update comment
Pull Request #416: Extend `ufl.extract_blocks` to preserving the initial arguments of an element with sub-spaces (mixed-element)

19 of 21 new or added lines in 2 files covered. (90.48%)

9 existing lines in 1 file now uncovered.

8949 of 11786 relevant lines covered (75.93%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.34
/ufl/algorithms/formsplitter.py
1
"""Extract part of a form in a mixed FunctionSpace."""
2

3
# Copyright (C) 2016-2024 Chris Richardson and Lawrence Mitchell
4
#
5
# This file is part of UFL (https://www.fenicsproject.org)
6
#
7
# SPDX-License-Identifier:    LGPL-3.0-or-later
8
#
9
# Modified by Cecile Daversin-Catty, 2018
10
# Modified by Jørgen S. Dokken, 2025
11

12
import numpy as np
1✔
13

14
from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
1✔
15
from ufl.argument import Argument
1✔
16
from ufl.classes import FixedIndex, ListTensor
1✔
17
from ufl.constantvalue import Zero
1✔
18
from ufl.corealg.multifunction import MultiFunction
1✔
19
from ufl.functionspace import FunctionSpace
1✔
20
from ufl.tensors import as_vector
1✔
21

22

23
class FormSplitter(MultiFunction):
1✔
24
    """Form splitter."""
25

26
    def __init__(self, replace_argument: bool = True):
1✔
27
        """Initialize form splitter.
28

29
        Args:
30
            replace_argument: If True, replace the argument by a new argument
31
                in the sub-function space. If False, keep the original argument.
32
                This is useful for instance when diagonalizing a form with a mixed-element
33
                form, where we want to keep the original argument.
34
        """
35
        MultiFunction.__init__(self)
1✔
36
        self.idx = [None, None]
1✔
37
        self.replace_argument = replace_argument
1✔
38

39
    def split(self, form, ix, iy=None):
1✔
40
        """Split form based on the argument part/number."""
41
        # Remember which block to extract
42
        self.idx = [ix, iy]
1✔
43
        return map_integrand_dags(self, form)
1✔
44

45
    def argument(self, obj):
1✔
46
        """Apply to argument."""
47
        if obj.part() is not None:
1✔
48
            # Mixed element built from MixedFunctionSpace,
49
            # whose sub-function spaces are indexed by obj.part()
50
            if self.idx[obj.number()] is None:
1✔
51
                return Zero(obj.ufl_shape)
1✔
52
            if obj.part() == self.idx[obj.number()]:
1✔
53
                return obj
1✔
54
            else:
55
                return Zero(obj.ufl_shape)
1✔
56
        else:
57
            # Mixed element built from MixedElement,
58
            # whose sub-elements need their function space to be created
59
            Q = obj.ufl_function_space()
1✔
60
            dom = Q.ufl_domain()
1✔
61
            sub_elements = obj.ufl_element().sub_elements
1✔
62

63
            # If not a mixed element, do nothing
64
            if len(sub_elements) == 0:
1✔
UNCOV
65
                return obj
×
66

67
            args = []
1✔
68
            counter = 0
1✔
69
            for i, sub_elem in enumerate(sub_elements):
1✔
70
                Q_i = FunctionSpace(dom, sub_elem)
1✔
71
                a = Argument(Q_i, obj.number(), part=obj.part())
1✔
72
                if self.replace_argument:
1✔
73
                    if i == self.idx[obj.number()]:
1✔
74
                        args.extend(a[j] for j in np.ndindex(a.ufl_shape))
1✔
75
                    else:
76
                        args.extend(Zero() for _ in np.ndindex(a.ufl_shape))
1✔
77
                else:
78
                    # If we are not replacing the argument, we need to insert
79
                    # the original argument at the right place in the vector.
80
                    # Mixed elements are flattened, thus we need to keep track of
81
                    # the position in the flattened vector.
82
                    if i == self.idx[obj.number()]:
1✔
83
                        if a.ufl_shape == ():
1✔
84
                            args.append(obj[counter])
1✔
85
                        else:
86
                            args.extend(
1✔
87
                                obj[counter + j] for j, _ in enumerate(np.ndindex(a.ufl_shape))
88
                            )
89
                    else:
90
                        args.extend(Zero() for _ in np.ndindex(a.ufl_shape))
1✔
91
                    counter += int(np.prod(a.ufl_shape))
1✔
92
            return as_vector(args)
1✔
93

94
    def indexed(self, o, child, multiindex):
1✔
95
        """Extract indexed entry if multindices are fixed.
96

97
        This avoids tensors like (v_0, 0)[1] to be created.
98
        """
99
        indices = multiindex.indices()
1✔
100
        if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
1✔
101
            if len(indices) == 1:
1✔
102
                return child[indices[0]]
1✔
UNCOV
103
            elif len(indices) == len(child.ufl_operands) and all(
×
104
                k == int(i) for k, i in enumerate(indices)
105
            ):
UNCOV
106
                return child
×
107
            else:
UNCOV
108
                return ListTensor(*(child[i] for i in indices))
×
109
        return self.expr(o, child, multiindex)
1✔
110

111
    def multi_index(self, obj):
1✔
112
        """Apply to multi_index."""
113
        return obj
1✔
114

115
    def restricted(self, o):
1✔
116
        """Apply to a restricted function."""
117
        # If we hit a restriction first apply form splitter to argument, then check for zero
118
        op_split = map_expr_dag(self, o.ufl_operands[0])
1✔
119
        if isinstance(op_split, Zero):
1✔
120
            return op_split
1✔
121
        else:
122
            return op_split(o._side)
1✔
123

124
    expr = MultiFunction.reuse_if_untouched
1✔
125

126

127
def extract_blocks(
1✔
128
    form,
129
    i: int | None = None,
130
    j: int | None = None,
131
    arity: int | None = None,
132
    replace_argument: bool = True,
133
):
134
    """Extract blocks of a form.
135

136
    If arity is 0, returns the form.
137
    If arity is 1, return the ith block. If ``i`` is ``None``, return all blocks.
138
    If arity is 2, return the ``(i,j)`` entry. If ``j`` is ``None``, return the ith row.
139

140
    If neither `i` nor `j` are set, return all blocks (as a scalar, vector or tensor).
141

142
    Args:
143
        form: A form
144
        i: Index of the block to extract. If set to ``None``, ``j`` must be None.
145
        j: Index of the block to extract.
146
        arity: Arity of the form. If not set, it will be inferred from the form.
147
        replace_argument: If True, replace the argument by a new argument
148
            in the (collapsed) sub-function space. If False, keep the original argument.
149
    """
150
    if i is None and j is not None:
1✔
UNCOV
151
        raise RuntimeError(f"Cannot extract block with {j=} and {i=}.")
×
152

153
    fs = FormSplitter(replace_argument=replace_argument)
1✔
154
    arguments = form.arguments()
1✔
155

156
    if arity is None:
1✔
157
        numbers = tuple(sorted(set(a.number() for a in arguments)))
1✔
158
        arity = len(numbers)
1✔
159

160
    assert arity <= 2
1✔
161
    if arity == 0:
1✔
162
        return (form,)
1✔
163

164
    # If mixed element, each argument has no sub-elements
165
    parts = tuple(sorted(set(part for a in arguments if (part := a.part()) is not None)))
1✔
166
    if parts == ():
1✔
167
        if i is None and j is None:
1✔
UNCOV
168
            num_sub_elements = arguments[0].ufl_element().num_sub_elements
×
169
            # If form has no sub elements, return the form itself.
NEW
170
            if num_sub_elements == 0:
×
NEW
171
                return form
×
172
            forms = []
×
173
            for pi in range(num_sub_elements):
×
174
                form_i: list[object | None] = []
×
175
                for pj in range(num_sub_elements):
×
176
                    f = fs.split(form, pi, pj)
×
177
                    if f.empty():
×
178
                        form_i.append(None)
×
179
                    else:
UNCOV
180
                        form_i.append(f)
×
181
                forms.append(tuple(form_i))
×
182
            return tuple(forms)
×
183
        else:
184
            return fs.split(form, i, j)
1✔
185

186
    # If mixed function space, each argument has sub-elements
187
    forms = []
1✔
188
    num_parts = len(parts)
1✔
189
    for pi in range(num_parts):
1✔
190
        form_i = []
1✔
191
        if arity > 1:
1✔
192
            for pj in range(num_parts):
1✔
193
                f = fs.split(form, pi, pj)
1✔
194
                # Ignore empty forms and rank 0 or 1 forms
195
                if f.empty() or len(f.arguments()) != 2:
1✔
196
                    form_i.append(None)
1✔
197
                else:
198
                    form_i.append(f)
1✔
199
            forms.append(tuple(form_i))
1✔
200
        else:
201
            f = fs.split(form, pi)
1✔
202
            # Ignore empty forms and bilinear forms
203
            if f.empty() or len(f.arguments()) != 1:
1✔
UNCOV
204
                forms.append(None)  # type: ignore
×
205
            else:
206
                forms.append(f)
1✔
207

208
    try:
1✔
209
        forms_tuple = tuple(forms)
1✔
UNCOV
210
    except TypeError:
×
211
        # Only one form returned
212
        forms_tuple = (forms,)  # type: ignore
×
213
    if i is not None:
1✔
214
        if (num_rows := len(forms_tuple)) <= i:
1✔
215
            raise RuntimeError(f"Cannot extract block {i} from form with {num_rows} blocks.")
×
216
        if arity > 1 and j is not None:
1✔
217
            if (num_cols := len(forms_tuple[i])) <= j:
1✔
218
                raise RuntimeError(
×
219
                    f"Cannot extract block {i},{j} from form with {num_rows}x{num_cols} blocks."
220
                )
221
            return forms_tuple[i][j]
1✔
222
        else:
223
            return forms_tuple[i]
1✔
224
    else:
225
        return forms_tuple
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc