• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tBuLi / kingdon / 20096369315

09 Dec 2025 11:24AM UTC coverage: 88.356% (+0.2%) from 88.11%
20096369315

push

github

tBuLi
Added test for grad of weighted geometric product

40 of 45 new or added lines in 3 files covered. (88.89%)

14 existing lines in 2 files now uncovered.

1677 of 1898 relevant lines covered (88.36%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.38
/kingdon/codegen.py
1
from __future__ import annotations
1✔
2

3
import string
1✔
4
from itertools import product, combinations, groupby, chain
1✔
5
from collections import namedtuple, defaultdict
1✔
6
from typing import NamedTuple, Callable, Tuple, Dict
1✔
7
from functools import reduce, cached_property
1✔
8
import linecache
1✔
9
import warnings
1✔
10
import operator
1✔
11
from dataclasses import dataclass
1✔
12
import inspect
1✔
13
import builtins
1✔
14
import keyword
1✔
15

16
from sympy.utilities.iterables import iterable, flatten
1✔
17
from sympy.printing.lambdarepr import LambdaPrinter
1✔
18
from sympy.simplify.cse_main import numbered_symbols
1✔
19
from sympy import Symbol
1✔
20

21

22
@dataclass
1✔
23
class AdditionChains:
1✔
24
    limit: int
1✔
25

26
    @cached_property
1✔
27
    def minimal_chains(self) -> Dict[int, Tuple[int, ...]]:
1✔
28
        chains = {1: (1,)}
1✔
29
        while any(i not in chains for i in range(1, self.limit + 1)):
1✔
30
            for chain in chains.copy().values():
1✔
31
                right_summand = chain[-1]
1✔
32
                for left_summand in chain:
1✔
33
                    value = left_summand + right_summand
1✔
34
                    if value <= self.limit and value not in chains:
1✔
35
                        chains[value] = (*chain, value)
1✔
36
        return chains
1✔
37

38
    def __getitem__(self, n: int) -> Tuple[int, ...]:
1✔
39
        return self.minimal_chains[n]
1✔
40

41
    def __contains__(self, item):
1✔
42
        return self[item]
×
43

44
def power_supply(x: "MultiVector", exponents: Tuple[int, ...], operation: Callable[["MultiVector", "MultiVector"], "MultiVector"] = operator.mul):
1✔
45
    """
46
    Generates powers of a given multivector using the least amount of multiplications.
47
    For example, to raise a multivector :math:`x` to the power :math:`a = 15`, only 5
48
    multiplications are needed since :math:`x^{2} = x * x`, :math:`x^{3} = x * x^2`,
49
    :math:`x^{5} = x^2 * x^3`, :math:`x^{10} = x^5 * x^5`, :math:`x^{15} = x^5 * x^{10}`.
50
    The :class:`power_supply` uses :class:`AdditionChains` to determine these shortest
51
    chains.
52

53
    When called with only a single integer, e.g. :code:`power_supply(x, 15)`, iterating
54
    over it yields the above sequence in order; ending with :math:`x^{15}`.
55

56
    When called with a sequence of integers, the generator instead returns only the requested terms.
57

58

59
    :param x: The MultiVector to be raised to a power.
60
    :param exponents: When an :code:`int`, this generates the shortest possible way to
61
        get to :math:`x^a`, where :math:`x`
62
    """
63
    if isinstance(exponents, int):
1✔
64
        target = exponents
1✔
65
        addition_chains = AdditionChains(target)
1✔
66
        exponents = addition_chains[target]
1✔
67
    else:
68
        addition_chains = AdditionChains(max(exponents))
1✔
69

70
    powers = {1: x}
1✔
71
    for step in exponents:
1✔
72
        if step not in powers:
1✔
73
            chain = addition_chains[step]
1✔
74
            powers[step] = operation(powers[chain[-2]], powers[step - chain[-2]])
1✔
75

76
        yield powers[step]
1✔
77

78

79
class CodegenOutput(NamedTuple):
1✔
80
    """
81
    Output of a codegen function.
82

83
    :param keys_out: tuple with the output blades in binary rep.
84
    :param func: callable that takes (several) sequence(s) of values
85
        returns a tuple of :code:`len(keys_out)`.
86
    """
87
    keys_out: Tuple[int]
1✔
88
    func: Callable
1✔
89

90

91
def codegen_product(x, y, filter_func=None, sign_func=None, keyout_func=operator.xor):
1✔
92
    """
93
    Helper function for the codegen of all product-type functions.
94

95
    :param x: Fully symbolic :class:`~kingdon.multivector.MultiVector`.
96
    :param y: Fully symbolic :class:`~kingdon.multivector.MultiVector`.
97
    :param filter_func: A condition which should be true in the preprocessing of terms.
98
        Input is a TermTuple.
99
    :param sign_func: function to compute sign between terms. E.g. algebra.signs[ei, ej]
100
        for metric dependent products. Input: 2-tuple of blade indices, e.g. (ei, ej).
101
    :param keyout_func:
102
    """
103
    sign_func = sign_func or (lambda pair: x.algebra.signs[pair])
1✔
104

105
    res = {}
1✔
106
    for (kx, vx), (ky, vy) in product(x.items(), y.items()):
1✔
107
        if (sign := sign_func((kx, ky))):
1✔
108
            key_out = keyout_func(kx, ky)
1✔
109
            if filter_func and not filter_func(kx, ky, key_out): continue
1✔
110
            termstr = vx * vy if sign > 0 else (- vx * vy)
1✔
111
            if key_out in res:
1✔
112
                res[key_out] += termstr
1✔
113
            else:
114
                res[key_out] = termstr
1✔
115
    return res
1✔
116

117

118
def codegen_gp(x, y):
1✔
119
    """
120
    Generate the geometric product between :code:`x` and :code:`y`.
121

122
    :param x: Fully symbolic :class:`~kingdon.multivector.MultiVector`.
123
    :param y: Fully symbolic :class:`~kingdon.multivector.MultiVector`.
124
    :return: tuple with integers indicating the basis blades present in the
125
        product in binary convention, and a lambda function that perform the product.
126
    """
127
    return codegen_product(x, y)
1✔
128

129

130
def codegen_sw(x, y):
1✔
131
    r"""
132
    Generate the conjugation of :code:`y` by :code:`x`: :math:`x y \widetilde{x}`.
133

134
    :return: tuple of keys in binary representation and a lambda function.
135
    """
136
    return x * y * ~x
1✔
137

138

139
def codegen_cp(x, y):
1✔
140
    """
141
    Generate the commutator product of :code:`x` and :code:`y`: :code:`x.cp(y) = 0.5*(x*y-y*x)`.
142

143
    :return: tuple of keys in binary representation and a lambda function.
144
    """
145
    algebra = x.algebra
1✔
146
    filter_func = lambda kx, ky, k_out: (algebra.signs[kx, ky] - algebra.signs[ky, kx])
1✔
147
    return codegen_product(x, y, filter_func=filter_func)
1✔
148

149

150
def codegen_acp(x, y):
1✔
151
    """
152
    Generate the anti-commutator product of :code:`x` and :code:`y`: :code:`x.acp(y) = 0.5*(x*y+y*x)`.
153

154
    :return: tuple of keys in binary representation and a lambda function.
155
    """
156
    algebra = x.algebra
1✔
157
    filter_func = lambda kx, ky, k_out: (algebra.signs[kx, ky] + algebra.signs[ky, kx])
1✔
158
    return codegen_product(x, y, filter_func=filter_func)
1✔
159

160

161
def codegen_ip(x, y, diff_func=abs):
1✔
162
    """
163
    Generate the inner product of :code:`x` and :code:`y`.
164

165
    :param diff_func: How to treat the difference between the binary reps of the basis blades.
166
        if :code:`abs`, compute the symmetric inner product. When :code:`lambda x: -x` this
167
        function generates left-contraction, and when :code:`lambda x: x`, right-contraction.
168
    :return: tuple of keys in binary representation and a lambda function.
169
    """
170
    filter_func = lambda kx, ky, k_out: k_out == diff_func(kx - ky)
1✔
171
    return codegen_product(x, y, filter_func=filter_func)
1✔
172

173

174
def codegen_lc(x, y):
1✔
175
    """
176
    Generate the left-contraction of :code:`x` and :code:`y`.
177

178
    :return: tuple of keys in binary representation and a lambda function.
179
    """
180
    return codegen_ip(x, y, diff_func=lambda x: -x)
1✔
181

182

183
def codegen_rc(x, y):
1✔
184
    """
185
    Generate the right-contraction of :code:`x` and :code:`y`.
186

187
    :return: tuple of keys in binary representation and a lambda function.
188
    """
189
    return codegen_ip(x, y, diff_func=lambda x: x)
1✔
190

191

192
def codegen_sp(x, y):
1✔
193
    """
194
    Generate the scalar product of :code:`x` and :code:`y`.
195

196
    :return: tuple of keys in binary representation and a lambda function.
197
    """
198
    return codegen_ip(x, y, diff_func=lambda x: 0)
1✔
199

200

201
def codegen_proj(x, y):
1✔
202
    r"""
203
    Generate the projection of :code:`x` onto :code:`y`: :math:`(x \cdot y) \widetilde{y}`.
204

205
    :return: tuple of keys in binary representation and a lambda function.
206
    """
207
    return (x | y) * ~y
1✔
208

209

210
def codegen_op(x, y):
1✔
211
    """
212
    Generate the outer product of :code:`x` and :code:`y`: :code:`x.op(y) = x ^ y`.
213

214
    :x: MultiVector
215
    :y: MultiVector
216
    :return: dictionary with integer keys indicating the corresponding basis blade in binary convention,
217
        and values which are a 3-tuple of indices in `x`, indices in `y`, and a lambda function.
218
    """
219
    filter_func = lambda kx, ky, k_out: k_out == kx + ky
1✔
220
    return codegen_product(x, y, filter_func=filter_func)
1✔
221

222

223
def codegen_rp(x, y):
1✔
224
    """
225
    Generate the regressive product of :code:`x` and :code:`y`:,
226
    :math:`x \\vee y`.
227

228
    :param x:
229
    :param y:
230
    :return: tuple of keys in binary representation and a lambda function.
231
    """
232
    algebra = x.algebra
1✔
233
    key_pss = len(algebra) - 1
1✔
234
    keyout_func = lambda kx, ky: key_pss - (kx ^ ky)
1✔
235
    filter_func = lambda kx, ky, k_out: key_pss == kx + ky - k_out
1✔
236
    # Sign is composed of dualization of each blade, exterior product, and undual.
237
    sign_func = lambda pair: (
1✔
238
        algebra.signs[pair[0], key_pss - pair[0]] *
239
        algebra.signs[pair[1], key_pss - pair[1]] *
240
        algebra.signs[key_pss - pair[0], key_pss - pair[1]] *
241
        algebra.signs[key_pss - (pair[0] ^ pair[1]), pair[0] ^ pair[1]]
242
    )
243

244
    return codegen_product(
1✔
245
        x, y,
246
        filter_func=filter_func,
247
        keyout_func=keyout_func,
248
        sign_func=sign_func,
249
    )
250

251

252
Fraction = namedtuple('Fraction', ['numer', 'denom'])
1✔
253
Fraction.__doc__ = """
1✔
254
Tuple representing a fraction.
255
"""
256

257

258
def codegen_inv(y, symbolic=False):
1✔
259
    alg = y.algebra
1✔
260
    if alg.d < 6:
1✔
261
        num, denom = codegen_hitzer_inv(y, symbolic=True)
1✔
262
    else:
263
        num, denom = codegen_shirokov_inv(y, symbolic=True)
1✔
264

265
    if symbolic:
1✔
266
        return Fraction(num, denom)
1✔
267

268
    d = denom.e
1✔
269
    return num.map(lambda v: v / d)
1✔
270

271

272
def codegen_hitzer_inv(x, symbolic=False):
1✔
273
    """
274
    Generate code for the inverse of :code:`x` using the Hitzer inverse,
275
    which works up to 5D algebras.
276
    """
277
    alg = x.algebra
1✔
278
    d = alg.d
1✔
279
    if d == 0:
1✔
280
        num = alg.blades.e
1✔
281
    elif d == 1:
1✔
282
        num = x.involute()
1✔
283
    elif d == 2:
1✔
284
        num = x.conjugate()
1✔
285
    elif d == 3:
1✔
286
        xconj = x.conjugate()
1✔
287
        num = xconj * ~(x * xconj)
1✔
288
    elif d == 4:
1✔
289
        xconj = x.conjugate()
1✔
290
        x_xconj = x * xconj
1✔
291
        num = xconj * (x_xconj - 2 * x_xconj.grade(3, 4))
1✔
292
    elif d == 5:
1✔
293
        xconj = x.conjugate()
1✔
294
        x_xconj = x * xconj
1✔
295
        combo = xconj * ~x_xconj
1✔
296
        x_combo = x * combo
1✔
297
        num = combo * (x_combo - 2 * x_combo.grade(1, 4))
1✔
298
    else:
299
        raise NotImplementedError(f"Closed form inverses are not known in {d=} dimensions.")
×
300
    denom = x.sp(num)
1✔
301

302
    if symbolic:
1✔
303
        return Fraction(num, denom)
1✔
304
    denom = denom.e
×
305
    return num.map(lambda v: v / denom)
×
306

307

308
def codegen_shirokov_inv(x, symbolic=False):
1✔
309
    """
310
    Generate code for the inverse of :code:`x` using the Shirokov inverse,
311
    which is works in any algebra, but it can be expensive to compute.
312
    """
313
    alg = x.algebra
1✔
314
    n = 2 ** ((alg.d + 1) // 2)
1✔
315
    supply = power_supply(x, tuple(range(1, n + 1)))  # Generate powers of x efficiently.
1✔
316
    powers = []
1✔
317
    cs = []
1✔
318
    xs = []
1✔
319
    for i in range(1, n + 1):
1✔
320
        powers.append(next(supply))
1✔
321
        xi = powers[i - 1]
1✔
322
        for j in range(i - 1):
1✔
323
            power_idx = i - j - 2
×
324
            xi_diff = powers[power_idx] * cs[j]
×
325
            xi = xi - xi_diff
×
326
        if xi.grades == (0,):
1✔
327
            break
1✔
328
        xs.append(xi)
×
329
        cs.append(s if (s := xi.e) == 0 else n * s / i)
×
330

331
    if i == 1:
1✔
332
        adj = alg.blades.e
1✔
333
    else:
334
        adj = xs[-1] - cs[-1]
×
335

336
    if symbolic:
1✔
337
        return Fraction(adj, xi)
1✔
338
    xi = xi.e
×
339
    return adj.map(lambda v: v / xi)
×
340

341

342
def codegen_div(x, y):
1✔
343
    """
344
    Generate code for :math:`x y^{-1}`.
345
    """
346
    num, denom = codegen_inv(y, symbolic=True)
1✔
347
    if not denom:
1✔
348
        raise ZeroDivisionError
×
349
    d = denom.e
1✔
350
    return (x * num).map(lambda v: v / d)
1✔
351

352

353
def codegen_normsq(x):
1✔
354
    return x * ~x
1✔
355

356

357
def codegen_outerexp(x, asterms=False):
1✔
358
    alg = x.algebra
1✔
359
    if len(x.grades) != 1:
1✔
360
        warnings.warn('Outer exponential might not converge for mixed-grade multivectors.', RuntimeWarning)
1✔
361
    k = alg.d
1✔
362

363
    Ws = [alg.scalar([1]), x]
1✔
364
    j = 2
1✔
365
    while j <= k:
1✔
366
        Wj = Ws[-1] ^ x
1✔
367
        # Dividing like this avoids floating point numbers, which is excellent.
368
        Wj._values = tuple(v / j for v in Wj._values)
1✔
369
        if Wj:
1✔
370
            Ws.append(Wj)
1✔
371
            j += 1
1✔
372
        else:
373
            break
1✔
374

375
    if asterms:
1✔
376
        return Ws
1✔
377
    return reduce(operator.add, Ws)
1✔
378

379
def codegen_outersin(x):
1✔
380
    odd_Ws = codegen_outerexp(x, asterms=True)[1::2]
1✔
381
    outersin = reduce(operator.add, odd_Ws)
1✔
382
    return outersin
1✔
383

384

385
def codegen_outercos(x):
1✔
386
    even_Ws = codegen_outerexp(x, asterms=True)[0::2]
1✔
387
    outercos = reduce(operator.add, even_Ws)
1✔
388
    return outercos
1✔
389

390

391
def codegen_outertan(x):
1✔
392
    Ws = codegen_outerexp(x, asterms=True)
1✔
393
    even_Ws, odd_Ws = Ws[0::2], Ws[1::2]
1✔
394
    outercos = reduce(operator.add, even_Ws)
1✔
395
    outersin = reduce(operator.add, odd_Ws)
1✔
396
    outertan = outersin / outercos
1✔
397
    return outertan
1✔
398

399

400
def codegen_add(x, y):
1✔
401
    vals = dict(x.items())
1✔
402
    for k, v in y.items():
1✔
403
        if k in vals:
1✔
404
            vals[k] = vals[k] + v
1✔
405
        else:
406
            vals[k] = v
1✔
407
    return vals
1✔
408

409

410
def codegen_sub(x, y):
1✔
411
    vals = dict(x.items())
1✔
412
    for k, v in y.items():
1✔
413
        if k in vals:
1✔
414
            vals[k] = vals[k] - v
1✔
415
        else:
416
            vals[k] = -v
1✔
417
    return vals
1✔
418

419
def codegen_neg(x):
1✔
420
    return {k: -v for k, v in x.items()}
1✔
421

422

423
def codegen_involutions(x, invert_grades=(2, 3)):
1✔
424
    """
425
    Codegen for the involutions of Clifford algebras:
426
    reverse, grade involute, and Clifford involution.
427

428
    :param invert_grades: The grades that flip sign under this involution mod 4, e.g. (2, 3) for reversion.
429
    """
430
    return {k: -v if bin(k).count('1') % 4 in invert_grades else v
1✔
431
            for k, v in x.items()}
432

433

434
def codegen_reverse(x):
1✔
435
    return codegen_involutions(x, invert_grades=(2, 3))
1✔
436

437

438
def codegen_involute(x):
1✔
439
    return codegen_involutions(x, invert_grades=(1, 3))
1✔
440

441

442
def codegen_conjugate(x):
1✔
443
    return codegen_involutions(x, invert_grades=(1, 2))
1✔
444

445

446
def codegen_sqrt(x):
1✔
447
    """
448
    Take the square root using the study number approach as described in
449
    https://doi.org/10.1002/mma.8639
450
    """
451
    alg = x.algebra
1✔
452
    if x.grades == (0,):
1✔
453
        return x.map(lambda v: v**0.5)
1✔
454
    a, bI = x.grade(0), x - x.grade(0)
1✔
455
    has_solution = len(x.grades) <= 2 and 0 in x.grades
1✔
456
    if not has_solution:
1✔
457
        warnings.warn("Cannot verify that we really are taking the sqrt of a Study number.", RuntimeWarning)
1✔
458

459
    bI_sq = bI * bI
1✔
460
    if not bI_sq:
1✔
461
        cp = a.e**0.5
1✔
462
    else:
463
        normS = (a * a - bI_sq).e
1✔
464
        cp = (0.5 * (a.e + normS**0.5))**0.5
1✔
465
    return (0.5 * bI / cp) + cp
1✔
466

467

468
def codegen_polarity(x, undual=False):
1✔
469
    if undual:
1✔
470
        return x * x.algebra.pss
1✔
471
    key_pss = len(x.algebra) - 1
1✔
472
    sign = x.algebra.signs[key_pss, key_pss]
1✔
473
    if sign == -1:
1✔
474
        return - x * x.algebra.pss
1✔
475
    if sign == 1:
1✔
476
        return x * x.algebra.pss
×
477
    if sign == 0:
1✔
478
        raise ZeroDivisionError
1✔
479

480

481
def codegen_unpolarity(x):
1✔
482
    return codegen_polarity(x, undual=True)
1✔
483

484

485
def codegen_hodge(x, undual=False):
1✔
486
    if undual:
1✔
487
        return {(key_dual := len(x.algebra) - 1 - eI): -v if x.algebra.signs[key_dual, eI] < 0 else v
1✔
488
                for eI, v in x.items()}
489
    return {(key_dual := len(x.algebra) - 1 - eI): -v if x.algebra.signs[eI, key_dual] < 0 else v
1✔
490
            for eI, v in x.items()}
491

492

493
def codegen_unhodge(x):
1✔
494
    return codegen_hodge(x, undual=True)
1✔
495

496

497
def _lambdify_mv(mv):
1✔
498
    func = lambdify(
1✔
499
        args={'x': sorted(mv.free_symbols, key=lambda x: x.name)},
500
        exprs=list(mv.values()),
501
        funcname=f'custom_{mv.type_number}',
502
        cse=mv.algebra.cse
503
    )
504
    return CodegenOutput(tuple(mv.keys()), func)
1✔
505

506

507
def do_codegen(codegen, *mvs) -> CodegenOutput:
1✔
508
    """
509
    :param codegen: callable that performs codegen for the given :code:`mvs`. This can be any callable
510
        that returns either a :class:`~kingdon.multivector.MultiVector`, a dictionary, or an instance of :class:`CodegenOutput`.
511
    :param mvs: Any remaining positional arguments are taken to be symbolic :class:`~kingdon.multivector.MultiVector`'s.
512
    :return: Instance of :class:`CodegenOutput`.
513
    """
514
    algebra = mvs[0].algebra
1✔
515

516
    res = codegen(*mvs)
1✔
517

518
    # Turn a list of Multivectors into a single Multivector of lists.
519
    if isinstance(res, (list, tuple)):
1✔
520
        reshaped_res = defaultdict(list)
1✔
521
        for mv in res:
1✔
522
            for k, v in mv.items():
1✔
523
                reshaped_res[k].append(v)
1✔
524
        res = reshaped_res
1✔
525

526
    funcname = f'{codegen.__name__}_' + '_x_'.join(f"{format(mv[0].type_number if isinstance(mv, list) else mv.type_number, 'X')}" for mv in mvs)
1✔
527
    args = {arg_name: [tuple(chain(*(x.values() for x in arg)))] if isinstance(arg, list) else arg.values()
1✔
528
            for arg_name, arg in zip(string.ascii_uppercase, mvs)}
529
    dependencies = None
1✔
530

531
    # Sort the keys in canonical order
532
    res = {bin: res[bin] if isinstance(res, dict) else getattr(res, canon)
1✔
533
           for canon, bin in algebra.canon2bin.items() if bin in res.keys()}
534

535
    if not algebra.cse and any(isinstance(v, str) for v in res.values()):
1✔
536
        return func_builder(res, *mvs, funcname=funcname)
×
537

538

539
    keys, exprs = tuple(res.keys()), list(res.values())
1✔
540
    func = lambdify(args, exprs, funcname=funcname, cse=algebra.cse, dependencies=dependencies)
1✔
541
    return CodegenOutput(
1✔
542
        keys, func
543
    )
544

545
def do_compile(codegen, *tapes):
1✔
546
    algebra = tapes[0].algebra
1✔
547
    namespace = algebra.numspace
1✔
548

549
    res = codegen(*tapes)
1✔
550
    funcname = f'{codegen.__name__}_' + '_x_'.join(f"{tape.type_number}" for tape in tapes)
1✔
551
    funcstr = f"def {funcname}({', '.join(t.expr for t in tapes)}):"
1✔
552
    if not isinstance(res, str):
1✔
553
        funcstr += f"    return {res.expr}"
1✔
554
    else:
555
        funcstr += f"    return ({res},)"
×
556

557
    funclocals = {}
1✔
558
    filename = f'<{funcname}>'
1✔
559
    c = compile(funcstr, filename, 'exec')
1✔
560
    exec(c, namespace, funclocals)
1✔
561
    # mtime has to be None or else linecache.checkcache will remove it
562
    linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore
1✔
563

564
    func = funclocals[funcname]
1✔
565
    return CodegenOutput(
1✔
566
        res.keys() if not isinstance(res, str) else (0,), func
567
    )
568

569

570
def func_builder(res_vals: defaultdict, *mvs, funcname: str) -> CodegenOutput:
1✔
571
    """
572
    Build a Python function for the product between given multivectors.
573

574
    :param res_vals: Dict to be converted into a function. The keys correspond to the basis blades in binary,
575
        while the values are strings to be converted into source code.
576
    :param mvs: all the multivectors that the resulting function is a product of.
577
    :param funcname: Name of the function. Be aware: if a function by that name already existed, it will be overwritten.
578
    :return: tuple of output keys of the callable, and the callable.
579
    """
580
    args = string.ascii_uppercase[:len(mvs)]
×
581
    header = f'def {funcname}({", ".join(args)}):'
×
582
    if res_vals:
×
583
        body = ''
×
584
        for mv, arg in zip(mvs, args):
×
585
            body += f'    {", ".join(str(v) for v in mv.values())}, = {arg}\n'
×
586
        return_val = f'    return [{", ".join(res_vals.values())},]'
×
587
    else:
588
        body = ''
×
589
        return_val = f'    return list()'
×
590
    func_source = f'{header}\n{body}\n{return_val}'
×
591

592
    # Dynamically build a function
593
    func_locals = {}
×
594
    c = compile(func_source, funcname, 'exec')
×
595
    exec(c, {}, func_locals)
×
596

597
    # Add the generated code to linecache such that it is inspect-safe.
598
    linecache.cache[funcname] = (len(func_source), None, func_source.splitlines(True), funcname)
×
599
    func = func_locals[funcname]
×
600
    func.__module__ = __name__
×
601
    return CodegenOutput(tuple(res_vals.keys()), func)
×
602

603

604
def lambdify(args: dict, exprs: list, funcname: str, dependencies: tuple = None, printer=LambdaPrinter, dummify=False, cse=False):
1✔
605
    """
606
    Function that turns symbolic expressions into Python functions. Heavily inspired by
607
    :mod:`sympy`'s function by the same name, but adapted for the needs of :code:`kingdon`.
608

609
    Particularly, this version gives us more control over the names of the function and its
610
    arguments, and is more performant, particularly when the given expressions are strings.
611

612
    Example usage:
613

614
    .. code-block ::
615

616
        alg = Algebra(2)
617
        a = alg.multivector(name='a')
618
        b = alg.multivector(name='b')
619
        args = {'A': a.values(), 'B': b.values()}
620
        exprs = tuple(codegen_cp(a, b).values())
621
        func = lambdify(args, exprs, funcname='cp', cse=False)
622

623
    This will produce the following code:
624

625
    .. code-block ::
626

627
        def cp(A, B):
628
            [a, a1, a2, a12] = A
629
            [b, b1, b2, b12] = B
630
            return (+a1*b2-a2*b1,)
631

632
    It is recommended not to call this function directly, but rather to use
633
    :func:`do_codegen` which provides a clean API around this function.
634

635
    :param args: dictionary of type dict[str | Symbol, tuple[Symbol]].
636
    :param exprs: tuple[Expr]
637
    :param funcname: string to be used as the bases for the name of the function.
638
    :param dependencies: These are extra expressions that can be provided such that quantities can be precomputed.
639
        For example, in the inverse of a multivector, this is used to compute the scalar denominator only once,
640
        after which all values in expr are multiplied by it. When :code:`cse = True`, these dependencies are also
641
        included in the CSE process.
642
    :param cse: If :code:`True` (default), CSE is applied to the expressions and dependencies.
643
        This typically greatly improves performance and reduces numba's initialization time.
644
    :return: Function that represents that can be used to calculate the values of exprs.
645
    """
646
    if printer is LambdaPrinter:
1✔
647
        printer = LambdaPrinter(
1✔
648
            {'fully_qualified_modules': False, 'inline': True,
649
             'allow_unknown_functions': True,
650
             'user_functions': {}}
651
        )
652

653
    tosympy = lambda x: x.tosympy() if hasattr(x, 'tosympy') else x
1✔
654
    args = {name: [tosympy(v) for v in values]
1✔
655
            for name, values in args.items()}
656
    exprs = [tosympy(expr) for expr in exprs]
1✔
657
    if dependencies:
1✔
658
        dependencies = [(tosympy(y), tosympy(x)) for y, x in dependencies]
×
659
    names = tuple(arg if isinstance(arg, str) else arg.name for arg in args.keys())
1✔
660
    iterable_args = tuple(args.values())
1✔
661

662
    funcprinter = KingdonPrinter(printer, dummify)
1✔
663

664
    def unflatten(template, flat):
1✔
665
        it = iter(flat)
1✔
666
        def walk(t):
1✔
667
            return type(t)(walk(x) for x in t) if isinstance(t, (list, tuple)) else next(it)
1✔
668
        return walk(template)
1✔
669

670
    # TODO: Extend CSE to include the dependencies.
671
    lhsides, rhsides = zip(*dependencies) if dependencies else ([], [])
1✔
672
    flat_exprs = flatten(exprs)
1✔
673
    symbols = numbered_symbols(cls=Symbol, prefix='_x')
1✔
674
    if cse and not any(isinstance(expr, str) for expr in flat_exprs):
1✔
675
        if not callable(cse):
1✔
676
            from sympy.simplify.cse_main import cse
1✔
677
        if dependencies:
1✔
NEW
678
            flat_rhsides = flatten(rhsides)
×
NEW
679
            cses, _all_exprs = cse([*flat_exprs, *flat_rhsides], list=False, order='none', ignore=lhsides, symbols=symbols)
×
NEW
680
            _flat_exprs = _all_exprs[:len(flat_exprs)]
×
NEW
681
            _rhsides = unflatten(rhsides, _all_exprs[len(flat_exprs):])
×
UNCOV
682
            cses.extend(list(zip(flatten(lhsides), flatten(_rhsides))))
×
683
        else:
684
            cses, _flat_exprs = cse(flat_exprs, list=False, symbols=symbols)
1✔
685
    else:
686
        cses, _flat_exprs = list(zip(flatten(lhsides), flatten(rhsides))), flat_exprs
1✔
687
    _exprs = unflatten(exprs, _flat_exprs)
1✔
688

689
    if not any(_exprs):
1✔
690
        _exprs = list('0' for expr in _exprs)
1✔
691
    funcstr = funcprinter.doprint(funcname, iterable_args, names, _exprs, cses=cses)
1✔
692

693
    # Provide lambda expression with builtins, and compatible implementation of range
694
    namespace = {'builtins': builtins, 'range': range}
1✔
695

696
    funclocals = {}
1✔
697
    filename = f'<{funcname}>'
1✔
698
    c = compile(funcstr, filename, 'exec')
1✔
699
    exec(c, namespace, funclocals)
1✔
700
    # mtime has to be None or else linecache.checkcache will remove it
701
    linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore
1✔
702

703
    func = funclocals[funcname]
1✔
704
    func.__module__ = __name__
1✔
705
    return func
1✔
706

707

708
class KingdonPrinter:
1✔
709
    def __init__(self, printer=None, dummify=False):
1✔
710
        self._dummify = dummify
1✔
711

712
        #XXX: This has to be done here because of circular imports
713
        from sympy.printing.lambdarepr import LambdaPrinter
1✔
714

715
        if printer is None:
1✔
716
            printer = LambdaPrinter()
×
717

718
        if inspect.isfunction(printer):
1✔
719
            self._exprrepr = printer
×
720
        else:
721
            if inspect.isclass(printer):
1✔
722
                printer = printer()
×
723

724
            self._exprrepr = printer.doprint
1✔
725

726
        # Used to print the generated function arguments in a standard way
727
        self._argrepr = LambdaPrinter().doprint
1✔
728

729
    def doprint(self, funcname, args, names, expr, *, cses=()):
1✔
730
        """
731
        Returns the function definition code as a string.
732
        """
733
        funcbody = []
1✔
734

735
        if not iterable(args):
1✔
736
            args = [args]
×
737

738
        if cses:
1✔
739
            subvars, subexprs = zip(*cses)
1✔
740
            exprs = [expr] + list(subexprs)
1✔
741
            argstrs, exprs = self._preprocess(args, exprs)
1✔
742
            expr, subexprs = exprs[0], exprs[1:]
1✔
743
            cses = zip(subvars, subexprs)
1✔
744
        else:
745
            argstrs, expr = self._preprocess(args, expr)
1✔
746

747
        # Generate argument unpacking and final argument list
748
        funcargs = []
1✔
749
        unpackings = []
1✔
750

751
        for name, argstr, arg in zip(names, argstrs, args):
1✔
752
            if not arg:
1✔
753
                funcargs.append(name)
1✔
754
            elif iterable(argstr) and iterable(argstr[0]):
1✔
755
                funcargs.append(name)
1✔
756
                unpackings.extend(self._print_unpacking([f'{name}_{i}' for i in range(len(argstr))], name))
1✔
757
                for i, subargstr in enumerate(argstr):
1✔
758
                    unpackings.extend(self._print_unpacking(subargstr, f'{name}_{i}'))
1✔
759
            elif iterable(argstr):
1✔
760
                funcargs.append(name)
1✔
761
                unpackings.extend(self._print_unpacking(argstr, name))
1✔
762
            else:
763
                funcargs.append(argstr)
×
764

765
        funcsig = 'def {}({}):'.format(funcname, ', '.join(funcargs))
1✔
766

767
        # Wrap input arguments before unpacking
768
        funcbody.extend(self._print_funcargwrapping(funcargs))
1✔
769

770
        funcbody.extend(unpackings)
1✔
771

772
        for s, e in cses:
1✔
773
            if e is None:
1✔
774
                funcbody.append('del {}'.format(s))
×
775
            else:
776
                funcbody.append('{} = {}'.format(s, self._exprrepr(e)))
1✔
777

778
        str_expr = _recursive_to_string(self._exprrepr, expr)
1✔
779

780
        if '\n' in str_expr:
1✔
781
            str_expr = '({})'.format(str_expr)
×
782
        funcbody.append('return {}'.format(str_expr))
1✔
783

784
        funclines = [funcsig]
1✔
785
        funclines.extend(['    ' + line for line in funcbody])
1✔
786

787
        return '\n'.join(funclines) + '\n'
1✔
788

789
    @classmethod
1✔
790
    def _is_safe_ident(cls, ident):
1✔
791
        return isinstance(ident, str) and ident.isidentifier() \
×
792
                and not keyword.iskeyword(ident)
793

794
    def _preprocess(self, args, expr):
1✔
795
        """Preprocess args, expr to replace arguments that do not map
796
        to valid Python identifiers.
797

798
        Returns string form of args, and updated expr.
799
        """
800
        argstrs = [None]*len(args)
1✔
801
        for i, arg in enumerate(args):
1✔
802
            if iterable(arg):
1✔
803
                s, expr = self._preprocess(arg, expr)
1✔
804
            elif hasattr(arg, 'name'):
1✔
805
                s = arg.name
1✔
806
            elif hasattr(arg, 'is_symbol') and arg.is_symbol:
1✔
807
                s = self._argrepr(arg)
×
808
            else:
809
                s = str(arg)
1✔
810
            argstrs[i] = s
1✔
811
        return argstrs, expr
1✔
812

813
    def _print_funcargwrapping(self, args):
1✔
814
        """Generate argument wrapping code.
815

816
        args is the argument list of the generated function (strings).
817

818
        Return value is a list of lines of code that will be inserted  at
819
        the beginning of the function definition.
820
        """
821
        return []
1✔
822

823
    def _print_unpacking(self, unpackto, arg):
1✔
824
        """Generate argument unpacking code.
825

826
        arg is the function argument to be unpacked (a string), and
827
        unpackto is a list or nested lists of the variable names (strings) to
828
        unpack to.
829
        """
830
        def unpack_lhs(lvalues):
1✔
831
            return '({},)'.format(', '.join(
1✔
832
                unpack_lhs(val) if iterable(val) else val for val in lvalues))
833

834
        return ['{} = {}'.format(unpack_lhs(unpackto), arg)]
1✔
835

836
def _recursive_to_string(doprint, arg):
1✔
837
    if isinstance(arg, str):
1✔
838
        return arg
1✔
839
    elif not arg:
1✔
840
        return str(arg)  # Empty list or tuple
1✔
841
    elif iterable(arg):
1✔
842
        if isinstance(arg, list):
1✔
843
            left, right = "[", "]"
1✔
844
        elif isinstance(arg, tuple):
×
845
            left, right = "(", ",)"
×
846
        else:
847
            raise NotImplementedError("unhandled type: %s, %s" % (type(arg), arg))
×
848
        return ''.join((left, ', '.join(_recursive_to_string(doprint, e) for e in arg), right))
1✔
849
    else:
850
        return doprint(arg)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc