• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ufl / 18629405325

19 Oct 2025 10:56AM UTC coverage: 77.06% (+0.4%) from 76.622%
18629405325

Pull #401

github

schnellerhase
Ruff
Pull Request #401: Removal of custom type system

494 of 533 new or added lines in 41 files covered. (92.68%)

6 existing lines in 2 files now uncovered.

9325 of 12101 relevant lines covered (77.06%)

0.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.31
/ufl/algorithms/apply_derivatives.py
1
"""Apply derivatives algorithm which computes the derivatives of a form of expression."""
2

3
# Copyright (C) 2008-2016 Martin Sandve Alnæs
4
#
5
# This file is part of UFL (https://www.fenicsproject.org)
6
#
7
# SPDX-License-Identifier:    LGPL-3.0-or-later
8

9
from __future__ import annotations
1✔
10

11
import warnings
1✔
12
from functools import singledispatchmethod
1✔
13
from math import pi
1✔
14

15
import numpy as np
1✔
16

17
from ufl.action import Action
1✔
18
from ufl.algorithms.analysis import extract_arguments, extract_coefficients
1✔
19
from ufl.algorithms.map_integrands import map_integrands
1✔
20
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
1✔
21
from ufl.argument import Argument, BaseArgument, Coargument
1✔
22
from ufl.averaging import CellAvg, FacetAvg
1✔
23
from ufl.checks import is_cellwise_constant
1✔
24
from ufl.classes import (
1✔
25
    Abs,
26
    CellCoordinate,
27
    Coefficient,
28
    Cofunction,
29
    ComponentTensor,
30
    Conj,
31
    Constant,
32
    ConstantValue,
33
    Division,
34
    Expr,
35
    ExprList,
36
    ExprMapping,
37
    FacetNormal,
38
    FloatValue,
39
    FormArgument,
40
    GeometricQuantity,
41
    Grad,
42
    Identity,
43
    Imag,
44
    Indexed,
45
    IndexSum,
46
    Jacobian,
47
    JacobianDeterminant,
48
    JacobianInverse,
49
    Label,
50
    ListTensor,
51
    Power,
52
    Product,
53
    Real,
54
    ReferenceGrad,
55
    ReferenceValue,
56
    SpatialCoordinate,
57
    Sum,
58
    Variable,
59
    Zero,
60
)
61
from ufl.conditional import BinaryCondition, Conditional, NotCondition
1✔
62
from ufl.constantvalue import is_true_ufl_scalar, is_ufl_scalar
1✔
63
from ufl.core.base_form_operator import BaseFormOperator
1✔
64
from ufl.core.expr import ufl_err_str
1✔
65
from ufl.core.external_operator import ExternalOperator
1✔
66
from ufl.core.interpolate import Interpolate
1✔
67
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
1✔
68
from ufl.core.terminal import Terminal
1✔
69
from ufl.corealg.dag_traverser import DAGTraverser
1✔
70
from ufl.differentiation import (
1✔
71
    BaseFormCoordinateDerivative,
72
    BaseFormOperatorDerivative,
73
    CoefficientDerivative,
74
    CoordinateDerivative,
75
    Derivative,
76
    VariableDerivative,
77
)
78
from ufl.domain import MeshSequence, extract_unique_domain
1✔
79
from ufl.form import BaseForm, Form, ZeroBaseForm
1✔
80
from ufl.mathfunctions import (
1✔
81
    Acos,
82
    Asin,
83
    Atan,
84
    Atan2,
85
    BesselI,
86
    BesselJ,
87
    BesselK,
88
    BesselY,
89
    Cos,
90
    Cosh,
91
    Erf,
92
    Exp,
93
    Ln,
94
    MathFunction,
95
    Sin,
96
    Sinh,
97
    Sqrt,
98
    Tan,
99
    Tanh,
100
)
101
from ufl.matrix import Matrix
1✔
102
from ufl.operators import (
1✔
103
    MaxValue,
104
    MinValue,
105
    bessel_I,
106
    bessel_J,
107
    bessel_K,
108
    bessel_Y,
109
    cell_avg,
110
    conditional,
111
    cos,
112
    cosh,
113
    exp,
114
    facet_avg,
115
    ln,
116
    sign,
117
    sin,
118
    sinh,
119
    sqrt,
120
)
121
from ufl.pullback import CustomPullback, PhysicalPullback
1✔
122
from ufl.restriction import Restricted
1✔
123
from ufl.tensors import as_scalar, as_scalars, as_tensor, unit_indexed_tensor, unwrap_list_tensor
1✔
124

125
# TODO: Add more rulesets?
126
# - DivRuleset
127
# - CurlRuleset
128
# - ReferenceGradRuleset
129
# - ReferenceDivRuleset
130

131

132
def flatten_domain_element(domain, element):
1✔
133
    """Return the flattened (domain, element) pairs for mixed domain problems.
134

135
    Args:
136
        domain: `Mesh` or `MeshSequence`.
137
        element: `FiniteElement`.
138

139
    Returns:
140
        Nested tuples of (domain, element) pairs; just ((domain, element),)
141
        if domain is a `Mesh` (and not a `MeshSequence`).
142

143
    """
144
    if not isinstance(domain, MeshSequence):
1✔
145
        return ((domain, element),)
1✔
146
    flattened = ()
1✔
147
    assert len(domain) == len(element.sub_elements)
1✔
148
    for d, e in zip(domain, element.sub_elements):
1✔
149
        flattened += flatten_domain_element(d, e)
1✔
150
    return flattened
1✔
151

152

153
class GenericDerivativeRuleset(DAGTraverser):
1✔
154
    """A generic derivative."""
155

156
    def __init__(
1✔
157
        self,
158
        var_shape: tuple,
159
        compress: bool | None = True,
160
        visited_cache: dict[tuple, Expr] | None = None,
161
        result_cache: dict[Expr, Expr] | None = None,
162
    ) -> None:
163
        """Initialise."""
164
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
165
        self._var_shape = var_shape
1✔
166

167
    def unexpected(self, o):
1✔
168
        """Raise error about unexpected type."""
NEW
169
        raise ValueError(f"Unexpected type {type(o).__name__} in AD rules.")
×
170

171
    def override(self, o):
1✔
172
        """Raise error about overriding."""
NEW
173
        raise ValueError(f"Type {type(o).__name__} must be overridden in specialized AD rule set.")
×
174

175
    # --- Some types just don't have any derivative, this is just to
176
    # --- make algorithm structure generic
177

178
    def non_differentiable_terminal(self, o):
1✔
179
        """Return the non-differentiated object.
180

181
        Labels and indices are not differentiable: it's convenient to
182
        return the non-differentiated object.
183
        """
184
        return o
1✔
185

186
    # --- Helper functions for creating zeros with the right shapes
187

188
    def independent_terminal(self, o):
1✔
189
        """A zero with correct shape for terminals independent of diff. variable."""
190
        return Zero(o.ufl_shape + self._var_shape)
1✔
191

192
    def independent_operator(self, o):
1✔
193
        """A zero with correct shape and indices for operators independent of diff. variable."""
194
        return Zero(o.ufl_shape + self._var_shape, o.ufl_free_indices, o.ufl_index_dimensions)
1✔
195

196
    # --- Error checking for missing handlers and unexpected types
197

198
    @singledispatchmethod
1✔
199
    def process(self, o: Expr) -> Expr:
1✔
200
        """Process ``o``.
201

202
        Args:
203
            o: `Expr` to be processed.
204

205
        Returns:
206
            Processed object.
207

208
        """
209
        return super().process(o)
×
210

211
    @process.register(Expr)
1✔
212
    def _(self, o: Expr) -> Expr:
1✔
213
        """Raise error."""
214
        raise ValueError(
×
215
            f"Missing differentiation handler for type {type(o).__name__}. "
216
            "Have you added a new type?"
217
        )
218

219
    @process.register(Derivative)
1✔
220
    def _(self, o: Expr) -> Expr:
1✔
221
        """Raise error."""
222
        raise ValueError(
×
223
            f"Unhandled derivative type {type(o).__name__}, nested differentiation has failed."
224
        )
225

226
    @process.register(Label)
1✔
227
    @process.register(MultiIndex)
1✔
228
    def _(self, o: Expr) -> Expr:
1✔
229
        return self.non_differentiable_terminal(o)
1✔
230

231
    # --- All derivatives need to define grad and averaging
232

233
    @process.register(Grad)
1✔
234
    @process.register(CellAvg)
1✔
235
    @process.register(FacetAvg)
1✔
236
    def _(self, o: Expr) -> Expr:
1✔
237
        return self.override(o)
×
238

239
    # --- Default rules for terminals
240

241
    # Literals are by definition independent of any differentiation variable
242
    @process.register(ConstantValue)
1✔
243
    # Constants are independent of any differentiation
244
    @process.register(Constant)
1✔
245
    def _(self, o: Expr) -> Expr:
1✔
246
        return self.independent_terminal(o)
1✔
247

248
    # Zero may have free indices
249
    @process.register(Zero)
1✔
250
    def _(self, o: Expr) -> Expr:
1✔
251
        return self.independent_operator(o)
1✔
252

253
    # Rules for form arguments must be specified in specialized rule set
254
    @process.register(FormArgument)
1✔
255
    # Rules for geometric quantities must be specified in specialized rule set
256
    @process.register(GeometricQuantity)
1✔
257
    def _(self, o: Expr) -> Expr:
1✔
258
        return self.override(o)
×
259

260
    # These types are currently assumed independent, but for non-affine domains
261
    # this no longer holds and we want to implement rules for them.
262
    # facet_normal = independent_terminal
263
    # spatial_coordinate = independent_terminal
264
    # cell_coordinate = independent_terminal
265

266
    # Measures of cell entities, assuming independent although
267
    # this will not be true for all of these for non-affine domains
268
    # cell_volume = independent_terminal
269
    # circumradius = independent_terminal
270
    # facet_area = independent_terminal
271
    # cell_surface_area = independent_terminal
272
    # min_cell_edge_length = independent_terminal
273
    # max_cell_edge_length = independent_terminal
274
    # min_facet_edge_length = independent_terminal
275
    # max_facet_edge_length = independent_terminal
276

277
    # Other stuff
278
    # cell_orientation = independent_terminal
279
    # quadrature_weigth = independent_terminal
280

281
    # These types are currently not expected to show up in AD pass.
282
    # To make some of these available to the end-user, they need to be
283
    # implemented here.
284
    # facet_coordinate = unexpected
285
    # cell_origin = unexpected
286
    # facet_origin = unexpected
287
    # cell_facet_origin = unexpected
288
    # jacobian = unexpected
289
    # jacobian_determinant = unexpected
290
    # jacobian_inverse = unexpected
291
    # facet_jacobian = unexpected
292
    # facet_jacobian_determinant = unexpected
293
    # facet_jacobian_inverse = unexpected
294
    # cell_facet_jacobian = unexpected
295
    # cell_facet_jacobian_determinant = unexpected
296
    # cell_facet_jacobian_inverse = unexpected
297
    # cell_vertices = unexpected
298
    # cell_edge_vectors = unexpected
299
    # facet_edge_vectors = unexpected
300
    # reference_cell_edge_vectors = unexpected
301
    # reference_facet_edge_vectors = unexpected
302
    # cell_normal = unexpected # TODO: Expecting rename
303
    # cell_normals = unexpected
304
    # facet_tangents = unexpected
305
    # cell_tangents = unexpected
306
    # cell_midpoint = unexpected
307
    # facet_midpoint = unexpected
308

309
    # --- Default rules for operators
310

311
    @process.register(Variable)
1✔
312
    def _(self, o: Expr) -> Expr:
1✔
313
        """Differentiate a variable."""
314
        op, _ = o.ufl_operands
1✔
315
        return self(op)
1✔
316

317
    # --- Indexing and component handling
318

319
    @process.register(Indexed)
1✔
320
    @DAGTraverser.postorder
1✔
321
    def _(self, o: Indexed, Ap: Expr, ii: MultiIndex) -> Indexed:
1✔
322
        """Differentiate an indexed."""
323
        # Propagate zeros
324
        if isinstance(Ap, Zero):
1✔
325
            return self.independent_operator(o)
1✔
326
        r = len(Ap.ufl_shape) - len(ii)
1✔
327
        if r:
1✔
328
            kk = indices(r)
1✔
329
            op = Indexed(Ap, MultiIndex(ii.indices() + kk))
1✔
330
            op = as_tensor(op, kk)
1✔
331
        else:
332
            op = Indexed(Ap, ii)
1✔
333
        return op
1✔
334

335
    @process.register(ListTensor)
1✔
336
    def _(self, o: Expr) -> Expr:
1✔
337
        """Differentiate a list_tensor."""
338
        return ListTensor(*(self(op) for op in o.ufl_operands))
1✔
339

340
    @process.register(ComponentTensor)
1✔
341
    @DAGTraverser.postorder
1✔
342
    def _(self, o: ComponentTensor, Ap: Expr, ii: MultiIndex) -> Expr:
1✔
343
        """Differentiate a component_tensor."""
344
        if isinstance(Ap, Zero):
1✔
345
            op = self.independent_operator(o)
1✔
346
        else:
347
            Ap, jj = as_scalar(Ap)
1✔
348
            op = as_tensor(Ap, ii.indices() + jj)
1✔
349
        return op
1✔
350

351
    # --- Algebra operators
352

353
    @process.register(IndexSum)
1✔
354
    @DAGTraverser.postorder
1✔
355
    def _(self, o: Expr, Ap: Expr, ii: Expr) -> Expr:
1✔
356
        """Differentiate an index_sum."""
357
        return IndexSum(Ap, ii)
1✔
358

359
    @process.register(Sum)
1✔
360
    @DAGTraverser.postorder
1✔
361
    def _(self, o: Expr, da: Expr, db: Expr) -> Expr:
1✔
362
        """Differentiate a sum."""
363
        return da + db
1✔
364

365
    @process.register(Product)
1✔
366
    @DAGTraverser.postorder
1✔
367
    def _(self, o: Expr, da: Expr, db: Expr) -> Expr:
1✔
368
        """Differentiate a product."""
369
        # Even though arguments to o are scalar, da and db may be
370
        # tensor valued
371
        a, b = o.ufl_operands
1✔
372
        (da, db), ii = as_scalars(da, db)
1✔
373
        pa = Product(da, b)
1✔
374
        pb = Product(a, db)
1✔
375
        s = Sum(pa, pb)
1✔
376
        if ii:
1✔
377
            s = as_tensor(s, ii)
1✔
378
        return s
1✔
379

380
    @process.register(Division)
1✔
381
    @DAGTraverser.postorder
1✔
382
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
383
        """Differentiate a division."""
384
        f, g = o.ufl_operands
1✔
385
        if not is_ufl_scalar(f):
1✔
386
            raise ValueError("Not expecting nonscalar nominator")
×
387
        if not is_true_ufl_scalar(g):
1✔
388
            raise ValueError("Not expecting nonscalar denominator")
×
389
        # do_df = 1/g
390
        # do_dg = -h/g
391
        # op = do_df*fp + do_df*gp
392
        # op = (fp - o*gp) / g
393
        # Get o and gp as scalars, multiply, then wrap as a tensor
394
        # again
395
        so, oi = as_scalar(o)
1✔
396
        sgp, gi = as_scalar(gp)
1✔
397
        o_gp = so * sgp
1✔
398
        if oi or gi:
1✔
399
            o_gp = as_tensor(o_gp, oi + gi)
1✔
400
        op = (fp - o_gp) / g
1✔
401
        return op
1✔
402

403
    @process.register(Power)
1✔
404
    @DAGTraverser.postorder
1✔
405
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
406
        """Differentiate a power."""
407
        f, g = o.ufl_operands
1✔
408
        if not is_true_ufl_scalar(f):
1✔
409
            raise ValueError("Expecting scalar expression f in f**g.")
×
410
        if not is_true_ufl_scalar(g):
1✔
411
            raise ValueError("Expecting scalar expression g in f**g.")
×
412
        # Derivation of the general case: o = f(x)**g(x)
413
        # do/df  = g * f**(g-1) = g / f * o
414
        # do/dg  = ln(f) * f**g = ln(f) * o
415
        # do/df * df + do/dg * dg = o * (g / f * df + ln(f) * dg)
416
        if isinstance(gp, Zero):
1✔
417
            # This probably produces better results for the common
418
            # case of f**constant
419
            op = fp * g * f ** (g - 1)
1✔
420
        else:
421
            # Note: This produces expressions like (1/w)*w**5 instead of w**4
422
            # op = o * (fp * g / f + gp * ln(f)) # This reuses o
423
            op = f ** (g - 1) * (
1✔
424
                g * fp + f * ln(f) * gp
425
            )  # This gives better accuracy in dolfin integration test
426
        # Example: d/dx[x**(x**3)]:
427
        # f = x
428
        # g = x**3
429
        # df = 1
430
        # dg = 3*x**2
431
        # op1 = o * (fp * g / f + gp * ln(f))
432
        #     = x**(x**3)   * (x**3/x + 3*x**2*ln(x))
433
        # op2 = f**(g-1) * (g*fp + f*ln(f)*gp)
434
        #     = x**(x**3-1) * (x**3 + x*3*x**2*ln(x))
435
        return op
1✔
436

437
    @process.register(Abs)
1✔
438
    @DAGTraverser.postorder
1✔
439
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
440
        """Differentiate an abs."""
441
        (f,) = o.ufl_operands
1✔
442
        # return conditional(eq(f, 0), 0, Product(sign(f), df)) abs is
443
        # not complex differentiable, so we workaround the case of a
444
        # real F in complex mode by defensively casting to real inside
445
        # the sign.
446
        return sign(Real(f)) * df
1✔
447

448
    # --- Complex algebra
449

450
    @process.register(Conj)
1✔
451
    @DAGTraverser.postorder
1✔
452
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
453
        """Differentiate a conj."""
454
        return Conj(df)
1✔
455

456
    @process.register(Real)
1✔
457
    @DAGTraverser.postorder
1✔
458
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
459
        """Differentiate a real."""
460
        return Real(df)
×
461

462
    @process.register(Imag)
1✔
463
    @DAGTraverser.postorder
1✔
464
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
465
        """Differentiate a imag."""
466
        return Imag(df)
×
467

468
    # --- Mathfunctions
469

470
    @process.register(MathFunction)
1✔
471
    @DAGTraverser.postorder
1✔
472
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
473
        """Differentiate a math_function."""
474
        # FIXME: Introduce a UserOperator type instead of this hack
475
        # and define user derivative() function properly
476
        if hasattr(o, "derivative"):
×
477
            return df * o.derivative()
×
478
        else:
479
            raise ValueError("Unknown math function.")
×
480

481
    @process.register(Sqrt)
1✔
482
    @DAGTraverser.postorder
1✔
483
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
484
        """Differentiate a sqrt."""
485
        return fp / (2 * o)
1✔
486

487
    @process.register(Exp)
1✔
488
    @DAGTraverser.postorder
1✔
489
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
490
        """Differentiate an exp."""
491
        return fp * o
1✔
492

493
    @process.register(Ln)
1✔
494
    @DAGTraverser.postorder
1✔
495
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
496
        """Differentiate a ln."""
497
        (f,) = o.ufl_operands
1✔
498
        if isinstance(f, Zero):
1✔
499
            raise ZeroDivisionError()
×
500
        return fp / f
1✔
501

502
    @process.register(Cos)
1✔
503
    @DAGTraverser.postorder
1✔
504
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
505
        """Differentiate a cos."""
506
        (f,) = o.ufl_operands
1✔
507
        return fp * -sin(f)
1✔
508

509
    @process.register(Sin)
1✔
510
    @DAGTraverser.postorder
1✔
511
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
512
        """Differentiate a sin."""
513
        (f,) = o.ufl_operands
1✔
514
        return fp * cos(f)
1✔
515

516
    @process.register(Tan)
1✔
517
    @DAGTraverser.postorder
1✔
518
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
519
        """Differentiate a tan."""
520
        (f,) = o.ufl_operands
1✔
521
        return 2.0 * fp / (cos(2.0 * f) + 1.0)
1✔
522

523
    @process.register(Cosh)
1✔
524
    @DAGTraverser.postorder
1✔
525
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
526
        """Differentiate a cosh."""
527
        (f,) = o.ufl_operands
×
528
        return fp * sinh(f)
×
529

530
    @process.register(Sinh)
1✔
531
    @DAGTraverser.postorder
1✔
532
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
533
        """Differentiate a sinh."""
534
        (f,) = o.ufl_operands
×
535
        return fp * cosh(f)
×
536

537
    @process.register(Tanh)
1✔
538
    @DAGTraverser.postorder
1✔
539
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
540
        """Differentiate a tanh."""
541
        (f,) = o.ufl_operands
×
542

543
        def sech(y):
×
544
            return (2.0 * cosh(y)) / (cosh(2.0 * y) + 1.0)
×
545

546
        return fp * sech(f) ** 2
×
547

548
    @process.register(Acos)
1✔
549
    @DAGTraverser.postorder
1✔
550
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
551
        """Differentiate an acos."""
552
        (f,) = o.ufl_operands
1✔
553
        return -fp / sqrt(1.0 - f**2)
1✔
554

555
    @process.register(Asin)
1✔
556
    @DAGTraverser.postorder
1✔
557
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
558
        """Differentiate an asin."""
559
        (f,) = o.ufl_operands
1✔
560
        return fp / sqrt(1.0 - f**2)
1✔
561

562
    @process.register(Atan)
1✔
563
    @DAGTraverser.postorder
1✔
564
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
565
        """Differentiate an atan."""
566
        (f,) = o.ufl_operands
1✔
567
        return fp / (1.0 + f**2)
1✔
568

569
    @process.register(Atan2)
1✔
570
    @DAGTraverser.postorder
1✔
571
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
572
        """Differentiate an atan2."""
573
        f, g = o.ufl_operands
×
574
        return (g * fp - f * gp) / (f**2 + g**2)
×
575

576
    @process.register(Erf)
1✔
577
    @DAGTraverser.postorder
1✔
578
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
579
        """Differentiate an erf."""
580
        (f,) = o.ufl_operands
1✔
581
        return fp * (2.0 / sqrt(pi) * exp(-(f**2)))
1✔
582

583
    # --- Bessel functions
584

585
    @process.register(BesselJ)
1✔
586
    @DAGTraverser.postorder
1✔
587
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
588
        """Differentiate a bessel_j."""
589
        nu, f = o.ufl_operands
1✔
590
        if not (nup is None or isinstance(nup, Zero)):
1✔
591
            raise NotImplementedError(
×
592
                "Differentiation of bessel function w.r.t. nu is not supported."
593
            )
594

595
        if isinstance(nu, Zero):
1✔
596
            op = -bessel_J(1, f)
×
597
        else:
598
            op = 0.5 * (bessel_J(nu - 1, f) - bessel_J(nu + 1, f))
1✔
599
        return op * fp
1✔
600

601
    @process.register(BesselY)
1✔
602
    @DAGTraverser.postorder
1✔
603
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
604
        """Differentiate a bessel_y."""
605
        nu, f = o.ufl_operands
1✔
606
        if not (nup is None or isinstance(nup, Zero)):
1✔
607
            raise NotImplementedError(
×
608
                "Differentiation of bessel function w.r.t. nu is not supported."
609
            )
610

611
        if isinstance(nu, Zero):
1✔
612
            op = -bessel_Y(1, f)
×
613
        else:
614
            op = 0.5 * (bessel_Y(nu - 1, f) - bessel_Y(nu + 1, f))
1✔
615
        return op * fp
1✔
616

617
    @process.register(BesselI)
1✔
618
    @DAGTraverser.postorder
1✔
619
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
620
        """Differentiate a bessel_i."""
621
        nu, f = o.ufl_operands
1✔
622
        if not (nup is None or isinstance(nup, Zero)):
1✔
623
            raise NotImplementedError(
×
624
                "Differentiation of bessel function w.r.t. nu is not supported."
625
            )
626

627
        if isinstance(nu, Zero):
1✔
628
            op = bessel_I(1, f)
×
629
        else:
630
            op = 0.5 * (bessel_I(nu - 1, f) + bessel_I(nu + 1, f))
1✔
631
        return op * fp
1✔
632

633
    @process.register(BesselK)
1✔
634
    @DAGTraverser.postorder
1✔
635
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
636
        """Differentiate a bessel_k."""
637
        nu, f = o.ufl_operands
1✔
638
        if not (nup is None or isinstance(nup, Zero)):
1✔
639
            raise NotImplementedError(
×
640
                "Differentiation of bessel function w.r.t. nu is not supported."
641
            )
642

643
        if isinstance(nu, Zero):
1✔
644
            op = -bessel_K(1, f)
×
645
        else:
646
            op = -0.5 * (bessel_K(nu - 1, f) + bessel_K(nu + 1, f))
1✔
647
        return op * fp
1✔
648

649
    # --- Restrictions
650

651
    @process.register(Restricted)
1✔
652
    @DAGTraverser.postorder
1✔
653
    def _(self, o: Restricted, fp: Expr) -> Expr:
1✔
654
        """Differentiate a restricted."""
655
        # Restriction and differentiation commutes
656
        if isinstance(fp, ConstantValue):
1✔
657
            return fp  # TODO: Add simplification to Restricted instead?
1✔
658
        else:
659
            return fp(o._side)  # (f+-)' == (f')+-
1✔
660

661
    # --- Conditionals
662

663
    @process.register(BinaryCondition)
1✔
664
    def _(self, o: BinaryCondition) -> Expr:
1✔
665
        """Differentiate a binary_condition."""
666
        raise RuntimeError("Can not differentiate a binary_condition.")
×
667

668
    @process.register(NotCondition)
1✔
669
    def _(self, o: NotCondition) -> Expr:
1✔
670
        """Differentiate a not_condition."""
671
        raise RuntimeError("Can not differentiate a not_condition.")
×
672

673
    @process.register(Conditional)
1✔
674
    @DAGTraverser.postorder_only_children([1, 2])
1✔
675
    def _(self, o: Expr, dt: Expr, df: Expr) -> Expr:
1✔
676
        """Differentiate a conditional."""
677
        if isinstance(dt, Zero) and isinstance(df, Zero):
1✔
678
            # Assuming dt and df have the same indices here, which
679
            # should be the case
680
            return dt
1✔
681
        else:
682
            # Not placing t[1],f[1] outside, allowing arguments inside
683
            # conditionals.  This will make legacy ffc fail, but
684
            # should work with uflacs.
685
            c = o.ufl_operands[0]
1✔
686
            return conditional(c, dt, df)
1✔
687

688
    @process.register(MaxValue)
1✔
689
    @DAGTraverser.postorder
1✔
690
    def _(self, o: Expr, df: Expr, dg: Expr) -> Expr:
1✔
691
        """Differentiate a max_value."""
692
        # d/dx max(f, g) =
693
        # f > g: df/dx
694
        # f < g: dg/dx
695
        # Placing df,dg outside here to avoid getting arguments inside
696
        # conditionals
697
        f, g = o.ufl_operands
×
698
        dc = conditional(f > g, 1, 0)
×
699
        return dc * df + (1.0 - dc) * dg
×
700

701
    @process.register(MinValue)
1✔
702
    @DAGTraverser.postorder
1✔
703
    def _(self, o: Expr, df: Expr, dg: Expr) -> Expr:
1✔
704
        """Differentiate a min_value."""
705
        # d/dx min(f, g) =
706
        #  f < g: df/dx
707
        #  else: dg/dx
708
        #  Placing df,dg outside here to avoid getting arguments
709
        #  inside conditionals
710
        f, g = o.ufl_operands
×
711
        dc = conditional(f < g, 1, 0)
×
712
        return dc * df + (1.0 - dc) * dg
×
713

714

715
class GradRuleset(GenericDerivativeRuleset):
1✔
716
    """Take the grad derivative."""
717

718
    def __init__(
1✔
719
        self,
720
        geometric_dimension: int,
721
        compress: bool | None = True,
722
        visited_cache: dict[tuple, Expr] | None = None,
723
        result_cache: dict[Expr, Expr] | None = None,
724
    ) -> None:
725
        """Initialise."""
726
        super().__init__(
1✔
727
            (geometric_dimension,),
728
            compress=compress,
729
            visited_cache=visited_cache,
730
            result_cache=result_cache,
731
        )
732
        self._Id = Identity(geometric_dimension)
1✔
733

734
    # Work around singledispatchmethod inheritance issue;
735
    # see https://bugs.python.org/issue36457.
736
    @singledispatchmethod
1✔
737
    def process(self, o: Expr) -> Expr:
1✔
738
        """Process ``o``.
739

740
        Args:
741
            o: `Expr` to be processed.
742

743
        Returns:
744
            Processed object.
745

746
        """
747
        return super().process(o)
1✔
748

749
    # --- Specialized rules for geometric quantities
750

751
    @process.register(GeometricQuantity)
1✔
752
    def _(self, o: Expr) -> Expr:
1✔
753
        """Differentiate a geometric_quantity.
754

755
        Default for geometric quantities is do/dx = 0 if piecewise constant,
756
        otherwise transform derivatives to reference derivatives.
757
        Override for specific types if other behaviour is needed.
758
        """
759
        if is_cellwise_constant(o):
1✔
760
            return self.independent_terminal(o)
1✔
761
        else:
762
            domain = extract_unique_domain(o)
1✔
763
            K = JacobianInverse(domain)
1✔
764
            Do = grad_to_reference_grad(o, K)
1✔
765
            return Do
1✔
766

767
    @process.register(JacobianInverse)
1✔
768
    def _(self, o: JacobianInverse) -> Expr:
1✔
769
        """Differentiate a jacobian_inverse."""
770
        # grad(K) == K_ji rgrad(K)_rj
771
        if is_cellwise_constant(o):
1✔
772
            return self.independent_terminal(o)
×
773
        if not o._ufl_is_terminal_:
1✔
774
            raise ValueError("ReferenceValue can only wrap a terminal")
×
775
        Do = grad_to_reference_grad(o, o)
1✔
776
        return Do
1✔
777

778
    # TODO: Add more explicit geometry type handlers here, with
779
    # non-affine domains several should be non-zero.
780

781
    @process.register(SpatialCoordinate)
1✔
782
    def _(self, o: Expr) -> Expr:
1✔
783
        """Differentiate a spatial_coordinate.
784

785
        dx/dx = I.
786
        """
787
        return self._Id
1✔
788

789
    @process.register(CellCoordinate)
1✔
790
    def _(self, o: Expr) -> Expr:
1✔
791
        """Differentiate a cell_coordinate.
792

793
        dX/dx = inv(dx/dX) = inv(J) = K.
794
        """
795
        # FIXME: Is this true for manifolds? What about orientation?
796
        return JacobianInverse(extract_unique_domain(o))
×
797

798
    # --- Specialized rules for form arguments
799

800
    @process.register(BaseFormOperator)
1✔
801
    def _(self, o: Expr) -> Expr:
1✔
802
        """Differentiate a base_form_operator."""
803
        # Push the grad through the operator is not legal in most cases:
804
        #    -> Not enouth regularity for chain rule to hold!
805
        # By the time we evaluate `grad(o)`, the operator `o` will have
806
        # been assembled and substituted by its output.
807
        return Grad(o)
1✔
808

809
    @process.register(Coefficient)
1✔
810
    def _(self, o: Expr) -> Expr:
1✔
811
        """Differentiate a coefficient."""
812
        if is_cellwise_constant(o):
1✔
813
            return self.independent_terminal(o)
1✔
814
        return Grad(o)
1✔
815

816
    @process.register(Argument)
1✔
817
    def _(self, o: Expr) -> Expr:
1✔
818
        """Differentiate an argument."""
819
        # TODO: Enable this after fixing issue#13, unless we move
820
        # simplificat ion to a separate stage?
821
        # if is_cellwise_constant(o):
822
        #     # Collapse gradient of cellwise constant function to zero
823
        #     # TODO: Missing this type
824
        #     return AnnotatedZero(o.ufl_shape + self._var_shape, arguments=(o,))
825
        return Grad(o)
1✔
826

827
    # --- Rules for values or derivatives in reference frame
828

829
    @process.register(ReferenceValue)
1✔
830
    def _(self, o: ReferenceValue) -> Expr:
1✔
831
        """Differentiate a reference_value."""
832
        # grad(o) == grad(rv(f)) -> K_ji*rgrad(rv(f))_rj
833
        f = o.ufl_operands[0]
1✔
834
        if not f._ufl_is_terminal_:
1✔
835
            raise ValueError("ReferenceValue can only wrap a terminal")
×
836
        domain = extract_unique_domain(f, expand_mesh_sequence=False)
1✔
837
        if isinstance(domain, MeshSequence):
1✔
838
            element = f.ufl_function_space().ufl_element()  # type: ignore
1✔
839
            if element.num_sub_elements != len(domain):
1✔
840
                raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
×
841
            # Get monolithic representation of rgrad(o); o might live in a mixed space.
842
            rgrad = ReferenceGrad(o)
1✔
843
            ref_dim = rgrad.ufl_shape[-1]
1✔
844
            # Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
845
            # and put them back together at the end using as_tensor().
846
            components = []
1✔
847
            dofoffset = 0
1✔
848
            for d, e in flatten_domain_element(domain, element):
1✔
849
                esh = e.reference_value_shape
1✔
850
                ndof = int(np.prod(esh))
1✔
851
                assert ndof > 0
1✔
852
                if isinstance(e.pullback, PhysicalPullback):
1✔
853
                    if ref_dim != self._var_shape[0]:
×
854
                        raise NotImplementedError("""
×
855
                            PhysicalPullback not handled for immersed domain :
856
                            reference dim ({ref_dim}) != physical dim (self._var_shape[0])""")
857
                    for idx in range(ndof):
×
858
                        for i in range(ref_dim):
×
859
                            components.append(rgrad[(dofoffset + idx,) + (i,)])
×
860
                else:
861
                    K = JacobianInverse(d)
1✔
862
                    rdim, gdim = K.ufl_shape
1✔
863
                    if rdim != ref_dim:
1✔
864
                        raise RuntimeError(f"{rdim} != {ref_dim}")
×
865
                    if gdim != self._var_shape[0]:
1✔
866
                        raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
×
867
                    # Note that rgrad[dofoffset + [0,ndof), [0,rdim)] are the components
868
                    # corresponding to (d, e).
869
                    # For each row, rgrad[dofoffset + idx, [0,rdim)], we apply
870
                    # K_ji(d)[[0,rdim), [0,gdim)].
871
                    for idx in range(ndof):
1✔
872
                        for i in range(gdim):
1✔
873
                            temp = Zero()
1✔
874
                            for j in range(rdim):
1✔
875
                                temp += rgrad[(dofoffset + idx,) + (j,)] * K[j, i]
1✔
876
                            components.append(temp)
1✔
877
                dofoffset += ndof
1✔
878
            return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
1✔
879
        else:
880
            if isinstance(f.ufl_element().pullback, PhysicalPullback):  # type: ignore
1✔
881
                # TODO: Do we need to be more careful for immersed things?
882
                return ReferenceGrad(o)
×
883
            else:
884
                K = JacobianInverse(domain)
1✔
885
                return grad_to_reference_grad(o, K)
1✔
886

887
    @process.register(ReferenceGrad)
1✔
888
    def _(self, o: Expr) -> Expr:
1✔
889
        """Differentiate a reference_grad."""
890
        if is_cellwise_constant(o):
1✔
891
            return self.independent_terminal(o)
×
892
        # grad(o) == grad(rgrad(rv(f))) -> K_ji*rgrad(rgrad(rv(f)))_rj
893
        f = o.ufl_operands[0]
1✔
894
        valid_operand = f._ufl_is_in_reference_frame_ or isinstance(
1✔
895
            f, JacobianInverse | SpatialCoordinate | Jacobian | JacobianDeterminant | FacetNormal
896
        )
897
        if not valid_operand:
1✔
898
            raise ValueError("ReferenceGrad can only wrap a reference frame type!")
×
899
        domain = extract_unique_domain(f, expand_mesh_sequence=False)
1✔
900
        if isinstance(domain, MeshSequence):
1✔
901
            if not f._ufl_is_in_reference_frame_:
×
902
                raise RuntimeError("Expecting a reference frame type")
×
903
            while not f._ufl_is_terminal_:
×
NEW
904
                (f,) = f.ufl_operands  # type: ignore
×
NEW
905
            element = f.ufl_function_space().ufl_element()  # type: ignore
×
906
            if element.num_sub_elements != len(domain):
×
907
                raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
×
908
            # Get monolithic representation of rgrad(o); o might live in a mixed space.
909
            rgrad = ReferenceGrad(o)
×
910
            ref_dim = rgrad.ufl_shape[-1]
×
911
            # Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
912
            # and put them back together at the end using as_tensor().
913
            components = []
×
914
            dofoffset = 0
×
915
            for d, e in flatten_domain_element(domain, element):
×
916
                esh = e.reference_value_shape
×
917
                ndof = int(np.prod(esh))
×
918
                assert ndof > 0
×
919
                K = JacobianInverse(d)
×
920
                rdim, gdim = K.ufl_shape
×
921
                if rdim != ref_dim:
×
922
                    raise RuntimeError(f"{rdim} != {ref_dim}")
×
923
                if gdim != self._var_shape[0]:
×
924
                    raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
×
925
                # Note that rgrad[dofoffset + [0,ndof), [0,rdim), [0,rdim)] are the components
926
                # corresponding to (d, e).
927
                # For each row, rgrad[dofoffset + idx, [0,rdim), [0,rdim)], we apply
928
                # K_ji(d)[[0,rdim), [0,gdim)].
929
                for idx in range(ndof):
×
930
                    for midx in np.ndindex(rgrad.ufl_shape[1:-1]):
×
931
                        for i in range(gdim):
×
932
                            temp = Zero()
×
933
                            for j in range(rdim):
×
934
                                temp += rgrad[(dofoffset + idx,) + midx + (j,)] * K[j, i]
×
935
                            components.append(temp)
×
936
                dofoffset += ndof
×
937
            if rgrad.ufl_shape[0] != dofoffset:
×
938
                raise RuntimeError(f"{rgrad.ufl_shape[0]} != {dofoffset}")
×
939
            return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
×
940
        else:
941
            K = JacobianInverse(domain)
1✔
942
            return grad_to_reference_grad(o, K)
1✔
943

944
    # --- Nesting of gradients
945

946
    @process.register(Grad)
1✔
947
    def _(self, o: Expr) -> Expr:
1✔
948
        """Differentiate a grad.
949

950
        Represent grad(grad(f)) as Grad(Grad(f)).
951
        """
952
        # Check that o is a "differential terminal"
953
        if not isinstance(o.ufl_operands[0], Grad | Terminal):
1✔
954
            raise ValueError("Expecting only grads applied to a terminal.")
×
955
        return Grad(o)
1✔
956

957
    def _grad(self, o):
1✔
958
        """Differentiate a _grad."""
959
        pass
×
960
        # TODO: Not sure how to detect that gradient of f is cellwise constant.
961
        #       Can we trust element degrees?
962
        # if is_cellwise_constant(o):
963
        #     return self.terminal(o)
964
        # TODO: Maybe we can ask "f.has_derivatives_of_order(n)" to check
965
        #       if we should make a zero here?
966
        # 1) n = count number of Grads, get f
967
        # 2) if not f.has_derivatives(n): return zero(...)
968

969
    @process.register(CellAvg)
1✔
970
    @process.register(FacetAvg)
1✔
971
    def _(self, o: Expr) -> Expr:
1✔
972
        return self.independent_operator(o)
×
973

974

975
def grad_to_reference_grad(o, K):
1✔
976
    """Relates grad(o) to reference_grad(o) using the Jacobian inverse.
977

978
    Args:
979
        o: Operand
980
        K: Jacobian inverse
981
    Returns:
982
        grad(o) written in terms of reference_grad(o) and K
983
    """
984
    r = indices(len(o.ufl_shape))
1✔
985
    i, j = indices(2)
1✔
986
    # grad(o) == K_ji rgrad(o)_rj
987
    Do = as_tensor(K[j, i] * ReferenceGrad(o)[r + (j,)], r + (i,))
1✔
988
    return Do
1✔
989

990

991
class ReferenceGradRuleset(GenericDerivativeRuleset):
1✔
992
    """Apply the reference grad derivative."""
993

994
    def __init__(
1✔
995
        self,
996
        topological_dimension: int,
997
        compress: bool | None = True,
998
        visited_cache: dict[tuple, Expr] | None = None,
999
        result_cache: dict[Expr, Expr] | None = None,
1000
    ) -> None:
1001
        """Initialise."""
1002
        super().__init__(
1✔
1003
            (topological_dimension,),
1004
            compress=compress,
1005
            visited_cache=visited_cache,
1006
            result_cache=result_cache,
1007
        )
1008
        self._Id = Identity(topological_dimension)
1✔
1009

1010
    # Work around singledispatchmethod inheritance issue;
1011
    # see https://bugs.python.org/issue36457.
1012
    @singledispatchmethod
1✔
1013
    def process(self, o: Expr) -> Expr:
1✔
1014
        """Process ``o``.
1015

1016
        Args:
1017
            o: `Expr` to be processed.
1018

1019
        Returns:
1020
            Processed object.
1021

1022
        """
1023
        return super().process(o)
1✔
1024

1025
    # --- Specialized rules for geometric quantities
1026

1027
    @process.register(GeometricQuantity)
1✔
1028
    def _(self, o: Expr) -> Expr:
1✔
1029
        """Differentiate a geometric_quantity.
1030

1031
        dg/dX = 0 if piecewise constant, otherwise ReferenceGrad(g).
1032
        """
1033
        if is_cellwise_constant(o):
1✔
1034
            return self.independent_terminal(o)
×
1035
        else:
1036
            # TODO: Which types does this involve? I don't think the
1037
            # form compilers will handle this.
1038
            return ReferenceGrad(o)
1✔
1039

1040
    @process.register(SpatialCoordinate)
1✔
1041
    def _(self, o: Expr) -> Expr:
1✔
1042
        """Differentiate a spatial_coordinate.
1043

1044
        dx/dX = J.
1045
        """
1046
        # Don't convert back to J, otherwise we get in a loop
1047
        return ReferenceGrad(o)
1✔
1048

1049
    @process.register(CellCoordinate)
1✔
1050
    def _(self, o: Expr) -> Expr:
1✔
1051
        """Differentiate a cell_coordinate.
1052

1053
        dX/dX = I.
1054
        """
1055
        return self._Id
×
1056

1057
    # TODO: Add more geometry types here, with non-affine domains
1058
    # several should be non-zero.
1059

1060
    # --- Specialized rules for form arguments
1061

1062
    @process.register(ReferenceValue)
1✔
1063
    def _(self, o: Expr) -> Expr:
1✔
1064
        """Differentiate a reference_value."""
1065
        if not o.ufl_operands[0]._ufl_is_terminal_:
1✔
1066
            raise ValueError("ReferenceValue can only wrap a terminal")
×
1067
        return ReferenceGrad(o)
1✔
1068

1069
    @process.register(Coefficient)
1✔
1070
    def _(self, o: Expr) -> Expr:
1✔
1071
        """Differentiate a coefficient."""
1072
        raise ValueError("Coefficient should be wrapped in ReferenceValue by now")
×
1073

1074
    @process.register(Argument)
1✔
1075
    def _(self, o: Expr) -> Expr:
1✔
1076
        """Differentiate an argument."""
1077
        raise ValueError("Argument should be wrapped in ReferenceValue by now")
×
1078

1079
    # --- Nesting of gradients
1080

1081
    @process.register(Grad)
1✔
1082
    def _(self, o: Expr) -> Expr:
1✔
1083
        """Differentiate a grad."""
1084
        raise ValueError(
×
1085
            f"Grad should have been transformed by this point, but got {type(o).__name__}."
1086
        )
1087

1088
    @process.register(ReferenceGrad)
1✔
1089
    def _(self, o: Expr) -> Expr:
1✔
1090
        """Differentiate a reference_grad.
1091

1092
        Represent ref_grad(ref_grad(f)) as RefGrad(RefGrad(f)).
1093
        """
1094
        # Check that o is a "differential terminal"
1095
        if not isinstance(o.ufl_operands[0], ReferenceGrad | ReferenceValue | Terminal):
1✔
1096
            raise ValueError("Expecting only grads applied to a terminal.")
×
1097
        return ReferenceGrad(o)
1✔
1098

1099
    @process.register(CellAvg)
1✔
1100
    @process.register(FacetAvg)
1✔
1101
    def _(self, o: Expr) -> Expr:
1✔
1102
        return self.independent_operator(o)
×
1103

1104

1105
class VariableRuleset(GenericDerivativeRuleset):
1✔
1106
    """Differentiate with respect to a variable."""
1107

1108
    def __init__(
1✔
1109
        self,
1110
        var: Expr,
1111
        compress: bool | None = True,
1112
        visited_cache: dict[tuple, Expr] | None = None,
1113
        result_cache: dict[Expr, Expr] | None = None,
1114
    ) -> None:
1115
        """Initialise."""
1116
        super().__init__(
1✔
1117
            var.ufl_shape, compress=compress, visited_cache=visited_cache, result_cache=result_cache
1118
        )
1119
        if var.ufl_free_indices:
1✔
1120
            raise ValueError("Differentiation variable cannot have free indices.")
×
1121
        self._variable = var
1✔
1122
        self._Id = self._make_identity(self._var_shape)
1✔
1123

1124
    def _make_identity(self, sh):
1✔
1125
        """Differentiate a _make_identity.
1126

1127
        Creates a higher order identity tensor to represent dv/dv.
1128
        """
1129
        res = None
1✔
1130
        if sh == ():
1✔
1131
            # Scalar dv/dv is scalar
1132
            return FloatValue(1.0)
1✔
1133
        elif len(sh) == 1:
1✔
1134
            # Vector v makes dv/dv the identity matrix
1135
            return Identity(sh[0])
1✔
1136
        else:
1137
            # TODO: Add a type for this higher order identity?
1138
            # II[i0,i1,i2,j0,j1,j2] = 1 if all((i0==j0, i1==j1, i2==j2)) else 0
1139
            # Tensor v makes dv/dv some kind of higher rank identity tensor
1140
            ind1 = ()
1✔
1141
            ind2 = ()
1✔
1142
            for d in sh:
1✔
1143
                i, j = indices(2)
1✔
1144
                dij = Identity(d)[i, j]
1✔
1145
                if res is None:
1✔
1146
                    res = dij
1✔
1147
                else:
1148
                    res *= dij
1✔
1149
                ind1 += (i,)
1✔
1150
                ind2 += (j,)
1✔
1151
            fp = as_tensor(res, ind1 + ind2)
1✔
1152
        return fp
1✔
1153

1154
    # Work around singledispatchmethod inheritance issue;
1155
    # see https://bugs.python.org/issue36457.
1156
    @singledispatchmethod
1✔
1157
    def process(self, o: Expr) -> Expr:
1✔
1158
        """Process ``o``.
1159

1160
        Args:
1161
            o: `Expr` to be processed.
1162

1163
        Returns:
1164
            Processed object.
1165

1166
        """
1167
        return super().process(o)
1✔
1168

1169
    @process.register(GeometricQuantity)
1✔
1170
    def _(self, o: Expr) -> Expr:
1✔
1171
        # Explicitly defining dg/dw == 0
1172
        return self.independent_terminal(o)
1✔
1173

1174
    @process.register(Argument)
1✔
1175
    def _(self, o: Expr) -> Expr:
1✔
1176
        # Explicitly defining da/dw == 0
1177
        return self.independent_terminal(o)
1✔
1178

1179
    @process.register(Coefficient)
1✔
1180
    def _(self, o: Expr) -> Expr:
1✔
1181
        """Differentiate a coefficient.
1182

1183
        df/dv = Id if v is f else 0.
1184

1185
        Note that if v = variable(f), df/dv is still 0,
1186
        but if v == f, i.e. isinstance(v, Coefficient) == True,
1187
        then df/dv == df/df = Id.
1188
        """
1189
        v = self._variable
1✔
1190
        if isinstance(v, Coefficient) and o == v:
1✔
1191
            # dv/dv = identity of rank 2*rank(v)
1192
            return self._Id
1✔
1193
        else:
1194
            # df/v = 0
1195
            return self.independent_terminal(o)
1✔
1196

1197
    @process.register(Variable)
1✔
1198
    @DAGTraverser.postorder
1✔
1199
    def _(self, o: Expr, df: Expr, a: Expr) -> Expr:
1✔
1200
        """Differentiate a variable."""
1201
        v = self._variable
1✔
1202
        if isinstance(v, Variable) and v.label() == a:
1✔
1203
            # dv/dv = identity of rank 2*rank(v)
1204
            return self._Id
1✔
1205
        else:
1206
            # df/v = df
1207
            return df
1✔
1208

1209
    @process.register(Grad)
1✔
1210
    def _(self, o: Expr) -> Expr:
1✔
1211
        """Differentiate a grad.
1212

1213
        Variable derivative of a gradient of a terminal must be 0.
1214
        """
1215
        # Check that o is a "differential terminal"
1216
        if not isinstance(o.ufl_operands[0], Grad | Terminal):
×
1217
            raise ValueError("Expecting only grads applied to a terminal.")
×
1218
        return self.independent_terminal(o)
×
1219

1220
    # --- Rules for values or derivatives in reference frame
1221

1222
    @process.register(ReferenceValue)
1✔
1223
    def _(self, o: Expr) -> Expr:
1✔
1224
        """Differentiate a reference_value."""
1225
        # d/dv(o) == d/dv(rv(f)) = 0 if v is not f, or rv(dv/df)
1226
        v = self._variable
×
1227
        if isinstance(v, Coefficient) and o.ufl_operands[0] == v:
×
1228
            if not v.ufl_element().pullback.is_identity:
×
1229
                # FIXME: This is a bit tricky, instead of Identity it is
1230
                #   actually inverse(transform), or we should rather not
1231
                #   convert to reference frame in the first place
1232
                raise ValueError(
×
1233
                    "Missing implementation: To handle derivatives of rv(f) w.r.t. f for "
1234
                    "mapped elements, rewriting to reference frame should not happen first..."
1235
                )
1236
            # dv/dv = identity of rank 2*rank(v)
1237
            return self._Id
×
1238
        else:
1239
            # df/v = 0
1240
            return self.independent_terminal(o)
×
1241

1242
    @process.register(ReferenceGrad)
1✔
1243
    def _(self, o: Expr) -> Expr:
1✔
1244
        """Differentiate a reference_grad.
1245

1246
        Variable derivative of a gradient of a terminal must be 0.
1247
        """
1248
        return self.independent_terminal(o)
1✔
1249

1250
    @process.register(CellAvg)
1✔
1251
    @process.register(FacetAvg)
1✔
1252
    def _(self, o: Expr) -> Expr:
1✔
1253
        return self.independent_operator(o)
×
1254

1255

1256
class GateauxDerivativeRuleset(GenericDerivativeRuleset):
1✔
1257
    """Apply AFD (Automatic Functional Differentiation) to expression.
1258

1259
    Implements rules for the Gateaux derivative D_w[v](...) defined as
1260
    D_w[v](e) = d/dtau e(w+tau v)|tau=0.
1261
    """
1262

1263
    def __init__(
1✔
1264
        self,
1265
        coefficients: ExprList,
1266
        arguments: ExprList,
1267
        coefficient_derivatives: ExprMapping,
1268
        compress: bool | None = True,
1269
        visited_cache: dict[tuple, Expr] | None = None,
1270
        result_cache: dict[Expr, Expr] | None = None,
1271
    ) -> None:
1272
        """Initialise."""
1273
        super().__init__(
1✔
1274
            (), compress=compress, visited_cache=visited_cache, result_cache=result_cache
1275
        )
1276
        # Type checking
1277
        if not isinstance(coefficients, ExprList):
1✔
1278
            raise ValueError("Expecting a ExprList of coefficients.")
×
1279
        if not isinstance(arguments, ExprList):
1✔
1280
            raise ValueError("Expecting a ExprList of arguments.")
×
1281
        if not isinstance(coefficient_derivatives, ExprMapping):
1✔
1282
            raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
×
1283
        # The coefficient(s) to differentiate w.r.t. and the
1284
        # argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
1285
        self._w = coefficients.ufl_operands
1✔
1286
        self._v = arguments.ufl_operands
1✔
1287
        self._w2v = {w: v for w, v in zip(self._w, self._v)}
1✔
1288
        # Build more convenient dict {f: df/dw} for each coefficient f
1289
        # where df/dw is nonzero
1290
        cd = coefficient_derivatives.ufl_operands
1✔
1291
        self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
1✔
1292
        # Record the operations delayed to the derivative expansion phase:
1293
        # Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
1294
        self.pending_operations = BaseFormOperatorDerivativeRecorder(
1✔
1295
            coefficients,
1296
            arguments=arguments,
1297
            coefficient_derivatives=coefficient_derivatives,
1298
        )
1299

1300
    # Work around singledispatchmethod inheritance issue;
1301
    # see https://bugs.python.org/issue36457.
1302
    @singledispatchmethod
1✔
1303
    def process(self, o: Expr) -> Expr:
1✔
1304
        """Process ``o``.
1305

1306
        Args:
1307
            o: `Expr` to be processed.
1308

1309
        Returns:
1310
            Processed object.
1311

1312
        """
1313
        return super().process(o)
1✔
1314

1315
    # --- Specialized rules for geometric quantities
1316

1317
    @process.register(GeometricQuantity)
1✔
1318
    def _(self, o: Expr) -> Expr:
1✔
1319
        # Explicitly defining dg/dw == 0
1320
        return self.independent_terminal(o)
1✔
1321

1322
    @process.register(CellAvg)
1✔
1323
    @DAGTraverser.postorder
1✔
1324
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
1325
        """Differentiate a cell_avg."""
1326
        # Cell average of a single function and differentiation
1327
        # commutes, D_f[v](cell_avg(f)) = cell_avg(v)
1328
        return cell_avg(fp)
×
1329

1330
    @process.register(FacetAvg)
1✔
1331
    @DAGTraverser.postorder
1✔
1332
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
1333
        """Differentiate a facet_avg."""
1334
        # Facet average of a single function and differentiation
1335
        # commutes, D_f[v](facet_avg(f)) = facet_avg(v)
1336
        return facet_avg(fp)
×
1337

1338
    @process.register(Argument)
1✔
1339
    def _(self, o: Argument) -> Expr:
1✔
1340
        # Explicitly defining da/dw == 0
1341
        return self._process_argument(o)
1✔
1342

1343
    def _process_argument(self, o: Argument | Coargument) -> Zero:
1✔
1344
        return self.independent_terminal(o)
1✔
1345

1346
    @process.register(Coefficient)
1✔
1347
    def _(self, o: Coefficient) -> Expr:
1✔
1348
        return self._process_coefficient(o)
1✔
1349

1350
    def _process_coefficient(self, o: Expr) -> Expr:
1✔
1351
        """Differentiate an Expr or a BaseForm."""
1352
        # Define dw/dw := d/ds [w + s v] = v
1353

1354
        # Return corresponding argument if we can find o among w
1355
        do = self._w2v.get(o)  # type: ignore
1✔
1356
        if do is not None:
1✔
1357
            return do
1✔
1358

1359
        # Look for o among coefficient derivatives
1360
        dos = self._cd.get(o)  # type: ignore
1✔
1361
        if dos is None:
1✔
1362
            # If o is not among coefficient derivatives, return
1363
            # do/dw=0
1364
            do = Zero(o.ufl_shape)
1✔
1365
            return do
1✔
1366
        else:
1367
            # Compute do/dw_j = do/dw_h : v.
1368
            # Since we may actually have a tuple of oprimes and vs in a
1369
            # 'mixed' space, sum over them all to get the complete inner
1370
            # product. Using indices to define a non-compound inner product.
1371

1372
            # Example:
1373
            # (f:g) -> (dfdu:v):g + f:(dgdu:v)
1374
            # shape(dfdu) == shape(f) + shape(v)
1375
            # shape(f) == shape(g) == shape(dfdu : v)
1376

1377
            # Make sure we have a tuple to match the self._v tuple
1378
            if not isinstance(dos, tuple):
1✔
1379
                dos = (dos,)
1✔
1380
            if len(dos) != len(self._v):
1✔
1381
                raise ValueError(
×
1382
                    "Got a tuple of arguments, expecting a "
1383
                    "matching tuple of coefficient derivatives."
1384
                )
1385
            dosum = Zero(o.ufl_shape)
1✔
1386
            for do, v in zip(dos, self._v):
1✔
1387
                so, oi = as_scalar(do)
1✔
1388
                rv = len(oi) - len(v.ufl_shape)
1✔
1389
                oi1 = oi[:rv]
1✔
1390
                oi2 = oi[rv:]
1✔
1391
                prod = so * v[oi2]
1✔
1392
                if oi1:
1✔
1393
                    dosum += as_tensor(prod, oi1)
1✔
1394
                else:
1395
                    dosum += prod
1✔
1396
            return dosum
1✔
1397

1398
    @process.register(ReferenceValue)
1✔
1399
    def _(self, o: Expr) -> Expr:
1✔
1400
        """Differentiate a reference_value."""
1401
        raise NotImplementedError(
×
1402
            "Currently no support for ReferenceValue in CoefficientDerivative."
1403
        )
1404
        # TODO: This is implementable for regular derivative(M(f),f,v)
1405
        #       but too messy if customized coefficient derivative
1406
        #       relations are given by the user.  We would only need
1407
        #       this to allow the user to write
1408
        #       derivative(...ReferenceValue...,...).
1409
        # f, = o.ufl_operands
1410
        # if not f._ufl_is_terminal_:
1411
        #     raise ValueError("ReferenceValue can only wrap terminals directly.")
1412
        # FIXME: check all cases like in coefficient
1413
        # if f is w:
1414
        #     # FIXME: requires that v is an Argument with the same element mapping!
1415
        #     return ReferenceValue(v)
1416
        # else:
1417
        #     return self.independent_terminal(o)
1418

1419
    @process.register(ReferenceGrad)
1✔
1420
    def _(self, o: Expr) -> Expr:
1✔
1421
        """Differentiate a reference_grad."""
1422
        if len(extract_coefficients(o)) > 0:
×
1423
            raise NotImplementedError(
×
1424
                "Currently no support for ReferenceGrad in CoefficientDerivative."
1425
            )
1426
        else:
1427
            return Zero(o.ufl_shape)
×
1428
        # TODO: This is implementable for regular derivative(M(f),f,v)
1429
        #       but too messy if customized coefficient derivative
1430
        #       relations are given by the user.  We would only need
1431
        #       this to allow the user to write
1432
        #       derivative(...ReferenceValue...,...).
1433

1434
    @process.register(Grad)
1✔
1435
    def _(self, g: Expr) -> Expr:
1✔
1436
        """Differentiate a grad."""
1437
        # If we hit this type, it has already been propagated to a
1438
        # coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
1439
        # this!  so we need to take the gradient of the variation or
1440
        # return zero.  Complications occur when dealing with
1441
        # derivatives w.r.t. single components...
1442

1443
        # Figure out how many gradients are around the inner terminal
1444
        ngrads = 0
1✔
1445
        o = g
1✔
1446
        while isinstance(o, Grad):
1✔
1447
            (o,) = o.ufl_operands
1✔
1448
            ngrads += 1
1✔
1449
        # `grad(N)` where N is a BaseFormOperator is treated as if `N` was a Coefficient.
1450
        if not isinstance(o, FormArgument | BaseFormOperator):
1✔
1451
            raise ValueError(f"Expecting gradient of a FormArgument, not {ufl_err_str(o)}.")
×
1452

1453
        def apply_grads(f):
1✔
1454
            for i in range(ngrads):
1✔
1455
                f = Grad(f)
1✔
1456
            return f
1✔
1457

1458
        # Find o among all w without any indexing, which makes this
1459
        # easy
1460
        for w, v in zip(self._w, self._v):
1✔
1461
            if o == w and isinstance(v, FormArgument):
1✔
1462
                # Case: d/dt [w + t v]
1463
                return apply_grads(v)
1✔
1464

1465
        # If o is not among coefficient derivatives, return do/dw=0
1466
        gprimesum = Zero(g.ufl_shape)
1✔
1467

1468
        def analyse_variation_argument(v):
1✔
1469
            # Analyse variation argument
1470
            if isinstance(v, FormArgument):
1✔
1471
                # Case: d/dt [w[...] + t v]
1472
                vval, vcomp = v, ()
1✔
1473
            elif isinstance(v, Indexed):
×
1474
                # Case: d/dt [w + t v[...]]
1475
                # Case: d/dt [w[...] + t v[...]]
1476
                vval, vcomp = v.ufl_operands
×
1477
                vcomp = tuple(vcomp)
×
1478
            else:
1479
                raise ValueError("Expecting argument or component of argument.")
×
1480
            if not all(isinstance(k, FixedIndex) for k in vcomp):
1✔
1481
                raise ValueError("Expecting only fixed indices in variation.")
×
1482
            return vval, vcomp
1✔
1483

1484
        def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
1✔
1485
            # Apply gradients directly to argument vval, and get the
1486
            # right indexed scalar component(s)
1487
            kk = indices(ngrads)
1✔
1488
            Dvkk = apply_grads(vval)[vcomp + kk]
1✔
1489
            # Place scalar component(s) Dvkk into the right tensor
1490
            # positions
1491
            if wshape:
1✔
1492
                Ejj, jj = unit_indexed_tensor(wshape, wcomp)
1✔
1493
            else:
1494
                Ejj, jj = 1, ()
×
1495
            gprimeterm = as_tensor(Ejj * Dvkk, jj + kk)
1✔
1496
            return gprimeterm
1✔
1497

1498
        # Accumulate contributions from variations in different
1499
        # components
1500
        for w, v in zip(self._w, self._v):
1✔
1501
            # -- Analyse differentiation variable coefficient -- #
1502

1503
            # Can differentiate a Form wrt a BaseFormOperator
1504
            if isinstance(w, FormArgument | BaseFormOperator):
1✔
1505
                if not w == o:
1✔
1506
                    continue
1✔
1507
                wshape = w.ufl_shape
1✔
1508

1509
                if isinstance(v, FormArgument):
1✔
1510
                    # Case: d/dt [w + t v]
1511
                    return apply_grads(v)
×
1512

1513
                elif isinstance(v, ListTensor):
1✔
1514
                    # Case: d/dt [w + t <...,v,...>]
1515
                    for wcomp, vsub in unwrap_list_tensor(v):
1✔
1516
                        if not isinstance(vsub, Zero):
1✔
1517
                            vval, vcomp = analyse_variation_argument(vsub)
1✔
1518
                            gprimesum = gprimesum + compute_gprimeterm(
1✔
1519
                                ngrads, vval, vcomp, wshape, wcomp
1520
                            )
1521
                elif isinstance(v, Zero):
×
1522
                    pass
×
1523

1524
                else:
1525
                    if wshape != ():
×
1526
                        raise ValueError("Expecting scalar coefficient in this branch.")
×
1527
                    # Case: d/dt [w + t v[...]]
1528
                    wval, wcomp = w, ()
×
1529

1530
                    vval, vcomp = analyse_variation_argument(v)
×
1531
                    gprimesum = gprimesum + compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp)
×
1532

1533
            elif isinstance(
×
1534
                w, Indexed
1535
            ):  # This path is tested in unit tests, but not actually used?
1536
                # Case: d/dt [w[...] + t v[...]]
1537
                # Case: d/dt [w[...] + t v]
1538
                wval, wcomp = w.ufl_operands
×
1539
                if not wval == o:
×
1540
                    continue
×
1541
                assert isinstance(wval, FormArgument)
×
1542
                if not all(isinstance(k, FixedIndex) for k in wcomp):
×
1543
                    raise ValueError("Expecting only fixed indices in differentiation variable.")
×
1544
                wshape = wval.ufl_shape
×
1545

1546
                vval, vcomp = analyse_variation_argument(v)
×
1547
                gprimesum = gprimesum + compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp)
×
1548

1549
            else:
1550
                raise ValueError("Expecting coefficient or component of coefficient.")
×
1551

1552
        # FIXME: Handle other coefficient derivatives: oprimes =
1553
        # self._cd.get(o)
1554

1555
        if 0:
1✔
1556
            oprimes = self._cd.get(o)
1557
            if oprimes is None:
1558
                if self._cd:
1559
                    # TODO: Make it possible to silence this message
1560
                    #       in particular?  It may be good to have for
1561
                    #       debugging...
1562
                    warnings.warn(f"Assuming d{{{0}}}/d{{{self._w}}} = 0.")
1563
            else:
1564
                # Make sure we have a tuple to match the self._v tuple
1565
                if not isinstance(oprimes, tuple):
1566
                    oprimes = (oprimes,)
1567
                    if len(oprimes) != len(self._v):
1568
                        raise ValueError(
1569
                            "Got a tuple of arguments, expecting a"
1570
                            " matching tuple of coefficient derivatives."
1571
                        )
1572

1573
                # Compute dg/dw_j = dg/dw_h : v.
1574
                # Since we may actually have a tuple of oprimes and vs
1575
                # in a 'mixed' space, sum over them all to get the
1576
                # complete inner product. Using indices to define a
1577
                # non-compound inner product.
1578
                for oprime, v in zip(oprimes, self._v):
1579
                    raise NotImplementedError("FIXME: Figure out how to do this with ngrads")
1580
                    so, oi = as_scalar(oprime)
1581
                    rv = len(v.ufl_shape)
1582
                    oi1 = oi[:-rv]
1583
                    oi2 = oi[-rv:]
1584
                    prod = so * v[oi2]
1585
                    if oi1:
1586
                        gprimesum += as_tensor(prod, oi1)
1587
                    else:
1588
                        gprimesum += prod
1589

1590
        return gprimesum
1✔
1591

1592
    @process.register(CoordinateDerivative)
1✔
1593
    @DAGTraverser.postorder_only_children([0])
1✔
1594
    def _(self, o: Expr, o0: Expr) -> Expr:
1✔
1595
        """Differentiate a coordinate_derivative."""
1596
        _, o1, o2, o3 = o.ufl_operands
×
1597
        return CoordinateDerivative(o0, o1, o2, o3)
×
1598

1599
    @process.register(BaseFormOperator)
1✔
1600
    @DAGTraverser.postorder
1✔
1601
    def _(self, o: BaseFormOperator, *dfs) -> Expr:
1✔
1602
        """Differentiate a base_form_operator.
1603

1604
        If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
1605
        variable => we call the appropriate handler. Otherwise =>
1606
        differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
1607
        we treat o as a Coefficient.
1608
        """
1609
        d_coeff = self._process_coefficient(o)
1✔
1610
        # It also handles the non-scalar case
1611
        if d_coeff == 0:
1✔
1612
            self.pending_operations += (o,)
1✔
1613
        return d_coeff
1✔
1614

1615
    # -- Handlers for BaseForm objects -- #
1616

1617
    @process.register(Cofunction)
1✔
1618
    def _(self, o: Cofunction) -> Expr:
1✔
1619
        """Differentiate a cofunction."""
1620
        # Same rule than for Coefficient except that we use a Coargument.
1621
        # The coargument is already attached to the class (self._v)
1622
        # which `self.coefficient` relies on.
1623
        dc = self._process_coefficient(o)  # type: ignore
1✔
1624
        if dc == 0:
1✔
1625
            # Convert ufl.Zero into ZeroBaseForm
1626
            return ZeroBaseForm(o.arguments() + self._v)  # type: ignore
1✔
1627
        return dc
1✔
1628

1629
    @process.register(Coargument)
1✔
1630
    def _(self, o: Coargument) -> Expr:
1✔
1631
        """Differentiate a coargument."""
1632
        # Same rule than for Argument (da/dw == 0).
1633
        dc = self._process_argument(o)
1✔
1634
        if dc == 0:
1✔
1635
            # Convert ufl.Zero into ZeroBaseForm
1636
            return ZeroBaseForm(o.arguments() + self._v)  # type: ignore
1✔
1637
        return dc
×
1638

1639
    @process.register(Matrix)  # type: ignore
1✔
1640
    def _(self, M: Matrix) -> BaseForm:
1✔
1641
        """Differentiate a matrix."""
1642
        # Matrix rule: D_w[v](M) = v if M == w else 0
1643
        # We can't differentiate wrt a matrix so always return zero in
1644
        # the appropriate space
1645
        return ZeroBaseForm(M.arguments() + self._v)
1✔
1646

1647

1648
class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
1✔
1649
    """Apply AFD (Automatic Functional Differentiation) to BaseFormOperator.
1650

1651
    Implements rules for the Gateaux derivative D_w[v](...) defined as
1652
    D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
1653
    """
1654

1655
    @staticmethod
1✔
1656
    def pending_operations_recording(base_form_operator_handler):
1✔
1657
        """Decorate a function to record pending operations."""
1658

1659
        def wrapper(self, base_form_op, *dfs):
1✔
1660
            """Decorate."""
1661
            # Get the outer `BaseFormOperator` expression, i.e. the
1662
            # operator that is being differentiated.
1663
            expression = self.outer_base_form_op
1✔
1664
            # If the base form operator we observe is different from the
1665
            # outer `BaseFormOperator`:
1666
            # -> Record that `BaseFormOperator` so that
1667
            # `d(expression)/d(base_form_op)` can then be computed
1668
            # later.
1669
            # Else:
1670
            # -> Compute the Gateaux derivative of `base_form_ops` by
1671
            # calling the appropriate handler.
1672
            if expression != base_form_op:
1✔
1673
                self.pending_operations += (base_form_op,)
1✔
1674
                return self._process_coefficient(base_form_op)
1✔
1675
            return base_form_operator_handler(self, base_form_op, *dfs)
1✔
1676

1677
        return wrapper
1✔
1678

1679
    def __init__(
1✔
1680
        self,
1681
        coefficients: ExprList,
1682
        arguments: ExprList,
1683
        coefficient_derivatives: ExprMapping,
1684
        outer_base_form_op: Expr,
1685
        compress: bool | None = True,
1686
        visited_cache: dict[tuple, Expr] | None = None,
1687
        result_cache: dict[Expr, Expr] | None = None,
1688
    ) -> None:
1689
        """Initialise."""
1690
        super().__init__(
1✔
1691
            coefficients,
1692
            arguments,
1693
            coefficient_derivatives,
1694
            compress=compress,
1695
            visited_cache=visited_cache,
1696
            result_cache=result_cache,
1697
        )
1698
        self.outer_base_form_op = outer_base_form_op
1✔
1699

1700
    # Work around singledispatchmethod inheritance issue;
1701
    # see https://bugs.python.org/issue36457.
1702
    @singledispatchmethod
1✔
1703
    def process(self, o: Expr) -> Expr:
1✔
1704
        """Process ``o``.
1705

1706
        Args:
1707
            o: `Expr` to be processed.
1708

1709
        Returns:
1710
            Processed object.
1711

1712
        """
1713
        return super().process(o)
1✔
1714

1715
    @process.register(Interpolate)
1✔
1716
    @DAGTraverser.postorder
1✔
1717
    @pending_operations_recording
1✔
1718
    def _(self, i_op: Interpolate, dw: Expr) -> Expr:
1✔
1719
        """Differentiate an interpolate."""
1720
        # Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
1721
        if not dw:
1✔
1722
            # i_op doesn't depend on w:
1723
            #  -> It also covers the Hessian case since Interpolate is linear,
1724
            #     e.g. D_w[v](D_w[v](i_op(w, v*))) = D_w[v](i_op(v, v*)) = 0 (since w not found).
1725
            return ZeroBaseForm(i_op.arguments() + self._v)  # type: ignore
×
1726
        return i_op._ufl_expr_reconstruct_(expr=dw)
1✔
1727

1728
    @process.register(ExternalOperator)
1✔
1729
    @DAGTraverser.postorder
1✔
1730
    @pending_operations_recording
1✔
1731
    def external_operator(self, N: ExternalOperator, *dfs) -> Expr:
1✔
1732
        """Differentiate an external_operator."""
1733
        result: tuple[Expr, ...] = ()
1✔
1734
        for i, df in enumerate(dfs):
1✔
1735
            derivatives = tuple(dj + int(i == j) for j, dj in enumerate(N.derivatives))
1✔
1736
            if len(extract_arguments(df)) != 0:
1✔
1737
                # Handle the symbolic differentiation of external operators.
1738
                # This bit returns:
1739
                #
1740
                #   `\sum_{i} dNdOi(..., Oi, ...; DOi(u)[v], ..., v*)`
1741
                #
1742
                # where we differentate wrt u, Oi is the i-th operand,
1743
                # N(..., Oi, ...; ..., v*) an ExternalOperator and v the
1744
                # direction (Argument). dNdOi(..., Oi, ...; DOi(u)[v])
1745
                # is an ExternalOperator representing the
1746
                # Gateaux-derivative of N. For example:
1747
                #  -> From N(u) = u**2, we get `dNdu(u; uhat, v*) = 2 * u * uhat`.
1748
                new_args = N.argument_slots() + (df,)
1✔
1749
                extop = N._ufl_expr_reconstruct_(
1✔
1750
                    *N.ufl_operands, derivatives=derivatives, argument_slots=new_args
1751
                )
1752
            elif df == 0:
1✔
1753
                extop = ZeroBaseForm(N.arguments())
1✔
1754
            else:
1755
                raise NotImplementedError(
×
1756
                    "Frechet derivative of external operators need to be provided!"
1757
                )
1758
            result += (extop,)
1✔
1759
        return sum(result)  # type: ignore
1✔
1760

1761

1762
class DerivativeRuleDispatcher(DAGTraverser):
1✔
1763
    """Dispatch a derivative rule."""
1764

1765
    def __init__(
1✔
1766
        self,
1767
        compress: bool | None = True,
1768
        visited_cache: dict[tuple, Expr] | None = None,
1769
        result_cache: dict[Expr, Expr] | None = None,
1770
    ) -> None:
1771
        """Initialise."""
1772
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
1773
        # Record the operations delayed to the derivative expansion phase:
1774
        # Example: dN(u)/du where `N` is a BaseFormOperator and `u` a Coefficient
1775
        self.pending_operations = ()
1✔
1776
        # Create DAGTraverser caches.
1777
        self._dag_traverser_cache: dict[
1✔
1778
            tuple[type, Expr] | tuple[type, Expr, Expr, Expr] | tuple[type, Expr, Expr, Expr, Expr],
1779
            DAGTraverser,
1780
        ] = {}
1781

1782
    @singledispatchmethod
1✔
1783
    def process(self, o: Expr) -> Expr:
1✔
1784
        """Process ``o``.
1785

1786
        Args:
1787
            o: `Expr` to be processed.
1788

1789
        Returns:
1790
            Processed object.
1791

1792
        """
1793
        return super().process(o)
×
1794

1795
    @process.register(Expr)
1✔
1796
    @process.register(BaseForm)  # type: ignore
1✔
1797
    def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
1✔
1798
        """Apply to expr and base form."""
1799
        return self.reuse_if_untouched(o)
1✔
1800

1801
    @process.register(Terminal)
1✔
1802
    def _(self, o: Terminal) -> Terminal:
1✔
1803
        """Apply to a terminal."""
1804
        return o
1✔
1805

1806
    @process.register(Derivative)
1✔
1807
    def _(self, o: Derivative) -> Expr:
1✔
1808
        """Apply to a derivative."""
1809
        raise NotImplementedError(f"Missing derivative handler for {type(o).__name__}.")
×
1810

1811
    @process.register(Grad)
1✔
1812
    @DAGTraverser.postorder
1✔
1813
    def _(self, o: Grad, f: Expr | BaseForm) -> Expr:
1✔
1814
        """Apply to a grad."""
1815
        gdim = o.ufl_shape[-1]
1✔
1816
        key = (GradRuleset, gdim)
1✔
1817
        dag_traverser = self._dag_traverser_cache.setdefault(key, GradRuleset(gdim))
1✔
1818
        return dag_traverser(f)  # type: ignore
1✔
1819

1820
    @process.register(ReferenceGrad)
1✔
1821
    @DAGTraverser.postorder
1✔
1822
    def _(self, o: ReferenceGrad, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1823
        """Apply to a reference_grad."""
1824
        tdim = o.ufl_shape[-1]
1✔
1825
        key = (ReferenceGradRuleset, tdim)
1✔
1826
        dag_traverser = self._dag_traverser_cache.setdefault(key, ReferenceGradRuleset(tdim))
1✔
1827
        return dag_traverser(f)  # type: ignore
1✔
1828

1829
    @process.register(VariableDerivative)
1✔
1830
    @DAGTraverser.postorder_only_children([0])
1✔
1831
    def _(self, o: Expr, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1832
        """Apply to a variable_derivative."""
1833
        _, op = o.ufl_operands
1✔
1834
        key = (VariableRuleset, op)
1✔
1835
        dag_traverser = self._dag_traverser_cache.setdefault(key, VariableRuleset(op))
1✔
1836
        return dag_traverser(f)  # type: ignore
1✔
1837

1838
    @process.register(CoefficientDerivative)
1✔
1839
    @DAGTraverser.postorder_only_children([0])
1✔
1840
    def _(self, o: CoefficientDerivative, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1841
        """Apply to a coefficient_derivative."""
1842
        _, w, v, cd = o.ufl_operands
1✔
1843
        key = (GateauxDerivativeRuleset, w, v, cd)
1✔
1844
        # We need to go through the dag first to record the pending
1845
        # operations
1846
        dag_traverser = self._dag_traverser_cache.setdefault(
1✔
1847
            key,
1848
            GateauxDerivativeRuleset(w, v, cd),  # type: ignore
1849
        )
1850
        # If f has been seen by the traverser, it immediately returns
1851
        # the cached value.
1852
        mapped_expr = dag_traverser(f)  # type: ignore
1✔
1853
        # Need to account for pending operations that have been stored
1854
        # in other integrands
1855
        self.pending_operations += dag_traverser.pending_operations  # type: ignore
1✔
1856
        return mapped_expr
1✔
1857

1858
    @process.register(BaseFormOperatorDerivative)
1✔
1859
    @DAGTraverser.postorder_only_children([0])
1✔
1860
    def _(self, o: BaseFormOperatorDerivative, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1861
        """Apply to a base_form_operator_derivative."""
1862
        _, w, v, cd = o.ufl_operands
1✔
1863
        if isinstance(f, ZeroBaseForm):
1✔
1864
            (arg,) = v.ufl_operands  # type: ignore
×
1865
            arguments = f.arguments()
×
1866
            # derivative(F, u, du) with `du` a Coefficient
1867
            # is equivalent to taking the action of the derivative.
1868
            # In that case, we don't add arguments to `ZeroBaseForm`.
1869
            if isinstance(arg, BaseArgument):
×
1870
                arguments += (arg,)
×
1871
            return ZeroBaseForm(arguments)
×
1872
        # Need a BaseFormOperatorDerivativeRuleset object
1873
        # for each outer_base_form_op (= f).
1874
        key = (BaseFormOperatorDerivativeRuleset, w, v, cd, f)
1✔
1875
        # We need to go through the dag first to record the pending operations
1876
        dag_traverser = self._dag_traverser_cache.setdefault(
1✔
1877
            key,  # type: ignore
1878
            BaseFormOperatorDerivativeRuleset(w, v, cd, f),  # type: ignore
1879
        )
1880
        # If f has been seen by the traverser, it immediately returns
1881
        # the cached value.
1882
        mapped_expr = dag_traverser(f)  # type: ignore
1✔
1883
        mapped_f = dag_traverser._process_coefficient(f)  # type: ignore
1✔
1884
        if mapped_f != 0:
1✔
1885
            # If dN/dN needs to return an Argument in N space
1886
            # with N a BaseFormOperator.
1887
            return mapped_f
1✔
1888
        # Need to account for pending operations that have been stored in other integrands
1889
        self.pending_operations += dag_traverser.pending_operations  # type: ignore
1✔
1890
        return mapped_expr
1✔
1891

1892
    @process.register(CoordinateDerivative)
1✔
1893
    @DAGTraverser.postorder_only_children([0])
1✔
1894
    def _(self, o: Expr, f: Expr | BaseForm) -> CoordinateDerivative:
1✔
1895
        """Apply to a coordinate_derivative."""
1896
        _, o1, o2, o3 = o.ufl_operands
×
1897
        return CoordinateDerivative(f, o1, o2, o3)
×
1898

1899
    @process.register(BaseFormCoordinateDerivative)
1✔
1900
    @DAGTraverser.postorder_only_children([0])
1✔
1901
    def _(self, o: Expr, f: Expr | BaseForm) -> BaseFormCoordinateDerivative:
1✔
1902
        """Apply to a base_form_coordinate_derivative."""
1903
        _, o1, o2, o3 = o.ufl_operands
×
1904
        return BaseFormCoordinateDerivative(f, o1, o2, o3)
×
1905

1906
    @process.register(Indexed)
1✔
1907
    @DAGTraverser.postorder
1✔
1908
    def _(self, o: Indexed, Ap: Expr, ii: MultiIndex) -> Expr | BaseForm:
1✔
1909
        """Apply to an indexed."""
1910
        # Reuse if untouched
1911
        if Ap is o.ufl_operands[0]:
1✔
1912
            return o
1✔
1913
        r = len(Ap.ufl_shape) - len(ii)
1✔
1914
        if r:
1✔
1915
            kk = indices(r)
×
1916
            op = Indexed(Ap, MultiIndex(ii.indices() + kk))
×
1917
            op = as_tensor(op, kk)
×
1918
        else:
1919
            op = Indexed(Ap, ii)
1✔
1920
        return op
1✔
1921

1922

1923
class BaseFormOperatorDerivativeRecorder:
1✔
1924
    """A derivative recorded for a base form operator."""
1925

1926
    def __init__(self, var, **kwargs):
1✔
1927
        """Initialise."""
1928
        base_form_ops = kwargs.pop("base_form_ops", ())
1✔
1929

1930
        if kwargs.keys() != {"arguments", "coefficient_derivatives"}:
1✔
1931
            raise ValueError(
×
1932
                "Only `arguments` and `coefficient_derivatives` are "
1933
                "allowed as derivative arguments."
1934
            )
1935

1936
        self.var = var
1✔
1937
        self.der_kwargs = kwargs
1✔
1938
        self.base_form_ops = base_form_ops
1✔
1939

1940
    def __len__(self):
1✔
1941
        """Get the length."""
1942
        return len(self.base_form_ops)
×
1943

1944
    def __bool__(self):
1✔
1945
        """Convert to a bool."""
1946
        return bool(self.base_form_ops)
1✔
1947

1948
    def __add__(self, other):
1✔
1949
        """Add."""
1950
        if isinstance(other, list | tuple):
1✔
1951
            base_form_ops = self.base_form_ops + other
1✔
1952
        elif isinstance(other, BaseFormOperatorDerivativeRecorder):
×
1953
            if self.der_kwargs != other.der_kwargs:
×
1954
                raise ValueError(
×
1955
                    f"Derivative arguments must match when summing {type(self).__name__} objects."
1956
                )
1957
            base_form_ops = self.base_form_ops + other.base_form_ops
×
1958
        else:
1959
            raise NotImplementedError(
×
1960
                f"Sum of {type(self)} and {type(other)} objects is not supported."
1961
            )
1962

1963
        return BaseFormOperatorDerivativeRecorder(
1✔
1964
            self.var, base_form_ops=base_form_ops, **self.der_kwargs
1965
        )
1966

1967
    def __radd__(self, other):
1✔
1968
        """Add."""
1969
        # Recording order doesn't matter as collected
1970
        # `BaseFormOperator`s are sorted later on.
1971
        return self.__add__(other)
1✔
1972

1973
    def __iadd__(self, other):
1✔
1974
        """Add."""
1975
        if isinstance(other, list | tuple):
1✔
1976
            self.base_form_ops += other
1✔
1977
        elif isinstance(other, BaseFormOperatorDerivativeRecorder):
1✔
1978
            self.base_form_ops += other.base_form_ops
1✔
1979
        else:
1980
            raise NotImplementedError
×
1981
        return self
1✔
1982

1983

1984
def apply_derivatives(expression):
1✔
1985
    """Apply derivatives to an expression.
1986

1987
    Args:
1988
        expression: A Form, an Expr or a BaseFormOperator to be differentiated
1989

1990
    Returns:
1991
        A differentiated expression
1992
    """
1993
    # Notation: Let `var` be the thing we are differentating with respect to.
1994

1995
    dag_traverser = DerivativeRuleDispatcher()
1✔
1996

1997
    # If we hit a base form operator (bfo), then if `var` is:
1998
    #    - a BaseFormOperator → Return `d(expression)/dw` where `w` is
1999
    #      the coefficient produced by the bfo `var`.
2000
    #    - else → Record the bfo on the DAGTraverser object and returns
2001
    #    - 0.
2002
    # Example:
2003
    #    → If derivative(F(u, N(u); v), u) was taken the following line would compute `∂F/∂u`.
2004
    dexpression_dvar = map_integrands(dag_traverser, expression)
1✔
2005
    if (
1✔
2006
        isinstance(expression, BaseForm)
2007
        and isinstance(dexpression_dvar, int)
2008
        and dexpression_dvar == 0
2009
    ):
2010
        # The arguments got lost, just keep an empty Form
2011
        dexpression_dvar = Form([])
×
2012

2013
    # Get the recorded delayed operations
2014
    pending_operations = dag_traverser.pending_operations
1✔
2015
    if not pending_operations:
1✔
2016
        return dexpression_dvar
1✔
2017

2018
    # Don't take into account empty Forms
2019
    if isinstance(dexpression_dvar, Form) and dexpression_dvar.empty():
1✔
2020
        dexpression_dvar = []
1✔
2021
    else:
2022
        dexpression_dvar = [dexpression_dvar]
1✔
2023

2024
    # Retrieve the base form operators, var, and the argument and
2025
    # coefficient_derivatives for `derivative`
2026
    var = pending_operations.var
1✔
2027
    base_form_ops = pending_operations.base_form_ops
1✔
2028
    der_kwargs = pending_operations.der_kwargs
1✔
2029
    for N in sorted(set(base_form_ops), key=lambda x: x.count()):
1✔
2030
        # -- Replace dexpr/dvar by dexpr/dN -- #
2031
        # We don't use `apply_derivatives` since the differentiation is
2032
        # done via `\partial` and not `d`.
2033
        dexpr_dN = map_integrands(
1✔
2034
            dag_traverser, replace_derivative_nodes(expression, {var.ufl_operands[0]: N})
2035
        )
2036
        # Don't take into account empty Forms
2037
        if isinstance(dexpr_dN, Form) and dexpr_dN.empty():
1✔
2038
            continue
1✔
2039

2040
        # -- Add the BaseFormOperatorDerivative node -- #
2041
        (var_arg,) = der_kwargs["arguments"].ufl_operands
1✔
2042
        cd = der_kwargs["coefficient_derivatives"]
1✔
2043
        # Not always the case since `derivative`'s syntax enables one to
2044
        # use a Coefficient as the Gateaux direction
2045
        if isinstance(var_arg, BaseArgument):
1✔
2046
            # Construct the argument number based on the
2047
            # BaseFormOperator arguments instead of naively using
2048
            # `var_arg`. This is critical when BaseFormOperators are
2049
            # used inside 0-forms.
2050
            #
2051
            # Example: F = 0.5 * u** 2 * dx + 0.5 * N(u; v*)** 2 * dx
2052
            #    -> dFdu[vhat] = <u, vhat> + Action(<N(u; v*), v0>, dNdu(u; v1, v*))
2053
            # with `vhat` a 0-numbered argument, and where `v1` and
2054
            # `vhat` have the same function space but a different
2055
            # number. Here, applying `vhat` (`var_arg`) naively would
2056
            # result in `dNdu(u; vhat, v*)`, i.e. the 2-forms `dNdu`
2057
            # would have two 0-numbered arguments. Instead we increment
2058
            # the argument number of `vhat` to form `v1`.
2059
            var_arg = type(var_arg)(
1✔
2060
                var_arg.ufl_function_space(), number=len(N.arguments()), part=var_arg.part()
2061
            )
2062
        dN_dvar = apply_derivatives(BaseFormOperatorDerivative(N, var, ExprList(var_arg), cd))
1✔
2063
        # -- Sum the Action: dF/du = ∂F/∂u + \sum_{i=1,...} Action(∂F/∂Ni, dNi/du) -- #
2064
        # In this case: Action <=> ufl.action since `dN_var` has 2 arguments.
2065
        # We use Action to handle the trivial case `dN_dvar` = 0.
2066
        dexpression_dvar.append(Action(dexpr_dN, dN_dvar))
1✔
2067
    return sum(dexpression_dvar)
1✔
2068

2069

2070
class CoordinateDerivativeRuleset(GenericDerivativeRuleset):
1✔
2071
    """Apply AFD (Automatic Functional Differentiation) to expression.
2072

2073
    Implements rules for the Gateaux derivative D_w[v](...) defined as
2074
    D_w[v](e) = d/dtau e(w+tau v)|tau=0
2075
    where 'e' is a ufl form after pullback and w is a SpatialCoordinate.
2076
    """
2077

2078
    def __init__(
1✔
2079
        self,
2080
        coefficients: ExprList,
2081
        arguments: ExprList,
2082
        coefficient_derivatives: ExprMapping,
2083
        compress: bool | None = True,
2084
        visited_cache: dict[tuple, Expr] | None = None,
2085
        result_cache: dict[Expr, Expr] | None = None,
2086
    ) -> None:
2087
        """Initialise."""
2088
        super().__init__(
×
2089
            (), compress=compress, visited_cache=visited_cache, result_cache=result_cache
2090
        )
2091
        # Type checking
2092
        if not isinstance(coefficients, ExprList):
×
2093
            raise ValueError("Expecting a ExprList of coefficients.")
×
2094
        if not isinstance(arguments, ExprList):
×
2095
            raise ValueError("Expecting a ExprList of arguments.")
×
2096
        if not isinstance(coefficient_derivatives, ExprMapping):
×
2097
            raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
×
2098
        # The coefficient(s) to differentiate w.r.t. and the
2099
        # argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
2100
        self._w = coefficients.ufl_operands
×
2101
        self._v = arguments.ufl_operands
×
2102
        self._w2v = {w: v for w, v in zip(self._w, self._v)}
×
2103
        # Build more convenient dict {f: df/dw} for each coefficient f
2104
        # where df/dw is nonzero
2105
        cd = coefficient_derivatives.ufl_operands
×
2106
        self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
×
2107

2108
    # Work around singledispatchmethod inheritance issue;
2109
    # see https://bugs.python.org/issue36457.
2110
    @singledispatchmethod
1✔
2111
    def process(self, o: Expr) -> Expr:
1✔
2112
        """Process ``o``.
2113

2114
        Args:
2115
            o: `Expr` to be processed.
2116

2117
        Returns:
2118
            Processed object.
2119

2120
        """
2121
        return super().process(o)
×
2122

2123
    @process.register(GeometricQuantity)
1✔
2124
    def _(self, o: Expr) -> Expr:
1✔
2125
        # Explicitly defining dg/dw == 0
2126
        return self.independent_terminal(o)
×
2127

2128
    @process.register(Argument)
1✔
2129
    def _(self, o: Expr) -> Expr:
1✔
2130
        # Explicitly defining da/dw == 0
2131
        return self.independent_terminal(o)
×
2132

2133
    @process.register(Coefficient)
1✔
2134
    def _(self, o: Expr) -> Expr:
1✔
2135
        """Differentiate a coefficient."""
2136
        raise NotImplementedError(
×
2137
            "CoordinateDerivative of coefficient in physical space is not implemented."
2138
        )
2139

2140
    @process.register(Grad)
1✔
2141
    def _(self, o: Expr) -> Expr:
1✔
2142
        """Differentiate a grad."""
2143
        raise NotImplementedError("CoordinateDerivative grad in physical space is not implemented.")
×
2144

2145
    @process.register(SpatialCoordinate)
1✔
2146
    def _(self, o: Expr) -> Expr:
1✔
2147
        """Differentiate a spatial_coordinate."""
2148
        do = self._w2v.get(o)  # type: ignore
×
2149
        # d x /d x => Argument(x.function_space())
2150
        if do is not None:
×
2151
            return do
×
2152
        else:
2153
            raise NotImplementedError(
×
2154
                "CoordinateDerivative found a SpatialCoordinate that is different "
2155
                "from the one being differentiated."
2156
            )
2157

2158
    @process.register(ReferenceValue)
1✔
2159
    def _(self, o: Expr) -> Expr:
1✔
2160
        """Differentiate a reference_value."""
2161
        do = self._cd.get(o)  # type: ignore
×
2162
        if do is not None:
×
2163
            return do
×
2164
        else:
2165
            return self.independent_terminal(o)
×
2166

2167
    @process.register(ReferenceGrad)
1✔
2168
    def _(self, g: Expr) -> Expr:
1✔
2169
        """Differentiate a reference_grad."""
2170
        # d (grad_X(...(x)) / dx => grad_X(...(Argument(x.function_space()))
2171
        o = g
×
2172
        ngrads = 0
×
2173
        while isinstance(o, ReferenceGrad):
×
2174
            (o,) = o.ufl_operands
×
2175
            ngrads += 1
×
2176
        if not (isinstance(o, SpatialCoordinate) or isinstance(o.ufl_operands[0], FormArgument)):
×
2177
            raise ValueError(f"Expecting gradient of a FormArgument, not {ufl_err_str(o)}")
×
2178

2179
        def apply_grads(f):
×
2180
            for i in range(ngrads):
×
2181
                f = ReferenceGrad(f)
×
2182
            return f
×
2183

2184
        # Find o among all w without any indexing, which makes this
2185
        # easy
2186
        for w, v in zip(self._w, self._v):
×
2187
            if (
×
2188
                o == w
2189
                and isinstance(v, ReferenceValue)
2190
                and isinstance(v.ufl_operands[0], FormArgument)
2191
            ):
2192
                # Case: d/dt [w + t v]
2193
                return apply_grads(v)
×
2194
        return self.independent_terminal(o)
×
2195

2196
    @process.register(Jacobian)
1✔
2197
    def _(self, o: Expr) -> Expr:
1✔
2198
        """Differentiate a jacobian."""
2199
        # d (grad_X(x))/d x => grad_X(Argument(x.function_space())
2200
        for w, v in zip(self._w, self._v):
×
2201
            if extract_unique_domain(o) == extract_unique_domain(w) and isinstance(
×
2202
                v.ufl_operands[0],  # type: ignore
2203
                FormArgument,
2204
            ):
2205
                return ReferenceGrad(v)
×
2206
        return self.independent_terminal(o)
×
2207

2208

2209
class CoordinateDerivativeRuleDispatcher(DAGTraverser):
1✔
2210
    """Dispatcher."""
2211

2212
    def __init__(
1✔
2213
        self,
2214
        compress: bool | None = True,
2215
        visited_cache: dict[tuple, Expr] | None = None,
2216
        result_cache: dict[Expr, Expr] | None = None,
2217
    ) -> None:
2218
        """Initialise."""
2219
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
2220
        self._dag_traverser_cache: dict[tuple[type, Expr, Expr, Expr], DAGTraverser] = {}
1✔
2221

2222
    @singledispatchmethod
1✔
2223
    def process(self, o: Expr) -> Expr:
1✔
2224
        """Process ``o``.
2225

2226
        Args:
2227
            o: `Expr` to be processed.
2228

2229
        Returns:
2230
            Processed object.
2231

2232
        """
2233
        return super().process(o)
×
2234

2235
    @process.register(Expr)
1✔
2236
    @process.register(BaseForm)  # type: ignore
1✔
2237
    def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
1✔
2238
        """Apply to expr and base form."""
2239
        return self.reuse_if_untouched(o)
1✔
2240

2241
    @process.register(Terminal)
1✔
2242
    def _(self, o: Expr) -> Expr:
1✔
2243
        """Apply to a terminal."""
2244
        return o
1✔
2245

2246
    @process.register(Derivative)
1✔
2247
    def _(self, o: Expr) -> Expr:
1✔
2248
        """Apply to a derivative."""
2249
        raise NotImplementedError(f"Missing derivative handler for {type(o).__name__}.")
×
2250

2251
    @process.register(Grad)
1✔
2252
    def _(self, o: Expr) -> Expr:
1✔
2253
        """Apply to a grad."""
2254
        return o
1✔
2255

2256
    @process.register(ReferenceGrad)
1✔
2257
    def _(self, o: Expr) -> Expr:
1✔
2258
        """Apply to a reference_grad."""
2259
        return o
1✔
2260

2261
    @process.register(CoefficientDerivative)
1✔
2262
    def _(self, o: Expr) -> Expr:
1✔
2263
        """Apply to a coefficient_derivative."""
2264
        return o
×
2265

2266
    @process.register(CoordinateDerivative)
1✔
2267
    @DAGTraverser.postorder_only_children([0])
1✔
2268
    def _(self, o: Expr, f: Expr) -> Expr:
1✔
2269
        """Apply to a coordinate_derivative."""
2270
        from ufl.algorithms import extract_unique_elements
×
2271

2272
        for space in extract_unique_elements(o):
×
2273
            if isinstance(space.pullback, CustomPullback):
×
2274
                raise NotImplementedError(
×
2275
                    "CoordinateDerivative is not supported for elements with custom pull back."
2276
                )
2277
        _, w, v, cd = o.ufl_operands
×
2278
        key = (CoordinateDerivativeRuleset, w, v, cd)
×
2279
        dag_traverser = self._dag_traverser_cache.setdefault(
×
2280
            key,
2281
            CoordinateDerivativeRuleset(w, v, cd),  # type: ignore
2282
        )
2283
        return dag_traverser(f)
×
2284

2285

2286
def apply_coordinate_derivatives(expression):
1✔
2287
    """Apply coordinate derivatives to an expression."""
2288
    dag_traverser = CoordinateDerivativeRuleDispatcher()
1✔
2289
    return map_integrands(dag_traverser, expression)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc