• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

FEniCS / ufl / 17557791029

08 Sep 2025 04:38PM UTC coverage: 76.363% (+0.4%) from 75.917%
17557791029

Pull #401

github

web-flow
Merge branch 'main' into schnellerhase/remove-type-system
Pull Request #401: Removal of custom type system

495 of 534 new or added lines in 42 files covered. (92.7%)

6 existing lines in 2 files now uncovered.

9133 of 11960 relevant lines covered (76.36%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.29
/ufl/algorithms/apply_derivatives.py
1
"""Apply derivatives algorithm which computes the derivatives of a form of expression."""
2

3
# Copyright (C) 2008-2016 Martin Sandve Alnæs
4
#
5
# This file is part of UFL (https://www.fenicsproject.org)
6
#
7
# SPDX-License-Identifier:    LGPL-3.0-or-later
8

9
import warnings
1✔
10
from functools import singledispatchmethod
1✔
11
from math import pi
1✔
12

13
import numpy as np
1✔
14

15
from ufl.action import Action
1✔
16
from ufl.algorithms.analysis import extract_arguments, extract_coefficients
1✔
17
from ufl.algorithms.map_integrands import map_integrands
1✔
18
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
1✔
19
from ufl.argument import Argument, BaseArgument, Coargument
1✔
20
from ufl.averaging import CellAvg, FacetAvg
1✔
21
from ufl.checks import is_cellwise_constant
1✔
22
from ufl.classes import (
1✔
23
    Abs,
24
    CellCoordinate,
25
    Coefficient,
26
    Cofunction,
27
    ComponentTensor,
28
    Conj,
29
    Constant,
30
    ConstantValue,
31
    Division,
32
    Expr,
33
    ExprList,
34
    ExprMapping,
35
    FacetNormal,
36
    FloatValue,
37
    FormArgument,
38
    GeometricQuantity,
39
    Grad,
40
    Identity,
41
    Imag,
42
    Indexed,
43
    IndexSum,
44
    Jacobian,
45
    JacobianDeterminant,
46
    JacobianInverse,
47
    Label,
48
    ListTensor,
49
    Power,
50
    Product,
51
    Real,
52
    ReferenceGrad,
53
    ReferenceValue,
54
    SpatialCoordinate,
55
    Sum,
56
    Variable,
57
    Zero,
58
)
59
from ufl.conditional import BinaryCondition, Conditional, NotCondition
1✔
60
from ufl.constantvalue import is_true_ufl_scalar, is_ufl_scalar
1✔
61
from ufl.core.base_form_operator import BaseFormOperator
1✔
62
from ufl.core.expr import ufl_err_str
1✔
63
from ufl.core.external_operator import ExternalOperator
1✔
64
from ufl.core.interpolate import Interpolate
1✔
65
from ufl.core.multiindex import FixedIndex, MultiIndex, indices
1✔
66
from ufl.core.terminal import Terminal
1✔
67
from ufl.corealg.dag_traverser import DAGTraverser
1✔
68
from ufl.differentiation import (
1✔
69
    BaseFormCoordinateDerivative,
70
    BaseFormOperatorDerivative,
71
    CoefficientDerivative,
72
    CoordinateDerivative,
73
    Derivative,
74
    VariableDerivative,
75
)
76
from ufl.domain import MeshSequence, extract_unique_domain
1✔
77
from ufl.form import BaseForm, Form, ZeroBaseForm
1✔
78
from ufl.mathfunctions import (
1✔
79
    Acos,
80
    Asin,
81
    Atan,
82
    Atan2,
83
    BesselI,
84
    BesselJ,
85
    BesselK,
86
    BesselY,
87
    Cos,
88
    Cosh,
89
    Erf,
90
    Exp,
91
    Ln,
92
    MathFunction,
93
    Sin,
94
    Sinh,
95
    Sqrt,
96
    Tan,
97
    Tanh,
98
)
99
from ufl.matrix import Matrix
1✔
100
from ufl.operators import (
1✔
101
    MaxValue,
102
    MinValue,
103
    bessel_I,
104
    bessel_J,
105
    bessel_K,
106
    bessel_Y,
107
    cell_avg,
108
    conditional,
109
    cos,
110
    cosh,
111
    exp,
112
    facet_avg,
113
    ln,
114
    sign,
115
    sin,
116
    sinh,
117
    sqrt,
118
)
119
from ufl.pullback import CustomPullback, PhysicalPullback
1✔
120
from ufl.restriction import Restricted
1✔
121
from ufl.tensors import as_scalar, as_scalars, as_tensor, unit_indexed_tensor, unwrap_list_tensor
1✔
122

123
# TODO: Add more rulesets?
124
# - DivRuleset
125
# - CurlRuleset
126
# - ReferenceGradRuleset
127
# - ReferenceDivRuleset
128

129

130
def flatten_domain_element(domain, element):
1✔
131
    """Return the flattened (domain, element) pairs for mixed domain problems.
132

133
    Args:
134
        domain: `Mesh` or `MeshSequence`.
135
        element: `FiniteElement`.
136

137
    Returns:
138
        Nested tuples of (domain, element) pairs; just ((domain, element),)
139
        if domain is a `Mesh` (and not a `MeshSequence`).
140

141
    """
142
    if not isinstance(domain, MeshSequence):
1✔
143
        return ((domain, element),)
1✔
144
    flattened = ()
1✔
145
    assert len(domain) == len(element.sub_elements)
1✔
146
    for d, e in zip(domain, element.sub_elements):
1✔
147
        flattened += flatten_domain_element(d, e)
1✔
148
    return flattened
1✔
149

150

151
class GenericDerivativeRuleset(DAGTraverser):
1✔
152
    """A generic derivative."""
153

154
    def __init__(
1✔
155
        self,
156
        var_shape: tuple,
157
        compress: bool | None = True,
158
        visited_cache: dict[tuple, Expr] | None = None,
159
        result_cache: dict[Expr, Expr] | None = None,
160
    ) -> None:
161
        """Initialise."""
162
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
163
        self._var_shape = var_shape
1✔
164

165
    def unexpected(self, o):
1✔
166
        """Raise error about unexpected type."""
NEW
167
        raise ValueError(f"Unexpected type {type(o).__name__} in AD rules.")
×
168

169
    def override(self, o):
1✔
170
        """Raise error about overriding."""
NEW
171
        raise ValueError(f"Type {type(o).__name__} must be overridden in specialized AD rule set.")
×
172

173
    # --- Some types just don't have any derivative, this is just to
174
    # --- make algorithm structure generic
175

176
    def non_differentiable_terminal(self, o):
1✔
177
        """Return the non-differentiated object.
178

179
        Labels and indices are not differentiable: it's convenient to
180
        return the non-differentiated object.
181
        """
182
        return o
1✔
183

184
    # --- Helper functions for creating zeros with the right shapes
185

186
    def independent_terminal(self, o):
1✔
187
        """A zero with correct shape for terminals independent of diff. variable."""
188
        return Zero(o.ufl_shape + self._var_shape)
1✔
189

190
    def independent_operator(self, o):
1✔
191
        """A zero with correct shape and indices for operators independent of diff. variable."""
192
        return Zero(o.ufl_shape + self._var_shape, o.ufl_free_indices, o.ufl_index_dimensions)
1✔
193

194
    # --- Error checking for missing handlers and unexpected types
195

196
    @singledispatchmethod
1✔
197
    def process(self, o: Expr) -> Expr:
1✔
198
        """Process ``o``.
199

200
        Args:
201
            o: `Expr` to be processed.
202

203
        Returns:
204
            Processed object.
205

206
        """
207
        return super().process(o)
×
208

209
    @process.register(Expr)
1✔
210
    def _(self, o: Expr) -> Expr:
1✔
211
        """Raise error."""
212
        raise ValueError(
×
213
            f"Missing differentiation handler for type {type(o).__name__}. "
214
            "Have you added a new type?"
215
        )
216

217
    @process.register(Derivative)
1✔
218
    def _(self, o: Expr) -> Expr:
1✔
219
        """Raise error."""
220
        raise ValueError(
×
221
            f"Unhandled derivative type {type(o).__name__}, nested differentiation has failed."
222
        )
223

224
    @process.register(Label)
1✔
225
    @process.register(MultiIndex)
1✔
226
    def _(self, o: Expr) -> Expr:
1✔
227
        return self.non_differentiable_terminal(o)
1✔
228

229
    # --- All derivatives need to define grad and averaging
230

231
    @process.register(Grad)
1✔
232
    @process.register(CellAvg)
1✔
233
    @process.register(FacetAvg)
1✔
234
    def _(self, o: Expr) -> Expr:
1✔
235
        return self.override(o)
×
236

237
    # --- Default rules for terminals
238

239
    # Literals are by definition independent of any differentiation variable
240
    @process.register(ConstantValue)
1✔
241
    # Constants are independent of any differentiation
242
    @process.register(Constant)
1✔
243
    def _(self, o: Expr) -> Expr:
1✔
244
        return self.independent_terminal(o)
1✔
245

246
    # Zero may have free indices
247
    @process.register(Zero)
1✔
248
    def _(self, o: Expr) -> Expr:
1✔
249
        return self.independent_operator(o)
1✔
250

251
    # Rules for form arguments must be specified in specialized rule set
252
    @process.register(FormArgument)
1✔
253
    # Rules for geometric quantities must be specified in specialized rule set
254
    @process.register(GeometricQuantity)
1✔
255
    def _(self, o: Expr) -> Expr:
1✔
256
        return self.override(o)
×
257

258
    # These types are currently assumed independent, but for non-affine domains
259
    # this no longer holds and we want to implement rules for them.
260
    # facet_normal = independent_terminal
261
    # spatial_coordinate = independent_terminal
262
    # cell_coordinate = independent_terminal
263

264
    # Measures of cell entities, assuming independent although
265
    # this will not be true for all of these for non-affine domains
266
    # cell_volume = independent_terminal
267
    # circumradius = independent_terminal
268
    # facet_area = independent_terminal
269
    # cell_surface_area = independent_terminal
270
    # min_cell_edge_length = independent_terminal
271
    # max_cell_edge_length = independent_terminal
272
    # min_facet_edge_length = independent_terminal
273
    # max_facet_edge_length = independent_terminal
274

275
    # Other stuff
276
    # cell_orientation = independent_terminal
277
    # quadrature_weigth = independent_terminal
278

279
    # These types are currently not expected to show up in AD pass.
280
    # To make some of these available to the end-user, they need to be
281
    # implemented here.
282
    # facet_coordinate = unexpected
283
    # cell_origin = unexpected
284
    # facet_origin = unexpected
285
    # cell_facet_origin = unexpected
286
    # jacobian = unexpected
287
    # jacobian_determinant = unexpected
288
    # jacobian_inverse = unexpected
289
    # facet_jacobian = unexpected
290
    # facet_jacobian_determinant = unexpected
291
    # facet_jacobian_inverse = unexpected
292
    # cell_facet_jacobian = unexpected
293
    # cell_facet_jacobian_determinant = unexpected
294
    # cell_facet_jacobian_inverse = unexpected
295
    # cell_vertices = unexpected
296
    # cell_edge_vectors = unexpected
297
    # facet_edge_vectors = unexpected
298
    # reference_cell_edge_vectors = unexpected
299
    # reference_facet_edge_vectors = unexpected
300
    # cell_normal = unexpected # TODO: Expecting rename
301
    # cell_normals = unexpected
302
    # facet_tangents = unexpected
303
    # cell_tangents = unexpected
304
    # cell_midpoint = unexpected
305
    # facet_midpoint = unexpected
306

307
    # --- Default rules for operators
308

309
    @process.register(Variable)
1✔
310
    def _(self, o: Expr) -> Expr:
1✔
311
        """Differentiate a variable."""
312
        op, _ = o.ufl_operands
1✔
313
        return self(op)
1✔
314

315
    # --- Indexing and component handling
316

317
    @process.register(Indexed)
1✔
318
    @DAGTraverser.postorder
1✔
319
    def _(self, o: Indexed, Ap: Expr, ii: MultiIndex) -> Indexed:
1✔
320
        """Differentiate an indexed."""
321
        # Propagate zeros
322
        if isinstance(Ap, Zero):
1✔
323
            return self.independent_operator(o)
1✔
324
        r = len(Ap.ufl_shape) - len(ii)
1✔
325
        if r:
1✔
326
            kk = indices(r)
1✔
327
            op = Indexed(Ap, MultiIndex(ii.indices() + kk))
1✔
328
            op = as_tensor(op, kk)
1✔
329
        else:
330
            op = Indexed(Ap, ii)
1✔
331
        return op
1✔
332

333
    @process.register(ListTensor)
1✔
334
    def _(self, o: Expr) -> Expr:
1✔
335
        """Differentiate a list_tensor."""
336
        return ListTensor(*(self(op) for op in o.ufl_operands))
1✔
337

338
    @process.register(ComponentTensor)
1✔
339
    @DAGTraverser.postorder
1✔
340
    def _(self, o: ComponentTensor, Ap: Expr, ii: MultiIndex) -> Expr:
1✔
341
        """Differentiate a component_tensor."""
342
        if isinstance(Ap, Zero):
1✔
343
            op = self.independent_operator(o)
1✔
344
        else:
345
            Ap, jj = as_scalar(Ap)
1✔
346
            op = as_tensor(Ap, ii.indices() + jj)
1✔
347
        return op
1✔
348

349
    # --- Algebra operators
350

351
    @process.register(IndexSum)
1✔
352
    @DAGTraverser.postorder
1✔
353
    def _(self, o: Expr, Ap: Expr, ii: Expr) -> Expr:
1✔
354
        """Differentiate an index_sum."""
355
        return IndexSum(Ap, ii)
1✔
356

357
    @process.register(Sum)
1✔
358
    @DAGTraverser.postorder
1✔
359
    def _(self, o: Expr, da: Expr, db: Expr) -> Expr:
1✔
360
        """Differentiate a sum."""
361
        return da + db
1✔
362

363
    @process.register(Product)
1✔
364
    @DAGTraverser.postorder
1✔
365
    def _(self, o: Expr, da: Expr, db: Expr) -> Expr:
1✔
366
        """Differentiate a product."""
367
        # Even though arguments to o are scalar, da and db may be
368
        # tensor valued
369
        a, b = o.ufl_operands
1✔
370
        (da, db), ii = as_scalars(da, db)
1✔
371
        pa = Product(da, b)
1✔
372
        pb = Product(a, db)
1✔
373
        s = Sum(pa, pb)
1✔
374
        if ii:
1✔
375
            s = as_tensor(s, ii)
1✔
376
        return s
1✔
377

378
    @process.register(Division)
1✔
379
    @DAGTraverser.postorder
1✔
380
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
381
        """Differentiate a division."""
382
        f, g = o.ufl_operands
1✔
383
        if not is_ufl_scalar(f):
1✔
384
            raise ValueError("Not expecting nonscalar nominator")
×
385
        if not is_true_ufl_scalar(g):
1✔
386
            raise ValueError("Not expecting nonscalar denominator")
×
387
        # do_df = 1/g
388
        # do_dg = -h/g
389
        # op = do_df*fp + do_df*gp
390
        # op = (fp - o*gp) / g
391
        # Get o and gp as scalars, multiply, then wrap as a tensor
392
        # again
393
        so, oi = as_scalar(o)
1✔
394
        sgp, gi = as_scalar(gp)
1✔
395
        o_gp = so * sgp
1✔
396
        if oi or gi:
1✔
397
            o_gp = as_tensor(o_gp, oi + gi)
1✔
398
        op = (fp - o_gp) / g
1✔
399
        return op
1✔
400

401
    @process.register(Power)
1✔
402
    @DAGTraverser.postorder
1✔
403
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
404
        """Differentiate a power."""
405
        f, g = o.ufl_operands
1✔
406
        if not is_true_ufl_scalar(f):
1✔
407
            raise ValueError("Expecting scalar expression f in f**g.")
×
408
        if not is_true_ufl_scalar(g):
1✔
409
            raise ValueError("Expecting scalar expression g in f**g.")
×
410
        # Derivation of the general case: o = f(x)**g(x)
411
        # do/df  = g * f**(g-1) = g / f * o
412
        # do/dg  = ln(f) * f**g = ln(f) * o
413
        # do/df * df + do/dg * dg = o * (g / f * df + ln(f) * dg)
414
        if isinstance(gp, Zero):
1✔
415
            # This probably produces better results for the common
416
            # case of f**constant
417
            op = fp * g * f ** (g - 1)
1✔
418
        else:
419
            # Note: This produces expressions like (1/w)*w**5 instead of w**4
420
            # op = o * (fp * g / f + gp * ln(f)) # This reuses o
421
            op = f ** (g - 1) * (
1✔
422
                g * fp + f * ln(f) * gp
423
            )  # This gives better accuracy in dolfin integration test
424
        # Example: d/dx[x**(x**3)]:
425
        # f = x
426
        # g = x**3
427
        # df = 1
428
        # dg = 3*x**2
429
        # op1 = o * (fp * g / f + gp * ln(f))
430
        #     = x**(x**3)   * (x**3/x + 3*x**2*ln(x))
431
        # op2 = f**(g-1) * (g*fp + f*ln(f)*gp)
432
        #     = x**(x**3-1) * (x**3 + x*3*x**2*ln(x))
433
        return op
1✔
434

435
    @process.register(Abs)
1✔
436
    @DAGTraverser.postorder
1✔
437
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
438
        """Differentiate an abs."""
439
        (f,) = o.ufl_operands
1✔
440
        # return conditional(eq(f, 0), 0, Product(sign(f), df)) abs is
441
        # not complex differentiable, so we workaround the case of a
442
        # real F in complex mode by defensively casting to real inside
443
        # the sign.
444
        return sign(Real(f)) * df
1✔
445

446
    # --- Complex algebra
447

448
    @process.register(Conj)
1✔
449
    @DAGTraverser.postorder
1✔
450
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
451
        """Differentiate a conj."""
452
        return Conj(df)
1✔
453

454
    @process.register(Real)
1✔
455
    @DAGTraverser.postorder
1✔
456
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
457
        """Differentiate a real."""
458
        return Real(df)
×
459

460
    @process.register(Imag)
1✔
461
    @DAGTraverser.postorder
1✔
462
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
463
        """Differentiate a imag."""
464
        return Imag(df)
×
465

466
    # --- Mathfunctions
467

468
    @process.register(MathFunction)
1✔
469
    @DAGTraverser.postorder
1✔
470
    def _(self, o: Expr, df: Expr) -> Expr:
1✔
471
        """Differentiate a math_function."""
472
        # FIXME: Introduce a UserOperator type instead of this hack
473
        # and define user derivative() function properly
474
        if hasattr(o, "derivative"):
×
475
            return df * o.derivative()
×
476
        else:
477
            raise ValueError("Unknown math function.")
×
478

479
    @process.register(Sqrt)
1✔
480
    @DAGTraverser.postorder
1✔
481
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
482
        """Differentiate a sqrt."""
483
        return fp / (2 * o)
1✔
484

485
    @process.register(Exp)
1✔
486
    @DAGTraverser.postorder
1✔
487
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
488
        """Differentiate an exp."""
489
        return fp * o
1✔
490

491
    @process.register(Ln)
1✔
492
    @DAGTraverser.postorder
1✔
493
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
494
        """Differentiate a ln."""
495
        (f,) = o.ufl_operands
1✔
496
        if isinstance(f, Zero):
1✔
497
            raise ZeroDivisionError()
×
498
        return fp / f
1✔
499

500
    @process.register(Cos)
1✔
501
    @DAGTraverser.postorder
1✔
502
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
503
        """Differentiate a cos."""
504
        (f,) = o.ufl_operands
1✔
505
        return fp * -sin(f)
1✔
506

507
    @process.register(Sin)
1✔
508
    @DAGTraverser.postorder
1✔
509
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
510
        """Differentiate a sin."""
511
        (f,) = o.ufl_operands
1✔
512
        return fp * cos(f)
1✔
513

514
    @process.register(Tan)
1✔
515
    @DAGTraverser.postorder
1✔
516
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
517
        """Differentiate a tan."""
518
        (f,) = o.ufl_operands
1✔
519
        return 2.0 * fp / (cos(2.0 * f) + 1.0)
1✔
520

521
    @process.register(Cosh)
1✔
522
    @DAGTraverser.postorder
1✔
523
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
524
        """Differentiate a cosh."""
525
        (f,) = o.ufl_operands
×
526
        return fp * sinh(f)
×
527

528
    @process.register(Sinh)
1✔
529
    @DAGTraverser.postorder
1✔
530
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
531
        """Differentiate a sinh."""
532
        (f,) = o.ufl_operands
×
533
        return fp * cosh(f)
×
534

535
    @process.register(Tanh)
1✔
536
    @DAGTraverser.postorder
1✔
537
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
538
        """Differentiate a tanh."""
539
        (f,) = o.ufl_operands
×
540

541
        def sech(y):
×
542
            return (2.0 * cosh(y)) / (cosh(2.0 * y) + 1.0)
×
543

544
        return fp * sech(f) ** 2
×
545

546
    @process.register(Acos)
1✔
547
    @DAGTraverser.postorder
1✔
548
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
549
        """Differentiate an acos."""
550
        (f,) = o.ufl_operands
1✔
551
        return -fp / sqrt(1.0 - f**2)
1✔
552

553
    @process.register(Asin)
1✔
554
    @DAGTraverser.postorder
1✔
555
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
556
        """Differentiate an asin."""
557
        (f,) = o.ufl_operands
1✔
558
        return fp / sqrt(1.0 - f**2)
1✔
559

560
    @process.register(Atan)
1✔
561
    @DAGTraverser.postorder
1✔
562
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
563
        """Differentiate an atan."""
564
        (f,) = o.ufl_operands
1✔
565
        return fp / (1.0 + f**2)
1✔
566

567
    @process.register(Atan2)
1✔
568
    @DAGTraverser.postorder
1✔
569
    def _(self, o: Expr, fp: Expr, gp: Expr) -> Expr:
1✔
570
        """Differentiate an atan2."""
571
        f, g = o.ufl_operands
×
572
        return (g * fp - f * gp) / (f**2 + g**2)
×
573

574
    @process.register(Erf)
1✔
575
    @DAGTraverser.postorder
1✔
576
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
577
        """Differentiate an erf."""
578
        (f,) = o.ufl_operands
1✔
579
        return fp * (2.0 / sqrt(pi) * exp(-(f**2)))
1✔
580

581
    # --- Bessel functions
582

583
    @process.register(BesselJ)
1✔
584
    @DAGTraverser.postorder
1✔
585
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
586
        """Differentiate a bessel_j."""
587
        nu, f = o.ufl_operands
1✔
588
        if not (nup is None or isinstance(nup, Zero)):
1✔
589
            raise NotImplementedError(
×
590
                "Differentiation of bessel function w.r.t. nu is not supported."
591
            )
592

593
        if isinstance(nu, Zero):
1✔
594
            op = -bessel_J(1, f)
×
595
        else:
596
            op = 0.5 * (bessel_J(nu - 1, f) - bessel_J(nu + 1, f))
1✔
597
        return op * fp
1✔
598

599
    @process.register(BesselY)
1✔
600
    @DAGTraverser.postorder
1✔
601
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
602
        """Differentiate a bessel_y."""
603
        nu, f = o.ufl_operands
1✔
604
        if not (nup is None or isinstance(nup, Zero)):
1✔
605
            raise NotImplementedError(
×
606
                "Differentiation of bessel function w.r.t. nu is not supported."
607
            )
608

609
        if isinstance(nu, Zero):
1✔
610
            op = -bessel_Y(1, f)
×
611
        else:
612
            op = 0.5 * (bessel_Y(nu - 1, f) - bessel_Y(nu + 1, f))
1✔
613
        return op * fp
1✔
614

615
    @process.register(BesselI)
1✔
616
    @DAGTraverser.postorder
1✔
617
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
618
        """Differentiate a bessel_i."""
619
        nu, f = o.ufl_operands
1✔
620
        if not (nup is None or isinstance(nup, Zero)):
1✔
621
            raise NotImplementedError(
×
622
                "Differentiation of bessel function w.r.t. nu is not supported."
623
            )
624

625
        if isinstance(nu, Zero):
1✔
626
            op = bessel_I(1, f)
×
627
        else:
628
            op = 0.5 * (bessel_I(nu - 1, f) + bessel_I(nu + 1, f))
1✔
629
        return op * fp
1✔
630

631
    @process.register(BesselK)
1✔
632
    @DAGTraverser.postorder
1✔
633
    def _(self, o: Expr, nup: Expr, fp: Expr) -> Expr:
1✔
634
        """Differentiate a bessel_k."""
635
        nu, f = o.ufl_operands
1✔
636
        if not (nup is None or isinstance(nup, Zero)):
1✔
637
            raise NotImplementedError(
×
638
                "Differentiation of bessel function w.r.t. nu is not supported."
639
            )
640

641
        if isinstance(nu, Zero):
1✔
642
            op = -bessel_K(1, f)
×
643
        else:
644
            op = -0.5 * (bessel_K(nu - 1, f) + bessel_K(nu + 1, f))
1✔
645
        return op * fp
1✔
646

647
    # --- Restrictions
648

649
    @process.register(Restricted)
1✔
650
    @DAGTraverser.postorder
1✔
651
    def _(self, o: Restricted, fp: Expr) -> Expr:
1✔
652
        """Differentiate a restricted."""
653
        # Restriction and differentiation commutes
654
        if isinstance(fp, ConstantValue):
1✔
655
            return fp  # TODO: Add simplification to Restricted instead?
1✔
656
        else:
657
            return fp(o._side)  # (f+-)' == (f')+-
1✔
658

659
    # --- Conditionals
660

661
    @process.register(BinaryCondition)
1✔
662
    def _(self, o: BinaryCondition) -> Expr:
1✔
663
        """Differentiate a binary_condition."""
664
        raise RuntimeError("Can not differentiate a binary_condition.")
×
665

666
    @process.register(NotCondition)
1✔
667
    def _(self, o: NotCondition) -> Expr:
1✔
668
        """Differentiate a not_condition."""
669
        raise RuntimeError("Can not differentiate a not_condition.")
×
670

671
    @process.register(Conditional)
1✔
672
    @DAGTraverser.postorder_only_children([1, 2])
1✔
673
    def _(self, o: Expr, dt: Expr, df: Expr) -> Expr:
1✔
674
        """Differentiate a conditional."""
675
        if isinstance(dt, Zero) and isinstance(df, Zero):
1✔
676
            # Assuming dt and df have the same indices here, which
677
            # should be the case
678
            return dt
1✔
679
        else:
680
            # Not placing t[1],f[1] outside, allowing arguments inside
681
            # conditionals.  This will make legacy ffc fail, but
682
            # should work with uflacs.
683
            c = o.ufl_operands[0]
1✔
684
            return conditional(c, dt, df)
1✔
685

686
    @process.register(MaxValue)
1✔
687
    @DAGTraverser.postorder
1✔
688
    def _(self, o: Expr, df: Expr, dg: Expr) -> Expr:
1✔
689
        """Differentiate a max_value."""
690
        # d/dx max(f, g) =
691
        # f > g: df/dx
692
        # f < g: dg/dx
693
        # Placing df,dg outside here to avoid getting arguments inside
694
        # conditionals
695
        f, g = o.ufl_operands
×
696
        dc = conditional(f > g, 1, 0)
×
697
        return dc * df + (1.0 - dc) * dg
×
698

699
    @process.register(MinValue)
1✔
700
    @DAGTraverser.postorder
1✔
701
    def _(self, o: Expr, df: Expr, dg: Expr) -> Expr:
1✔
702
        """Differentiate a min_value."""
703
        # d/dx min(f, g) =
704
        #  f < g: df/dx
705
        #  else: dg/dx
706
        #  Placing df,dg outside here to avoid getting arguments
707
        #  inside conditionals
708
        f, g = o.ufl_operands
×
709
        dc = conditional(f < g, 1, 0)
×
710
        return dc * df + (1.0 - dc) * dg
×
711

712

713
class GradRuleset(GenericDerivativeRuleset):
1✔
714
    """Take the grad derivative."""
715

716
    def __init__(
1✔
717
        self,
718
        geometric_dimension: int,
719
        compress: bool | None = True,
720
        visited_cache: dict[tuple, Expr] | None = None,
721
        result_cache: dict[Expr, Expr] | None = None,
722
    ) -> None:
723
        """Initialise."""
724
        super().__init__(
1✔
725
            (geometric_dimension,),
726
            compress=compress,
727
            visited_cache=visited_cache,
728
            result_cache=result_cache,
729
        )
730
        self._Id = Identity(geometric_dimension)
1✔
731

732
    # Work around singledispatchmethod inheritance issue;
733
    # see https://bugs.python.org/issue36457.
734
    @singledispatchmethod
1✔
735
    def process(self, o: Expr) -> Expr:
1✔
736
        """Process ``o``.
737

738
        Args:
739
            o: `Expr` to be processed.
740

741
        Returns:
742
            Processed object.
743

744
        """
745
        return super().process(o)
1✔
746

747
    # --- Specialized rules for geometric quantities
748

749
    @process.register(GeometricQuantity)
1✔
750
    def _(self, o: Expr) -> Expr:
1✔
751
        """Differentiate a geometric_quantity.
752

753
        Default for geometric quantities is do/dx = 0 if piecewise constant,
754
        otherwise transform derivatives to reference derivatives.
755
        Override for specific types if other behaviour is needed.
756
        """
757
        if is_cellwise_constant(o):
1✔
758
            return self.independent_terminal(o)
1✔
759
        else:
760
            domain = extract_unique_domain(o)
1✔
761
            K = JacobianInverse(domain)
1✔
762
            Do = grad_to_reference_grad(o, K)
1✔
763
            return Do
1✔
764

765
    @process.register(JacobianInverse)
1✔
766
    def _(self, o: JacobianInverse) -> Expr:
1✔
767
        """Differentiate a jacobian_inverse."""
768
        # grad(K) == K_ji rgrad(K)_rj
769
        if is_cellwise_constant(o):
1✔
770
            return self.independent_terminal(o)
×
771
        if not o._ufl_is_terminal_:
1✔
772
            raise ValueError("ReferenceValue can only wrap a terminal")
×
773
        Do = grad_to_reference_grad(o, o)
1✔
774
        return Do
1✔
775

776
    # TODO: Add more explicit geometry type handlers here, with
777
    # non-affine domains several should be non-zero.
778

779
    @process.register(SpatialCoordinate)
1✔
780
    def _(self, o: Expr) -> Expr:
1✔
781
        """Differentiate a spatial_coordinate.
782

783
        dx/dx = I.
784
        """
785
        return self._Id
1✔
786

787
    @process.register(CellCoordinate)
1✔
788
    def _(self, o: Expr) -> Expr:
1✔
789
        """Differentiate a cell_coordinate.
790

791
        dX/dx = inv(dx/dX) = inv(J) = K.
792
        """
793
        # FIXME: Is this true for manifolds? What about orientation?
794
        return JacobianInverse(extract_unique_domain(o))
×
795

796
    # --- Specialized rules for form arguments
797

798
    @process.register(BaseFormOperator)
1✔
799
    def _(self, o: Expr) -> Expr:
1✔
800
        """Differentiate a base_form_operator."""
801
        # Push the grad through the operator is not legal in most cases:
802
        #    -> Not enouth regularity for chain rule to hold!
803
        # By the time we evaluate `grad(o)`, the operator `o` will have
804
        # been assembled and substituted by its output.
805
        return Grad(o)
1✔
806

807
    @process.register(Coefficient)
1✔
808
    def _(self, o: Expr) -> Expr:
1✔
809
        """Differentiate a coefficient."""
810
        if is_cellwise_constant(o):
1✔
811
            return self.independent_terminal(o)
1✔
812
        return Grad(o)
1✔
813

814
    @process.register(Argument)
1✔
815
    def _(self, o: Expr) -> Expr:
1✔
816
        """Differentiate an argument."""
817
        # TODO: Enable this after fixing issue#13, unless we move
818
        # simplificat ion to a separate stage?
819
        # if is_cellwise_constant(o):
820
        #     # Collapse gradient of cellwise constant function to zero
821
        #     # TODO: Missing this type
822
        #     return AnnotatedZero(o.ufl_shape + self._var_shape, arguments=(o,))
823
        return Grad(o)
1✔
824

825
    # --- Rules for values or derivatives in reference frame
826

827
    @process.register(ReferenceValue)
1✔
828
    def _(self, o: ReferenceValue) -> Expr:
1✔
829
        """Differentiate a reference_value."""
830
        # grad(o) == grad(rv(f)) -> K_ji*rgrad(rv(f))_rj
831
        f = o.ufl_operands[0]
1✔
832
        if not f._ufl_is_terminal_:
1✔
833
            raise ValueError("ReferenceValue can only wrap a terminal")
×
834
        domain = extract_unique_domain(f, expand_mesh_sequence=False)
1✔
835
        if isinstance(domain, MeshSequence):
1✔
836
            element = f.ufl_function_space().ufl_element()  # type: ignore
1✔
837
            if element.num_sub_elements != len(domain):
1✔
838
                raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
×
839
            # Get monolithic representation of rgrad(o); o might live in a mixed space.
840
            rgrad = ReferenceGrad(o)
1✔
841
            ref_dim = rgrad.ufl_shape[-1]
1✔
842
            # Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
843
            # and put them back together at the end using as_tensor().
844
            components = []
1✔
845
            dofoffset = 0
1✔
846
            for d, e in flatten_domain_element(domain, element):
1✔
847
                esh = e.reference_value_shape
1✔
848
                ndof = int(np.prod(esh))
1✔
849
                assert ndof > 0
1✔
850
                if isinstance(e.pullback, PhysicalPullback):
1✔
851
                    if ref_dim != self._var_shape[0]:
×
852
                        raise NotImplementedError("""
×
853
                            PhysicalPullback not handled for immersed domain :
854
                            reference dim ({ref_dim}) != physical dim (self._var_shape[0])""")
855
                    for idx in range(ndof):
×
856
                        for i in range(ref_dim):
×
857
                            components.append(rgrad[(dofoffset + idx,) + (i,)])
×
858
                else:
859
                    K = JacobianInverse(d)
1✔
860
                    rdim, gdim = K.ufl_shape
1✔
861
                    if rdim != ref_dim:
1✔
862
                        raise RuntimeError(f"{rdim} != {ref_dim}")
×
863
                    if gdim != self._var_shape[0]:
1✔
864
                        raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
×
865
                    # Note that rgrad[dofoffset + [0,ndof), [0,rdim)] are the components
866
                    # corresponding to (d, e).
867
                    # For each row, rgrad[dofoffset + idx, [0,rdim)], we apply
868
                    # K_ji(d)[[0,rdim), [0,gdim)].
869
                    for idx in range(ndof):
1✔
870
                        for i in range(gdim):
1✔
871
                            temp = Zero()
1✔
872
                            for j in range(rdim):
1✔
873
                                temp += rgrad[(dofoffset + idx,) + (j,)] * K[j, i]
1✔
874
                            components.append(temp)
1✔
875
                dofoffset += ndof
1✔
876
            return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
1✔
877
        else:
878
            if isinstance(f.ufl_element().pullback, PhysicalPullback):  # type: ignore
1✔
879
                # TODO: Do we need to be more careful for immersed things?
880
                return ReferenceGrad(o)
×
881
            else:
882
                K = JacobianInverse(domain)
1✔
883
                return grad_to_reference_grad(o, K)
1✔
884

885
    @process.register(ReferenceGrad)
1✔
886
    def _(self, o: Expr) -> Expr:
1✔
887
        """Differentiate a reference_grad."""
888
        if is_cellwise_constant(o):
1✔
889
            return self.independent_terminal(o)
×
890
        # grad(o) == grad(rgrad(rv(f))) -> K_ji*rgrad(rgrad(rv(f)))_rj
891
        f = o.ufl_operands[0]
1✔
892
        valid_operand = f._ufl_is_in_reference_frame_ or isinstance(
1✔
893
            f, JacobianInverse | SpatialCoordinate | Jacobian | JacobianDeterminant | FacetNormal
894
        )
895
        if not valid_operand:
1✔
896
            raise ValueError("ReferenceGrad can only wrap a reference frame type!")
×
897
        domain = extract_unique_domain(f, expand_mesh_sequence=False)
1✔
898
        if isinstance(domain, MeshSequence):
1✔
899
            if not f._ufl_is_in_reference_frame_:
×
900
                raise RuntimeError("Expecting a reference frame type")
×
901
            while not f._ufl_is_terminal_:
×
NEW
902
                (f,) = f.ufl_operands  # type: ignore
×
NEW
903
            element = f.ufl_function_space().ufl_element()  # type: ignore
×
904
            if element.num_sub_elements != len(domain):
×
905
                raise RuntimeError(f"{element.num_sub_elements} != {len(domain)}")
×
906
            # Get monolithic representation of rgrad(o); o might live in a mixed space.
907
            rgrad = ReferenceGrad(o)
×
908
            ref_dim = rgrad.ufl_shape[-1]
×
909
            # Apply K_ji(d) to the corresponding components of rgrad, store them in a list,
910
            # and put them back together at the end using as_tensor().
911
            components = []
×
912
            dofoffset = 0
×
913
            for d, e in flatten_domain_element(domain, element):
×
914
                esh = e.reference_value_shape
×
915
                ndof = int(np.prod(esh))
×
916
                assert ndof > 0
×
917
                K = JacobianInverse(d)
×
918
                rdim, gdim = K.ufl_shape
×
919
                if rdim != ref_dim:
×
920
                    raise RuntimeError(f"{rdim} != {ref_dim}")
×
921
                if gdim != self._var_shape[0]:
×
922
                    raise RuntimeError(f"{gdim} != {self._var_shape[0]}")
×
923
                # Note that rgrad[dofoffset + [0,ndof), [0,rdim), [0,rdim)] are the components
924
                # corresponding to (d, e).
925
                # For each row, rgrad[dofoffset + idx, [0,rdim), [0,rdim)], we apply
926
                # K_ji(d)[[0,rdim), [0,gdim)].
927
                for idx in range(ndof):
×
928
                    for midx in np.ndindex(rgrad.ufl_shape[1:-1]):
×
929
                        for i in range(gdim):
×
930
                            temp = Zero()
×
931
                            for j in range(rdim):
×
932
                                temp += rgrad[(dofoffset + idx,) + midx + (j,)] * K[j, i]
×
933
                            components.append(temp)
×
934
                dofoffset += ndof
×
935
            if rgrad.ufl_shape[0] != dofoffset:
×
936
                raise RuntimeError(f"{rgrad.ufl_shape[0]} != {dofoffset}")
×
937
            return as_tensor(np.asarray(components).reshape(rgrad.ufl_shape[:-1] + self._var_shape))
×
938
        else:
939
            K = JacobianInverse(domain)
1✔
940
            return grad_to_reference_grad(o, K)
1✔
941

942
    # --- Nesting of gradients
943

944
    @process.register(Grad)
1✔
945
    def _(self, o: Expr) -> Expr:
1✔
946
        """Differentiate a grad.
947

948
        Represent grad(grad(f)) as Grad(Grad(f)).
949
        """
950
        # Check that o is a "differential terminal"
951
        if not isinstance(o.ufl_operands[0], Grad | Terminal):
1✔
952
            raise ValueError("Expecting only grads applied to a terminal.")
×
953
        return Grad(o)
1✔
954

955
    def _grad(self, o):
1✔
956
        """Differentiate a _grad."""
957
        pass
×
958
        # TODO: Not sure how to detect that gradient of f is cellwise constant.
959
        #       Can we trust element degrees?
960
        # if is_cellwise_constant(o):
961
        #     return self.terminal(o)
962
        # TODO: Maybe we can ask "f.has_derivatives_of_order(n)" to check
963
        #       if we should make a zero here?
964
        # 1) n = count number of Grads, get f
965
        # 2) if not f.has_derivatives(n): return zero(...)
966

967
    @process.register(CellAvg)
1✔
968
    @process.register(FacetAvg)
1✔
969
    def _(self, o: Expr) -> Expr:
1✔
970
        return self.independent_operator(o)
×
971

972

973
def grad_to_reference_grad(o, K):
1✔
974
    """Relates grad(o) to reference_grad(o) using the Jacobian inverse.
975

976
    Args:
977
        o: Operand
978
        K: Jacobian inverse
979
    Returns:
980
        grad(o) written in terms of reference_grad(o) and K
981
    """
982
    r = indices(len(o.ufl_shape))
1✔
983
    i, j = indices(2)
1✔
984
    # grad(o) == K_ji rgrad(o)_rj
985
    Do = as_tensor(K[j, i] * ReferenceGrad(o)[r + (j,)], r + (i,))
1✔
986
    return Do
1✔
987

988

989
class ReferenceGradRuleset(GenericDerivativeRuleset):
1✔
990
    """Apply the reference grad derivative."""
991

992
    def __init__(
1✔
993
        self,
994
        topological_dimension: int,
995
        compress: bool | None = True,
996
        visited_cache: dict[tuple, Expr] | None = None,
997
        result_cache: dict[Expr, Expr] | None = None,
998
    ) -> None:
999
        """Initialise."""
1000
        super().__init__(
1✔
1001
            (topological_dimension,),
1002
            compress=compress,
1003
            visited_cache=visited_cache,
1004
            result_cache=result_cache,
1005
        )
1006
        self._Id = Identity(topological_dimension)
1✔
1007

1008
    # Work around singledispatchmethod inheritance issue;
1009
    # see https://bugs.python.org/issue36457.
1010
    @singledispatchmethod
1✔
1011
    def process(self, o: Expr) -> Expr:
1✔
1012
        """Process ``o``.
1013

1014
        Args:
1015
            o: `Expr` to be processed.
1016

1017
        Returns:
1018
            Processed object.
1019

1020
        """
1021
        return super().process(o)
1✔
1022

1023
    # --- Specialized rules for geometric quantities
1024

1025
    @process.register(GeometricQuantity)
1✔
1026
    def _(self, o: Expr) -> Expr:
1✔
1027
        """Differentiate a geometric_quantity.
1028

1029
        dg/dX = 0 if piecewise constant, otherwise ReferenceGrad(g).
1030
        """
1031
        if is_cellwise_constant(o):
1✔
1032
            return self.independent_terminal(o)
×
1033
        else:
1034
            # TODO: Which types does this involve? I don't think the
1035
            # form compilers will handle this.
1036
            return ReferenceGrad(o)
1✔
1037

1038
    @process.register(SpatialCoordinate)
1✔
1039
    def _(self, o: Expr) -> Expr:
1✔
1040
        """Differentiate a spatial_coordinate.
1041

1042
        dx/dX = J.
1043
        """
1044
        # Don't convert back to J, otherwise we get in a loop
1045
        return ReferenceGrad(o)
1✔
1046

1047
    @process.register(CellCoordinate)
1✔
1048
    def _(self, o: Expr) -> Expr:
1✔
1049
        """Differentiate a cell_coordinate.
1050

1051
        dX/dX = I.
1052
        """
1053
        return self._Id
×
1054

1055
    # TODO: Add more geometry types here, with non-affine domains
1056
    # several should be non-zero.
1057

1058
    # --- Specialized rules for form arguments
1059

1060
    @process.register(ReferenceValue)
1✔
1061
    def _(self, o: Expr) -> Expr:
1✔
1062
        """Differentiate a reference_value."""
1063
        if not o.ufl_operands[0]._ufl_is_terminal_:
1✔
1064
            raise ValueError("ReferenceValue can only wrap a terminal")
×
1065
        return ReferenceGrad(o)
1✔
1066

1067
    @process.register(Coefficient)
1✔
1068
    def _(self, o: Expr) -> Expr:
1✔
1069
        """Differentiate a coefficient."""
1070
        raise ValueError("Coefficient should be wrapped in ReferenceValue by now")
×
1071

1072
    @process.register(Argument)
1✔
1073
    def _(self, o: Expr) -> Expr:
1✔
1074
        """Differentiate an argument."""
1075
        raise ValueError("Argument should be wrapped in ReferenceValue by now")
×
1076

1077
    # --- Nesting of gradients
1078

1079
    @process.register(Grad)
1✔
1080
    def _(self, o: Expr) -> Expr:
1✔
1081
        """Differentiate a grad."""
1082
        raise ValueError(
×
1083
            f"Grad should have been transformed by this point, but got {type(o).__name__}."
1084
        )
1085

1086
    @process.register(ReferenceGrad)
1✔
1087
    def _(self, o: Expr) -> Expr:
1✔
1088
        """Differentiate a reference_grad.
1089

1090
        Represent ref_grad(ref_grad(f)) as RefGrad(RefGrad(f)).
1091
        """
1092
        # Check that o is a "differential terminal"
1093
        if not isinstance(o.ufl_operands[0], ReferenceGrad | ReferenceValue | Terminal):
1✔
1094
            raise ValueError("Expecting only grads applied to a terminal.")
×
1095
        return ReferenceGrad(o)
1✔
1096

1097
    @process.register(CellAvg)
1✔
1098
    @process.register(FacetAvg)
1✔
1099
    def _(self, o: Expr) -> Expr:
1✔
1100
        return self.independent_operator(o)
×
1101

1102

1103
class VariableRuleset(GenericDerivativeRuleset):
1✔
1104
    """Differentiate with respect to a variable."""
1105

1106
    def __init__(
1✔
1107
        self,
1108
        var: Expr,
1109
        compress: bool | None = True,
1110
        visited_cache: dict[tuple, Expr] | None = None,
1111
        result_cache: dict[Expr, Expr] | None = None,
1112
    ) -> None:
1113
        """Initialise."""
1114
        super().__init__(
1✔
1115
            var.ufl_shape, compress=compress, visited_cache=visited_cache, result_cache=result_cache
1116
        )
1117
        if var.ufl_free_indices:
1✔
1118
            raise ValueError("Differentiation variable cannot have free indices.")
×
1119
        self._variable = var
1✔
1120
        self._Id = self._make_identity(self._var_shape)
1✔
1121

1122
    def _make_identity(self, sh):
1✔
1123
        """Differentiate a _make_identity.
1124

1125
        Creates a higher order identity tensor to represent dv/dv.
1126
        """
1127
        res = None
1✔
1128
        if sh == ():
1✔
1129
            # Scalar dv/dv is scalar
1130
            return FloatValue(1.0)
1✔
1131
        elif len(sh) == 1:
1✔
1132
            # Vector v makes dv/dv the identity matrix
1133
            return Identity(sh[0])
1✔
1134
        else:
1135
            # TODO: Add a type for this higher order identity?
1136
            # II[i0,i1,i2,j0,j1,j2] = 1 if all((i0==j0, i1==j1, i2==j2)) else 0
1137
            # Tensor v makes dv/dv some kind of higher rank identity tensor
1138
            ind1 = ()
1✔
1139
            ind2 = ()
1✔
1140
            for d in sh:
1✔
1141
                i, j = indices(2)
1✔
1142
                dij = Identity(d)[i, j]
1✔
1143
                if res is None:
1✔
1144
                    res = dij
1✔
1145
                else:
1146
                    res *= dij
1✔
1147
                ind1 += (i,)
1✔
1148
                ind2 += (j,)
1✔
1149
            fp = as_tensor(res, ind1 + ind2)
1✔
1150
        return fp
1✔
1151

1152
    # Work around singledispatchmethod inheritance issue;
1153
    # see https://bugs.python.org/issue36457.
1154
    @singledispatchmethod
1✔
1155
    def process(self, o: Expr) -> Expr:
1✔
1156
        """Process ``o``.
1157

1158
        Args:
1159
            o: `Expr` to be processed.
1160

1161
        Returns:
1162
            Processed object.
1163

1164
        """
1165
        return super().process(o)
1✔
1166

1167
    @process.register(GeometricQuantity)
1✔
1168
    def _(self, o: Expr) -> Expr:
1✔
1169
        # Explicitly defining dg/dw == 0
1170
        return self.independent_terminal(o)
1✔
1171

1172
    @process.register(Argument)
1✔
1173
    def _(self, o: Expr) -> Expr:
1✔
1174
        # Explicitly defining da/dw == 0
1175
        return self.independent_terminal(o)
1✔
1176

1177
    @process.register(Coefficient)
1✔
1178
    def _(self, o: Expr) -> Expr:
1✔
1179
        """Differentiate a coefficient.
1180

1181
        df/dv = Id if v is f else 0.
1182

1183
        Note that if v = variable(f), df/dv is still 0,
1184
        but if v == f, i.e. isinstance(v, Coefficient) == True,
1185
        then df/dv == df/df = Id.
1186
        """
1187
        v = self._variable
1✔
1188
        if isinstance(v, Coefficient) and o == v:
1✔
1189
            # dv/dv = identity of rank 2*rank(v)
1190
            return self._Id
1✔
1191
        else:
1192
            # df/v = 0
1193
            return self.independent_terminal(o)
1✔
1194

1195
    @process.register(Variable)
1✔
1196
    @DAGTraverser.postorder
1✔
1197
    def _(self, o: Expr, df: Expr, a: Expr) -> Expr:
1✔
1198
        """Differentiate a variable."""
1199
        v = self._variable
1✔
1200
        if isinstance(v, Variable) and v.label() == a:
1✔
1201
            # dv/dv = identity of rank 2*rank(v)
1202
            return self._Id
1✔
1203
        else:
1204
            # df/v = df
1205
            return df
1✔
1206

1207
    @process.register(Grad)
1✔
1208
    def _(self, o: Expr) -> Expr:
1✔
1209
        """Differentiate a grad.
1210

1211
        Variable derivative of a gradient of a terminal must be 0.
1212
        """
1213
        # Check that o is a "differential terminal"
1214
        if not isinstance(o.ufl_operands[0], Grad | Terminal):
×
1215
            raise ValueError("Expecting only grads applied to a terminal.")
×
1216
        return self.independent_terminal(o)
×
1217

1218
    # --- Rules for values or derivatives in reference frame
1219

1220
    @process.register(ReferenceValue)
1✔
1221
    def _(self, o: Expr) -> Expr:
1✔
1222
        """Differentiate a reference_value."""
1223
        # d/dv(o) == d/dv(rv(f)) = 0 if v is not f, or rv(dv/df)
1224
        v = self._variable
×
1225
        if isinstance(v, Coefficient) and o.ufl_operands[0] == v:
×
1226
            if not v.ufl_element().pullback.is_identity:
×
1227
                # FIXME: This is a bit tricky, instead of Identity it is
1228
                #   actually inverse(transform), or we should rather not
1229
                #   convert to reference frame in the first place
1230
                raise ValueError(
×
1231
                    "Missing implementation: To handle derivatives of rv(f) w.r.t. f for "
1232
                    "mapped elements, rewriting to reference frame should not happen first..."
1233
                )
1234
            # dv/dv = identity of rank 2*rank(v)
1235
            return self._Id
×
1236
        else:
1237
            # df/v = 0
1238
            return self.independent_terminal(o)
×
1239

1240
    @process.register(ReferenceGrad)
1✔
1241
    def _(self, o: Expr) -> Expr:
1✔
1242
        """Differentiate a reference_grad.
1243

1244
        Variable derivative of a gradient of a terminal must be 0.
1245
        """
1246
        return self.independent_terminal(o)
1✔
1247

1248
    @process.register(CellAvg)
1✔
1249
    @process.register(FacetAvg)
1✔
1250
    def _(self, o: Expr) -> Expr:
1✔
1251
        return self.independent_operator(o)
×
1252

1253

1254
class GateauxDerivativeRuleset(GenericDerivativeRuleset):
1✔
1255
    """Apply AFD (Automatic Functional Differentiation) to expression.
1256

1257
    Implements rules for the Gateaux derivative D_w[v](...) defined as
1258
    D_w[v](e) = d/dtau e(w+tau v)|tau=0.
1259
    """
1260

1261
    def __init__(
1✔
1262
        self,
1263
        coefficients: ExprList,
1264
        arguments: ExprList,
1265
        coefficient_derivatives: ExprMapping,
1266
        compress: bool | None = True,
1267
        visited_cache: dict[tuple, Expr] | None = None,
1268
        result_cache: dict[Expr, Expr] | None = None,
1269
    ) -> None:
1270
        """Initialise."""
1271
        super().__init__(
1✔
1272
            (), compress=compress, visited_cache=visited_cache, result_cache=result_cache
1273
        )
1274
        # Type checking
1275
        if not isinstance(coefficients, ExprList):
1✔
1276
            raise ValueError("Expecting a ExprList of coefficients.")
×
1277
        if not isinstance(arguments, ExprList):
1✔
1278
            raise ValueError("Expecting a ExprList of arguments.")
×
1279
        if not isinstance(coefficient_derivatives, ExprMapping):
1✔
1280
            raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
×
1281
        # The coefficient(s) to differentiate w.r.t. and the
1282
        # argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
1283
        self._w = coefficients.ufl_operands
1✔
1284
        self._v = arguments.ufl_operands
1✔
1285
        self._w2v = {w: v for w, v in zip(self._w, self._v)}
1✔
1286
        # Build more convenient dict {f: df/dw} for each coefficient f
1287
        # where df/dw is nonzero
1288
        cd = coefficient_derivatives.ufl_operands
1✔
1289
        self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
1✔
1290
        # Record the operations delayed to the derivative expansion phase:
1291
        # Example: dN(u)/du where `N` is an ExternalOperator and `u` a Coefficient
1292
        self.pending_operations = BaseFormOperatorDerivativeRecorder(
1✔
1293
            coefficients,
1294
            arguments=arguments,
1295
            coefficient_derivatives=coefficient_derivatives,
1296
        )
1297

1298
    # Work around singledispatchmethod inheritance issue;
1299
    # see https://bugs.python.org/issue36457.
1300
    @singledispatchmethod
1✔
1301
    def process(self, o: Expr) -> Expr:
1✔
1302
        """Process ``o``.
1303

1304
        Args:
1305
            o: `Expr` to be processed.
1306

1307
        Returns:
1308
            Processed object.
1309

1310
        """
1311
        return super().process(o)
1✔
1312

1313
    # --- Specialized rules for geometric quantities
1314

1315
    @process.register(GeometricQuantity)
1✔
1316
    def _(self, o: Expr) -> Expr:
1✔
1317
        # Explicitly defining dg/dw == 0
1318
        return self.independent_terminal(o)
1✔
1319

1320
    @process.register(CellAvg)
1✔
1321
    @DAGTraverser.postorder
1✔
1322
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
1323
        """Differentiate a cell_avg."""
1324
        # Cell average of a single function and differentiation
1325
        # commutes, D_f[v](cell_avg(f)) = cell_avg(v)
1326
        return cell_avg(fp)
×
1327

1328
    @process.register(FacetAvg)
1✔
1329
    @DAGTraverser.postorder
1✔
1330
    def _(self, o: Expr, fp: Expr) -> Expr:
1✔
1331
        """Differentiate a facet_avg."""
1332
        # Facet average of a single function and differentiation
1333
        # commutes, D_f[v](facet_avg(f)) = facet_avg(v)
1334
        return facet_avg(fp)
×
1335

1336
    @process.register(Argument)
1✔
1337
    def _(self, o: Argument) -> Expr:
1✔
1338
        # Explicitly defining da/dw == 0
1339
        return self._process_argument(o)
1✔
1340

1341
    def _process_argument(self, o: Argument | Coargument) -> Zero:
1✔
1342
        return self.independent_terminal(o)
1✔
1343

1344
    @process.register(Coefficient)
1✔
1345
    def _(self, o: Coefficient) -> Expr:
1✔
1346
        return self._process_coefficient(o)
1✔
1347

1348
    def _process_coefficient(self, o: Expr) -> Expr:
1✔
1349
        """Differentiate an Expr or a BaseForm."""
1350
        # Define dw/dw := d/ds [w + s v] = v
1351

1352
        # Return corresponding argument if we can find o among w
1353
        do = self._w2v.get(o)  # type: ignore
1✔
1354
        if do is not None:
1✔
1355
            return do
1✔
1356

1357
        # Look for o among coefficient derivatives
1358
        dos = self._cd.get(o)  # type: ignore
1✔
1359
        if dos is None:
1✔
1360
            # If o is not among coefficient derivatives, return
1361
            # do/dw=0
1362
            do = Zero(o.ufl_shape)
1✔
1363
            return do
1✔
1364
        else:
1365
            # Compute do/dw_j = do/dw_h : v.
1366
            # Since we may actually have a tuple of oprimes and vs in a
1367
            # 'mixed' space, sum over them all to get the complete inner
1368
            # product. Using indices to define a non-compound inner product.
1369

1370
            # Example:
1371
            # (f:g) -> (dfdu:v):g + f:(dgdu:v)
1372
            # shape(dfdu) == shape(f) + shape(v)
1373
            # shape(f) == shape(g) == shape(dfdu : v)
1374

1375
            # Make sure we have a tuple to match the self._v tuple
1376
            if not isinstance(dos, tuple):
1✔
1377
                dos = (dos,)
1✔
1378
            if len(dos) != len(self._v):
1✔
1379
                raise ValueError(
×
1380
                    "Got a tuple of arguments, expecting a "
1381
                    "matching tuple of coefficient derivatives."
1382
                )
1383
            dosum = Zero(o.ufl_shape)
1✔
1384
            for do, v in zip(dos, self._v):
1✔
1385
                so, oi = as_scalar(do)
1✔
1386
                rv = len(oi) - len(v.ufl_shape)
1✔
1387
                oi1 = oi[:rv]
1✔
1388
                oi2 = oi[rv:]
1✔
1389
                prod = so * v[oi2]
1✔
1390
                if oi1:
1✔
1391
                    dosum += as_tensor(prod, oi1)
1✔
1392
                else:
1393
                    dosum += prod
1✔
1394
            return dosum
1✔
1395

1396
    @process.register(ReferenceValue)
1✔
1397
    def _(self, o: Expr) -> Expr:
1✔
1398
        """Differentiate a reference_value."""
1399
        raise NotImplementedError(
×
1400
            "Currently no support for ReferenceValue in CoefficientDerivative."
1401
        )
1402
        # TODO: This is implementable for regular derivative(M(f),f,v)
1403
        #       but too messy if customized coefficient derivative
1404
        #       relations are given by the user.  We would only need
1405
        #       this to allow the user to write
1406
        #       derivative(...ReferenceValue...,...).
1407
        # f, = o.ufl_operands
1408
        # if not f._ufl_is_terminal_:
1409
        #     raise ValueError("ReferenceValue can only wrap terminals directly.")
1410
        # FIXME: check all cases like in coefficient
1411
        # if f is w:
1412
        #     # FIXME: requires that v is an Argument with the same element mapping!
1413
        #     return ReferenceValue(v)
1414
        # else:
1415
        #     return self.independent_terminal(o)
1416

1417
    @process.register(ReferenceGrad)
1✔
1418
    def _(self, o: Expr) -> Expr:
1✔
1419
        """Differentiate a reference_grad."""
1420
        if len(extract_coefficients(o)) > 0:
×
1421
            raise NotImplementedError(
×
1422
                "Currently no support for ReferenceGrad in CoefficientDerivative."
1423
            )
1424
        else:
1425
            return Zero(o.ufl_shape)
×
1426
        # TODO: This is implementable for regular derivative(M(f),f,v)
1427
        #       but too messy if customized coefficient derivative
1428
        #       relations are given by the user.  We would only need
1429
        #       this to allow the user to write
1430
        #       derivative(...ReferenceValue...,...).
1431

1432
    @process.register(Grad)
1✔
1433
    def _(self, g: Expr) -> Expr:
1✔
1434
        """Differentiate a grad."""
1435
        # If we hit this type, it has already been propagated to a
1436
        # coefficient (or grad of a coefficient) or a base form operator, # FIXME: Assert
1437
        # this!  so we need to take the gradient of the variation or
1438
        # return zero.  Complications occur when dealing with
1439
        # derivatives w.r.t. single components...
1440

1441
        # Figure out how many gradients are around the inner terminal
1442
        ngrads = 0
1✔
1443
        o = g
1✔
1444
        while isinstance(o, Grad):
1✔
1445
            (o,) = o.ufl_operands
1✔
1446
            ngrads += 1
1✔
1447
        # `grad(N)` where N is a BaseFormOperator is treated as if `N` was a Coefficient.
1448
        if not isinstance(o, FormArgument | BaseFormOperator):
1✔
1449
            raise ValueError(f"Expecting gradient of a FormArgument, not {ufl_err_str(o)}.")
×
1450

1451
        def apply_grads(f):
1✔
1452
            for i in range(ngrads):
1✔
1453
                f = Grad(f)
1✔
1454
            return f
1✔
1455

1456
        # Find o among all w without any indexing, which makes this
1457
        # easy
1458
        for w, v in zip(self._w, self._v):
1✔
1459
            if o == w and isinstance(v, FormArgument):
1✔
1460
                # Case: d/dt [w + t v]
1461
                return apply_grads(v)
1✔
1462

1463
        # If o is not among coefficient derivatives, return do/dw=0
1464
        gprimesum = Zero(g.ufl_shape)
1✔
1465

1466
        def analyse_variation_argument(v):
1✔
1467
            # Analyse variation argument
1468
            if isinstance(v, FormArgument):
1✔
1469
                # Case: d/dt [w[...] + t v]
1470
                vval, vcomp = v, ()
1✔
1471
            elif isinstance(v, Indexed):
×
1472
                # Case: d/dt [w + t v[...]]
1473
                # Case: d/dt [w[...] + t v[...]]
1474
                vval, vcomp = v.ufl_operands
×
1475
                vcomp = tuple(vcomp)
×
1476
            else:
1477
                raise ValueError("Expecting argument or component of argument.")
×
1478
            if not all(isinstance(k, FixedIndex) for k in vcomp):
1✔
1479
                raise ValueError("Expecting only fixed indices in variation.")
×
1480
            return vval, vcomp
1✔
1481

1482
        def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
1✔
1483
            # Apply gradients directly to argument vval, and get the
1484
            # right indexed scalar component(s)
1485
            kk = indices(ngrads)
1✔
1486
            Dvkk = apply_grads(vval)[vcomp + kk]
1✔
1487
            # Place scalar component(s) Dvkk into the right tensor
1488
            # positions
1489
            if wshape:
1✔
1490
                Ejj, jj = unit_indexed_tensor(wshape, wcomp)
1✔
1491
            else:
1492
                Ejj, jj = 1, ()
×
1493
            gprimeterm = as_tensor(Ejj * Dvkk, jj + kk)
1✔
1494
            return gprimeterm
1✔
1495

1496
        # Accumulate contributions from variations in different
1497
        # components
1498
        for w, v in zip(self._w, self._v):
1✔
1499
            # -- Analyse differentiation variable coefficient -- #
1500

1501
            # Can differentiate a Form wrt a BaseFormOperator
1502
            if isinstance(w, FormArgument | BaseFormOperator):
1✔
1503
                if not w == o:
1✔
1504
                    continue
1✔
1505
                wshape = w.ufl_shape
1✔
1506

1507
                if isinstance(v, FormArgument):
1✔
1508
                    # Case: d/dt [w + t v]
1509
                    return apply_grads(v)
×
1510

1511
                elif isinstance(v, ListTensor):
1✔
1512
                    # Case: d/dt [w + t <...,v,...>]
1513
                    for wcomp, vsub in unwrap_list_tensor(v):
1✔
1514
                        if not isinstance(vsub, Zero):
1✔
1515
                            vval, vcomp = analyse_variation_argument(vsub)
1✔
1516
                            gprimesum = gprimesum + compute_gprimeterm(
1✔
1517
                                ngrads, vval, vcomp, wshape, wcomp
1518
                            )
1519
                elif isinstance(v, Zero):
×
1520
                    pass
×
1521

1522
                else:
1523
                    if wshape != ():
×
1524
                        raise ValueError("Expecting scalar coefficient in this branch.")
×
1525
                    # Case: d/dt [w + t v[...]]
1526
                    wval, wcomp = w, ()
×
1527

1528
                    vval, vcomp = analyse_variation_argument(v)
×
1529
                    gprimesum = gprimesum + compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp)
×
1530

1531
            elif isinstance(
×
1532
                w, Indexed
1533
            ):  # This path is tested in unit tests, but not actually used?
1534
                # Case: d/dt [w[...] + t v[...]]
1535
                # Case: d/dt [w[...] + t v]
1536
                wval, wcomp = w.ufl_operands
×
1537
                if not wval == o:
×
1538
                    continue
×
1539
                assert isinstance(wval, FormArgument)
×
1540
                if not all(isinstance(k, FixedIndex) for k in wcomp):
×
1541
                    raise ValueError("Expecting only fixed indices in differentiation variable.")
×
1542
                wshape = wval.ufl_shape
×
1543

1544
                vval, vcomp = analyse_variation_argument(v)
×
1545
                gprimesum = gprimesum + compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp)
×
1546

1547
            else:
1548
                raise ValueError("Expecting coefficient or component of coefficient.")
×
1549

1550
        # FIXME: Handle other coefficient derivatives: oprimes =
1551
        # self._cd.get(o)
1552

1553
        if 0:
1✔
1554
            oprimes = self._cd.get(o)
1555
            if oprimes is None:
1556
                if self._cd:
1557
                    # TODO: Make it possible to silence this message
1558
                    #       in particular?  It may be good to have for
1559
                    #       debugging...
1560
                    warnings.warn(f"Assuming d{{{0}}}/d{{{self._w}}} = 0.")
1561
            else:
1562
                # Make sure we have a tuple to match the self._v tuple
1563
                if not isinstance(oprimes, tuple):
1564
                    oprimes = (oprimes,)
1565
                    if len(oprimes) != len(self._v):
1566
                        raise ValueError(
1567
                            "Got a tuple of arguments, expecting a"
1568
                            " matching tuple of coefficient derivatives."
1569
                        )
1570

1571
                # Compute dg/dw_j = dg/dw_h : v.
1572
                # Since we may actually have a tuple of oprimes and vs
1573
                # in a 'mixed' space, sum over them all to get the
1574
                # complete inner product. Using indices to define a
1575
                # non-compound inner product.
1576
                for oprime, v in zip(oprimes, self._v):
1577
                    raise NotImplementedError("FIXME: Figure out how to do this with ngrads")
1578
                    so, oi = as_scalar(oprime)
1579
                    rv = len(v.ufl_shape)
1580
                    oi1 = oi[:-rv]
1581
                    oi2 = oi[-rv:]
1582
                    prod = so * v[oi2]
1583
                    if oi1:
1584
                        gprimesum += as_tensor(prod, oi1)
1585
                    else:
1586
                        gprimesum += prod
1587

1588
        return gprimesum
1✔
1589

1590
    @process.register(CoordinateDerivative)
1✔
1591
    @DAGTraverser.postorder_only_children([0])
1✔
1592
    def _(self, o: Expr, o0: Expr) -> Expr:
1✔
1593
        """Differentiate a coordinate_derivative."""
1594
        _, o1, o2, o3 = o.ufl_operands
×
1595
        return CoordinateDerivative(o0, o1, o2, o3)
×
1596

1597
    @process.register(BaseFormOperator)
1✔
1598
    @DAGTraverser.postorder
1✔
1599
    def _(self, o: BaseFormOperator, *dfs) -> Expr:
1✔
1600
        """Differentiate a base_form_operator.
1601

1602
        If d_coeff = 0 => BaseFormOperator's derivative is taken wrt a
1603
        variable => we call the appropriate handler. Otherwise =>
1604
        differentiation done wrt the BaseFormOperator (dF/dN[Nhat]) =>
1605
        we treat o as a Coefficient.
1606
        """
1607
        d_coeff = self._process_coefficient(o)
1✔
1608
        # It also handles the non-scalar case
1609
        if d_coeff == 0:
1✔
1610
            self.pending_operations += (o,)
1✔
1611
        return d_coeff
1✔
1612

1613
    # -- Handlers for BaseForm objects -- #
1614

1615
    @process.register(Cofunction)
1✔
1616
    def _(self, o: Cofunction) -> Expr:
1✔
1617
        """Differentiate a cofunction."""
1618
        # Same rule than for Coefficient except that we use a Coargument.
1619
        # The coargument is already attached to the class (self._v)
1620
        # which `self.coefficient` relies on.
1621
        dc = self._process_coefficient(o)  # type: ignore
1✔
1622
        if dc == 0:
1✔
1623
            # Convert ufl.Zero into ZeroBaseForm
1624
            return ZeroBaseForm(o.arguments() + self._v)  # type: ignore
1✔
1625
        return dc
1✔
1626

1627
    @process.register(Coargument)
1✔
1628
    def _(self, o: Coargument) -> Expr:
1✔
1629
        """Differentiate a coargument."""
1630
        # Same rule than for Argument (da/dw == 0).
1631
        dc = self._process_argument(o)
1✔
1632
        if dc == 0:
1✔
1633
            # Convert ufl.Zero into ZeroBaseForm
1634
            return ZeroBaseForm(o.arguments() + self._v)  # type: ignore
1✔
1635
        return dc
×
1636

1637
    @process.register(Matrix)  # type: ignore
1✔
1638
    def _(self, M: Matrix) -> BaseForm:
1✔
1639
        """Differentiate a matrix."""
1640
        # Matrix rule: D_w[v](M) = v if M == w else 0
1641
        # We can't differentiate wrt a matrix so always return zero in
1642
        # the appropriate space
1643
        return ZeroBaseForm(M.arguments() + self._v)
1✔
1644

1645

1646
class BaseFormOperatorDerivativeRuleset(GateauxDerivativeRuleset):
1✔
1647
    """Apply AFD (Automatic Functional Differentiation) to BaseFormOperator.
1648

1649
    Implements rules for the Gateaux derivative D_w[v](...) defined as
1650
    D_w[v](B) = d/dtau B(w+tau v)|tau=0 where B is a ufl.BaseFormOperator.
1651
    """
1652

1653
    @staticmethod
1✔
1654
    def pending_operations_recording(base_form_operator_handler):
1✔
1655
        """Decorate a function to record pending operations."""
1656

1657
        def wrapper(self, base_form_op, *dfs):
1✔
1658
            """Decorate."""
1659
            # Get the outer `BaseFormOperator` expression, i.e. the
1660
            # operator that is being differentiated.
1661
            expression = self.outer_base_form_op
1✔
1662
            # If the base form operator we observe is different from the
1663
            # outer `BaseFormOperator`:
1664
            # -> Record that `BaseFormOperator` so that
1665
            # `d(expression)/d(base_form_op)` can then be computed
1666
            # later.
1667
            # Else:
1668
            # -> Compute the Gateaux derivative of `base_form_ops` by
1669
            # calling the appropriate handler.
1670
            if expression != base_form_op:
1✔
1671
                self.pending_operations += (base_form_op,)
1✔
1672
                return self._process_coefficient(base_form_op)
1✔
1673
            return base_form_operator_handler(self, base_form_op, *dfs)
1✔
1674

1675
        return wrapper
1✔
1676

1677
    def __init__(
1✔
1678
        self,
1679
        coefficients: ExprList,
1680
        arguments: ExprList,
1681
        coefficient_derivatives: ExprMapping,
1682
        outer_base_form_op: Expr,
1683
        compress: bool | None = True,
1684
        visited_cache: dict[tuple, Expr] | None = None,
1685
        result_cache: dict[Expr, Expr] | None = None,
1686
    ) -> None:
1687
        """Initialise."""
1688
        super().__init__(
1✔
1689
            coefficients,
1690
            arguments,
1691
            coefficient_derivatives,
1692
            compress=compress,
1693
            visited_cache=visited_cache,
1694
            result_cache=result_cache,
1695
        )
1696
        self.outer_base_form_op = outer_base_form_op
1✔
1697

1698
    # Work around singledispatchmethod inheritance issue;
1699
    # see https://bugs.python.org/issue36457.
1700
    @singledispatchmethod
1✔
1701
    def process(self, o: Expr) -> Expr:
1✔
1702
        """Process ``o``.
1703

1704
        Args:
1705
            o: `Expr` to be processed.
1706

1707
        Returns:
1708
            Processed object.
1709

1710
        """
1711
        return super().process(o)
1✔
1712

1713
    @process.register(Interpolate)
1✔
1714
    @DAGTraverser.postorder
1✔
1715
    @pending_operations_recording
1✔
1716
    def _(self, i_op: Interpolate, dw: Expr) -> Expr:
1✔
1717
        """Differentiate an interpolate."""
1718
        # Interpolate rule: D_w[v](i_op(w, v*)) = i_op(v, v*), by linearity of Interpolate!
1719
        if not dw:
1✔
1720
            # i_op doesn't depend on w:
1721
            #  -> It also covers the Hessian case since Interpolate is linear,
1722
            #     e.g. D_w[v](D_w[v](i_op(w, v*))) = D_w[v](i_op(v, v*)) = 0 (since w not found).
1723
            return ZeroBaseForm(i_op.arguments() + self._v)  # type: ignore
×
1724
        return i_op._ufl_expr_reconstruct_(expr=dw)
1✔
1725

1726
    @process.register(ExternalOperator)
1✔
1727
    @DAGTraverser.postorder
1✔
1728
    @pending_operations_recording
1✔
1729
    def external_operator(self, N: ExternalOperator, *dfs) -> Expr:
1✔
1730
        """Differentiate an external_operator."""
1731
        result: tuple[Expr, ...] = ()
1✔
1732
        for i, df in enumerate(dfs):
1✔
1733
            derivatives = tuple(dj + int(i == j) for j, dj in enumerate(N.derivatives))
1✔
1734
            if len(extract_arguments(df)) != 0:
1✔
1735
                # Handle the symbolic differentiation of external operators.
1736
                # This bit returns:
1737
                #
1738
                #   `\sum_{i} dNdOi(..., Oi, ...; DOi(u)[v], ..., v*)`
1739
                #
1740
                # where we differentate wrt u, Oi is the i-th operand,
1741
                # N(..., Oi, ...; ..., v*) an ExternalOperator and v the
1742
                # direction (Argument). dNdOi(..., Oi, ...; DOi(u)[v])
1743
                # is an ExternalOperator representing the
1744
                # Gateaux-derivative of N. For example:
1745
                #  -> From N(u) = u**2, we get `dNdu(u; uhat, v*) = 2 * u * uhat`.
1746
                new_args = N.argument_slots() + (df,)
1✔
1747
                extop = N._ufl_expr_reconstruct_(
1✔
1748
                    *N.ufl_operands, derivatives=derivatives, argument_slots=new_args
1749
                )
1750
            elif df == 0:
1✔
1751
                extop = ZeroBaseForm(N.arguments())
1✔
1752
            else:
1753
                raise NotImplementedError(
×
1754
                    "Frechet derivative of external operators need to be provided!"
1755
                )
1756
            result += (extop,)
1✔
1757
        return sum(result)  # type: ignore
1✔
1758

1759

1760
class DerivativeRuleDispatcher(DAGTraverser):
1✔
1761
    """Dispatch a derivative rule."""
1762

1763
    def __init__(
1✔
1764
        self,
1765
        compress: bool | None = True,
1766
        visited_cache: dict[tuple, Expr] | None = None,
1767
        result_cache: dict[Expr, Expr] | None = None,
1768
    ) -> None:
1769
        """Initialise."""
1770
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
1771
        # Record the operations delayed to the derivative expansion phase:
1772
        # Example: dN(u)/du where `N` is a BaseFormOperator and `u` a Coefficient
1773
        self.pending_operations = ()
1✔
1774
        # Create DAGTraverser caches.
1775
        self._dag_traverser_cache: dict[
1✔
1776
            tuple[type, Expr] | tuple[type, Expr, Expr, Expr] | tuple[type, Expr, Expr, Expr, Expr],
1777
            DAGTraverser,
1778
        ] = {}
1779

1780
    @singledispatchmethod
1✔
1781
    def process(self, o: Expr) -> Expr:
1✔
1782
        """Process ``o``.
1783

1784
        Args:
1785
            o: `Expr` to be processed.
1786

1787
        Returns:
1788
            Processed object.
1789

1790
        """
1791
        return super().process(o)
×
1792

1793
    @process.register(Expr)
1✔
1794
    @process.register(BaseForm)  # type: ignore
1✔
1795
    def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
1✔
1796
        """Apply to expr and base form."""
1797
        return self.reuse_if_untouched(o)
1✔
1798

1799
    @process.register(Terminal)
1✔
1800
    def _(self, o: Terminal) -> Terminal:
1✔
1801
        """Apply to a terminal."""
1802
        return o
1✔
1803

1804
    @process.register(Derivative)
1✔
1805
    def _(self, o: Derivative) -> Expr:
1✔
1806
        """Apply to a derivative."""
1807
        raise NotImplementedError(f"Missing derivative handler for {type(o).__name__}.")
×
1808

1809
    @process.register(Grad)
1✔
1810
    @DAGTraverser.postorder
1✔
1811
    def _(self, o: Grad, f: Expr | BaseForm) -> Expr:
1✔
1812
        """Apply to a grad."""
1813
        gdim = o.ufl_shape[-1]
1✔
1814
        key = (GradRuleset, gdim)
1✔
1815
        dag_traverser = self._dag_traverser_cache.setdefault(key, GradRuleset(gdim))
1✔
1816
        return dag_traverser(f)  # type: ignore
1✔
1817

1818
    @process.register(ReferenceGrad)
1✔
1819
    @DAGTraverser.postorder
1✔
1820
    def _(self, o: ReferenceGrad, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1821
        """Apply to a reference_grad."""
1822
        tdim = o.ufl_shape[-1]
1✔
1823
        key = (ReferenceGradRuleset, tdim)
1✔
1824
        dag_traverser = self._dag_traverser_cache.setdefault(key, ReferenceGradRuleset(tdim))
1✔
1825
        return dag_traverser(f)  # type: ignore
1✔
1826

1827
    @process.register(VariableDerivative)
1✔
1828
    @DAGTraverser.postorder_only_children([0])
1✔
1829
    def _(self, o: Expr, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1830
        """Apply to a variable_derivative."""
1831
        _, op = o.ufl_operands
1✔
1832
        key = (VariableRuleset, op)
1✔
1833
        dag_traverser = self._dag_traverser_cache.setdefault(key, VariableRuleset(op))
1✔
1834
        return dag_traverser(f)  # type: ignore
1✔
1835

1836
    @process.register(CoefficientDerivative)
1✔
1837
    @DAGTraverser.postorder_only_children([0])
1✔
1838
    def _(self, o: CoefficientDerivative, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1839
        """Apply to a coefficient_derivative."""
1840
        _, w, v, cd = o.ufl_operands
1✔
1841
        key = (GateauxDerivativeRuleset, w, v, cd)
1✔
1842
        # We need to go through the dag first to record the pending
1843
        # operations
1844
        dag_traverser = self._dag_traverser_cache.setdefault(
1✔
1845
            key,
1846
            GateauxDerivativeRuleset(w, v, cd),  # type: ignore
1847
        )
1848
        # If f has been seen by the traverser, it immediately returns
1849
        # the cached value.
1850
        mapped_expr = dag_traverser(f)  # type: ignore
1✔
1851
        # Need to account for pending operations that have been stored
1852
        # in other integrands
1853
        self.pending_operations += dag_traverser.pending_operations  # type: ignore
1✔
1854
        return mapped_expr
1✔
1855

1856
    @process.register(BaseFormOperatorDerivative)
1✔
1857
    @DAGTraverser.postorder_only_children([0])
1✔
1858
    def _(self, o: BaseFormOperatorDerivative, f: Expr | BaseForm) -> Expr | BaseForm:
1✔
1859
        """Apply to a base_form_operator_derivative."""
1860
        _, w, v, cd = o.ufl_operands
1✔
1861
        if isinstance(f, ZeroBaseForm):
1✔
1862
            (arg,) = v.ufl_operands  # type: ignore
×
1863
            arguments = f.arguments()
×
1864
            # derivative(F, u, du) with `du` a Coefficient
1865
            # is equivalent to taking the action of the derivative.
1866
            # In that case, we don't add arguments to `ZeroBaseForm`.
1867
            if isinstance(arg, BaseArgument):
×
1868
                arguments += (arg,)
×
1869
            return ZeroBaseForm(arguments)
×
1870
        # Need a BaseFormOperatorDerivativeRuleset object
1871
        # for each outer_base_form_op (= f).
1872
        key = (BaseFormOperatorDerivativeRuleset, w, v, cd, f)
1✔
1873
        # We need to go through the dag first to record the pending operations
1874
        dag_traverser = self._dag_traverser_cache.setdefault(
1✔
1875
            key,  # type: ignore
1876
            BaseFormOperatorDerivativeRuleset(w, v, cd, f),  # type: ignore
1877
        )
1878
        # If f has been seen by the traverser, it immediately returns
1879
        # the cached value.
1880
        mapped_expr = dag_traverser(f)  # type: ignore
1✔
1881
        mapped_f = dag_traverser._process_coefficient(f)  # type: ignore
1✔
1882
        if mapped_f != 0:
1✔
1883
            # If dN/dN needs to return an Argument in N space
1884
            # with N a BaseFormOperator.
1885
            return mapped_f
1✔
1886
        # Need to account for pending operations that have been stored in other integrands
1887
        self.pending_operations += dag_traverser.pending_operations  # type: ignore
1✔
1888
        return mapped_expr
1✔
1889

1890
    @process.register(CoordinateDerivative)
1✔
1891
    @DAGTraverser.postorder_only_children([0])
1✔
1892
    def _(self, o: Expr, f: Expr | BaseForm) -> CoordinateDerivative:
1✔
1893
        """Apply to a coordinate_derivative."""
1894
        _, o1, o2, o3 = o.ufl_operands
×
1895
        return CoordinateDerivative(f, o1, o2, o3)
×
1896

1897
    @process.register(BaseFormCoordinateDerivative)
1✔
1898
    @DAGTraverser.postorder_only_children([0])
1✔
1899
    def _(self, o: Expr, f: Expr | BaseForm) -> BaseFormCoordinateDerivative:
1✔
1900
        """Apply to a base_form_coordinate_derivative."""
1901
        _, o1, o2, o3 = o.ufl_operands
×
1902
        return BaseFormCoordinateDerivative(f, o1, o2, o3)
×
1903

1904
    @process.register(Indexed)
1✔
1905
    @DAGTraverser.postorder
1✔
1906
    def _(self, o: Indexed, Ap: Expr, ii: MultiIndex) -> Expr | BaseForm:
1✔
1907
        """Apply to an indexed."""
1908
        # Reuse if untouched
1909
        if Ap is o.ufl_operands[0]:
1✔
1910
            return o
1✔
1911
        r = len(Ap.ufl_shape) - len(ii)
1✔
1912
        if r:
1✔
1913
            kk = indices(r)
×
1914
            op = Indexed(Ap, MultiIndex(ii.indices() + kk))
×
1915
            op = as_tensor(op, kk)
×
1916
        else:
1917
            op = Indexed(Ap, ii)
1✔
1918
        return op
1✔
1919

1920

1921
class BaseFormOperatorDerivativeRecorder:
1✔
1922
    """A derivative recorded for a base form operator."""
1923

1924
    def __init__(self, var, **kwargs):
1✔
1925
        """Initialise."""
1926
        base_form_ops = kwargs.pop("base_form_ops", ())
1✔
1927

1928
        if kwargs.keys() != {"arguments", "coefficient_derivatives"}:
1✔
1929
            raise ValueError(
×
1930
                "Only `arguments` and `coefficient_derivatives` are "
1931
                "allowed as derivative arguments."
1932
            )
1933

1934
        self.var = var
1✔
1935
        self.der_kwargs = kwargs
1✔
1936
        self.base_form_ops = base_form_ops
1✔
1937

1938
    def __len__(self):
1✔
1939
        """Get the length."""
1940
        return len(self.base_form_ops)
×
1941

1942
    def __bool__(self):
1✔
1943
        """Convert to a bool."""
1944
        return bool(self.base_form_ops)
1✔
1945

1946
    def __add__(self, other):
1✔
1947
        """Add."""
1948
        if isinstance(other, list | tuple):
1✔
1949
            base_form_ops = self.base_form_ops + other
1✔
1950
        elif isinstance(other, BaseFormOperatorDerivativeRecorder):
×
1951
            if self.der_kwargs != other.der_kwargs:
×
1952
                raise ValueError(
×
1953
                    f"Derivative arguments must match when summing {type(self).__name__} objects."
1954
                )
1955
            base_form_ops = self.base_form_ops + other.base_form_ops
×
1956
        else:
1957
            raise NotImplementedError(
×
1958
                f"Sum of {type(self)} and {type(other)} objects is not supported."
1959
            )
1960

1961
        return BaseFormOperatorDerivativeRecorder(
1✔
1962
            self.var, base_form_ops=base_form_ops, **self.der_kwargs
1963
        )
1964

1965
    def __radd__(self, other):
1✔
1966
        """Add."""
1967
        # Recording order doesn't matter as collected
1968
        # `BaseFormOperator`s are sorted later on.
1969
        return self.__add__(other)
1✔
1970

1971
    def __iadd__(self, other):
1✔
1972
        """Add."""
1973
        if isinstance(other, list | tuple):
1✔
1974
            self.base_form_ops += other
1✔
1975
        elif isinstance(other, BaseFormOperatorDerivativeRecorder):
1✔
1976
            self.base_form_ops += other.base_form_ops
1✔
1977
        else:
1978
            raise NotImplementedError
×
1979
        return self
1✔
1980

1981

1982
def apply_derivatives(expression):
1✔
1983
    """Apply derivatives to an expression.
1984

1985
    Args:
1986
        expression: A Form, an Expr or a BaseFormOperator to be differentiated
1987

1988
    Returns:
1989
        A differentiated expression
1990
    """
1991
    # Notation: Let `var` be the thing we are differentating with respect to.
1992

1993
    dag_traverser = DerivativeRuleDispatcher()
1✔
1994

1995
    # If we hit a base form operator (bfo), then if `var` is:
1996
    #    - a BaseFormOperator → Return `d(expression)/dw` where `w` is
1997
    #      the coefficient produced by the bfo `var`.
1998
    #    - else → Record the bfo on the DAGTraverser object and returns
1999
    #    - 0.
2000
    # Example:
2001
    #    → If derivative(F(u, N(u); v), u) was taken the following line would compute `∂F/∂u`.
2002
    dexpression_dvar = map_integrands(dag_traverser, expression)
1✔
2003
    if (
1✔
2004
        isinstance(expression, BaseForm)
2005
        and isinstance(dexpression_dvar, int)
2006
        and dexpression_dvar == 0
2007
    ):
2008
        # The arguments got lost, just keep an empty Form
2009
        dexpression_dvar = Form([])
×
2010

2011
    # Get the recorded delayed operations
2012
    pending_operations = dag_traverser.pending_operations
1✔
2013
    if not pending_operations:
1✔
2014
        return dexpression_dvar
1✔
2015

2016
    # Don't take into account empty Forms
2017
    if isinstance(dexpression_dvar, Form) and dexpression_dvar.empty():
1✔
2018
        dexpression_dvar = []
1✔
2019
    else:
2020
        dexpression_dvar = [dexpression_dvar]
1✔
2021

2022
    # Retrieve the base form operators, var, and the argument and
2023
    # coefficient_derivatives for `derivative`
2024
    var = pending_operations.var
1✔
2025
    base_form_ops = pending_operations.base_form_ops
1✔
2026
    der_kwargs = pending_operations.der_kwargs
1✔
2027
    for N in sorted(set(base_form_ops), key=lambda x: x.count()):
1✔
2028
        # -- Replace dexpr/dvar by dexpr/dN -- #
2029
        # We don't use `apply_derivatives` since the differentiation is
2030
        # done via `\partial` and not `d`.
2031
        dexpr_dN = map_integrands(
1✔
2032
            dag_traverser, replace_derivative_nodes(expression, {var.ufl_operands[0]: N})
2033
        )
2034
        # Don't take into account empty Forms
2035
        if isinstance(dexpr_dN, Form) and dexpr_dN.empty():
1✔
2036
            continue
1✔
2037

2038
        # -- Add the BaseFormOperatorDerivative node -- #
2039
        (var_arg,) = der_kwargs["arguments"].ufl_operands
1✔
2040
        cd = der_kwargs["coefficient_derivatives"]
1✔
2041
        # Not always the case since `derivative`'s syntax enables one to
2042
        # use a Coefficient as the Gateaux direction
2043
        if isinstance(var_arg, BaseArgument):
1✔
2044
            # Construct the argument number based on the
2045
            # BaseFormOperator arguments instead of naively using
2046
            # `var_arg`. This is critical when BaseFormOperators are
2047
            # used inside 0-forms.
2048
            #
2049
            # Example: F = 0.5 * u** 2 * dx + 0.5 * N(u; v*)** 2 * dx
2050
            #    -> dFdu[vhat] = <u, vhat> + Action(<N(u; v*), v0>, dNdu(u; v1, v*))
2051
            # with `vhat` a 0-numbered argument, and where `v1` and
2052
            # `vhat` have the same function space but a different
2053
            # number. Here, applying `vhat` (`var_arg`) naively would
2054
            # result in `dNdu(u; vhat, v*)`, i.e. the 2-forms `dNdu`
2055
            # would have two 0-numbered arguments. Instead we increment
2056
            # the argument number of `vhat` to form `v1`.
2057
            var_arg = type(var_arg)(
1✔
2058
                var_arg.ufl_function_space(), number=len(N.arguments()), part=var_arg.part()
2059
            )
2060
        dN_dvar = apply_derivatives(BaseFormOperatorDerivative(N, var, ExprList(var_arg), cd))
1✔
2061
        # -- Sum the Action: dF/du = ∂F/∂u + \sum_{i=1,...} Action(∂F/∂Ni, dNi/du) -- #
2062
        # In this case: Action <=> ufl.action since `dN_var` has 2 arguments.
2063
        # We use Action to handle the trivial case `dN_dvar` = 0.
2064
        dexpression_dvar.append(Action(dexpr_dN, dN_dvar))
1✔
2065
    return sum(dexpression_dvar)
1✔
2066

2067

2068
class CoordinateDerivativeRuleset(GenericDerivativeRuleset):
1✔
2069
    """Apply AFD (Automatic Functional Differentiation) to expression.
2070

2071
    Implements rules for the Gateaux derivative D_w[v](...) defined as
2072
    D_w[v](e) = d/dtau e(w+tau v)|tau=0
2073
    where 'e' is a ufl form after pullback and w is a SpatialCoordinate.
2074
    """
2075

2076
    def __init__(
1✔
2077
        self,
2078
        coefficients: ExprList,
2079
        arguments: ExprList,
2080
        coefficient_derivatives: ExprMapping,
2081
        compress: bool | None = True,
2082
        visited_cache: dict[tuple, Expr] | None = None,
2083
        result_cache: dict[Expr, Expr] | None = None,
2084
    ) -> None:
2085
        """Initialise."""
2086
        super().__init__(
×
2087
            (), compress=compress, visited_cache=visited_cache, result_cache=result_cache
2088
        )
2089
        # Type checking
2090
        if not isinstance(coefficients, ExprList):
×
2091
            raise ValueError("Expecting a ExprList of coefficients.")
×
2092
        if not isinstance(arguments, ExprList):
×
2093
            raise ValueError("Expecting a ExprList of arguments.")
×
2094
        if not isinstance(coefficient_derivatives, ExprMapping):
×
2095
            raise ValueError("Expecting a coefficient-coefficient ExprMapping.")
×
2096
        # The coefficient(s) to differentiate w.r.t. and the
2097
        # argument(s) s.t. D_w[v](e) = d/dtau e(w+tau v)|tau=0
2098
        self._w = coefficients.ufl_operands
×
2099
        self._v = arguments.ufl_operands
×
2100
        self._w2v = {w: v for w, v in zip(self._w, self._v)}
×
2101
        # Build more convenient dict {f: df/dw} for each coefficient f
2102
        # where df/dw is nonzero
2103
        cd = coefficient_derivatives.ufl_operands
×
2104
        self._cd = {cd[2 * i]: cd[2 * i + 1] for i in range(len(cd) // 2)}
×
2105

2106
    # Work around singledispatchmethod inheritance issue;
2107
    # see https://bugs.python.org/issue36457.
2108
    @singledispatchmethod
1✔
2109
    def process(self, o: Expr) -> Expr:
1✔
2110
        """Process ``o``.
2111

2112
        Args:
2113
            o: `Expr` to be processed.
2114

2115
        Returns:
2116
            Processed object.
2117

2118
        """
2119
        return super().process(o)
×
2120

2121
    @process.register(GeometricQuantity)
1✔
2122
    def _(self, o: Expr) -> Expr:
1✔
2123
        # Explicitly defining dg/dw == 0
2124
        return self.independent_terminal(o)
×
2125

2126
    @process.register(Argument)
1✔
2127
    def _(self, o: Expr) -> Expr:
1✔
2128
        # Explicitly defining da/dw == 0
2129
        return self.independent_terminal(o)
×
2130

2131
    @process.register(Coefficient)
1✔
2132
    def _(self, o: Expr) -> Expr:
1✔
2133
        """Differentiate a coefficient."""
2134
        raise NotImplementedError(
×
2135
            "CoordinateDerivative of coefficient in physical space is not implemented."
2136
        )
2137

2138
    @process.register(Grad)
1✔
2139
    def _(self, o: Expr) -> Expr:
1✔
2140
        """Differentiate a grad."""
2141
        raise NotImplementedError("CoordinateDerivative grad in physical space is not implemented.")
×
2142

2143
    @process.register(SpatialCoordinate)
1✔
2144
    def _(self, o: Expr) -> Expr:
1✔
2145
        """Differentiate a spatial_coordinate."""
2146
        do = self._w2v.get(o)  # type: ignore
×
2147
        # d x /d x => Argument(x.function_space())
2148
        if do is not None:
×
2149
            return do
×
2150
        else:
2151
            raise NotImplementedError(
×
2152
                "CoordinateDerivative found a SpatialCoordinate that is different "
2153
                "from the one being differentiated."
2154
            )
2155

2156
    @process.register(ReferenceValue)
1✔
2157
    def _(self, o: Expr) -> Expr:
1✔
2158
        """Differentiate a reference_value."""
2159
        do = self._cd.get(o)  # type: ignore
×
2160
        if do is not None:
×
2161
            return do
×
2162
        else:
2163
            return self.independent_terminal(o)
×
2164

2165
    @process.register(ReferenceGrad)
1✔
2166
    def _(self, g: Expr) -> Expr:
1✔
2167
        """Differentiate a reference_grad."""
2168
        # d (grad_X(...(x)) / dx => grad_X(...(Argument(x.function_space()))
2169
        o = g
×
2170
        ngrads = 0
×
2171
        while isinstance(o, ReferenceGrad):
×
2172
            (o,) = o.ufl_operands
×
2173
            ngrads += 1
×
2174
        if not (isinstance(o, SpatialCoordinate) or isinstance(o.ufl_operands[0], FormArgument)):
×
2175
            raise ValueError(f"Expecting gradient of a FormArgument, not {ufl_err_str(o)}")
×
2176

2177
        def apply_grads(f):
×
2178
            for i in range(ngrads):
×
2179
                f = ReferenceGrad(f)
×
2180
            return f
×
2181

2182
        # Find o among all w without any indexing, which makes this
2183
        # easy
2184
        for w, v in zip(self._w, self._v):
×
2185
            if (
×
2186
                o == w
2187
                and isinstance(v, ReferenceValue)
2188
                and isinstance(v.ufl_operands[0], FormArgument)
2189
            ):
2190
                # Case: d/dt [w + t v]
2191
                return apply_grads(v)
×
2192
        return self.independent_terminal(o)
×
2193

2194
    @process.register(Jacobian)
1✔
2195
    def _(self, o: Expr) -> Expr:
1✔
2196
        """Differentiate a jacobian."""
2197
        # d (grad_X(x))/d x => grad_X(Argument(x.function_space())
2198
        for w, v in zip(self._w, self._v):
×
2199
            if extract_unique_domain(o) == extract_unique_domain(w) and isinstance(
×
2200
                v.ufl_operands[0],  # type: ignore
2201
                FormArgument,
2202
            ):
2203
                return ReferenceGrad(v)
×
2204
        return self.independent_terminal(o)
×
2205

2206

2207
class CoordinateDerivativeRuleDispatcher(DAGTraverser):
1✔
2208
    """Dispatcher."""
2209

2210
    def __init__(
1✔
2211
        self,
2212
        compress: bool | None = True,
2213
        visited_cache: dict[tuple, Expr] | None = None,
2214
        result_cache: dict[Expr, Expr] | None = None,
2215
    ) -> None:
2216
        """Initialise."""
2217
        super().__init__(compress=compress, visited_cache=visited_cache, result_cache=result_cache)
1✔
2218
        self._dag_traverser_cache: dict[tuple[type, Expr, Expr, Expr], DAGTraverser] = {}
1✔
2219

2220
    @singledispatchmethod
1✔
2221
    def process(self, o: Expr) -> Expr:
1✔
2222
        """Process ``o``.
2223

2224
        Args:
2225
            o: `Expr` to be processed.
2226

2227
        Returns:
2228
            Processed object.
2229

2230
        """
2231
        return super().process(o)
×
2232

2233
    @process.register(Expr)
1✔
2234
    @process.register(BaseForm)  # type: ignore
1✔
2235
    def _(self, o: Expr | BaseForm) -> Expr | BaseForm:
1✔
2236
        """Apply to expr and base form."""
2237
        return self.reuse_if_untouched(o)
1✔
2238

2239
    @process.register(Terminal)
1✔
2240
    def _(self, o: Expr) -> Expr:
1✔
2241
        """Apply to a terminal."""
2242
        return o
1✔
2243

2244
    @process.register(Derivative)
1✔
2245
    def _(self, o: Expr) -> Expr:
1✔
2246
        """Apply to a derivative."""
2247
        raise NotImplementedError(f"Missing derivative handler for {type(o).__name__}.")
×
2248

2249
    @process.register(Grad)
1✔
2250
    def _(self, o: Expr) -> Expr:
1✔
2251
        """Apply to a grad."""
2252
        return o
1✔
2253

2254
    @process.register(ReferenceGrad)
1✔
2255
    def _(self, o: Expr) -> Expr:
1✔
2256
        """Apply to a reference_grad."""
2257
        return o
1✔
2258

2259
    @process.register(CoefficientDerivative)
1✔
2260
    def _(self, o: Expr) -> Expr:
1✔
2261
        """Apply to a coefficient_derivative."""
2262
        return o
×
2263

2264
    @process.register(CoordinateDerivative)
1✔
2265
    @DAGTraverser.postorder_only_children([0])
1✔
2266
    def _(self, o: Expr, f: Expr) -> Expr:
1✔
2267
        """Apply to a coordinate_derivative."""
2268
        from ufl.algorithms import extract_unique_elements
×
2269

2270
        for space in extract_unique_elements(o):
×
2271
            if isinstance(space.pullback, CustomPullback):
×
2272
                raise NotImplementedError(
×
2273
                    "CoordinateDerivative is not supported for elements with custom pull back."
2274
                )
2275
        _, w, v, cd = o.ufl_operands
×
2276
        key = (CoordinateDerivativeRuleset, w, v, cd)
×
2277
        dag_traverser = self._dag_traverser_cache.setdefault(
×
2278
            key,
2279
            CoordinateDerivativeRuleset(w, v, cd),  # type: ignore
2280
        )
2281
        return dag_traverser(f)
×
2282

2283

2284
def apply_coordinate_derivatives(expression):
1✔
2285
    """Apply coordinate derivatives to an expression."""
2286
    dag_traverser = CoordinateDerivativeRuleDispatcher()
1✔
2287
    return map_integrands(dag_traverser, expression)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc