• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

brian-team / brian2 / 18967526696

31 Oct 2025 08:51AM UTC coverage: 92.107% (+0.003%) from 92.104%
18967526696

push

github

web-flow
Merge pull request #1712 from brian-team/pre_commit_update

Use ruff

2518 of 2655 branches covered (94.84%)

14914 of 16192 relevant lines covered (92.11%)

2.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.16
/brian2/codegen/optimisation.py
1
"""
2
Simplify and optimise sequences of statements by rewriting and pulling out loop invariants.
3
"""
4

5
import ast
2✔
6
import copy
2✔
7
import itertools
2✔
8
from collections import OrderedDict
2✔
9
from functools import reduce
2✔
10

11
from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS
2✔
12
from brian2.core.preferences import prefs
2✔
13
from brian2.core.variables import AuxiliaryVariable
2✔
14
from brian2.parsing.bast import (
2✔
15
    BrianASTRenderer,
16
    brian_ast,
17
    brian_dtype_from_dtype,
18
    dtype_hierarchy,
19
)
20
from brian2.parsing.rendering import NodeRenderer
2✔
21
from brian2.utils.stringtools import get_identifiers, word_substitute
2✔
22

23
from .statements import Statement
2✔
24

25
# Default namespace has all the standard functions and constants in it
26
defaults_ns = {k: v.pyfunc for k, v in DEFAULT_FUNCTIONS.items()}
2✔
27
defaults_ns.update({k: v.value for k, v in DEFAULT_CONSTANTS.items()})
2✔
28

29

30
__all__ = ["optimise_statements", "ArithmeticSimplifier", "Simplifier"]
2✔
31

32

33
def evaluate_expr(expr, ns):
2✔
34
    """
35
    Try to evaluate the expression in the given namespace
36

37
    Returns either (value, True) if successful, or (expr, False) otherwise.
38

39
    Examples
40
    --------
41
    >>> assumptions = {'exp': DEFAULT_FUNCTIONS['exp'].pyfunc,
42
    ...                'inf': DEFAULT_CONSTANTS['inf'].value}
43
    >>> evaluate_expr('1/2', assumptions)
44
    (0.5, True)
45
    >>> evaluate_expr('exp(-inf)', assumptions)
46
    (0.0, True)
47
    >>> evaluate_expr('sin(2*pi*freq*t)', assumptions)
48
    ('sin(2*pi*freq*t)', False)
49
    >>> evaluate_expr('1/0', assumptions)
50
    ('1/0', False)
51
    """
52
    try:
4✔
53
        val = eval(expr, ns)
4✔
54
        return val, True
4✔
55
    except (NameError, ArithmeticError):
4✔
56
        return expr, False
4✔
57

58

59
def expression_complexity(expr, variables):
2✔
60
    return brian_ast(expr, variables).complexity
4✔
61

62

63
def optimise_statements(scalar_statements, vector_statements, variables, blockname=""):
2✔
64
    """
65
    Optimise a sequence of scalar and vector statements
66

67
    Performs the following optimisations:
68

69
    1. Constant evaluations (e.g. exp(0) to 1). See `evaluate_expr`.
70
    2. Arithmetic simplifications (e.g. 0*x to 0). See `ArithmeticSimplifier`, `collect`.
71
    3. Pulling out loop invariants (e.g. v*exp(-dt/tau) to a=exp(-dt/tau) outside the loop and v*a inside).
72
       See `Simplifier`.
73
    4. Boolean simplifications (allowing the replacement of expressions with booleans with a sequence of if/thens).
74
       See `Simplifier`.
75

76
    Parameters
77
    ----------
78
    scalar_statements : sequence of Statement
79
        Statements that only involve scalar values and should be evaluated in the scalar block.
80
    vector_statements : sequence of Statement
81
        Statements that involve vector values and should be evaluated in the vector block.
82
    variables : dict of (str, Variable)
83
        Definition of the types of the variables.
84
    blockname : str, optional
85
        Name of the block (used for LIO constant prefixes to avoid name clashes)
86

87
    Returns
88
    -------
89
    new_scalar_statements : sequence of Statement
90
        As above but with loop invariants pulled out from vector statements
91
    new_vector_statements : sequence of Statement
92
        Simplified/optimised versions of statements
93
    """
94
    boolvars = {
4✔
95
        k: v
96
        for k, v in variables.items()
97
        if hasattr(v, "dtype") and brian_dtype_from_dtype(v.dtype) == "boolean"
98
    }
99
    # We use the Simplifier class by rendering each expression, which generates new scalar statements
100
    # stored in the Simplifier object, and these are then added to the scalar statements.
101
    simplifier = Simplifier(variables, scalar_statements, extra_lio_prefix=blockname)
4✔
102
    new_vector_statements = []
4✔
103
    for stmt in vector_statements:
4✔
104
        # Carry out constant evaluation, arithmetic simplification and loop invariants
105
        new_expr = simplifier.render_expr(stmt.expr)
4✔
106
        new_stmt = Statement(
4✔
107
            stmt.var,
108
            stmt.op,
109
            new_expr,
110
            stmt.comment,
111
            dtype=stmt.dtype,
112
            constant=stmt.constant,
113
            subexpression=stmt.subexpression,
114
            scalar=stmt.scalar,
115
        )
116
        # Now check if boolean simplification can be carried out
117
        complexity_std = expression_complexity(new_expr, simplifier.variables)
4✔
118
        idents = get_identifiers(new_expr)
4✔
119
        used_boolvars = [var for var in boolvars if var in idents]
4✔
120
        if len(used_boolvars):
4✔
121
            # We want to iterate over all the possible assignments of boolean variables to values in (True, False)
122
            bool_space = [[False, True] for _ in used_boolvars]
4✔
123
            expanded_expressions = {}
4✔
124
            complexities = {}
4✔
125
            for bool_vals in itertools.product(*bool_space):
4✔
126
                # substitute those values into the expr and simplify (including potentially pulling out new
127
                # loop invariants)
128
                subs = {
4✔
129
                    var: str(val)
130
                    for var, val in zip(used_boolvars, bool_vals, strict=True)
131
                }
132
                curexpr = word_substitute(new_expr, subs)
4✔
133
                curexpr = simplifier.render_expr(curexpr)
4✔
134
                key = tuple(
4✔
135
                    (var, val)
136
                    for var, val in zip(used_boolvars, bool_vals, strict=True)
137
                )
138
                expanded_expressions[key] = curexpr
4✔
139
                complexities[key] = expression_complexity(curexpr, simplifier.variables)
4✔
140
            # See Statement for details on these
141
            new_stmt.used_boolean_variables = used_boolvars
4✔
142
            new_stmt.boolean_simplified_expressions = expanded_expressions
4✔
143
            new_stmt.complexity_std = complexity_std
4✔
144
            new_stmt.complexities = complexities
4✔
145
        new_vector_statements.append(new_stmt)
4✔
146
    # Generate additional scalar statements for the loop invariants
147
    new_scalar_statements = copy.copy(scalar_statements)
4✔
148
    for expr, name in simplifier.loop_invariants.items():
4✔
149
        dtype_name = simplifier.loop_invariant_dtypes[name]
4✔
150
        if dtype_name == "boolean":
4✔
151
            dtype = bool
4✔
152
        elif dtype_name == "integer":
4✔
153
            dtype = int
4✔
154
        else:
155
            dtype = prefs.core.default_float_dtype
4✔
156
        new_stmt = Statement(
4✔
157
            name,
158
            ":=",
159
            expr,
160
            "",
161
            dtype=dtype,
162
            constant=True,
163
            subexpression=False,
164
            scalar=True,
165
        )
166
        new_scalar_statements.append(new_stmt)
4✔
167
    return new_scalar_statements, new_vector_statements
4✔
168

169

170
def _replace_with_zero(zero_node, node):
2✔
171
    """
172
    Helper function to return a "zero node" of the correct type.
173

174
    Parameters
175
    ----------
176
    zero_node : `ast.Constant`
177
        The node to replace
178
    node : `ast.Node`
179
        The node that determines the type
180

181
    Returns
182
    -------
183
    zero_node : `ast.Constant`
184
        The original ``zero_node`` with its value replaced by 0 or 0.0.
185
    """
186
    # must not change the dtype of the output,
187
    # e.g. handle 0/float->0.0 and 0.0/int->0.0
188
    zero_node.dtype = node.dtype
4✔
189
    if node.dtype == "integer":
4✔
190
        zero_node.value = 0
2✔
191
    else:
192
        zero_node.value = prefs.core.default_float_dtype(0.0)
4✔
193
    return zero_node
4✔
194

195

196
class ArithmeticSimplifier(BrianASTRenderer):
2✔
197
    """
198
    Carries out the following arithmetic simplifications:
199

200
    1. Constant evaluation (e.g. exp(0)=1) by attempting to evaluate the expression in an "assumptions namespace"
201
    2. Binary operators, e.g. 0*x=0, 1*x=x, etc. You have to take care that the dtypes match here, e.g.
202
       if x is an integer, then 1.0*x shouldn't be replaced with x but left as 1.0*x.
203

204
    Parameters
205
    ----------
206
    variables : dict of (str, Variable)
207
        Usual definition of variables.
208
    assumptions : sequence of str
209
        Additional assumptions that can be used in simplification, each assumption is a string statement.
210
        These might be the scalar statements for example.
211
    """
212

213
    def __init__(self, variables):
2✔
214
        BrianASTRenderer.__init__(self, variables, copy_variables=False)
4✔
215
        self.assumptions = []
4✔
216
        self.assumptions_ns = dict(defaults_ns)
4✔
217
        self.bast_renderer = BrianASTRenderer(variables, copy_variables=False)
4✔
218

219
    def render_node(self, node):
2✔
220
        """
221
        Assumes that the node has already been fully processed by BrianASTRenderer
222
        """
223
        if not hasattr(node, "simplified"):
4✔
224
            node = super().render_node(node)
4✔
225
            node.simplified = True
4✔
226
        # can't evaluate vector expressions, so abandon in this case
227
        if not node.scalar:
4✔
228
            return node
4✔
229
        # No evaluation necessary for simple names or numbers
230
        if node.__class__.__name__ in ["Name", "NameConstant", "Num", "Constant"]:
4✔
231
            return node
4✔
232
        # Don't evaluate stateful nodes (e.g. those containing a rand() call)
233
        if not node.stateless:
4✔
234
            return node
×
235
        # try fully evaluating using assumptions
236
        expr = NodeRenderer().render_node(node)
4✔
237
        val, evaluated = evaluate_expr(expr, self.assumptions_ns)
4✔
238
        if evaluated:
4✔
239
            if node.dtype == "boolean":
4✔
240
                val = bool(val)
4✔
241
                if hasattr(ast, "Constant"):
4✔
242
                    newnode = ast.Constant(val)
4✔
243
                elif hasattr(ast, "NameConstant"):
×
244
                    newnode = ast.NameConstant(val)
×
245
                else:
246
                    # None is the expression context, we don't use it so we just set to None
247
                    newnode = ast.Name(repr(val), None)
×
248
            elif node.dtype == "integer":
4✔
249
                val = int(val)
4✔
250
            else:
251
                val = prefs.core.default_float_dtype(val)
4✔
252
            if node.dtype != "boolean":
4✔
253
                newnode = ast.Constant(val)
4✔
254
            newnode.dtype = node.dtype
4✔
255
            newnode.scalar = True
4✔
256
            newnode.stateless = node.stateless
4✔
257
            newnode.complexity = 0
4✔
258
            return newnode
4✔
259
        return node
4✔
260

261
    def render_BinOp(self, node):
2✔
262
        if node.dtype == "float":  # only try to collect float type nodes
4✔
263
            if node.op.__class__.__name__ in [
4✔
264
                "Mult",
265
                "Div",
266
                "Add",
267
                "Sub",
268
            ] and not hasattr(node, "collected"):
269
                newnode = self.bast_renderer.render_node(collect(node))
4✔
270
                newnode.collected = True
4✔
271
                return self.render_node(newnode)
4✔
272
        left = node.left = self.render_node(node.left)
4✔
273
        right = node.right = self.render_node(node.right)
4✔
274
        node = super().render_BinOp(node)
4✔
275
        op = node.op
4✔
276
        # Handle multiplication by 0 or 1
277
        if op.__class__.__name__ == "Mult":
4✔
278
            for operand, other in [(left, right), (right, left)]:
4✔
279
                if operand.__class__.__name__ in ["Num", "Constant"]:
4✔
280
                    op_value = operand.value
4✔
281
                    if op_value == 0:
4✔
282
                        # Do not remove stateful functions
283
                        if node.stateless:
4✔
284
                            return _replace_with_zero(operand, node)
4✔
285
                    if op_value == 1:
4✔
286
                        # only simplify this if the type wouldn't be cast by the operation
287
                        if (
4✔
288
                            dtype_hierarchy[operand.dtype]
289
                            <= dtype_hierarchy[other.dtype]
290
                        ):
291
                            return other
4✔
292
        # Handle division by 1, or 0/x
293
        elif op.__class__.__name__ == "Div":
4✔
294
            if (
4✔
295
                left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
296
            ):  # 0/x
297
                if node.stateless:
4✔
298
                    # Do not remove stateful functions
299
                    return _replace_with_zero(left, node)
4✔
300
            if (
4✔
301
                right.__class__.__name__ in ["Num", "Constant"] and right.value == 1
302
            ):  # x/1
303
                # only simplify this if the type wouldn't be cast by the operation
304
                if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
×
305
                    return left
×
306
        elif op.__class__.__name__ == "FloorDiv":
4✔
307
            if (
4✔
308
                left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
309
            ):  # 0//x
310
                if node.stateless:
×
311
                    # Do not remove stateful functions
312
                    return _replace_with_zero(left, node)
×
313
            # Only optimise floor division by 1 if both numbers are integers,
314
            # for floating point values, floor division by 1 changes the value,
315
            # and division by 1.0 can change the type for an integer value
316
            if (
4✔
317
                left.dtype == right.dtype == "integer"
318
                and right.__class__.__name__ in ["Num", "Constant"]
319
                and right.value == 1
320
            ):  # x//1
321
                return left
2✔
322
        # Handle addition of 0
323
        elif op.__class__.__name__ == "Add":
4✔
324
            for operand, other in [(left, right), (right, left)]:
4✔
325
                if (
4✔
326
                    operand.__class__.__name__ in ["Num", "Constant"]
327
                    and operand.value == 0
328
                ):
329
                    # only simplify this if the type wouldn't be cast by the operation
330
                    if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]:
4✔
331
                        return other
4✔
332
        # Handle subtraction of 0
333
        elif op.__class__.__name__ == "Sub":
4✔
334
            if right.__class__.__name__ in ["Num", "Constant"] and right.value == 0:
4✔
335
                # only simplify this if the type wouldn't be cast by the operation
336
                if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
2✔
337
                    return left
2✔
338

339
        # simplify e.g. 2*float to 2.0*float to make things more explicit: not strictly necessary
340
        # but might be useful for some codegen targets
341
        if node.dtype == "float" and op.__class__.__name__ in [
4✔
342
            "Mult",
343
            "Add",
344
            "Sub",
345
            "Div",
346
        ]:
347
            for subnode in [node.left, node.right]:
4✔
348
                if subnode.__class__.__name__ in ["Num", "Constant"] and not (
4✔
349
                    subnode.value is True or subnode.value is False
350
                ):
351
                    subnode.dtype = "float"
4✔
352
                    subnode.value = prefs.core.default_float_dtype(subnode.value)
4✔
353
        return node
4✔
354

355

356
class Simplifier(BrianASTRenderer):
2✔
357
    """
358
    Carry out arithmetic simplifications (see `ArithmeticSimplifier`) and loop invariants
359

360
    Parameters
361
    ----------
362
    variables : dict of (str, Variable)
363
        Usual definition of variables.
364
    scalar_statements : sequence of Statement
365
        Predefined scalar statements that can be used as part of simplification
366

367
    Notes
368
    -----
369

370
    After calling `render_expr` on a sequence of expressions (coming from vector statements typically),
371
    this object will have some new attributes:
372

373
    ``loop_invariants`` : OrderedDict of (expression, varname)
374
        varname will be of the form ``_lio_N`` where ``N`` is some integer, and the expressions will be
375
        strings that correspond to scalar-only expressions that can be evaluated outside of the vector
376
        block.
377
    ``loop_invariant_dtypes`` : dict of (varname, dtypename)
378
        dtypename will be one of ``'boolean'``, ``'integer'``, ``'float'``.
379
    """
380

381
    def __init__(self, variables, scalar_statements, extra_lio_prefix=""):
2✔
382
        BrianASTRenderer.__init__(self, variables, copy_variables=False)
4✔
383
        self.loop_invariants = OrderedDict()
4✔
384
        self.loop_invariant_dtypes = {}
4✔
385
        self.value = 0
4✔
386
        self.node_renderer = NodeRenderer()
4✔
387
        self.arithmetic_simplifier = ArithmeticSimplifier(variables)
4✔
388
        self.scalar_statements = scalar_statements
4✔
389
        if extra_lio_prefix is None:
4✔
390
            extra_lio_prefix = ""
4✔
391
        if len(extra_lio_prefix):
4✔
392
            extra_lio_prefix = f"{extra_lio_prefix}_"
4✔
393
        self.extra_lio_prefix = extra_lio_prefix
4✔
394

395
    def render_expr(self, expr):
2✔
396
        node = brian_ast(expr, self.variables)
4✔
397
        node = self.arithmetic_simplifier.render_node(node)
4✔
398
        node = self.render_node(node)
4✔
399
        return self.node_renderer.render_node(node)
4✔
400

401
    def render_node(self, node):
2✔
402
        """
403
        Assumes that the node has already been fully processed by BrianASTRenderer
404
        """
405
        # can we pull this out?
406
        if node.scalar and node.complexity > 0:
4✔
407
            expr = self.node_renderer.render_node(
4✔
408
                self.arithmetic_simplifier.render_node(node)
409
            )
410
            if expr in self.loop_invariants:
4✔
411
                name = self.loop_invariants[expr]
4✔
412
            else:
413
                self.value += 1
4✔
414
                name = f"_lio_{self.extra_lio_prefix}{str(self.value)}"
4✔
415
                self.loop_invariants[expr] = name
4✔
416
                self.loop_invariant_dtypes[name] = node.dtype
4✔
417
                numpy_dtype = {
4✔
418
                    "boolean": bool,
419
                    "integer": int,
420
                    "float": prefs.core.default_float_dtype,
421
                }[node.dtype]
422
                self.variables[name] = AuxiliaryVariable(
4✔
423
                    name, dtype=numpy_dtype, scalar=True
424
                )
425
            # None is the expression context, we don't use it so we just set to None
426
            newnode = ast.Name(name, None)
4✔
427
            newnode.scalar = True
4✔
428
            newnode.dtype = node.dtype
4✔
429
            newnode.complexity = 0
4✔
430
            newnode.stateless = node.stateless
4✔
431
            return newnode
4✔
432
        # otherwise, render node as usual
433
        return super().render_node(node)
4✔
434

435

436
def reduced_node(terms, op):
2✔
437
    """
438
    Reduce a sequence of terms with the given operator
439

440
    For examples, if terms were [a, b, c] and op was multiplication then the reduction would be (a*b)*c.
441

442
    Parameters
443
    ----------
444
    terms : list
445
        AST nodes.
446
    op : AST node
447
        Could be `ast.Mult` or `ast.Add`.
448

449
    Examples
450
    --------
451
    >>> import ast
452
    >>> nodes = [ast.Name(id='x'), ast.Name(id='y'), ast.Name(id='z')]
453
    >>> ast.unparse(reduced_node(nodes, ast.Mult))
454
    'x * y * z'
455
    >>> nodes = [ast.Name(id='x')]
456
    >>> ast.unparse(reduced_node(nodes, ast.Add))
457
    'x'
458
    """
459
    # Remove None terms
460
    terms = [term for term in terms if term is not None]
4✔
461
    if not len(terms):
4✔
462
        return None
4✔
463
    return reduce(lambda left, right: ast.BinOp(left, op(), right), terms)
4✔
464

465

466
def cancel_identical_terms(primary, inverted):
2✔
467
    """
468
    Cancel terms in a collection, e.g. a+b-a should be cancelled to b
469

470
    Simply renders the nodes into expressions and removes whenever there is a common expression
471
    in primary and inverted.
472

473
    Parameters
474
    ----------
475
    primary : list of AST nodes
476
        These are the nodes that are positive with respect to the operator, e.g.
477
        in x*y/z it would be [x, y].
478
    inverted : list of AST nodes
479
        These are the nodes that are inverted with respect to the operator, e.g.
480
        in x*y/z it would be [z].
481

482
    Returns
483
    -------
484
    primary : list of AST nodes
485
        Primary nodes after cancellation
486
    inverted : list of AST nodes
487
        Inverted nodes after cancellation
488
    """
489
    nr = NodeRenderer()
4✔
490
    expressions = {node: nr.render_node(node) for node in primary}
4✔
491
    expressions.update({node: nr.render_node(node) for node in inverted})
4✔
492
    new_primary = []
4✔
493
    inverted_expressions = [expressions[term] for term in inverted]
4✔
494
    for term in primary:
4✔
495
        expr = expressions[term]
4✔
496
        if expr in inverted_expressions and term.stateless:
4✔
497
            new_inverted = []
2✔
498
            for iterm in inverted:
2✔
499
                if expressions[iterm] == expr:
2✔
500
                    expr = ""  # handled
2✔
501
                else:
502
                    new_inverted.append(iterm)
×
503
            inverted = new_inverted
2✔
504
            inverted_expressions = [expressions[term] for term in inverted]
2✔
505
        else:
506
            new_primary.append(term)
4✔
507
    return new_primary, inverted
4✔
508

509

510
def collect(node):
2✔
511
    """
512
    Attempts to collect commutative operations into one and simplifies them.
513

514
    For example, if x and y are scalars, and z is a vector, then (x*z)*y should
515
    be rewritten as (x*y)*z to minimise the number of vector operations. Similarly,
516
    ((x*2)*3)*4 should be rewritten as x*24.
517

518
    Works for either multiplication/division or addition/subtraction nodes.
519

520
    The final output is a subexpression of the following maximal form:
521

522
        (((numerical_value*(product of scalars))/(product of scalars))*(product of vectors))/(product of vectors)
523

524
    Any possible cancellations will have been done.
525

526
    Parameters
527
    ----------
528
    node : Brian AST node
529
        The node to be collected/simplified.
530

531
    Returns
532
    -------
533
    node : Brian AST node
534
        Simplified node.
535
    """
536
    node.collected = True
4✔
537
    orignode_dtype = node.dtype
4✔
538
    # we only work on */ or +- ops, which are both BinOp
539
    if node.__class__.__name__ != "BinOp":
4✔
540
        return node
×
541
    # primary would be the * or + nodes, and inverted would be the / or - nodes
542
    terms_primary = []
4✔
543
    terms_inverted = []
4✔
544
    # we handle both multiplicative and additive nodes in the same way by using these variables
545
    if node.op.__class__.__name__ in ["Mult", "Div"]:
4✔
546
        op_primary = ast.Mult
4✔
547
        op_inverted = ast.Div
4✔
548
        op_null = prefs.core.default_float_dtype(1.0)  # the identity for the operator
4✔
549
        op_py_primary = lambda x, y: x * y
4✔
550
        op_py_inverted = lambda x, y: x / y
4✔
551
    elif node.op.__class__.__name__ in ["Add", "Sub"]:
4✔
552
        op_primary = ast.Add
4✔
553
        op_inverted = ast.Sub
4✔
554
        op_null = prefs.core.default_float_dtype(0.0)
4✔
555
        op_py_primary = lambda x, y: x + y
4✔
556
        op_py_inverted = lambda x, y: x - y
4✔
557
    else:
558
        return node
×
559
    if node.dtype == "integer":
4✔
560
        op_null_with_dtype = int(op_null)
×
561
    else:
562
        op_null_with_dtype = op_null
4✔
563
    # recursively collect terms into the terms_primary and terms_inverted lists
564
    collect_commutative(node, op_primary, op_inverted, terms_primary, terms_inverted)
4✔
565
    x = op_null
4✔
566
    # extract the numerical nodes and fully evaluate
567
    remaining_terms_primary = []
4✔
568
    remaining_terms_inverted = []
4✔
569
    for term in terms_primary:
4✔
570
        if term.__class__.__name__ == "Num":
4✔
571
            x = op_py_primary(x, term.n)
×
572
        elif term.__class__.__name__ == "Constant":
4✔
573
            x = op_py_primary(x, term.value)
4✔
574
        else:
575
            remaining_terms_primary.append(term)
4✔
576
    for term in terms_inverted:
4✔
577
        if term.__class__.__name__ == "Num":
4✔
578
            x = op_py_inverted(x, term.n)
×
579
        elif term.__class__.__name__ == "Constant":
4✔
580
            x = op_py_inverted(x, term.value)
4✔
581
        else:
582
            remaining_terms_inverted.append(term)
4✔
583
    # if the fully evaluated node is just the identity/null element then we
584
    # don't have to make it into an explicit term
585
    if x != op_null:
4✔
586
        num_node = ast.Constant(x)
4✔
587
    else:
588
        num_node = None
4✔
589
    terms_primary = remaining_terms_primary
4✔
590
    terms_inverted = remaining_terms_inverted
4✔
591
    node = num_node
4✔
592
    for scalar in (True, False):
4✔
593
        primary_terms = [term for term in terms_primary if term.scalar == scalar]
4✔
594
        inverted_terms = [term for term in terms_inverted if term.scalar == scalar]
4✔
595
        primary_terms, inverted_terms = cancel_identical_terms(
4✔
596
            primary_terms, inverted_terms
597
        )
598

599
        # produce nodes that are the reduction of the operator on these subsets
600
        prod_primary = reduced_node(primary_terms, op_primary)
4✔
601
        prod_inverted = reduced_node(inverted_terms, op_primary)
4✔
602

603
        # construct the simplest version of the fully simplified node (only doing operations where necessary)
604
        node = reduced_node([node, prod_primary], op_primary)
4✔
605
        if prod_inverted is not None:
4✔
606
            if node is None:
4✔
607
                node = ast.Constant(op_null_with_dtype)
4✔
608
            node = ast.BinOp(node, op_inverted(), prod_inverted)
4✔
609

610
    if node is None:  # everything cancelled
4✔
611
        node = ast.Constant(op_null_with_dtype)
2✔
612
    if (
4✔
613
        hasattr(node, "dtype")
614
        and dtype_hierarchy[node.dtype] < dtype_hierarchy[orignode_dtype]
615
    ):
616
        node = ast.BinOp(ast.Constant(op_null_with_dtype), op_primary(), node)
4✔
617
    node.collected = True
4✔
618
    return node
4✔
619

620

621
def collect_commutative(
2✔
622
    node, primary, inverted, terms_primary, terms_inverted, add_to_inverted=False
623
):
624
    # This function is called recursively, so we use add_to_inverted to keep track of whether or not
625
    # we're working in the numerator/denominator (for multiplicative nodes, equivalent for additive).
626
    op_primary = node.op.__class__ is primary
4✔
627
    # this should only be called with node a BinOp of type primary or inverted
628
    # left_exact is the condition that we can collect terms (we can do it with floats or add/sub,
629
    # but not integer mult/div - the reason being that for C-style division e.g. 3/(4/3)!=(3*3)/4
630
    left_exact = node.left.dtype == "float" or (
4✔
631
        hasattr(node.left, "op") and node.left.op.__class__.__name__ in ["Add", "Sub"]
632
    )
633
    if (
4✔
634
        node.left.__class__.__name__ == "BinOp"
635
        and node.left.op.__class__ in [primary, inverted]
636
        and left_exact
637
    ):
638
        collect_commutative(
4✔
639
            node.left,
640
            primary,
641
            inverted,
642
            terms_primary,
643
            terms_inverted,
644
            add_to_inverted=add_to_inverted,
645
        )
646
    else:
647
        if add_to_inverted:
4✔
648
            terms_inverted.append(node.left)
4✔
649
        else:
650
            terms_primary.append(node.left)
4✔
651
    right_exact = node.right.dtype == "float" or (
4✔
652
        hasattr(node.right, "op") and node.right.op.__class__.__name__ in ["Add", "Sub"]
653
    )
654
    if (
4✔
655
        node.right.__class__.__name__ == "BinOp"
656
        and node.right.op.__class__ in [primary, inverted]
657
        and right_exact
658
    ):
659
        if node.op.__class__ is primary:
4✔
660
            collect_commutative(
4✔
661
                node.right,
662
                primary,
663
                inverted,
664
                terms_primary,
665
                terms_inverted,
666
                add_to_inverted=add_to_inverted,
667
            )
668
        else:
669
            collect_commutative(
4✔
670
                node.right,
671
                primary,
672
                inverted,
673
                terms_primary,
674
                terms_inverted,
675
                add_to_inverted=not add_to_inverted,
676
            )
677
    else:
678
        if (not add_to_inverted and op_primary) or (add_to_inverted and not op_primary):
4✔
679
            terms_primary.append(node.right)
4✔
680
        else:
681
            terms_inverted.append(node.right)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc