• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / PySR / 12100062277

30 Nov 2024 10:48PM UTC coverage: 93.389% (-0.3%) from 93.735%
12100062277

Pull #748

github

MilesCranmer
test: skip tensorboard test on windows
Pull Request #748: Update to 1.0 backend

412 of 438 new or added lines in 13 files covered. (94.06%)

1 existing line in 1 file now uncovered.

1342 of 1437 relevant lines covered (93.39%)

2.62 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.87
/pysr/sr.py
1
"""Define the PySRRegressor scikit-learn interface."""
2

3
import copy
3✔
4
import os
3✔
5
import pickle as pkl
3✔
6
import re
3✔
7
import sys
3✔
8
import tempfile
3✔
9
import warnings
3✔
10
from collections.abc import Callable
3✔
11
from dataclasses import dataclass, fields
3✔
12
from io import StringIO
3✔
13
from multiprocessing import cpu_count
3✔
14
from pathlib import Path
3✔
15
from typing import Any, Literal, cast
3✔
16

17
import numpy as np
3✔
18
import pandas as pd
3✔
19
from numpy import ndarray
3✔
20
from numpy.typing import NDArray
3✔
21
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
3✔
22
from sklearn.utils import check_array, check_consistent_length, check_random_state
3✔
23
from sklearn.utils.validation import _check_feature_names_in  # type: ignore
3✔
24
from sklearn.utils.validation import check_is_fitted
3✔
25

26
from .denoising import denoise, multi_denoise
3✔
27
from .deprecated import DEPRECATED_KWARGS
3✔
28
from .export_latex import (
3✔
29
    sympy2latex,
30
    sympy2latextable,
31
    sympy2multilatextable,
32
    with_preamble,
33
)
34
from .export_sympy import assert_valid_sympy_symbol
3✔
35
from .expression_specs import (
3✔
36
    AbstractExpressionSpec,
37
    ExpressionSpec,
38
    ParametricExpressionSpec,
39
)
40
from .feature_selection import run_feature_selection
3✔
41
from .julia_extensions import load_required_packages
3✔
42
from .julia_helpers import (
3✔
43
    _escape_filename,
44
    _load_cluster_manager,
45
    jl_array,
46
    jl_deserialize,
47
    jl_is_function,
48
    jl_serialize,
49
)
50
from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl
3✔
51
from .logger_specs import AbstractLoggerSpec
3✔
52
from .utils import (
3✔
53
    ArrayLike,
54
    PathLike,
55
    _preprocess_julia_floats,
56
    _safe_check_feature_names_in,
57
    _subscriptify,
58
    _suggest_keywords,
59
)
60

61
ALREADY_RAN = False
3✔
62

63

64
def _process_constraints(
3✔
65
    binary_operators: list[str],
66
    unary_operators: list,
67
    constraints: dict[str, int | tuple[int, int]],
68
) -> dict[str, int | tuple[int, int]]:
69
    constraints = constraints.copy()
3✔
70
    for op in unary_operators:
3✔
71
        if op not in constraints:
3✔
72
            constraints[op] = -1
3✔
73
    for op in binary_operators:
3✔
74
        if op not in constraints:
3✔
75
            if op in ["^", "pow"]:
3✔
76
                # Warn user that they should set up constraints
77
                warnings.warn(
3✔
78
                    "You are using the `^` operator, but have not set up `constraints` for it. "
79
                    "This may lead to overly complex expressions. "
80
                    "One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which "
81
                    "will allow arbitrary-complexity base (-1) but only powers such as "
82
                    "a constant or variable (1). "
83
                    "For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/"
84
                )
85
            constraints[op] = (-1, -1)
3✔
86

87
        constraint_tuple = cast(tuple[int, int], constraints[op])
3✔
88
        if op in ["plus", "sub", "+", "-"]:
3✔
89
            if constraint_tuple[0] != constraint_tuple[1]:
3✔
90
                raise NotImplementedError(
3✔
91
                    "You need equal constraints on both sides for - and +, "
92
                    "due to simplification strategies."
93
                )
94
        elif op in ["mult", "*"]:
3✔
95
            # Make sure the complex expression is in the left side.
96
            if constraint_tuple[0] == -1:
3✔
97
                continue
3✔
NEW
98
            if constraint_tuple[1] == -1 or constraint_tuple[0] < constraint_tuple[1]:
×
NEW
99
                constraints[op] = (constraint_tuple[1], constraint_tuple[0])
×
100
    return constraints
3✔
101

102

103
def _maybe_create_inline_operators(
3✔
104
    binary_operators: list[str],
105
    unary_operators: list[str],
106
    extra_sympy_mappings: dict[str, Callable] | None,
107
) -> tuple[list[str], list[str]]:
108
    binary_operators = binary_operators.copy()
3✔
109
    unary_operators = unary_operators.copy()
3✔
110
    for op_list in [binary_operators, unary_operators]:
3✔
111
        for i, op in enumerate(op_list):
3✔
112
            is_user_defined_operator = "(" in op
3✔
113

114
            if is_user_defined_operator:
3✔
115
                jl.seval(op)
3✔
116
                # Cut off from the first non-alphanumeric char:
117
                first_non_char = [j for j, char in enumerate(op) if char == "("][0]
3✔
118
                function_name = op[:first_non_char]
3✔
119
                # Assert that function_name only contains
120
                # alphabetical characters, numbers,
121
                # and underscores:
122
                if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
3✔
123
                    raise ValueError(
3✔
124
                        f"Invalid function name {function_name}. "
125
                        "Only alphanumeric characters, numbers, "
126
                        "and underscores are allowed."
127
                    )
128
                if (extra_sympy_mappings is None) or (
3✔
129
                    function_name not in extra_sympy_mappings
130
                ):
131
                    raise ValueError(
3✔
132
                        f"Custom function {function_name} is not defined in `extra_sympy_mappings`. "
133
                        "You can define it with, "
134
                        "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
135
                        "`lambda x: 1/x` is a valid SymPy function defining the operator. "
136
                        "You can also define these at initialization time."
137
                    )
138
                op_list[i] = function_name
3✔
139
    return binary_operators, unary_operators
3✔
140

141

142
def _check_assertions(
3✔
143
    X,
144
    use_custom_variable_names,
145
    variable_names,
146
    complexity_of_variables,
147
    weights,
148
    y,
149
    X_units,
150
    y_units,
151
):
152
    # Check for potential errors before they happen
153
    assert len(X.shape) == 2
3✔
154
    assert len(y.shape) in [1, 2]
3✔
155
    assert X.shape[0] == y.shape[0]
3✔
156
    if weights is not None:
3✔
157
        assert weights.shape == y.shape
3✔
158
        assert X.shape[0] == weights.shape[0]
3✔
159
    if use_custom_variable_names:
3✔
160
        assert len(variable_names) == X.shape[1]
3✔
161
        # Check none of the variable names are function names:
162
        for var_name in variable_names:
3✔
163
            # Check if alphanumeric only:
164
            if not re.match(r"^[₀₁₂₃₄₅₆₇₈₉a-zA-Z0-9_]+$", var_name):
3✔
165
                raise ValueError(
3✔
166
                    f"Invalid variable name {var_name}. "
167
                    "Only alphanumeric characters, numbers, "
168
                    "and underscores are allowed."
169
                )
170
            assert_valid_sympy_symbol(var_name)
3✔
171
    if (
3✔
172
        isinstance(complexity_of_variables, list)
173
        and len(complexity_of_variables) != X.shape[1]
174
    ):
175
        raise ValueError(
3✔
176
            "The number of elements in `complexity_of_variables` must equal the number of features in `X`."
177
        )
178
    if X_units is not None and len(X_units) != X.shape[1]:
3✔
179
        raise ValueError(
3✔
180
            "The number of units in `X_units` must equal the number of features in `X`."
181
        )
182
    if y_units is not None:
3✔
183
        good_y_units = False
3✔
184
        if isinstance(y_units, list):
3✔
185
            if len(y.shape) == 1:
3✔
186
                good_y_units = len(y_units) == 1
3✔
187
            else:
188
                good_y_units = len(y_units) == y.shape[1]
3✔
189
        else:
190
            good_y_units = len(y.shape) == 1 or y.shape[1] == 1
3✔
191

192
        if not good_y_units:
3✔
193
            raise ValueError(
3✔
194
                "The number of units in `y_units` must equal the number of output features in `y`."
195
            )
196

197

198
def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings):
3✔
199
    # It is expected extra_jax/torch_mappings will be updated after fit.
200
    # Thus, validation is performed here instead of in _validate_init_params
201
    if extra_jax_mappings is not None:
3✔
202
        for value in extra_jax_mappings.values():
1✔
203
            if not isinstance(value, str):
1✔
NEW
204
                raise ValueError(
×
205
                    "extra_jax_mappings must have keys that are strings! "
206
                    "e.g., {sympy.sqrt: 'jnp.sqrt'}."
207
                )
208
    if extra_torch_mappings is not None:
3✔
209
        for value in extra_torch_mappings.values():
1✔
210
            if not callable(value):
1✔
NEW
211
                raise ValueError(
×
212
                    "extra_torch_mappings must be callable functions! "
213
                    "e.g., {sympy.sqrt: torch.sqrt}."
214
                )
215

216

217
# Class validation constants
218
VALID_OPTIMIZER_ALGORITHMS = ["BFGS", "NelderMead"]
3✔
219

220

221
@dataclass
3✔
222
class _DynamicallySetParams:
3✔
223
    """Defines some parameters that are set at runtime."""
224

225
    binary_operators: list[str]
3✔
226
    unary_operators: list[str]
3✔
227
    maxdepth: int
3✔
228
    constraints: dict[str, int | tuple[int, int]]
3✔
229
    batch_size: int
3✔
230
    update_verbosity: int
3✔
231
    progress: bool
3✔
232
    warmup_maxsize_by: float
3✔
233

234

235
class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
3✔
236
    """
237
    High-performance symbolic regression algorithm.
238

239
    This is the scikit-learn interface for SymbolicRegression.jl.
240
    This model will automatically search for equations which fit
241
    a given dataset subject to a particular loss and set of
242
    constraints.
243

244
    Most default parameters have been tuned over several example equations,
245
    but you should adjust `niterations`, `binary_operators`, `unary_operators`
246
    to your requirements. You can view more detailed explanations of the options
247
    on the [options page](https://ai.damtp.cam.ac.uk/pysr/options) of the
248
    documentation.
249

250
    Parameters
251
    ----------
252
    model_selection : str
253
        Model selection criterion when selecting a final expression from
254
        the list of best expression at each complexity.
255
        Can be `'accuracy'`, `'best'`, or `'score'`. Default is `'best'`.
256
        `'accuracy'` selects the candidate model with the lowest loss
257
        (highest accuracy).
258
        `'score'` selects the candidate model with the highest score.
259
        Score is defined as the negated derivative of the log-loss with
260
        respect to complexity - if an expression has a much better
261
        loss at a slightly higher complexity, it is preferred.
262
        `'best'` selects the candidate model with the highest score
263
        among expressions with a loss better than at least 1.5x the
264
        most accurate model.
265
    binary_operators : list[str]
266
        List of strings for binary operators used in the search.
267
        See the [operators page](https://ai.damtp.cam.ac.uk/pysr/operators/)
268
        for more details.
269
        Default is `["+", "-", "*", "/"]`.
270
    unary_operators : list[str]
271
        Operators which only take a single scalar as input.
272
        For example, `"cos"` or `"exp"`.
273
        Default is `None`.
274
    expression_spec : AbstractExpressionSpec
275
        The type of expression to search for. By default,
276
        this is just `ExpressionSpec()`. You can also use
277
        `TemplateExpressionSpec(...)` which allows you to specify
278
        a custom template for the expressions.
279
        Default is `ExpressionSpec()`.
280
    niterations : int
281
        Number of iterations of the algorithm to run. The best
282
        equations are printed and migrate between populations at the
283
        end of each iteration.
284
        Default is `100`.
285
    populations : int
286
        Number of populations running.
287
        Default is `31`.
288
    population_size : int
289
        Number of individuals in each population.
290
        Default is `27`.
291
    max_evals : int
292
        Limits the total number of evaluations of expressions to
293
        this number.  Default is `None`.
294
    maxsize : int
295
        Max complexity of an equation.  Default is `30`.
296
    maxdepth : int
297
        Max depth of an equation. You can use both `maxsize` and
298
        `maxdepth`. `maxdepth` is by default not used.
299
        Default is `None`.
300
    warmup_maxsize_by : float
301
        Whether to slowly increase max size from a small number up to
302
        the maxsize (if greater than 0).  If greater than 0, says the
303
        fraction of training time at which the current maxsize will
304
        reach the user-passed maxsize.
305
        Default is `0.0`.
306
    timeout_in_seconds : float
307
        Make the search return early once this many seconds have passed.
308
        Default is `None`.
309
    constraints : dict[str, int | tuple[int,int]]
310
        Dictionary of int (unary) or 2-tuples (binary), this enforces
311
        maxsize constraints on the individual arguments of operators.
312
        E.g., `'pow': (-1, 1)` says that power laws can have any
313
        complexity left argument, but only 1 complexity in the right
314
        argument. Use this to force more interpretable solutions.
315
        Default is `None`.
316
    nested_constraints : dict[str, dict]
317
        Specifies how many times a combination of operators can be
318
        nested. For example, `{"sin": {"cos": 0}}, "cos": {"cos": 2}}`
319
        specifies that `cos` may never appear within a `sin`, but `sin`
320
        can be nested with itself an unlimited number of times. The
321
        second term specifies that `cos` can be nested up to 2 times
322
        within a `cos`, so that `cos(cos(cos(x)))` is allowed
323
        (as well as any combination of `+` or `-` within it), but
324
        `cos(cos(cos(cos(x))))` is not allowed. When an operator is not
325
        specified, it is assumed that it can be nested an unlimited
326
        number of times. This requires that there is no operator which
327
        is used both in the unary operators and the binary operators
328
        (e.g., `-` could be both subtract, and negation). For binary
329
        operators, you only need to provide a single number: both
330
        arguments are treated the same way, and the max of each
331
        argument is constrained.
332
        Default is `None`.
333
    elementwise_loss : str
334
        String of Julia code specifying an elementwise loss function.
335
        Can either be a loss from LossFunctions.jl, or your own loss
336
        written as a function. Examples of custom written losses include:
337
        `myloss(x, y) = abs(x-y)` for non-weighted, or
338
        `myloss(x, y, w) = w*abs(x-y)` for weighted.
339
        The included losses include:
340
        Regression: `LPDistLoss{P}()`, `L1DistLoss()`,
341
        `L2DistLoss()` (mean square), `LogitDistLoss()`,
342
        `HuberLoss(d)`, `L1EpsilonInsLoss(ϵ)`, `L2EpsilonInsLoss(ϵ)`,
343
        `PeriodicLoss(c)`, `QuantileLoss(τ)`.
344
        Classification: `ZeroOneLoss()`, `PerceptronLoss()`,
345
        `L1HingeLoss()`, `SmoothedL1HingeLoss(γ)`,
346
        `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`,
347
        `SigmoidLoss()`, `DWDMarginLoss(q)`.
348
        Default is `"L2DistLoss()"`.
349
    loss_function : str
350
        Alternatively, you can specify the full objective function as
351
        a snippet of Julia code, including any sort of custom evaluation
352
        (including symbolic manipulations beforehand), and any sort
353
        of loss function or regularizations. The default `loss_function`
354
        used in SymbolicRegression.jl is roughly equal to:
355
        ```julia
356
        function eval_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
357
            prediction, flag = eval_tree_array(tree, dataset.X, options)
358
            if !flag
359
                return L(Inf)
360
            end
361
            return sum((prediction .- dataset.y) .^ 2) / dataset.n
362
        end
363
        ```
364
        where the example elementwise loss is mean-squared error.
365
        You may pass a function with the same arguments as this (note
366
        that the name of the function doesn't matter). Here,
367
        both `prediction` and `dataset.y` are 1D arrays of length `dataset.n`.
368
        If using `batching`, then you should add an
369
        `idx` argument to the function, which is `nothing`
370
        for non-batched, and a 1D array of indices for batched.
371
        Default is `None`.
372
    complexity_of_operators : dict[str, int | float]
373
        If you would like to use a complexity other than 1 for an
374
        operator, specify the complexity here. For example,
375
        `{"sin": 2, "+": 1}` would give a complexity of 2 for each use
376
        of the `sin` operator, and a complexity of 1 for each use of
377
        the `+` operator (which is the default). You may specify real
378
        numbers for a complexity, and the total complexity of a tree
379
        will be rounded to the nearest integer after computing.
380
        Default is `None`.
381
    complexity_of_constants : int | float
382
        Complexity of constants. Default is `1`.
383
    complexity_of_variables : int | float | list[int | float]
384
        Global complexity of variables. To set different complexities for
385
        different variables, pass a list of complexities to the `fit` method
386
        with keyword `complexity_of_variables`. You cannot use both.
387
        Default is `1`.
388
    complexity_mapping : str
389
        Alternatively, you can pass a function (a string of Julia code) that
390
        takes the expression as input and returns the complexity. Make sure that
391
        this operates on `AbstractExpression` (and unpacks to `AbstractExpressionNode`),
392
        and returns an integer.
393
        Default is `None`.
394
    parsimony : float
395
        Multiplicative factor for how much to punish complexity.
396
        Default is `0.0`.
397
    dimensional_constraint_penalty : float
398
        Additive penalty for if dimensional analysis of an expression fails.
399
        By default, this is `1000.0`.
400
    dimensionless_constants_only : bool
401
        Whether to only search for dimensionless constants, if using units.
402
        Default is `False`.
403
    use_frequency : bool
404
        Whether to measure the frequency of complexities, and use that
405
        instead of parsimony to explore equation space. Will naturally
406
        find equations of all complexities.
407
        Default is `True`.
408
    use_frequency_in_tournament : bool
409
        Whether to use the frequency mentioned above in the tournament,
410
        rather than just the simulated annealing.
411
        Default is `True`.
412
    adaptive_parsimony_scaling : float
413
        If the adaptive parsimony strategy (`use_frequency` and
414
        `use_frequency_in_tournament`), this is how much to (exponentially)
415
        weight the contribution. If you find that the search is only optimizing
416
        the most complex expressions while the simpler expressions remain stagnant,
417
        you should increase this value.
418
        Default is `1040.0`.
419
    alpha : float
420
        Initial temperature for simulated annealing
421
        (requires `annealing` to be `True`).
422
        Default is `3.17`.
423
    annealing : bool
424
        Whether to use annealing.  Default is `False`.
425
    early_stop_condition : float | str
426
        Stop the search early if this loss is reached. You may also
427
        pass a string containing a Julia function which
428
        takes a loss and complexity as input, for example:
429
        `"f(loss, complexity) = (loss < 0.1) && (complexity < 10)"`.
430
        Default is `None`.
431
    ncycles_per_iteration : int
432
        Number of total mutations to run, per 10 samples of the
433
        population, per iteration.
434
        Default is `380`.
435
    fraction_replaced : float
436
        How much of population to replace with migrating equations from
437
        other populations.
438
        Default is `0.00036`.
439
    fraction_replaced_hof : float
440
        How much of population to replace with migrating equations from
441
        hall of fame. Default is `0.0614`.
442
    weight_add_node : float
443
        Relative likelihood for mutation to add a node.
444
        Default is `2.47`.
445
    weight_insert_node : float
446
        Relative likelihood for mutation to insert a node.
447
        Default is `0.0112`.
448
    weight_delete_node : float
449
        Relative likelihood for mutation to delete a node.
450
        Default is `0.870`.
451
    weight_do_nothing : float
452
        Relative likelihood for mutation to leave the individual.
453
        Default is `0.273`.
454
    weight_mutate_constant : float
455
        Relative likelihood for mutation to change the constant slightly
456
        in a random direction.
457
        Default is `0.0346`.
458
    weight_mutate_operator : float
459
        Relative likelihood for mutation to swap an operator.
460
        Default is `0.293`.
461
    weight_swap_operands : float
462
        Relative likehood for swapping operands in binary operators.
463
        Default is `0.198`.
464
    weight_rotate_tree : float
465
        How often to perform a tree rotation at a random node.
466
        Default is `4.26`.
467
    weight_randomize : float
468
        Relative likelihood for mutation to completely delete and then
469
        randomly generate the equation
470
        Default is `0.000502`.
471
    weight_simplify : float
472
        Relative likelihood for mutation to simplify constant parts by evaluation
473
        Default is `0.00209`.
474
    weight_optimize: float
475
        Constant optimization can also be performed as a mutation, in addition to
476
        the normal strategy controlled by `optimize_probability` which happens
477
        every iteration. Using it as a mutation is useful if you want to use
478
        a large `ncycles_periteration`, and may not optimize very often.
479
        Default is `0.0`.
480
    crossover_probability : float
481
        Absolute probability of crossover-type genetic operation, instead of a mutation.
482
        Default is `0.0259`.
483
    skip_mutation_failures : bool
484
        Whether to skip mutation and crossover failures, rather than
485
        simply re-sampling the current member.
486
        Default is `True`.
487
    migration : bool
488
        Whether to migrate.  Default is `True`.
489
    hof_migration : bool
490
        Whether to have the hall of fame migrate.  Default is `True`.
491
    topn : int
492
        How many top individuals migrate from each population.
493
        Default is `12`.
494
    should_simplify : bool
495
        Whether to use algebraic simplification in the search. Note that only
496
        a few simple rules are implemented. Default is `True`.
497
    should_optimize_constants : bool
498
        Whether to numerically optimize constants (Nelder-Mead/Newton)
499
        at the end of each iteration. Default is `True`.
500
    optimizer_algorithm : str
501
        Optimization scheme to use for optimizing constants. Can currently
502
        be `NelderMead` or `BFGS`.
503
        Default is `"BFGS"`.
504
    optimizer_nrestarts : int
505
        Number of time to restart the constants optimization process with
506
        different initial conditions.
507
        Default is `2`.
508
    optimizer_f_calls_limit : int
509
        How many function calls to allow during optimization.
510
        Default is `10_000`.
511
    optimize_probability : float
512
        Probability of optimizing the constants during a single iteration of
513
        the evolutionary algorithm.
514
        Default is `0.14`.
515
    optimizer_iterations : int
516
        Number of iterations that the constants optimizer can take.
517
        Default is `8`.
518
    perturbation_factor : float
519
        Constants are perturbed by a max factor of
520
        (perturbation_factor*T + 1). Either multiplied by this or
521
        divided by this.
522
        Default is `0.129`.
523
    probability_negate_constant : float
524
        Probability of negating a constant in the equation when mutating it.
525
        Default is `0.00743`.
526
    tournament_selection_n : int
527
        Number of expressions to consider in each tournament.
528
        Default is `15`.
529
    tournament_selection_p : float
530
        Probability of selecting the best expression in each
531
        tournament. The probability will decay as p*(1-p)^n for other
532
        expressions, sorted by loss.
533
        Default is `0.982`.
534
    parallelism: Literal["serial", "multithreading", "multiprocessing"] | None
535
        Parallelism to use for the search. Can be `"serial"`, `"multithreading"`, or `"multiprocessing"`.
536
        Default is `"multithreading"`.
537
    procs: int | None
538
        Number of processes to use for parallelism. If `None`, defaults to `cpu_count()`.
539
        Default is `None`.
540
    cluster_manager : str
541
        For distributed computing, this sets the job queue system. Set
542
        to one of "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", or
543
        "htc". If set to one of these, PySR will run in distributed
544
        mode, and use `procs` to figure out how many processes to launch.
545
        Default is `None`.
546
    heap_size_hint_in_bytes : int
547
        For multiprocessing, this sets the `--heap-size-hint` parameter
548
        for new Julia processes. This can be configured when using
549
        multi-node distributed compute, to give a hint to each process
550
        about how much memory they can use before aggressive garbage
551
        collection.
552
    batching : bool
553
        Whether to compare population members on small batches during
554
        evolution. Still uses full dataset for comparing against hall
555
        of fame. Default is `False`.
556
    batch_size : int
557
        The amount of data to use if doing batching. Default is `50`.
558
    fast_cycle : bool
559
        Batch over population subsamples. This is a slightly different
560
        algorithm than regularized evolution, but does cycles 15%
561
        faster. May be algorithmically less efficient.
562
        Default is `False`.
563
    turbo: bool
564
        (Experimental) Whether to use LoopVectorization.jl to speed up the
565
        search evaluation. Certain operators may not be supported.
566
        Does not support 16-bit precision floats.
567
        Default is `False`.
568
    bumper: bool
569
        (Experimental) Whether to use Bumper.jl to speed up the search
570
        evaluation. Does not support 16-bit precision floats.
571
        Default is `False`.
572
    precision : int
573
        What precision to use for the data. By default this is `32`
574
        (float32), but you can select `64` or `16` as well, giving
575
        you 64 or 16 bits of floating point precision, respectively.
576
        If you pass complex data, the corresponding complex precision
577
        will be used (i.e., `64` for complex128, `32` for complex64).
578
        Default is `32`.
579
    autodiff_backend : Literal["Zygote"] | None
580
        Which backend to use for automatic differentiation during constant
581
        optimization. Currently only `"Zygote"` is supported. The default,
582
        `None`, uses forward-mode or finite difference.
583
        Default is `None`.
584
    random_state : int, Numpy RandomState instance or None
585
        Pass an int for reproducible results across multiple function calls.
586
        See :term:`Glossary <random_state>`.
587
        Default is `None`.
588
    deterministic : bool
589
        Make a PySR search give the same result every run.
590
        To use this, you must turn off parallelism
591
        (with `parallelism="serial"`),
592
        and set `random_state` to a fixed seed.
593
        Default is `False`.
594
    warm_start : bool
595
        Tells fit to continue from where the last call to fit finished.
596
        If false, each call to fit will be fresh, overwriting previous results.
597
        Default is `False`.
598
    verbosity : int
599
        What verbosity level to use. 0 means minimal print statements.
600
        Default is `1`.
601
    update_verbosity : int
602
        What verbosity level to use for package updates.
603
        Will take value of `verbosity` if not given.
604
        Default is `None`.
605
    print_precision : int
606
        How many significant digits to print for floats. Default is `5`.
607
    progress : bool
608
        Whether to use a progress bar instead of printing to stdout.
609
        Default is `True`.
610
    logger_spec: AbstractLoggerSpec | None
611
        Logger specification for the Julia backend. See, for example,
612
        `TensorBoardLoggerSpec`.
613
        Default is `None`.
614
    run_id : str
615
        A unique identifier for the run. Will be generated using the
616
        current date and time if not provided.
617
        Default is `None`.
618
    output_directory : str
619
        The base directory to save output files to. Files
620
        will be saved in a subdirectory according to the run ID.
621
        Will be set to `outputs/` if not provided.
622
        Default is `None`.
623
    temp_equation_file : bool
624
        Whether to put the hall of fame file in the temp directory.
625
        Deletion is then controlled with the `delete_tempfiles`
626
        parameter.
627
        Default is `False`.
628
    tempdir : str
629
        directory for the temporary files. Default is `None`.
630
    delete_tempfiles : bool
631
        Whether to delete the temporary files after finishing.
632
        Default is `True`.
633
    update: bool
634
        Whether to automatically update Julia packages when `fit` is called.
635
        You should make sure that PySR is up-to-date itself first, as
636
        the packaged Julia packages may not necessarily include all
637
        updated dependencies.
638
        Default is `False`.
639
    output_jax_format : bool
640
        Whether to create a 'jax_format' column in the output,
641
        containing jax-callable functions and the default parameters in
642
        a jax array.
643
        Default is `False`.
644
    output_torch_format : bool
645
        Whether to create a 'torch_format' column in the output,
646
        containing a torch module with trainable parameters.
647
        Default is `False`.
648
    extra_sympy_mappings : dict[str, Callable]
649
        Provides mappings between custom `binary_operators` or
650
        `unary_operators` defined in julia strings, to those same
651
        operators defined in sympy.
652
        E.G if `unary_operators=["inv(x)=1/x"]`, then for the fitted
653
        model to be export to sympy, `extra_sympy_mappings`
654
        would be `{"inv": lambda x: 1/x}`.
655
        Default is `None`.
656
    extra_jax_mappings : dict[Callable, str]
657
        Similar to `extra_sympy_mappings` but for model export
658
        to jax. The dictionary maps sympy functions to jax functions.
659
        For example: `extra_jax_mappings={sympy.sin: "jnp.sin"}` maps
660
        the `sympy.sin` function to the equivalent jax expression `jnp.sin`.
661
        Default is `None`.
662
    extra_torch_mappings : dict[Callable, Callable]
663
        The same as `extra_jax_mappings` but for model export
664
        to pytorch. Note that the dictionary keys should be callable
665
        pytorch expressions.
666
        For example: `extra_torch_mappings={sympy.sin: torch.sin}`.
667
        Default is `None`.
668
    denoise : bool
669
        Whether to use a Gaussian Process to denoise the data before
670
        inputting to PySR. Can help PySR fit noisy data.
671
        Default is `False`.
672
    select_k_features : int
673
        Whether to run feature selection in Python using random forests,
674
        before passing to the symbolic regression code. None means no
675
        feature selection; an int means select that many features.
676
        Default is `None`.
677
    **kwargs : dict
678
        Supports deprecated keyword arguments. Other arguments will
679
        result in an error.
680
    Attributes
681
    ----------
682
    equations_ : pandas.DataFrame | list[pandas.DataFrame]
683
        Processed DataFrame containing the results of model fitting.
684
    n_features_in_ : int
685
        Number of features seen during :term:`fit`.
686
    feature_names_in_ : ndarray of shape (`n_features_in_`,)
687
        Names of features seen during :term:`fit`. Defined only when `X`
688
        has feature names that are all strings.
689
    display_feature_names_in_ : ndarray of shape (`n_features_in_`,)
690
        Pretty names of features, used only during printing.
691
    X_units_ : list[str] of length n_features
692
        Units of each variable in the training dataset, `X`.
693
    y_units_ : str | list[str] of length n_out
694
        Units of each variable in the training dataset, `y`.
695
    nout_ : int
696
        Number of output dimensions.
697
    selection_mask_ : ndarray of shape (`n_features_in_`,)
698
        Mask of which features of `X` to use when `select_k_features` is set.
699
    tempdir_ : Path | None
700
        Path to the temporary equations directory.
701
    julia_state_stream_ : ndarray
702
        The serialized state for the julia SymbolicRegression.jl backend (after fitting),
703
        stored as an array of uint8, produced by Julia's Serialization.serialize function.
704
    julia_options_stream_ : ndarray
705
        The serialized julia options, stored as an array of uint8,
706
    expression_spec_ : AbstractExpressionSpec
707
        The expression specification used for this fit. This is equal to
708
        `self.expression_spec` if provided, or `ExpressionSpec()` otherwise.
709
    equation_file_contents_ : list[pandas.DataFrame]
710
        Contents of the equation file output by the Julia backend.
711
    show_pickle_warnings_ : bool
712
        Whether to show warnings about what attributes can be pickled.
713

714
    Examples
715
    --------
716
    ```python
717
    >>> import numpy as np
718
    >>> from pysr import PySRRegressor
719
    >>> randstate = np.random.RandomState(0)
720
    >>> X = 2 * randstate.randn(100, 5)
721
    >>> # y = 2.5382 * cos(x_3) + x_0 - 0.5
722
    >>> y = 2.5382 * np.cos(X[:, 3]) + X[:, 0] ** 2 - 0.5
723
    >>> model = PySRRegressor(
724
    ...     niterations=40,
725
    ...     binary_operators=["+", "*"],
726
    ...     unary_operators=[
727
    ...         "cos",
728
    ...         "exp",
729
    ...         "sin",
730
    ...         "inv(x) = 1/x",  # Custom operator (julia syntax)
731
    ...     ],
732
    ...     model_selection="best",
733
    ...     elementwise_loss="loss(x, y) = (x - y)^2",  # Custom loss function (julia syntax)
734
    ... )
735
    >>> model.fit(X, y)
736
    >>> model
737
    PySRRegressor.equations_ = [
738
    0         0.000000                                          3.8552167  3.360272e+01           1
739
    1         1.189847                                          (x0 * x0)  3.110905e+00           3
740
    2         0.010626                          ((x0 * x0) + -0.25573406)  3.045491e+00           5
741
    3         0.896632                              (cos(x3) + (x0 * x0))  1.242382e+00           6
742
    4         0.811362                ((x0 * x0) + (cos(x3) * 2.4384754))  2.451971e-01           8
743
    5  >>>>  13.733371          (((cos(x3) * 2.5382) + (x0 * x0)) + -0.5)  2.889755e-13          10
744
    6         0.194695  ((x0 * x0) + (((cos(x3) + -0.063180044) * 2.53...  1.957723e-13          12
745
    7         0.006988  ((x0 * x0) + (((cos(x3) + -0.32505524) * 1.538...  1.944089e-13          13
746
    8         0.000955  (((((x0 * x0) + cos(x3)) + -0.8251649) + (cos(...  1.940381e-13          15
747
    ]
748
    >>> model.score(X, y)
749
    1.0
750
    >>> model.predict(np.array([1,2,3,4,5]))
751
    array([-1.15907818, -1.15907818, -1.15907818, -1.15907818, -1.15907818])
752
    ```
753
    """
754

755
    equations_: pd.DataFrame | list[pd.DataFrame] | None
3✔
756
    n_features_in_: int
3✔
757
    feature_names_in_: ArrayLike[str]
3✔
758
    display_feature_names_in_: ArrayLike[str]
3✔
759
    complexity_of_variables_: int | float | list[int | float] | None
3✔
760
    X_units_: ArrayLike[str] | None
3✔
761
    y_units_: str | ArrayLike[str] | None
3✔
762
    nout_: int
3✔
763
    selection_mask_: NDArray[np.bool_] | None
3✔
764
    run_id_: str
3✔
765
    output_directory_: str
3✔
766
    julia_state_stream_: NDArray[np.uint8] | None
3✔
767
    julia_options_stream_: NDArray[np.uint8] | None
3✔
768
    equation_file_contents_: list[pd.DataFrame] | None
3✔
769
    show_pickle_warnings_: bool
3✔
770

771
    def __init__(
3✔
772
        self,
773
        model_selection: Literal["best", "accuracy", "score"] = "best",
774
        *,
775
        binary_operators: list[str] | None = None,
776
        unary_operators: list[str] | None = None,
777
        expression_spec: AbstractExpressionSpec | None = None,
778
        niterations: int = 100,
779
        populations: int = 31,
780
        population_size: int = 27,
781
        max_evals: int | None = None,
782
        maxsize: int = 30,
783
        maxdepth: int | None = None,
784
        warmup_maxsize_by: float | None = None,
785
        timeout_in_seconds: float | None = None,
786
        constraints: dict[str, int | tuple[int, int]] | None = None,
787
        nested_constraints: dict[str, dict[str, int]] | None = None,
788
        elementwise_loss: str | None = None,
789
        loss_function: str | None = None,
790
        complexity_of_operators: dict[str, int | float] | None = None,
791
        complexity_of_constants: int | float | None = None,
792
        complexity_of_variables: int | float | list[int | float] | None = None,
793
        complexity_mapping: str | None = None,
794
        parsimony: float = 0.0,
795
        dimensional_constraint_penalty: float | None = None,
796
        dimensionless_constants_only: bool = False,
797
        use_frequency: bool = True,
798
        use_frequency_in_tournament: bool = True,
799
        adaptive_parsimony_scaling: float = 1040.0,
800
        alpha: float = 3.17,
801
        annealing: bool = False,
802
        early_stop_condition: float | str | None = None,
803
        ncycles_per_iteration: int = 380,
804
        fraction_replaced: float = 0.00036,
805
        fraction_replaced_hof: float = 0.0614,
806
        weight_add_node: float = 2.47,
807
        weight_insert_node: float = 0.0112,
808
        weight_delete_node: float = 0.870,
809
        weight_do_nothing: float = 0.273,
810
        weight_mutate_constant: float = 0.0346,
811
        weight_mutate_operator: float = 0.293,
812
        weight_swap_operands: float = 0.198,
813
        weight_rotate_tree: float = 4.26,
814
        weight_randomize: float = 0.000502,
815
        weight_simplify: float = 0.00209,
816
        weight_optimize: float = 0.0,
817
        crossover_probability: float = 0.0259,
818
        skip_mutation_failures: bool = True,
819
        migration: bool = True,
820
        hof_migration: bool = True,
821
        topn: int = 12,
822
        should_simplify: bool = True,
823
        should_optimize_constants: bool = True,
824
        optimizer_algorithm: Literal["BFGS", "NelderMead"] = "BFGS",
825
        optimizer_nrestarts: int = 2,
826
        optimizer_f_calls_limit: int | None = None,
827
        optimize_probability: float = 0.14,
828
        optimizer_iterations: int = 8,
829
        perturbation_factor: float = 0.129,
830
        probability_negate_constant: float = 0.00743,
831
        tournament_selection_n: int = 15,
832
        tournament_selection_p: float = 0.982,
833
        parallelism: (
834
            Literal["serial", "multithreading", "multiprocessing"] | None
835
        ) = None,
836
        procs: int | None = None,
837
        cluster_manager: (
838
            Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None
839
        ) = None,
840
        heap_size_hint_in_bytes: int | None = None,
841
        batching: bool = False,
842
        batch_size: int = 50,
843
        fast_cycle: bool = False,
844
        turbo: bool = False,
845
        bumper: bool = False,
846
        precision: Literal[16, 32, 64] = 32,
847
        autodiff_backend: Literal["Zygote"] | None = None,
848
        random_state: int | np.random.RandomState | None = None,
849
        deterministic: bool = False,
850
        warm_start: bool = False,
851
        verbosity: int = 1,
852
        update_verbosity: int | None = None,
853
        print_precision: int = 5,
854
        progress: bool = True,
855
        logger_spec: AbstractLoggerSpec | None = None,
856
        run_id: str | None = None,
857
        output_directory: str | None = None,
858
        temp_equation_file: bool = False,
859
        tempdir: str | None = None,
860
        delete_tempfiles: bool = True,
861
        update: bool = False,
862
        output_jax_format: bool = False,
863
        output_torch_format: bool = False,
864
        extra_sympy_mappings: dict[str, Callable] | None = None,
865
        extra_torch_mappings: dict[Callable, Callable] | None = None,
866
        extra_jax_mappings: dict[Callable, str] | None = None,
867
        denoise: bool = False,
868
        select_k_features: int | None = None,
869
        **kwargs,
870
    ):
871
        # Hyperparameters
872
        # - Model search parameters
873
        self.model_selection = model_selection
3✔
874
        self.binary_operators = binary_operators
3✔
875
        self.unary_operators = unary_operators
3✔
876
        self.expression_spec = expression_spec
3✔
877
        self.niterations = niterations
3✔
878
        self.populations = populations
3✔
879
        self.population_size = population_size
3✔
880
        self.ncycles_per_iteration = ncycles_per_iteration
3✔
881
        # - Equation Constraints
882
        self.maxsize = maxsize
3✔
883
        self.maxdepth = maxdepth
3✔
884
        self.constraints = constraints
3✔
885
        self.nested_constraints = nested_constraints
3✔
886
        self.warmup_maxsize_by = warmup_maxsize_by
3✔
887
        self.should_simplify = should_simplify
3✔
888
        # - Early exit conditions:
889
        self.max_evals = max_evals
3✔
890
        self.timeout_in_seconds = timeout_in_seconds
3✔
891
        self.early_stop_condition = early_stop_condition
3✔
892
        # - Loss parameters
893
        self.elementwise_loss = elementwise_loss
3✔
894
        self.loss_function = loss_function
3✔
895
        self.complexity_of_operators = complexity_of_operators
3✔
896
        self.complexity_of_constants = complexity_of_constants
3✔
897
        self.complexity_of_variables = complexity_of_variables
3✔
898
        self.complexity_mapping = complexity_mapping
3✔
899
        self.parsimony = parsimony
3✔
900
        self.dimensional_constraint_penalty = dimensional_constraint_penalty
3✔
901
        self.dimensionless_constants_only = dimensionless_constants_only
3✔
902
        self.use_frequency = use_frequency
3✔
903
        self.use_frequency_in_tournament = use_frequency_in_tournament
3✔
904
        self.adaptive_parsimony_scaling = adaptive_parsimony_scaling
3✔
905
        self.alpha = alpha
3✔
906
        self.annealing = annealing
3✔
907
        # - Evolutionary search parameters
908
        # -- Mutation parameters
909
        self.weight_add_node = weight_add_node
3✔
910
        self.weight_insert_node = weight_insert_node
3✔
911
        self.weight_delete_node = weight_delete_node
3✔
912
        self.weight_do_nothing = weight_do_nothing
3✔
913
        self.weight_mutate_constant = weight_mutate_constant
3✔
914
        self.weight_mutate_operator = weight_mutate_operator
3✔
915
        self.weight_swap_operands = weight_swap_operands
3✔
916
        self.weight_rotate_tree = weight_rotate_tree
3✔
917
        self.weight_randomize = weight_randomize
3✔
918
        self.weight_simplify = weight_simplify
3✔
919
        self.weight_optimize = weight_optimize
3✔
920
        self.crossover_probability = crossover_probability
3✔
921
        self.skip_mutation_failures = skip_mutation_failures
3✔
922
        # -- Migration parameters
923
        self.migration = migration
3✔
924
        self.hof_migration = hof_migration
3✔
925
        self.fraction_replaced = fraction_replaced
3✔
926
        self.fraction_replaced_hof = fraction_replaced_hof
3✔
927
        self.topn = topn
3✔
928
        # -- Constants parameters
929
        self.should_optimize_constants = should_optimize_constants
3✔
930
        self.optimizer_algorithm = optimizer_algorithm
3✔
931
        self.optimizer_nrestarts = optimizer_nrestarts
3✔
932
        self.optimizer_f_calls_limit = optimizer_f_calls_limit
3✔
933
        self.optimize_probability = optimize_probability
3✔
934
        self.optimizer_iterations = optimizer_iterations
3✔
935
        self.perturbation_factor = perturbation_factor
3✔
936
        self.probability_negate_constant = probability_negate_constant
3✔
937
        # -- Selection parameters
938
        self.tournament_selection_n = tournament_selection_n
3✔
939
        self.tournament_selection_p = tournament_selection_p
3✔
940
        # -- Performance parameters
941
        self.parallelism = parallelism
3✔
942
        self.procs = procs
3✔
943
        self.cluster_manager = cluster_manager
3✔
944
        self.heap_size_hint_in_bytes = heap_size_hint_in_bytes
3✔
945
        self.batching = batching
3✔
946
        self.batch_size = batch_size
3✔
947
        self.fast_cycle = fast_cycle
3✔
948
        self.turbo = turbo
3✔
949
        self.bumper = bumper
3✔
950
        self.precision = precision
3✔
951
        self.autodiff_backend = autodiff_backend
3✔
952
        self.random_state = random_state
3✔
953
        self.deterministic = deterministic
3✔
954
        self.warm_start = warm_start
3✔
955
        # Additional runtime parameters
956
        # - Runtime user interface
957
        self.verbosity = verbosity
3✔
958
        self.update_verbosity = update_verbosity
3✔
959
        self.print_precision = print_precision
3✔
960
        self.progress = progress
3✔
961
        self.logger_spec = logger_spec
3✔
962
        # - Project management
963
        self.run_id = run_id
3✔
964
        self.output_directory = output_directory
3✔
965
        self.temp_equation_file = temp_equation_file
3✔
966
        self.tempdir = tempdir
3✔
967
        self.delete_tempfiles = delete_tempfiles
3✔
968
        self.update = update
3✔
969
        self.output_jax_format = output_jax_format
3✔
970
        self.output_torch_format = output_torch_format
3✔
971
        self.extra_sympy_mappings = extra_sympy_mappings
3✔
972
        self.extra_jax_mappings = extra_jax_mappings
3✔
973
        self.extra_torch_mappings = extra_torch_mappings
3✔
974
        # Pre-modelling transformation
975
        self.denoise = denoise
3✔
976
        self.select_k_features = select_k_features
3✔
977

978
        # Once all valid parameters have been assigned handle the
979
        # deprecated kwargs
980
        if len(kwargs) > 0:  # pragma: no cover
981
            for k, v in kwargs.items():
982
                # Handle renamed kwargs
983
                if k in DEPRECATED_KWARGS:
984
                    updated_kwarg_name = DEPRECATED_KWARGS[k]
985
                    setattr(self, updated_kwarg_name, v)
986
                    warnings.warn(
987
                        f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
988
                        "Please use that instead.",
989
                        FutureWarning,
990
                    )
991
                elif k == "multithreading":
992
                    # Specific advice given in `_map_parallelism_params`
993
                    self.multithreading: bool | None = v
994
                # Handle kwargs that have been moved to the fit method
995
                elif k in ["weights", "variable_names", "Xresampled"]:
996
                    warnings.warn(
997
                        f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
998
                        f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
999
                        FutureWarning,
1000
                    )
1001
                elif k == "julia_project":
1002
                    warnings.warn(
1003
                        "The `julia_project` parameter has been deprecated. To use a custom "
1004
                        "julia project, please see `https://ai.damtp.cam.ac.uk/pysr/backend`.",
1005
                        FutureWarning,
1006
                    )
1007
                elif k == "julia_kwargs":
1008
                    warnings.warn(
1009
                        "The `julia_kwargs` parameter has been deprecated. To pass custom "
1010
                        "keyword arguments to the julia backend, you should use environment variables. "
1011
                        "See the Julia documentation for more information.",
1012
                        FutureWarning,
1013
                    )
1014
                else:
1015
                    suggested_keywords = _suggest_keywords(PySRRegressor, k)
1016
                    err_msg = (
1017
                        f"`{k}` is not a valid keyword argument for PySRRegressor."
1018
                    )
1019
                    if len(suggested_keywords) > 0:
1020
                        err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
1021
                    raise TypeError(err_msg)
1022

1023
    @classmethod
3✔
1024
    def from_file(
3✔
1025
        cls,
1026
        equation_file: None = None,  # Deprecated
1027
        *,
1028
        run_directory: PathLike,
1029
        binary_operators: list[str] | None = None,
1030
        unary_operators: list[str] | None = None,
1031
        n_features_in: int | None = None,
1032
        feature_names_in: ArrayLike[str] | None = None,
1033
        selection_mask: NDArray[np.bool_] | None = None,
1034
        nout: int = 1,
1035
        **pysr_kwargs,
1036
    ) -> "PySRRegressor":
1037
        """
1038
        Create a model from a saved model checkpoint or equation file.
1039

1040
        Parameters
1041
        ----------
1042
        run_directory : str
1043
            The directory containing outputs from a previous run.
1044
            This is of the form `[output_directory]/[run_id]`.
1045
            Default is `None`.
1046
        binary_operators : list[str]
1047
            The same binary operators used when creating the model.
1048
            Not needed if loading from a pickle file.
1049
        unary_operators : list[str]
1050
            The same unary operators used when creating the model.
1051
            Not needed if loading from a pickle file.
1052
        n_features_in : int
1053
            Number of features passed to the model.
1054
            Not needed if loading from a pickle file.
1055
        feature_names_in : list[str]
1056
            Names of the features passed to the model.
1057
            Not needed if loading from a pickle file.
1058
        selection_mask : NDArray[np.bool_]
1059
            If using `select_k_features`, you must pass `model.selection_mask_` here.
1060
            Not needed if loading from a pickle file.
1061
        nout : int
1062
            Number of outputs of the model.
1063
            Not needed if loading from a pickle file.
1064
            Default is `1`.
1065
        **pysr_kwargs : dict
1066
            Any other keyword arguments to initialize the PySRRegressor object.
1067
            These will overwrite those stored in the pickle file.
1068
            Not needed if loading from a pickle file.
1069

1070
        Returns
1071
        -------
1072
        model : PySRRegressor
1073
            The model with fitted equations.
1074
        """
1075
        if equation_file is not None:
3✔
1076
            raise ValueError(
3✔
1077
                "Passing `equation_file` is deprecated and no longer compatible with "
1078
                "the most recent versions of PySR's backend. Please pass `run_directory` "
1079
                "instead, which contains all checkpoint files."
1080
            )
1081

1082
        pkl_filename = Path(run_directory) / "checkpoint.pkl"
3✔
1083
        if pkl_filename.exists():
3✔
1084
            print(f"Attempting to load model from {pkl_filename}...")
3✔
1085
            assert binary_operators is None
3✔
1086
            assert unary_operators is None
3✔
1087
            assert n_features_in is None
3✔
1088
            with open(pkl_filename, "rb") as f:
3✔
1089
                model = cast("PySRRegressor", pkl.load(f))
3✔
1090

1091
            # Update any parameters if necessary, such as
1092
            # extra_sympy_mappings:
1093
            model.set_params(**pysr_kwargs)
3✔
1094

1095
            if "equations_" not in model.__dict__ or model.equations_ is None:
3✔
1096
                model.refresh()
×
1097

1098
            return model
3✔
1099
        else:
1100
            print(
3✔
1101
                f"Checkpoint file {pkl_filename} does not exist. "
1102
                "Attempting to recreate model from scratch..."
1103
            )
1104
            csv_filename = Path(run_directory) / "hall_of_fame.csv"
3✔
1105
            csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak"
3✔
1106
            if not csv_filename.exists() and not csv_filename_bak.exists():
3✔
NEW
1107
                raise FileNotFoundError(
×
1108
                    f"Hall of fame file `{csv_filename}` or `{csv_filename_bak}` does not exist. "
1109
                    "Please pass a `run_directory` containing a valid checkpoint file."
1110
                )
1111
            assert binary_operators is not None or unary_operators is not None
3✔
1112
            assert n_features_in is not None
3✔
1113
            model = cls(
3✔
1114
                binary_operators=binary_operators,
1115
                unary_operators=unary_operators,
1116
                **pysr_kwargs,
1117
            )
1118
            model.nout_ = nout
3✔
1119
            model.n_features_in_ = n_features_in
3✔
1120

1121
            if feature_names_in is None:
3✔
1122
                model.feature_names_in_ = np.array(
3✔
1123
                    [f"x{i}" for i in range(n_features_in)]
1124
                )
1125
                model.display_feature_names_in_ = np.array(
3✔
1126
                    [f"x{_subscriptify(i)}" for i in range(n_features_in)]
1127
                )
1128
            else:
1129
                assert len(feature_names_in) == n_features_in
3✔
1130
                model.feature_names_in_ = feature_names_in
3✔
1131
                model.display_feature_names_in_ = feature_names_in
3✔
1132

1133
            if selection_mask is None:
3✔
1134
                model.selection_mask_ = np.ones(n_features_in, dtype=np.bool_)
3✔
1135
            else:
NEW
1136
                model.selection_mask_ = selection_mask
×
1137

1138
            model.refresh(run_directory=run_directory)
3✔
1139

1140
            return model
3✔
1141

1142
    def __repr__(self) -> str:
3✔
1143
        """
1144
        Print all current equations fitted by the model.
1145

1146
        The string `>>>>` denotes which equation is selected by the
1147
        `model_selection`.
1148
        """
1149
        if not hasattr(self, "equations_") or self.equations_ is None:
3✔
1150
            return "PySRRegressor.equations_ = None"
3✔
1151

1152
        output = "PySRRegressor.equations_ = [\n"
3✔
1153

1154
        equations = self.equations_
3✔
1155
        if not isinstance(equations, list):
3✔
1156
            all_equations = [equations]
3✔
1157
        else:
1158
            all_equations = equations
×
1159

1160
        for i, equations in enumerate(all_equations):
3✔
1161
            selected = pd.Series([""] * len(equations), index=equations.index)
3✔
1162
            chosen_row = idx_model_selection(equations, self.model_selection)
3✔
1163
            selected[chosen_row] = ">>>>"
3✔
1164
            repr_equations = pd.DataFrame(
3✔
1165
                dict(
1166
                    pick=selected,
1167
                    score=equations["score"],
1168
                    equation=equations["equation"],
1169
                    loss=equations["loss"],
1170
                    complexity=equations["complexity"],
1171
                )
1172
            )
1173

1174
            if len(all_equations) > 1:
3✔
1175
                output += "[\n"
×
1176

1177
            for line in repr_equations.__repr__().split("\n"):
3✔
1178
                output += "\t" + line + "\n"
3✔
1179

1180
            if len(all_equations) > 1:
3✔
1181
                output += "]"
×
1182

1183
            if i < len(all_equations) - 1:
3✔
1184
                output += ", "
×
1185

1186
        output += "]"
3✔
1187
        return output
3✔
1188

1189
    def __getstate__(self) -> dict[str, Any]:
3✔
1190
        """
1191
        Handle pickle serialization for PySRRegressor.
1192

1193
        The Scikit-learn standard requires estimators to be serializable via
1194
        `pickle.dumps()`. However, some attributes do not support pickling
1195
        and need to be hidden, such as the JAX and Torch representations.
1196
        """
1197
        state = self.__dict__
3✔
1198
        show_pickle_warning = not (
3✔
1199
            "show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
1200
        )
1201
        state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
3✔
1202
        for state_key in state_keys_containing_lambdas:
3✔
1203
            if state[state_key] is not None and show_pickle_warning:
3✔
1204
                warnings.warn(
×
1205
                    f"`{state_key}` cannot be pickled and will be removed from the "
1206
                    "serialized instance. When loading the model, please redefine "
1207
                    f"`{state_key}` at runtime."
1208
                )
1209
        state_keys_to_clear = state_keys_containing_lambdas
3✔
1210
        pickled_state = {
3✔
1211
            key: (None if key in state_keys_to_clear else value)
1212
            for key, value in state.items()
1213
        }
1214
        if ("equations_" in pickled_state) and (
3✔
1215
            pickled_state["equations_"] is not None
1216
        ):
1217
            pickled_state["output_torch_format"] = False
3✔
1218
            pickled_state["output_jax_format"] = False
3✔
1219
            if self.nout_ == 1:
3✔
1220
                pickled_columns = ~pickled_state["equations_"].columns.isin(
3✔
1221
                    ["jax_format", "torch_format"]
1222
                )
1223
                pickled_state["equations_"] = (
3✔
1224
                    pickled_state["equations_"].loc[:, pickled_columns].copy()
1225
                )
1226
            else:
1227
                pickled_columns = [
3✔
1228
                    ~dataframe.columns.isin(["jax_format", "torch_format"])
1229
                    for dataframe in pickled_state["equations_"]
1230
                ]
1231
                pickled_state["equations_"] = [
3✔
1232
                    dataframe.loc[:, signle_pickled_columns]
1233
                    for dataframe, signle_pickled_columns in zip(
1234
                        pickled_state["equations_"], pickled_columns
1235
                    )
1236
                ]
1237
        return pickled_state
3✔
1238

1239
    def _checkpoint(self):
3✔
1240
        """Save the model's current state to a checkpoint file.
1241

1242
        This should only be used internally by PySRRegressor.
1243
        """
1244
        # Save model state:
1245
        self.show_pickle_warnings_ = False
3✔
1246
        with open(self.get_pkl_filename(), "wb") as f:
3✔
1247
            try:
3✔
1248
                pkl.dump(self, f)
3✔
1249
            except Exception as e:
1✔
1250
                print(f"Error checkpointing model: {e}")
1✔
1251
        self.show_pickle_warnings_ = True
3✔
1252

1253
    def get_pkl_filename(self) -> Path:
3✔
1254
        path = Path(self.output_directory_) / self.run_id_ / "checkpoint.pkl"
3✔
1255
        path.parent.mkdir(parents=True, exist_ok=True)
3✔
1256
        return path
3✔
1257

1258
    @property
1259
    def equations(self):  # pragma: no cover
1260
        warnings.warn(
1261
            "PySRRegressor.equations is now deprecated. "
1262
            "Please use PySRRegressor.equations_ instead.",
1263
            FutureWarning,
1264
        )
1265
        return self.equations_
1266

1267
    @property
3✔
1268
    def julia_options_(self):
3✔
1269
        """The deserialized julia options."""
1270
        return jl_deserialize(self.julia_options_stream_)
3✔
1271

1272
    @property
3✔
1273
    def julia_state_(self):
3✔
1274
        """The deserialized state."""
1275
        return cast(
3✔
1276
            tuple[VectorValue, AnyValue] | None,
1277
            jl_deserialize(self.julia_state_stream_),
1278
        )
1279

1280
    @property
3✔
1281
    def raw_julia_state_(self):
3✔
1282
        warnings.warn(
3✔
1283
            "PySRRegressor.raw_julia_state_ is now deprecated. "
1284
            "Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
1285
            "for the raw stream of bytes.",
1286
            FutureWarning,
1287
        )
1288
        return self.julia_state_
3✔
1289

1290
    @property
3✔
1291
    def expression_spec_(self):
3✔
1292
        return self.expression_spec or ExpressionSpec()
3✔
1293

1294
    def get_best(
3✔
1295
        self, index: int | list[int] | None = None
1296
    ) -> pd.Series | list[pd.Series]:
1297
        """
1298
        Get best equation using `model_selection`.
1299

1300
        Parameters
1301
        ----------
1302
        index : int | list[int]
1303
            If you wish to select a particular equation from `self.equations_`,
1304
            give the row number here. This overrides the `model_selection`
1305
            parameter. If there are multiple output features, then pass
1306
            a list of indices with the order the same as the output feature.
1307

1308
        Returns
1309
        -------
1310
        best_equation : pandas.Series
1311
            Dictionary representing the best expression found.
1312

1313
        Raises
1314
        ------
1315
        NotImplementedError
1316
            Raised when an invalid model selection strategy is provided.
1317
        """
1318
        check_is_fitted(self, attributes=["equations_"])
3✔
1319

1320
        if index is not None:
3✔
1321
            if isinstance(self.equations_, list):
3✔
1322
                assert isinstance(
3✔
1323
                    index, list
1324
                ), "With multiple output features, index must be a list."
1325
                return [eq.iloc[i] for eq, i in zip(self.equations_, index)]
3✔
1326
            else:
1327
                equations_ = cast(pd.DataFrame, self.equations_)
3✔
1328
                return cast(pd.Series, equations_.iloc[index])
3✔
1329

1330
        if isinstance(self.equations_, list):
3✔
1331
            return [
3✔
1332
                cast(pd.Series, eq.loc[idx_model_selection(eq, self.model_selection)])
1333
                for eq in self.equations_
1334
            ]
1335
        else:
1336
            equations_ = cast(pd.DataFrame, self.equations_)
3✔
1337
            return cast(
3✔
1338
                pd.Series,
1339
                equations_.loc[idx_model_selection(equations_, self.model_selection)],
1340
            )
1341

1342
    @property
3✔
1343
    def equation_file_(self):
3✔
1344
        raise NotImplementedError(
3✔
1345
            "PySRRegressor.equation_file_ is now deprecated. "
1346
            "Please use PySRRegressor.output_directory_ and PySRRegressor.run_id_ "
1347
            "instead. For loading, you should pass `run_directory`."
1348
        )
1349

1350
    def _setup_equation_file(self):
3✔
1351
        """Set the pathname of the output directory."""
1352
        if self.warm_start and (
3✔
1353
            hasattr(self, "run_id_") or hasattr(self, "output_directory_")
1354
        ):
1355
            assert hasattr(self, "output_directory_")
3✔
1356
            assert hasattr(self, "run_id_")
3✔
1357
            if self.run_id is not None:
3✔
1358
                assert self.run_id_ == self.run_id
3✔
1359
            if self.output_directory is not None:
3✔
1360
                assert self.output_directory_ == self.output_directory
3✔
1361
        else:
1362
            self.output_directory_ = (
3✔
1363
                tempfile.mkdtemp()
1364
                if self.temp_equation_file
1365
                else (
1366
                    "outputs"
1367
                    if self.output_directory is None
1368
                    else self.output_directory
1369
                )
1370
            )
1371
            self.run_id_ = (
3✔
1372
                cast(str, SymbolicRegression.SearchUtilsModule.generate_run_id())
1373
                if self.run_id is None
1374
                else self.run_id
1375
            )
1376
            if self.temp_equation_file:
3✔
1377
                assert self.output_directory is None
3✔
1378

1379
    def _clear_equation_file_contents(self):
3✔
1380
        self.equation_file_contents_ = None
3✔
1381

1382
    def _validate_and_modify_params(self) -> _DynamicallySetParams:
3✔
1383
        """
1384
        Ensure parameters passed at initialization are valid.
1385

1386
        Also returns a dictionary of parameters to update from their
1387
        values given at initialization.
1388

1389
        Returns
1390
        -------
1391
        packed_modified_params : dict
1392
            Dictionary of parameters to modify from their initialized
1393
            values. For example, default parameters are set here
1394
            when a parameter is left set to `None`.
1395
        """
1396
        # Immutable parameter validation
1397
        # Ensure instance parameters are allowable values:
1398
        if self.tournament_selection_n > self.population_size:
3✔
1399
            raise ValueError(
3✔
1400
                "`tournament_selection_n` parameter must be smaller than `population_size`."
1401
            )
1402

1403
        if self.maxsize > 40:
3✔
1404
            warnings.warn(
×
1405
                "Note: Using a large maxsize for the equation search will be "
1406
                "exponentially slower and use significant memory."
1407
            )
1408
        elif self.maxsize < 7:
3✔
1409
            raise ValueError("PySR requires a maxsize of at least 7")
3✔
1410

1411
        if self.elementwise_loss is not None and self.loss_function is not None:
3✔
1412
            raise ValueError(
3✔
1413
                "You cannot set both `elementwise_loss` and `loss_function`."
1414
            )
1415

1416
        # NotImplementedError - Values that could be supported at a later time
1417
        if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
3✔
1418
            raise NotImplementedError(
3✔
1419
                f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
1420
            )
1421

1422
        param_container = _DynamicallySetParams(
3✔
1423
            binary_operators=["+", "*", "-", "/"],
1424
            unary_operators=[],
1425
            maxdepth=self.maxsize,
1426
            constraints={},
1427
            batch_size=1,
1428
            update_verbosity=int(self.verbosity),
1429
            progress=self.progress,
1430
            warmup_maxsize_by=0.0,
1431
        )
1432

1433
        for param_name in map(lambda x: x.name, fields(_DynamicallySetParams)):
3✔
1434
            user_param_value = getattr(self, param_name)
3✔
1435
            if user_param_value is None:
3✔
1436
                # Leave as the default in DynamicallySetParams
1437
                ...
3✔
1438
            else:
1439
                # If user has specified it, we will override the default.
1440
                # However, there are some special cases to mutate it:
1441
                new_param_value = _mutate_parameter(param_name, user_param_value)
3✔
1442
                setattr(param_container, param_name, new_param_value)
3✔
1443
        # TODO: This should just be part of the __init__ of _DynamicallySetParams
1444

1445
        assert (
3✔
1446
            len(param_container.binary_operators) > 0
1447
            or len(param_container.unary_operators) > 0
1448
        ), "At least one operator must be provided."
1449

1450
        return param_container
3✔
1451

1452
    def _validate_and_set_fit_params(
3✔
1453
        self,
1454
        X,
1455
        y,
1456
        Xresampled,
1457
        weights,
1458
        variable_names,
1459
        complexity_of_variables,
1460
        X_units,
1461
        y_units,
1462
    ) -> tuple[
1463
        ndarray,
1464
        ndarray,
1465
        ndarray | None,
1466
        ndarray | None,
1467
        ArrayLike[str],
1468
        int | float | list[int | float] | None,
1469
        ArrayLike[str] | None,
1470
        str | ArrayLike[str] | None,
1471
    ]:
1472
        """
1473
        Validate the parameters passed to the :term`fit` method.
1474

1475
        This method also sets the `nout_` attribute.
1476

1477
        Parameters
1478
        ----------
1479
        X : ndarray | pandas.DataFrame
1480
            Training data of shape `(n_samples, n_features)`.
1481
        y : ndarray | pandas.DataFrame}
1482
            Target values of shape `(n_samples,)` or `(n_samples, n_targets)`.
1483
            Will be cast to `X`'s dtype if necessary.
1484
        Xresampled : ndarray | pandas.DataFrame
1485
            Resampled training data used for denoising,
1486
            of shape `(n_resampled, n_features)`.
1487
        weights : ndarray | pandas.DataFrame
1488
            Weight array of the same shape as `y`.
1489
            Each element is how to weight the mean-square-error loss
1490
            for that particular element of y.
1491
        variable_names : ndarray of length n_features
1492
            Names of each feature in the training dataset, `X`.
1493
        complexity_of_variables : int | float | list[int | float]
1494
            Complexity of each feature in the training dataset, `X`.
1495
        X_units : list[str] of length n_features
1496
            Units of each feature in the training dataset, `X`.
1497
        y_units : str | list[str] of length n_out
1498
            Units of each feature in the training dataset, `y`.
1499

1500
        Returns
1501
        -------
1502
        X_validated : ndarray of shape (n_samples, n_features)
1503
            Validated training data.
1504
        y_validated : ndarray of shape (n_samples,) or (n_samples, n_targets)
1505
            Validated target data.
1506
        Xresampled : ndarray of shape (n_resampled, n_features)
1507
            Validated resampled training data used for denoising.
1508
        variable_names_validated : list[str] of length n_features
1509
            Validated list of variable names for each feature in `X`.
1510
        X_units : list[str] of length n_features
1511
            Validated units for `X`.
1512
        y_units : str | list[str] of length n_out
1513
            Validated units for `y`.
1514

1515
        """
1516
        if isinstance(X, pd.DataFrame):
3✔
1517
            if variable_names:
3✔
1518
                variable_names = None
×
1519
                warnings.warn(
×
1520
                    "`variable_names` has been reset to `None` as `X` is a DataFrame. "
1521
                    "Using DataFrame column names instead."
1522
                )
1523

1524
            if (
3✔
1525
                pd.api.types.is_object_dtype(X.columns)
1526
                and X.columns.str.contains(" ").any()
1527
            ):
1528
                X.columns = X.columns.str.replace(" ", "_")
×
1529
                warnings.warn(
×
1530
                    "Spaces in DataFrame column names are not supported. "
1531
                    "Spaces have been replaced with underscores. \n"
1532
                    "Please rename the columns to valid names."
1533
                )
1534
        elif variable_names and any([" " in name for name in variable_names]):
3✔
1535
            variable_names = [name.replace(" ", "_") for name in variable_names]
×
1536
            warnings.warn(
×
1537
                "Spaces in `variable_names` are not supported. "
1538
                "Spaces have been replaced with underscores. \n"
1539
                "Please use valid names instead."
1540
            )
1541

1542
        if (
3✔
1543
            complexity_of_variables is not None
1544
            and self.complexity_of_variables is not None
1545
        ):
1546
            raise ValueError(
3✔
1547
                "You cannot set `complexity_of_variables` at both `fit` and `__init__`. "
1548
                "Pass it at `__init__` to set it to global default, OR use `fit` to set it for "
1549
                "each variable individually."
1550
            )
1551
        elif complexity_of_variables is not None:
3✔
1552
            complexity_of_variables = complexity_of_variables
3✔
1553
        elif self.complexity_of_variables is not None:
3✔
1554
            complexity_of_variables = self.complexity_of_variables
3✔
1555
        else:
1556
            complexity_of_variables = None
3✔
1557

1558
        # Data validation and feature name fetching via sklearn
1559
        # This method sets the n_features_in_ attribute
1560
        if Xresampled is not None:
3✔
1561
            Xresampled = check_array(Xresampled)
3✔
1562
        if weights is not None:
3✔
1563
            weights = check_array(weights, ensure_2d=False)
3✔
1564
            check_consistent_length(weights, y)
3✔
1565
        X, y = self._validate_data_X_y(X, y)
3✔
1566
        self.feature_names_in_ = _safe_check_feature_names_in(
3✔
1567
            self, variable_names, generate_names=False
1568
        )
1569

1570
        if self.feature_names_in_ is None:
3✔
1571
            self.feature_names_in_ = np.array([f"x{i}" for i in range(X.shape[1])])
3✔
1572
            self.display_feature_names_in_ = np.array(
3✔
1573
                [f"x{_subscriptify(i)}" for i in range(X.shape[1])]
1574
            )
1575
            variable_names = self.feature_names_in_
3✔
1576
        else:
1577
            self.display_feature_names_in_ = self.feature_names_in_
3✔
1578
            variable_names = self.feature_names_in_
3✔
1579

1580
        # Handle multioutput data
1581
        if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
3✔
1582
            y = y.reshape(-1)
3✔
1583
        elif len(y.shape) == 2:
3✔
1584
            self.nout_ = y.shape[1]
3✔
1585
        else:
1586
            raise NotImplementedError("y shape not supported!")
×
1587

1588
        self.complexity_of_variables_ = copy.deepcopy(complexity_of_variables)
3✔
1589
        self.X_units_ = copy.deepcopy(X_units)
3✔
1590
        self.y_units_ = copy.deepcopy(y_units)
3✔
1591

1592
        return (
3✔
1593
            X,
1594
            y,
1595
            Xresampled,
1596
            weights,
1597
            variable_names,
1598
            complexity_of_variables,
1599
            X_units,
1600
            y_units,
1601
        )
1602

1603
    def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
3✔
1604
        raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True)  # type: ignore
3✔
1605
        return cast(tuple[ndarray, ndarray], raw_out)
3✔
1606

1607
    def _validate_data_X(self, X: Any) -> ndarray:
3✔
1608
        raw_out = self._validate_data(X=X, reset=False)  # type: ignore
3✔
1609
        return cast(ndarray, raw_out)
3✔
1610

1611
    def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:
3✔
1612
        is_complex = np.issubdtype(X.dtype, np.complexfloating)
3✔
1613
        is_real = not is_complex
3✔
1614
        if is_real:
3✔
1615
            return {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
3✔
1616
        else:
1617
            return {32: np.complex64, 64: np.complex128}[self.precision]
3✔
1618

1619
    def _pre_transform_training_data(
3✔
1620
        self,
1621
        X: ndarray,
1622
        y: ndarray,
1623
        Xresampled: ndarray | None,
1624
        variable_names: ArrayLike[str],
1625
        complexity_of_variables: int | float | list[int | float] | None,
1626
        X_units: ArrayLike[str] | None,
1627
        y_units: ArrayLike[str] | str | None,
1628
        random_state: np.random.RandomState,
1629
    ):
1630
        """
1631
        Transform the training data before fitting the symbolic regressor.
1632

1633
        This method also updates/sets the `selection_mask_` attribute.
1634

1635
        Parameters
1636
        ----------
1637
        X : ndarray
1638
            Training data of shape (n_samples, n_features).
1639
        y : ndarray
1640
            Target values of shape (n_samples,) or (n_samples, n_targets).
1641
            Will be cast to X's dtype if necessary.
1642
        Xresampled : ndarray | None
1643
            Resampled training data, of shape `(n_resampled, n_features)`,
1644
            used for denoising.
1645
        variable_names : list[str]
1646
            Names of each variable in the training dataset, `X`.
1647
            Of length `n_features`.
1648
        complexity_of_variables : int | float | list[int | float] | None
1649
            Complexity of each variable in the training dataset, `X`.
1650
        X_units : list[str]
1651
            Units of each variable in the training dataset, `X`.
1652
        y_units : str | list[str]
1653
            Units of each variable in the training dataset, `y`.
1654
        random_state : int | np.RandomState
1655
            Pass an int for reproducible results across multiple function calls.
1656
            See :term:`Glossary <random_state>`. Default is `None`.
1657

1658
        Returns
1659
        -------
1660
        X_transformed : ndarray of shape (n_samples, n_features)
1661
            Transformed training data. n_samples will be equal to
1662
            `Xresampled.shape[0]` if `self.denoise` is `True`,
1663
            and `Xresampled is not None`, otherwise it will be
1664
            equal to `X.shape[0]`. n_features will be equal to
1665
            `self.select_k_features` if `self.select_k_features is not None`,
1666
            otherwise it will be equal to `X.shape[1]`
1667
        y_transformed : ndarray of shape (n_samples,) or (n_samples, n_outputs)
1668
            Transformed target data. n_samples will be equal to
1669
            `Xresampled.shape[0]` if `self.denoise` is `True`,
1670
            and `Xresampled is not None`, otherwise it will be
1671
            equal to `X.shape[0]`.
1672
        variable_names_transformed : list[str] of length n_features
1673
            Names of each variable in the transformed dataset,
1674
            `X_transformed`.
1675
        X_units_transformed : list[str] of length n_features
1676
            Units of each variable in the transformed dataset.
1677
        y_units_transformed : str | list[str] of length n_out
1678
            Units of each variable in the transformed dataset.
1679
        """
1680
        # Feature selection transformation
1681
        if self.select_k_features:
3✔
1682
            selection_mask = run_feature_selection(
3✔
1683
                X, y, self.select_k_features, random_state=random_state
1684
            )
1685
            X = X[:, selection_mask]
3✔
1686

1687
            if Xresampled is not None:
3✔
1688
                Xresampled = Xresampled[:, selection_mask]
3✔
1689

1690
            # Reduce variable_names to selection
1691
            variable_names = cast(
3✔
1692
                ArrayLike[str],
1693
                [
1694
                    variable_names[i]
1695
                    for i in range(len(variable_names))
1696
                    if selection_mask[i]
1697
                ],
1698
            )
1699

1700
            if isinstance(complexity_of_variables, list):
3✔
1701
                complexity_of_variables = [
×
1702
                    complexity_of_variables[i]
1703
                    for i in range(len(complexity_of_variables))
1704
                    if selection_mask[i]
1705
                ]
1706
                self.complexity_of_variables_ = copy.deepcopy(complexity_of_variables)
×
1707

1708
            if X_units is not None:
3✔
1709
                X_units = cast(
3✔
1710
                    ArrayLike[str],
1711
                    [X_units[i] for i in range(len(X_units)) if selection_mask[i]],
1712
                )
1713
                self.X_units_ = copy.deepcopy(X_units)
3✔
1714

1715
            # Re-perform data validation and feature name updating
1716
            X, y = self._validate_data_X_y(X, y)
3✔
1717
            # Update feature names with selected variable names
1718
            self.selection_mask_ = selection_mask
3✔
1719
            self.feature_names_in_ = _check_feature_names_in(self, variable_names)
3✔
1720
            self.display_feature_names_in_ = self.feature_names_in_
3✔
1721
            print(f"Using features {self.feature_names_in_}")
3✔
1722

1723
        # Denoising transformation
1724
        if self.denoise:
3✔
1725
            if self.nout_ > 1:
3✔
1726
                X, y = multi_denoise(
3✔
1727
                    X, y, Xresampled=Xresampled, random_state=random_state
1728
                )
1729
            else:
1730
                X, y = denoise(X, y, Xresampled=Xresampled, random_state=random_state)
3✔
1731

1732
        return X, y, variable_names, complexity_of_variables, X_units, y_units
3✔
1733

1734
    def _run(
3✔
1735
        self,
1736
        X: ndarray,
1737
        y: ndarray,
1738
        runtime_params: _DynamicallySetParams,
1739
        weights: ndarray | None,
1740
        category: ndarray | None,
1741
        seed: int,
1742
    ):
1743
        """
1744
        Run the symbolic regression fitting process on the julia backend.
1745

1746
        Parameters
1747
        ----------
1748
        X : ndarray
1749
            Training data of shape `(n_samples, n_features)`.
1750
        y : ndarray
1751
            Target values of shape `(n_samples,)` or `(n_samples, n_targets)`.
1752
            Will be cast to `X`'s dtype if necessary.
1753
        runtime_params : DynamicallySetParams
1754
            Dynamically set versions of some parameters passed in __init__.
1755
        weights : ndarray | None
1756
            Weight array of the same shape as `y`.
1757
            Each element is how to weight the mean-square-error loss
1758
            for that particular element of y.
1759
        category : ndarray | None
1760
            If `expression_spec` is a `ParametricExpressionSpec`, then this
1761
            argument should be a list of integers representing the category
1762
            of each sample in `X`.
1763
        seed : int
1764
            Random seed for julia backend process.
1765

1766
        Returns
1767
        -------
1768
        self : object
1769
            Reference to `self` with fitted attributes.
1770

1771
        Raises
1772
        ------
1773
        ImportError
1774
            Raised when the julia backend fails to import a package.
1775
        """
1776
        # Need to be global as we don't want to recreate/reinstate julia for
1777
        # every new instance of PySRRegressor
1778
        global ALREADY_RAN
1779

1780
        # These are the parameters which may be modified from the ones
1781
        # specified in init, so we define them here locally:
1782
        binary_operators = runtime_params.binary_operators
3✔
1783
        unary_operators = runtime_params.unary_operators
3✔
1784
        constraints = runtime_params.constraints
3✔
1785

1786
        nested_constraints = self.nested_constraints
3✔
1787
        complexity_of_operators = self.complexity_of_operators
3✔
1788
        complexity_of_variables = self.complexity_of_variables_
3✔
1789
        cluster_manager = self.cluster_manager
3✔
1790

1791
        # Start julia backend processes
1792
        if not ALREADY_RAN and runtime_params.update_verbosity != 0:
3✔
1793
            print("Compiling Julia backend...")
3✔
1794

1795
        parallelism, numprocs = _map_parallelism_params(
3✔
1796
            self.parallelism, self.procs, getattr(self, "multithreading", None)
1797
        )
1798

1799
        if self.deterministic and parallelism != "serial":
3✔
1800
            raise ValueError(
3✔
1801
                "To ensure deterministic searches, you must set `parallelism='serial'`. "
1802
                "Additionally, make sure to set `random_state` to a seed."
1803
            )
1804
        if self.random_state is not None and (
3✔
1805
            parallelism != "serial" or not self.deterministic
1806
        ):
1807
            warnings.warn(
3✔
1808
                "Note: Setting `random_state` without also setting `deterministic=True` "
1809
                "and `parallelism='serial'` will result in non-deterministic searches."
1810
            )
1811

1812
        if cluster_manager is not None:
3✔
NEW
1813
            if parallelism != "multiprocessing":
×
NEW
1814
                raise ValueError(
×
1815
                    "To use cluster managers, you must set `parallelism='multiprocessing'`."
1816
                )
UNCOV
1817
            cluster_manager = _load_cluster_manager(cluster_manager)
×
1818

1819
        # TODO(mcranmer): These functions should be part of this class.
1820
        binary_operators, unary_operators = _maybe_create_inline_operators(
3✔
1821
            binary_operators=binary_operators,
1822
            unary_operators=unary_operators,
1823
            extra_sympy_mappings=self.extra_sympy_mappings,
1824
        )
1825
        if constraints is not None:
3✔
1826
            _constraints = _process_constraints(
3✔
1827
                binary_operators=binary_operators,
1828
                unary_operators=unary_operators,
1829
                constraints=constraints,
1830
            )
1831
            una_constraints = [_constraints[op] for op in unary_operators]
3✔
1832
            bin_constraints = [_constraints[op] for op in binary_operators]
3✔
1833
        else:
NEW
1834
            una_constraints = None
×
NEW
1835
            bin_constraints = None
×
1836

1837
        # Parse dict into Julia Dict for nested constraints::
1838
        if nested_constraints is not None:
3✔
1839
            nested_constraints_str = "Dict("
3✔
1840
            for outer_k, outer_v in nested_constraints.items():
3✔
1841
                nested_constraints_str += f"({outer_k}) => Dict("
3✔
1842
                for inner_k, inner_v in outer_v.items():
3✔
1843
                    nested_constraints_str += f"({inner_k}) => {inner_v}, "
3✔
1844
                nested_constraints_str += "), "
3✔
1845
            nested_constraints_str += ")"
3✔
1846
            nested_constraints = jl.seval(nested_constraints_str)
3✔
1847

1848
        # Parse dict into Julia Dict for complexities:
1849
        if complexity_of_operators is not None:
3✔
1850
            complexity_of_operators_str = "Dict("
3✔
1851
            for k, v in complexity_of_operators.items():
3✔
1852
                complexity_of_operators_str += f"({k}) => {v}, "
3✔
1853
            complexity_of_operators_str += ")"
3✔
1854
            complexity_of_operators = jl.seval(complexity_of_operators_str)
3✔
1855
        # TODO: Refactor this into helper function
1856

1857
        if isinstance(complexity_of_variables, list):
3✔
1858
            complexity_of_variables = jl_array(complexity_of_variables)
3✔
1859

1860
        custom_loss = jl.seval(
3✔
1861
            str(self.elementwise_loss)
1862
            if self.elementwise_loss is not None
1863
            else "nothing"
1864
        )
1865
        custom_full_objective = jl.seval(
3✔
1866
            str(self.loss_function) if self.loss_function is not None else "nothing"
1867
        )
1868

1869
        early_stop_condition = jl.seval(
3✔
1870
            str(self.early_stop_condition)
1871
            if self.early_stop_condition is not None
1872
            else "nothing"
1873
        )
1874

1875
        load_required_packages(
3✔
1876
            turbo=self.turbo,
1877
            bumper=self.bumper,
1878
            autodiff_backend=self.autodiff_backend,
1879
            cluster_manager=cluster_manager,
1880
            logger_spec=self.logger_spec,
1881
        )
1882

1883
        if self.autodiff_backend is not None:
3✔
NEW
1884
            autodiff_backend = jl.Symbol(self.autodiff_backend)
×
1885
        else:
1886
            autodiff_backend = None
3✔
1887

1888
        mutation_weights = SymbolicRegression.MutationWeights(
3✔
1889
            mutate_constant=self.weight_mutate_constant,
1890
            mutate_operator=self.weight_mutate_operator,
1891
            swap_operands=self.weight_swap_operands,
1892
            rotate_tree=self.weight_rotate_tree,
1893
            add_node=self.weight_add_node,
1894
            insert_node=self.weight_insert_node,
1895
            delete_node=self.weight_delete_node,
1896
            simplify=self.weight_simplify,
1897
            randomize=self.weight_randomize,
1898
            do_nothing=self.weight_do_nothing,
1899
            optimize=self.weight_optimize,
1900
        )
1901

1902
        jl_binary_operators: list[Any] = []
3✔
1903
        jl_unary_operators: list[Any] = []
3✔
1904
        for input_list, output_list, name in [
3✔
1905
            (binary_operators, jl_binary_operators, "binary"),
1906
            (unary_operators, jl_unary_operators, "unary"),
1907
        ]:
1908
            for op in input_list:
3✔
1909
                jl_op = jl.seval(op)
3✔
1910
                if not jl_is_function(jl_op):
3✔
1911
                    raise ValueError(
3✔
1912
                        f"When building `{name}_operators`, `'{op}'` did not return a Julia function"
1913
                    )
1914
                output_list.append(jl_op)
3✔
1915

1916
        complexity_mapping = (
3✔
1917
            jl.seval(self.complexity_mapping) if self.complexity_mapping else None
1918
        )
1919

1920
        logger = self.logger_spec.create_logger() if self.logger_spec else None
3✔
1921

1922
        # Call to Julia backend.
1923
        # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1924
        options = SymbolicRegression.Options(
3✔
1925
            binary_operators=jl_array(jl_binary_operators, dtype=jl.Function),
1926
            unary_operators=jl_array(jl_unary_operators, dtype=jl.Function),
1927
            bin_constraints=jl_array(bin_constraints),
1928
            una_constraints=jl_array(una_constraints),
1929
            complexity_of_operators=complexity_of_operators,
1930
            complexity_of_constants=self.complexity_of_constants,
1931
            complexity_of_variables=complexity_of_variables,
1932
            complexity_mapping=complexity_mapping,
1933
            expression_type=self.expression_spec_.julia_expression_type(),
1934
            expression_options=self.expression_spec_.julia_expression_options(),
1935
            nested_constraints=nested_constraints,
1936
            elementwise_loss=custom_loss,
1937
            loss_function=custom_full_objective,
1938
            maxsize=int(self.maxsize),
1939
            output_directory=_escape_filename(self.output_directory_),
1940
            npopulations=int(self.populations),
1941
            batching=self.batching,
1942
            batch_size=int(
1943
                min([runtime_params.batch_size, len(X)]) if self.batching else len(X)
1944
            ),
1945
            mutation_weights=mutation_weights,
1946
            tournament_selection_p=self.tournament_selection_p,
1947
            tournament_selection_n=self.tournament_selection_n,
1948
            # These have the same name:
1949
            parsimony=self.parsimony,
1950
            dimensional_constraint_penalty=self.dimensional_constraint_penalty,
1951
            dimensionless_constants_only=self.dimensionless_constants_only,
1952
            alpha=self.alpha,
1953
            maxdepth=runtime_params.maxdepth,
1954
            fast_cycle=self.fast_cycle,
1955
            turbo=self.turbo,
1956
            bumper=self.bumper,
1957
            autodiff_backend=autodiff_backend,
1958
            migration=self.migration,
1959
            hof_migration=self.hof_migration,
1960
            fraction_replaced_hof=self.fraction_replaced_hof,
1961
            should_simplify=self.should_simplify,
1962
            should_optimize_constants=self.should_optimize_constants,
1963
            warmup_maxsize_by=runtime_params.warmup_maxsize_by,
1964
            use_frequency=self.use_frequency,
1965
            use_frequency_in_tournament=self.use_frequency_in_tournament,
1966
            adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
1967
            npop=self.population_size,
1968
            ncycles_per_iteration=self.ncycles_per_iteration,
1969
            fraction_replaced=self.fraction_replaced,
1970
            topn=self.topn,
1971
            print_precision=self.print_precision,
1972
            optimizer_algorithm=self.optimizer_algorithm,
1973
            optimizer_nrestarts=self.optimizer_nrestarts,
1974
            optimizer_f_calls_limit=self.optimizer_f_calls_limit,
1975
            optimizer_probability=self.optimize_probability,
1976
            optimizer_iterations=self.optimizer_iterations,
1977
            perturbation_factor=self.perturbation_factor,
1978
            probability_negate_constant=self.probability_negate_constant,
1979
            annealing=self.annealing,
1980
            timeout_in_seconds=self.timeout_in_seconds,
1981
            crossover_probability=self.crossover_probability,
1982
            skip_mutation_failures=self.skip_mutation_failures,
1983
            max_evals=self.max_evals,
1984
            early_stop_condition=early_stop_condition,
1985
            seed=seed,
1986
            deterministic=self.deterministic,
1987
            define_helper_functions=False,
1988
        )
1989

1990
        self.julia_options_stream_ = jl_serialize(options)
3✔
1991

1992
        # Convert data to desired precision
1993
        test_X = np.array(X)
3✔
1994
        np_dtype = self._get_precision_mapped_dtype(test_X)
3✔
1995

1996
        # This converts the data into a Julia array:
1997
        jl_X = jl_array(np.array(X, dtype=np_dtype).T)
3✔
1998
        if len(y.shape) == 1:
3✔
1999
            jl_y = jl_array(np.array(y, dtype=np_dtype))
3✔
2000
        else:
2001
            jl_y = jl_array(np.array(y, dtype=np_dtype).T)
3✔
2002
        if weights is not None:
3✔
2003
            if len(weights.shape) == 1:
3✔
2004
                jl_weights = jl_array(np.array(weights, dtype=np_dtype))
3✔
2005
            else:
2006
                jl_weights = jl_array(np.array(weights, dtype=np_dtype).T)
3✔
2007
        else:
2008
            jl_weights = None
3✔
2009

2010
        if category is not None:
3✔
2011
            offset_for_julia_indexing = 1
3✔
2012
            jl_category = jl_array(
3✔
2013
                (category + offset_for_julia_indexing).astype(np.int64)
2014
            )
2015
            jl_extra = jl.seval("NamedTuple{(:class,)}")((jl_category,))
3✔
2016
        else:
2017
            jl_extra = jl.NamedTuple()
3✔
2018

2019
        if len(y.shape) > 1:
3✔
2020
            # We set these manually so that they respect Python's 0 indexing
2021
            # (by default Julia will use y1, y2...)
2022
            jl_y_variable_names = jl_array(
3✔
2023
                [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
2024
            )
2025
        else:
2026
            jl_y_variable_names = None
3✔
2027

2028
        out = SymbolicRegression.equation_search(
3✔
2029
            jl_X,
2030
            jl_y,
2031
            weights=jl_weights,
2032
            extra=jl_extra,
2033
            niterations=int(self.niterations),
2034
            variable_names=jl_array([str(v) for v in self.feature_names_in_]),
2035
            display_variable_names=jl_array(
2036
                [str(v) for v in self.display_feature_names_in_]
2037
            ),
2038
            y_variable_names=jl_y_variable_names,
2039
            X_units=jl_array(self.X_units_),
2040
            y_units=(
2041
                jl_array(self.y_units_)
2042
                if isinstance(self.y_units_, list)
2043
                else self.y_units_
2044
            ),
2045
            options=options,
2046
            numprocs=numprocs,
2047
            parallelism=parallelism,
2048
            saved_state=self.julia_state_,
2049
            return_state=True,
2050
            run_id=self.run_id_,
2051
            addprocs_function=cluster_manager,
2052
            heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
2053
            progress=runtime_params.progress
2054
            and self.verbosity > 0
2055
            and len(y.shape) == 1,
2056
            verbosity=int(self.verbosity),
2057
            logger=logger,
2058
        )
2059
        if self.logger_spec is not None:
3✔
2060
            self.logger_spec.write_hparams(logger, self.get_params())
3✔
2061

2062
        self.julia_state_stream_ = jl_serialize(out)
3✔
2063

2064
        # Set attributes
2065
        self.equations_ = self.get_hof(out)
3✔
2066

2067
        ALREADY_RAN = True
3✔
2068

2069
        return self
3✔
2070

2071
    def fit(
3✔
2072
        self,
2073
        X,
2074
        y,
2075
        *,
2076
        Xresampled=None,
2077
        weights=None,
2078
        variable_names: ArrayLike[str] | None = None,
2079
        complexity_of_variables: int | float | list[int | float] | None = None,
2080
        X_units: ArrayLike[str] | None = None,
2081
        y_units: str | ArrayLike[str] | None = None,
2082
        category: ndarray | None = None,
2083
    ) -> "PySRRegressor":
2084
        """
2085
        Search for equations to fit the dataset and store them in `self.equations_`.
2086

2087
        Parameters
2088
        ----------
2089
        X : ndarray | pandas.DataFrame
2090
            Training data of shape (n_samples, n_features).
2091
        y : ndarray | pandas.DataFrame
2092
            Target values of shape (n_samples,) or (n_samples, n_targets).
2093
            Will be cast to X's dtype if necessary.
2094
        Xresampled : ndarray | pandas.DataFrame
2095
            Resampled training data, of shape (n_resampled, n_features),
2096
            to generate a denoised data on. This
2097
            will be used as the training data, rather than `X`.
2098
        weights : ndarray | pandas.DataFrame
2099
            Weight array of the same shape as `y`.
2100
            Each element is how to weight the mean-square-error loss
2101
            for that particular element of `y`. Alternatively,
2102
            if a custom `loss` was set, it will can be used
2103
            in arbitrary ways.
2104
        variable_names : list[str]
2105
            A list of names for the variables, rather than "x0", "x1", etc.
2106
            If `X` is a pandas dataframe, the column names will be used
2107
            instead of `variable_names`. Cannot contain spaces or special
2108
            characters. Avoid variable names which are also
2109
            function names in `sympy`, such as "N".
2110
        X_units : list[str]
2111
            A list of units for each variable in `X`. Each unit should be
2112
            a string representing a Julia expression. See DynamicQuantities.jl
2113
            https://symbolicml.org/DynamicQuantities.jl/dev/units/ for more
2114
            information.
2115
        y_units : str | list[str]
2116
            Similar to `X_units`, but as a unit for the target variable, `y`.
2117
            If `y` is a matrix, a list of units should be passed. If `X_units`
2118
            is given but `y_units` is not, then `y_units` will be arbitrary.
2119
        category : list[int]
2120
            If `expression_spec` is a `ParametricExpressionSpec`, then this
2121
            argument should be a list of integers representing the category
2122
            of each sample.
2123

2124
        Returns
2125
        -------
2126
        self : object
2127
            Fitted estimator.
2128
        """
2129
        # Init attributes that are not specified in BaseEstimator
2130
        if self.warm_start and hasattr(self, "julia_state_stream_"):
3✔
2131
            pass
3✔
2132
        else:
2133
            if hasattr(self, "julia_state_stream_"):
3✔
2134
                warnings.warn(
3✔
2135
                    "The discovered expressions are being reset. "
2136
                    "Please set `warm_start=True` if you wish to continue "
2137
                    "to start a search where you left off.",
2138
                )
2139

2140
            self.equations_ = None
3✔
2141
            self.nout_ = 1
3✔
2142
            self.selection_mask_ = None
3✔
2143
            self.julia_state_stream_ = None
3✔
2144
            self.julia_options_stream_ = None
3✔
2145
            self.complexity_of_variables_ = None
3✔
2146
            self.X_units_ = None
3✔
2147
            self.y_units_ = None
3✔
2148

2149
        self._setup_equation_file()
3✔
2150
        self._clear_equation_file_contents()
3✔
2151

2152
        runtime_params = self._validate_and_modify_params()
3✔
2153

2154
        if category is not None:
3✔
2155
            assert Xresampled is None
3✔
2156

2157
        if isinstance(self.expression_spec, ParametricExpressionSpec):
3✔
2158
            assert category is not None
3✔
2159

2160
        # TODO: Put `category` here
2161
        (
3✔
2162
            X,
2163
            y,
2164
            Xresampled,
2165
            weights,
2166
            variable_names,
2167
            complexity_of_variables,
2168
            X_units,
2169
            y_units,
2170
        ) = self._validate_and_set_fit_params(
2171
            X,
2172
            y,
2173
            Xresampled,
2174
            weights,
2175
            variable_names,
2176
            complexity_of_variables,
2177
            X_units,
2178
            y_units,
2179
        )
2180

2181
        if X.shape[0] > 10000 and not self.batching:
3✔
2182
            warnings.warn(
3✔
2183
                "Note: you are running with more than 10,000 datapoints. "
2184
                "You should consider turning on batching (https://ai.damtp.cam.ac.uk/pysr/options/#batching). "
2185
                "You should also reconsider if you need that many datapoints. "
2186
                "Unless you have a large amount of noise (in which case you "
2187
                "should smooth your dataset first), generally < 10,000 datapoints "
2188
                "is enough to find a functional form with symbolic regression. "
2189
                "More datapoints will lower the search speed."
2190
            )
2191

2192
        random_state = check_random_state(self.random_state)  # For np random
3✔
2193
        seed = cast(int, random_state.randint(0, 2**31 - 1))  # For julia random
3✔
2194

2195
        # Pre transformations (feature selection and denoising)
2196
        X, y, variable_names, complexity_of_variables, X_units, y_units = (
3✔
2197
            self._pre_transform_training_data(
2198
                X,
2199
                y,
2200
                Xresampled,
2201
                variable_names,
2202
                complexity_of_variables,
2203
                X_units,
2204
                y_units,
2205
                random_state,
2206
            )
2207
        )
2208

2209
        # Warn about large feature counts (still warn if feature count is large
2210
        # after running feature selection)
2211
        if self.n_features_in_ >= 10:
3✔
2212
            warnings.warn(
3✔
2213
                "Note: you are running with 10 features or more. "
2214
                "Genetic algorithms like used in PySR scale poorly with large numbers of features. "
2215
                "You should run PySR for more `niterations` to ensure it can find "
2216
                "the correct variables, and consider using a larger `maxsize`."
2217
            )
2218

2219
        # Assertion checks
2220
        use_custom_variable_names = variable_names is not None
3✔
2221
        # TODO: this is always true.
2222

2223
        _check_assertions(
3✔
2224
            X,
2225
            use_custom_variable_names,
2226
            variable_names,
2227
            complexity_of_variables,
2228
            weights,
2229
            y,
2230
            X_units,
2231
            y_units,
2232
        )
2233

2234
        # Initially, just save model parameters, so that
2235
        # it can be loaded from an early exit:
2236
        if not self.temp_equation_file:
3✔
2237
            self._checkpoint()
3✔
2238

2239
        # Perform the search:
2240
        self._run(X, y, runtime_params, weights=weights, seed=seed, category=category)
3✔
2241

2242
        # Then, after fit, we save again, so the pickle file contains
2243
        # the equations:
2244
        if not self.temp_equation_file:
3✔
2245
            self._checkpoint()
3✔
2246

2247
        return self
3✔
2248

2249
    def refresh(self, run_directory: PathLike | None = None) -> None:
3✔
2250
        """
2251
        Update self.equations_ with any new options passed.
2252

2253
        For example, updating `extra_sympy_mappings`
2254
        will require a `.refresh()` to update the equations.
2255

2256
        Parameters
2257
        ----------
2258
        checkpoint_file : str or Path
2259
            Path to checkpoint hall of fame file to be loaded.
2260
            The default will use the set `equation_file_`.
2261
        """
2262
        if run_directory is not None:
3✔
2263
            self.output_directory_ = str(Path(run_directory).parent)
3✔
2264
            self.run_id_ = Path(run_directory).name
3✔
2265
            self._clear_equation_file_contents()
3✔
2266
        check_is_fitted(self, attributes=["run_id_", "output_directory_"])
3✔
2267
        self.equations_ = self.get_hof()
3✔
2268

2269
    def predict(
3✔
2270
        self,
2271
        X,
2272
        index: int | list[int] | None = None,
2273
        *,
2274
        category: ndarray | None = None,
2275
    ) -> ndarray:
2276
        """
2277
        Predict y from input X using the equation chosen by `model_selection`.
2278

2279
        You may see what equation is used by printing this object. X should
2280
        have the same columns as the training data.
2281

2282
        Parameters
2283
        ----------
2284
        X : ndarray | pandas.DataFrame
2285
            Training data of shape `(n_samples, n_features)`.
2286
        index : int | list[int]
2287
            If you want to compute the output of an expression using a
2288
            particular row of `self.equations_`, you may specify the index here.
2289
            For multiple output equations, you must pass a list of indices
2290
            in the same order.
2291
        category : ndarray | None
2292
            If `expression_spec` is a `ParametricExpressionSpec`, then this
2293
            argument should be a list of integers representing the category
2294
            of each sample in `X`.
2295

2296
        Returns
2297
        -------
2298
        y_predicted : ndarray of shape (n_samples, nout_)
2299
            Values predicted by substituting `X` into the fitted symbolic
2300
            regression model.
2301

2302
        Raises
2303
        ------
2304
        ValueError
2305
            Raises if the `best_equation` cannot be evaluated.
2306
        """
2307
        check_is_fitted(
3✔
2308
            self, attributes=["selection_mask_", "feature_names_in_", "nout_"]
2309
        )
2310
        best_equation = self.get_best(index=index)
3✔
2311

2312
        # When X is an numpy array or a pandas dataframe with a RangeIndex,
2313
        # the self.feature_names_in_ generated during fit, for the same X,
2314
        # will cause a warning to be thrown during _validate_data.
2315
        # To avoid this, convert X to a dataframe, apply the selection mask,
2316
        # and then set the column/feature_names of X to be equal to those
2317
        # generated during fit.
2318
        if not isinstance(X, pd.DataFrame):
3✔
2319
            X = check_array(X)
3✔
2320
            X = pd.DataFrame(X)
3✔
2321
        if isinstance(X.columns, pd.RangeIndex):
3✔
2322
            if self.selection_mask_ is not None:
3✔
2323
                # RangeIndex enforces column order allowing columns to
2324
                # be correctly filtered with self.selection_mask_
2325
                X = X[X.columns[self.selection_mask_]]
3✔
2326
            X.columns = self.feature_names_in_
3✔
2327
        # Without feature information, CallableEquation/lambda_format equations
2328
        # require that the column order of X matches that of the X used during
2329
        # the fitting process. _validate_data removes this feature information
2330
        # when it converts the dataframe to an np array. Thus, to ensure feature
2331
        # order is preserved after conversion, the dataframe columns must be
2332
        # reordered/reindexed to match those of the transformed (denoised and
2333
        # feature selected) X in fit.
2334
        X = X.reindex(columns=self.feature_names_in_)
3✔
2335
        X = self._validate_data_X(X)
3✔
2336
        if self.expression_spec_.evaluates_in_julia:
3✔
2337
            # Julia wants the right dtype
2338
            X = X.astype(self._get_precision_mapped_dtype(X))
3✔
2339

2340
        if category is not None:
3✔
2341
            offset_for_julia_indexing = 1
3✔
2342
            args: tuple = (
3✔
2343
                jl_array((category + offset_for_julia_indexing).astype(np.int64)),
2344
            )
2345
        else:
2346
            args = ()
3✔
2347

2348
        try:
3✔
2349
            if isinstance(best_equation, list):
3✔
2350
                assert self.nout_ > 1
3✔
2351
                return np.stack(
3✔
2352
                    [
2353
                        cast(ndarray, eq["lambda_format"](X, *args))
2354
                        for eq in best_equation
2355
                    ],
2356
                    axis=1,
2357
                )
2358
            else:
2359
                return cast(ndarray, best_equation["lambda_format"](X, *args))
3✔
2360
        except Exception as error:
×
2361
            raise ValueError(
×
2362
                "Failed to evaluate the expression. "
2363
                "If you are using a custom operator, make sure to define it in `extra_sympy_mappings`, "
2364
                "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
2365
                "`lambda x: 1/x` is a valid SymPy function defining the operator. "
2366
                "You can then run `model.refresh()` to re-load the expressions."
2367
            ) from error
2368

2369
    def sympy(self, index: int | list[int] | None = None):
3✔
2370
        """
2371
        Return sympy representation of the equation(s) chosen by `model_selection`.
2372

2373
        Parameters
2374
        ----------
2375
        index : int | list[int]
2376
            If you wish to select a particular equation from
2377
            `self.equations_`, give the index number here. This overrides
2378
            the `model_selection` parameter. If there are multiple output
2379
            features, then pass a list of indices with the order the same
2380
            as the output feature.
2381

2382
        Returns
2383
        -------
2384
        best_equation : str, list[str] of length nout_
2385
            SymPy representation of the best equation.
2386
        """
2387
        if not self.expression_spec_.supports_sympy:
3✔
2388
            raise ValueError(
3✔
2389
                f"`expression_spec={self.expression_spec_}` does not support sympy export."
2390
            )
2391
        self.refresh()
3✔
2392
        best_equation = self.get_best(index=index)
3✔
2393
        if isinstance(best_equation, list):
3✔
2394
            assert self.nout_ > 1
3✔
2395
            return [eq["sympy_format"] for eq in best_equation]
3✔
2396
        else:
2397
            return best_equation["sympy_format"]
3✔
2398

2399
    def latex(
3✔
2400
        self, index: int | list[int] | None = None, precision: int = 3
2401
    ) -> str | list[str]:
2402
        """
2403
        Return latex representation of the equation(s) chosen by `model_selection`.
2404

2405
        Parameters
2406
        ----------
2407
        index : int | list[int]
2408
            If you wish to select a particular equation from
2409
            `self.equations_`, give the index number here. This overrides
2410
            the `model_selection` parameter. If there are multiple output
2411
            features, then pass a list of indices with the order the same
2412
            as the output feature.
2413
        precision : int
2414
            The number of significant figures shown in the LaTeX
2415
            representation.
2416
            Default is `3`.
2417

2418
        Returns
2419
        -------
2420
        best_equation : str or list[str] of length nout_
2421
            LaTeX expression of the best equation.
2422
        """
2423
        if not self.expression_spec_.supports_latex:
3✔
2424
            raise ValueError(
3✔
2425
                f"`expression_spec={self.expression_spec_}` does not support latex export."
2426
            )
2427
        self.refresh()
3✔
2428
        sympy_representation = self.sympy(index=index)
3✔
2429
        if self.nout_ > 1:
3✔
2430
            output = []
3✔
2431
            for s in sympy_representation:
3✔
2432
                latex = sympy2latex(s, prec=precision)
3✔
2433
                output.append(latex)
3✔
2434
            return output
3✔
2435
        return sympy2latex(sympy_representation, prec=precision)
3✔
2436

2437
    def jax(self, index=None):
3✔
2438
        """
2439
        Return jax representation of the equation(s) chosen by `model_selection`.
2440

2441
        Each equation (multiple given if there are multiple outputs) is a dictionary
2442
        containing {"callable": func, "parameters": params}. To call `func`, pass
2443
        func(X, params). This function is differentiable using `jax.grad`.
2444

2445
        Parameters
2446
        ----------
2447
        index : int | list[int]
2448
            If you wish to select a particular equation from
2449
            `self.equations_`, give the index number here. This overrides
2450
            the `model_selection` parameter. If there are multiple output
2451
            features, then pass a list of indices with the order the same
2452
            as the output feature.
2453

2454
        Returns
2455
        -------
2456
        best_equation : dict[str, Any]
2457
            Dictionary of callable jax function in "callable" key,
2458
            and jax array of parameters as "parameters" key.
2459
        """
2460
        if not self.expression_spec_.supports_jax:
3✔
2461
            raise ValueError(
3✔
2462
                f"`expression_spec={self.expression_spec_}` does not support jax export."
2463
            )
2464
        self.set_params(output_jax_format=True)
1✔
2465
        self.refresh()
1✔
2466
        best_equation = self.get_best(index=index)
1✔
2467
        if isinstance(best_equation, list):
1✔
2468
            assert self.nout_ > 1
×
2469
            return [eq["jax_format"] for eq in best_equation]
×
2470
        else:
2471
            return best_equation["jax_format"]
1✔
2472

2473
    def pytorch(self, index=None):
3✔
2474
        """
2475
        Return pytorch representation of the equation(s) chosen by `model_selection`.
2476

2477
        Each equation (multiple given if there are multiple outputs) is a PyTorch module
2478
        containing the parameters as trainable attributes. You can use the module like
2479
        any other PyTorch module: `module(X)`, where `X` is a tensor with the same
2480
        column ordering as trained with.
2481

2482
        Parameters
2483
        ----------
2484
        index : int | list[int]
2485
            If you wish to select a particular equation from
2486
            `self.equations_`, give the index number here. This overrides
2487
            the `model_selection` parameter. If there are multiple output
2488
            features, then pass a list of indices with the order the same
2489
            as the output feature.
2490

2491
        Returns
2492
        -------
2493
        best_equation : torch.nn.Module
2494
            PyTorch module representing the expression.
2495
        """
2496
        if not self.expression_spec_.supports_torch:
3✔
2497
            raise ValueError(
3✔
2498
                f"`expression_spec={self.expression_spec_}` does not support torch export."
2499
            )
2500
        self.set_params(output_torch_format=True)
1✔
2501
        self.refresh()
1✔
2502
        best_equation = self.get_best(index=index)
1✔
2503
        if isinstance(best_equation, list):
1✔
2504
            return [eq["torch_format"] for eq in best_equation]
×
2505
        else:
2506
            return best_equation["torch_format"]
1✔
2507

2508
    def get_equation_file(self, i: int | None = None) -> Path:
3✔
2509
        if i is not None:
3✔
2510
            return (
3✔
2511
                Path(self.output_directory_)
2512
                / self.run_id_
2513
                / f"hall_of_fame_output{i}.csv"
2514
            )
2515
        else:
2516
            return Path(self.output_directory_) / self.run_id_ / "hall_of_fame.csv"
3✔
2517

2518
    def _read_equation_file(self) -> list[pd.DataFrame]:
3✔
2519
        """Read the hall of fame file created by `SymbolicRegression.jl`."""
2520

2521
        try:
3✔
2522
            if self.nout_ > 1:
3✔
2523
                all_outputs = []
3✔
2524
                for i in range(1, self.nout_ + 1):
3✔
2525
                    cur_filename = str(self.get_equation_file(i)) + ".bak"
3✔
2526
                    if not os.path.exists(cur_filename):
3✔
NEW
2527
                        cur_filename = str(self.get_equation_file(i))
×
2528
                    with open(cur_filename, "r", encoding="utf-8") as f:
3✔
2529
                        buf = f.read()
3✔
2530
                    buf = _preprocess_julia_floats(buf)
3✔
2531
                    df = self._postprocess_dataframe(pd.read_csv(StringIO(buf)))
3✔
2532
                    all_outputs.append(df)
3✔
2533
            else:
2534
                filename = str(self.get_equation_file()) + ".bak"
3✔
2535
                if not os.path.exists(filename):
3✔
2536
                    filename = str(self.get_equation_file())
3✔
2537
                with open(filename, "r", encoding="utf-8") as f:
3✔
2538
                    buf = f.read()
3✔
2539
                buf = _preprocess_julia_floats(buf)
3✔
2540
                all_outputs = [self._postprocess_dataframe(pd.read_csv(StringIO(buf)))]
3✔
2541

2542
        except FileNotFoundError:
×
2543
            raise RuntimeError(
×
2544
                "Couldn't find equation file! The equation search likely exited "
2545
                "before a single iteration completed."
2546
            )
2547
        return all_outputs
3✔
2548

2549
    def _postprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
3✔
2550
        df = df.rename(
3✔
2551
            columns={
2552
                "Complexity": "complexity",
2553
                "Loss": "loss",
2554
                "Equation": "equation",
2555
            },
2556
        )
2557

2558
        return df
3✔
2559

2560
    def get_hof(self, search_output=None) -> pd.DataFrame | list[pd.DataFrame]:
3✔
2561
        """Get the equations from a hall of fame file or search output.
2562

2563
        If no arguments entered, the ones used
2564
        previously from a call to PySR will be used.
2565
        """
2566
        check_is_fitted(
3✔
2567
            self,
2568
            attributes=[
2569
                "nout_",
2570
                "run_id_",
2571
                "output_directory_",
2572
                "selection_mask_",
2573
                "feature_names_in_",
2574
            ],
2575
        )
2576
        should_read_from_file = (
3✔
2577
            not hasattr(self, "equation_file_contents_")
2578
            or self.equation_file_contents_ is None
2579
        )
2580
        if should_read_from_file:
3✔
2581
            self.equation_file_contents_ = self._read_equation_file()
3✔
2582

2583
        _validate_export_mappings(self.extra_jax_mappings, self.extra_torch_mappings)
3✔
2584

2585
        equation_file_contents = cast(list[pd.DataFrame], self.equation_file_contents_)
3✔
2586

2587
        ret_outputs = [
3✔
2588
            pd.concat(
2589
                [
2590
                    output,
2591
                    calculate_scores(output),
2592
                    self.expression_spec_.create_exports(self, output, search_output),
2593
                ],
2594
                axis=1,
2595
            )
2596
            for output in equation_file_contents
2597
        ]
2598

2599
        if self.nout_ > 1:
3✔
2600
            return ret_outputs
3✔
2601
        return ret_outputs[0]
3✔
2602

2603
    def latex_table(
3✔
2604
        self,
2605
        indices: list[int] | None = None,
2606
        precision: int = 3,
2607
        columns: list[str] = ["equation", "complexity", "loss", "score"],
2608
    ) -> str:
2609
        """Create a LaTeX/booktabs table for all, or some, of the equations.
2610

2611
        Parameters
2612
        ----------
2613
        indices : list[int] | list[list[int]]
2614
            If you wish to select a particular subset of equations from
2615
            `self.equations_`, give the row numbers here. By default,
2616
            all equations will be used. If there are multiple output
2617
            features, then pass a list of lists.
2618
        precision : int
2619
            The number of significant figures shown in the LaTeX
2620
            representations.
2621
            Default is `3`.
2622
        columns : list[str]
2623
            Which columns to include in the table.
2624
            Default is `["equation", "complexity", "loss", "score"]`.
2625

2626
        Returns
2627
        -------
2628
        latex_table_str : str
2629
            A string that will render a table in LaTeX of the equations.
2630
        """
2631
        if not self.expression_spec_.supports_latex:
3✔
2632
            raise ValueError(
3✔
2633
                f"`expression_spec={self.expression_spec_}` does not support latex export."
2634
            )
2635
        self.refresh()
3✔
2636

2637
        if isinstance(self.equations_, list):
3✔
2638
            if indices is not None:
3✔
2639
                assert isinstance(indices, list)
×
2640
                assert isinstance(indices[0], list)
×
2641
                assert len(indices) == self.nout_
×
2642

2643
            table_string = sympy2multilatextable(
3✔
2644
                self.equations_, indices=indices, precision=precision, columns=columns
2645
            )
2646
        elif isinstance(self.equations_, pd.DataFrame):
3✔
2647
            if indices is not None:
3✔
2648
                assert isinstance(indices, list)
3✔
2649
                assert isinstance(indices[0], int)
3✔
2650

2651
            table_string = sympy2latextable(
3✔
2652
                self.equations_, indices=indices, precision=precision, columns=columns
2653
            )
2654
        else:
2655
            raise ValueError(
×
2656
                "Invalid type for equations_ to pass to `latex_table`. "
2657
                "Expected a DataFrame or a list of DataFrames."
2658
            )
2659

2660
        return with_preamble(table_string)
3✔
2661

2662

2663
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
3✔
2664
    """Select an expression and return its index."""
2665
    if model_selection == "accuracy":
3✔
2666
        chosen_idx = equations["loss"].idxmin()
3✔
2667
    elif model_selection == "best":
3✔
2668
        threshold = 1.5 * equations["loss"].min()
3✔
2669
        filtered_equations = equations.query(f"loss <= {threshold}")
3✔
2670
        chosen_idx = filtered_equations["score"].idxmax()
3✔
2671
    elif model_selection == "score":
3✔
2672
        chosen_idx = equations["score"].idxmax()
3✔
2673
    else:
2674
        raise NotImplementedError(
3✔
2675
            f"{model_selection} is not a valid model selection strategy."
2676
        )
2677
    return chosen_idx
3✔
2678

2679

2680
def calculate_scores(df: pd.DataFrame) -> pd.DataFrame:
3✔
2681
    """Calculate scores for each equation based on loss and complexity.
2682

2683
    Score is defined as the negated derivative of the log-loss with respect to complexity.
2684
    A higher score means the equation achieved a much better loss at a slightly higher complexity.
2685
    """
2686
    scores = []
3✔
2687
    lastMSE = None
3✔
2688
    lastComplexity = 0
3✔
2689

2690
    for _, row in df.iterrows():
3✔
2691
        curMSE = row["loss"]
3✔
2692
        curComplexity = row["complexity"]
3✔
2693

2694
        if lastMSE is None:
3✔
2695
            cur_score = 0.0
3✔
2696
        else:
2697
            if curMSE > 0.0:
3✔
2698
                cur_score = -np.log(curMSE / lastMSE) / (curComplexity - lastComplexity)
3✔
2699
            else:
2700
                cur_score = np.inf
3✔
2701

2702
        scores.append(cur_score)
3✔
2703
        lastMSE = curMSE
3✔
2704
        lastComplexity = curComplexity
3✔
2705

2706
    return pd.DataFrame(
3✔
2707
        {
2708
            "score": np.array(scores),
2709
        },
2710
        index=df.index,
2711
    )
2712

2713

2714
def _mutate_parameter(param_name: str, param_value):
3✔
2715
    if param_name == "batch_size" and param_value < 1:
3✔
2716
        warnings.warn(
×
2717
            "Given `batch_size` must be greater than or equal to one. "
2718
            "`batch_size` has been increased to equal one."
2719
        )
2720
        return 1
×
2721

2722
    if (
3✔
2723
        param_name == "progress"
2724
        and param_value == True
2725
        and "buffer" not in sys.stdout.__dir__()
2726
    ):
2727
        warnings.warn(
×
2728
            "Note: it looks like you are running in Jupyter. "
2729
            "The progress bar will be turned off."
2730
        )
2731
        return False
×
2732

2733
    return param_value
3✔
2734

2735

2736
def _map_parallelism_params(
3✔
2737
    parallelism: Literal["serial", "multithreading", "multiprocessing"] | None,
2738
    procs: int | None,
2739
    multithreading: bool | None,
2740
) -> tuple[Literal["serial", "multithreading", "multiprocessing"], int | None]:
2741
    """Map old and new parallelism parameters to the new format.
2742

2743
    Parameters
2744
    ----------
2745
    parallelism : str or None
2746
        New parallelism parameter. Can be "serial", "multithreading", or "multiprocessing".
2747
    procs : int or None
2748
        Number of processes parameter.
2749
    multithreading : bool or None
2750
        Old multithreading parameter.
2751

2752
    Returns
2753
    -------
2754
    parallelism : str
2755
        Mapped parallelism mode.
2756
    procs : int or None
2757
        Mapped number of processes.
2758

2759
    Raises
2760
    ------
2761
    ValueError
2762
        If both old and new parameters are specified, or if invalid combinations are given.
2763
    """
2764
    # Check for mixing old and new parameters
2765
    using_new = parallelism is not None
3✔
2766
    using_old = multithreading is not None
3✔
2767

2768
    if using_new and using_old:
3✔
NEW
2769
        raise ValueError(
×
2770
            "Cannot mix old and new parallelism parameters. "
2771
            "Use either `parallelism` and `numprocs`, or `procs` and `multithreading`."
2772
        )
2773
    elif using_old:
3✔
2774
        warnings.warn(
3✔
2775
            "The `multithreading: bool` parameter has been deprecated in favor "
2776
            "of `parallelism: Literal['multithreading', 'serial', 'multiprocessing']`.\n"
2777
            "Previous usage of `multithreading=True` (default) is now `parallelism='multithreading'`; "
2778
            "`multithreading=False, procs=0` is now `parallelism='serial'`; and "
2779
            "`multithreading=True, procs={int}` is now `parallelism='multiprocessing', procs={int}`."
2780
        )
2781
        if multithreading:
3✔
NEW
2782
            _parallelism: Literal["multithreading", "multiprocessing", "serial"] = (
×
2783
                "multithreading"
2784
            )
NEW
2785
            _procs = None
×
2786
        elif procs is not None and procs > 0:
3✔
NEW
2787
            _parallelism = "multiprocessing"
×
NEW
2788
            _procs = procs
×
2789
        else:
2790
            _parallelism = "serial"
3✔
2791
            _procs = None
3✔
2792
    elif using_new:
3✔
2793
        _parallelism = cast(
3✔
2794
            Literal["serial", "multithreading", "multiprocessing"], parallelism
2795
        )
2796
        _procs = procs
3✔
2797
    else:
2798
        _parallelism = "multithreading"
3✔
2799
        _procs = None
3✔
2800

2801
    if _parallelism not in {"serial", "multithreading", "multiprocessing"}:
3✔
NEW
2802
        raise ValueError(
×
2803
            "`parallelism` must be one of 'serial', 'multithreading', or 'multiprocessing'"
2804
        )
2805
    elif _parallelism == "serial" and _procs is not None:
3✔
NEW
2806
        warnings.warn(
×
2807
            "`numprocs` is specified but will be ignored since `parallelism='serial'`"
2808
        )
NEW
2809
        _procs = None
×
2810
    elif parallelism == "multithreading" and _procs is not None:
3✔
NEW
2811
        warnings.warn(
×
2812
            "`numprocs` is specified but will be ignored since `parallelism='multithreading'`"
2813
        )
NEW
2814
        _procs = None
×
2815
    elif parallelism == "multiprocessing" and _procs is None:
3✔
NEW
2816
        _procs = cpu_count()
×
2817

2818
    return _parallelism, _procs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc