• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / PySR / 10138721258

29 Jul 2024 05:35AM UTC coverage: 93.448% (-0.3%) from 93.797%
10138721258

Pull #681

github

web-flow
Merge 7af2bd516 into 3aee19e38
Pull Request #681: Enhance cross-platform compatibility for loading PySRRegressor models

15 of 17 new or added lines in 2 files covered. (88.24%)

3 existing lines in 1 file now uncovered.

1141 of 1221 relevant lines covered (93.45%)

2.59 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.27
/pysr/export_sympy.py
1
"""Define utilities to export to sympy"""
2

3
from typing import Callable, Dict, List, Optional
3✔
4

5
import sympy  # type: ignore
3✔
6
from sympy import sympify
3✔
7

8
from .utils import ArrayLike
3✔
9

10
sympy_mappings = {
3✔
11
    "div": lambda x, y: x / y,
12
    "mult": lambda x, y: x * y,
13
    "sqrt": lambda x: sympy.sqrt(x),
14
    "sqrt_abs": lambda x: sympy.sqrt(abs(x)),
15
    "square": lambda x: x**2,
16
    "cube": lambda x: x**3,
17
    "plus": lambda x, y: x + y,
18
    "sub": lambda x, y: x - y,
19
    "neg": lambda x: -x,
20
    "pow": lambda x, y: x**y,
21
    "pow_abs": lambda x, y: abs(x) ** y,
22
    "cos": sympy.cos,
23
    "sin": sympy.sin,
24
    "tan": sympy.tan,
25
    "cosh": sympy.cosh,
26
    "sinh": sympy.sinh,
27
    "tanh": sympy.tanh,
28
    "exp": sympy.exp,
29
    "acos": sympy.acos,
30
    "asin": sympy.asin,
31
    "atan": sympy.atan,
32
    "acosh": lambda x: sympy.acosh(x),
33
    "acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
34
    "asinh": sympy.asinh,
35
    "atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
36
    "atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
37
    "abs": abs,
38
    "mod": sympy.Mod,
39
    "erf": sympy.erf,
40
    "erfc": sympy.erfc,
41
    "log": lambda x: sympy.log(x),
42
    "log10": lambda x: sympy.log(x, 10),
43
    "log2": lambda x: sympy.log(x, 2),
44
    "log1p": lambda x: sympy.log(x + 1),
45
    "log_abs": lambda x: sympy.log(abs(x)),
46
    "log10_abs": lambda x: sympy.log(abs(x), 10),
47
    "log2_abs": lambda x: sympy.log(abs(x), 2),
48
    "log1p_abs": lambda x: sympy.log(abs(x) + 1),
49
    "floor": sympy.floor,
50
    "ceil": sympy.ceiling,
51
    "sign": sympy.sign,
52
    "gamma": sympy.gamma,
53
    "round": lambda x: sympy.ceiling(x - 0.5),
54
    "max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
55
    "min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
56
    "greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
57
    "cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)),
58
    "logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
59
    "logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
60
    "relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)),
61
}
62

63

64
def create_sympy_symbols_map(
3✔
65
    feature_names_in: ArrayLike[str],
66
) -> Dict[str, sympy.Symbol]:
67
    return {variable: sympy.Symbol(variable) for variable in feature_names_in}
3✔
68

69

70
def create_sympy_symbols(
3✔
71
    feature_names_in: ArrayLike[str],
72
) -> List[sympy.Symbol]:
73
    return [sympy.Symbol(variable) for variable in feature_names_in]
3✔
74

75

76
def pysr2sympy(
3✔
77
    equation: str,
78
    *,
79
    feature_names_in: Optional[ArrayLike[str]] = None,
80
    extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
81
):
82
    if feature_names_in is None:
3✔
83
        feature_names_in = []
×
84
    local_sympy_mappings = {
3✔
85
        **create_sympy_symbols_map(feature_names_in),
86
        **(extra_sympy_mappings if extra_sympy_mappings is not None else {}),
87
        **sympy_mappings,
88
    }
89

90
    try:
3✔
91
        return sympify(equation, locals=local_sympy_mappings, evaluate=False)
3✔
UNCOV
92
    except TypeError as e:
×
UNCOV
93
        if "got an unexpected keyword argument 'evaluate'" in str(e):
×
UNCOV
94
            return sympify(equation, locals=local_sympy_mappings)
×
95
        raise TypeError(f"Error processing equation '{equation}'") from e
×
96

97

98
def assert_valid_sympy_symbol(var_name: str) -> None:
3✔
99
    if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
3✔
100
        raise ValueError(f"Variable name {var_name} is already a function name.")
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc