• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / PySR / 13483409665

23 Feb 2025 01:16PM UTC coverage: 92.664% (-0.7%) from 93.372%
13483409665

Pull #794

github

MilesCranmer
Merge branch 'master' into fix-cluster-manager
Pull Request #794: Automatic slurm allocations

11 of 26 new or added lines in 3 files covered. (42.31%)

1 existing line in 1 file now uncovered.

1440 of 1554 relevant lines covered (92.66%)

2.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.47
/pysr/sr.py
1
"""Define the PySRRegressor scikit-learn interface."""
2

3
import copy
3✔
4
import logging
3✔
5
import os
3✔
6
import pickle as pkl
3✔
7
import re
3✔
8
import sys
3✔
9
import tempfile
3✔
10
import warnings
3✔
11
from collections.abc import Callable
3✔
12
from dataclasses import dataclass, fields
3✔
13
from io import StringIO
3✔
14
from multiprocessing import cpu_count
3✔
15
from pathlib import Path
3✔
16
from typing import Any, Literal, cast
3✔
17

18
import numpy as np
3✔
19
import pandas as pd
3✔
20
from numpy import ndarray
3✔
21
from numpy.typing import NDArray
3✔
22
from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
3✔
23
from sklearn.utils import check_array, check_consistent_length, check_random_state
3✔
24
from sklearn.utils.validation import _check_feature_names_in  # type: ignore
3✔
25
from sklearn.utils.validation import check_is_fitted
3✔
26

27
from .denoising import denoise, multi_denoise
3✔
28
from .deprecated import DEPRECATED_KWARGS
3✔
29
from .export_latex import (
3✔
30
    sympy2latex,
31
    sympy2latextable,
32
    sympy2multilatextable,
33
    with_preamble,
34
)
35
from .export_sympy import assert_valid_sympy_symbol
3✔
36
from .expression_specs import (
3✔
37
    AbstractExpressionSpec,
38
    ExpressionSpec,
39
    ParametricExpressionSpec,
40
)
41
from .feature_selection import run_feature_selection
3✔
42
from .julia_extensions import load_required_packages
3✔
43
from .julia_helpers import (
3✔
44
    _escape_filename,
45
    jl_array,
46
    jl_deserialize,
47
    jl_is_function,
48
    jl_serialize,
49
    load_cluster_manager,
50
)
51
from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl
3✔
52
from .logger_specs import AbstractLoggerSpec
3✔
53
from .utils import (
3✔
54
    ArrayLike,
55
    PathLike,
56
    _preprocess_julia_floats,
57
    _safe_check_feature_names_in,
58
    _subscriptify,
59
    _suggest_keywords,
60
)
61

62
try:
3✔
63
    from sklearn.utils.validation import validate_data
3✔
64

65
    OLD_SKLEARN = False
3✔
66
except ImportError:
×
67
    OLD_SKLEARN = True
×
68

69
ALREADY_RAN = False
3✔
70

71
pysr_logger = logging.getLogger(__name__)
3✔
72

73

74
def _process_constraints(
3✔
75
    binary_operators: list[str],
76
    unary_operators: list,
77
    constraints: dict[str, int | tuple[int, int]],
78
) -> dict[str, int | tuple[int, int]]:
79
    constraints = constraints.copy()
3✔
80
    for op in unary_operators:
3✔
81
        if op not in constraints:
3✔
82
            constraints[op] = -1
3✔
83
    for op in binary_operators:
3✔
84
        if op not in constraints:
3✔
85
            if op in ["^", "pow"]:
3✔
86
                # Warn user that they should set up constraints
87
                warnings.warn(
3✔
88
                    "You are using the `^` operator, but have not set up `constraints` for it. "
89
                    "This may lead to overly complex expressions. "
90
                    "One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which "
91
                    "will allow arbitrary-complexity base (-1) but only powers such as "
92
                    "a constant or variable (1). "
93
                    "For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/"
94
                )
95
            constraints[op] = (-1, -1)
3✔
96

97
        constraint_tuple = cast(tuple[int, int], constraints[op])
3✔
98
        if op in ["plus", "sub", "+", "-"]:
3✔
99
            if constraint_tuple[0] != constraint_tuple[1]:
3✔
100
                raise NotImplementedError(
3✔
101
                    "You need equal constraints on both sides for - and +, "
102
                    "due to simplification strategies."
103
                )
104
        elif op in ["mult", "*"]:
3✔
105
            # Make sure the complex expression is in the left side.
106
            if constraint_tuple[0] == -1:
3✔
107
                continue
3✔
108
            if constraint_tuple[1] == -1 or constraint_tuple[0] < constraint_tuple[1]:
×
109
                constraints[op] = (constraint_tuple[1], constraint_tuple[0])
×
110
    return constraints
3✔
111

112

113
def _maybe_create_inline_operators(
3✔
114
    binary_operators: list[str],
115
    unary_operators: list[str],
116
    extra_sympy_mappings: dict[str, Callable] | None,
117
    expression_spec: AbstractExpressionSpec,
118
) -> tuple[list[str], list[str]]:
119
    binary_operators = binary_operators.copy()
3✔
120
    unary_operators = unary_operators.copy()
3✔
121
    for op_list in [binary_operators, unary_operators]:
3✔
122
        for i, op in enumerate(op_list):
3✔
123
            is_user_defined_operator = "(" in op
3✔
124

125
            if is_user_defined_operator:
3✔
126
                jl.seval(op)
3✔
127
                # Cut off from the first non-alphanumeric char:
128
                first_non_char = [j for j, char in enumerate(op) if char == "("][0]
3✔
129
                function_name = op[:first_non_char]
3✔
130
                # Assert that function_name only contains
131
                # alphabetical characters, numbers,
132
                # and underscores:
133
                if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
3✔
134
                    raise ValueError(
3✔
135
                        f"Invalid function name {function_name}. "
136
                        "Only alphanumeric characters, numbers, "
137
                        "and underscores are allowed."
138
                    )
139
                missing_sympy_mapping = (
3✔
140
                    extra_sympy_mappings is None
141
                    or function_name not in extra_sympy_mappings
142
                )
143
                if missing_sympy_mapping and expression_spec.supports_sympy:
3✔
144
                    raise ValueError(
3✔
145
                        f"Custom function {function_name} is not defined in `extra_sympy_mappings`. "
146
                        "You can define it with, "
147
                        "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
148
                        "`lambda x: 1/x` is a valid SymPy function defining the operator. "
149
                        "You can also define these at initialization time."
150
                    )
151
                op_list[i] = function_name
3✔
152
    return binary_operators, unary_operators
3✔
153

154

155
def _check_assertions(
3✔
156
    X,
157
    use_custom_variable_names,
158
    variable_names,
159
    complexity_of_variables,
160
    weights,
161
    y,
162
    X_units,
163
    y_units,
164
):
165
    # Check for potential errors before they happen
166
    assert len(X.shape) == 2
3✔
167
    assert len(y.shape) in [1, 2]
3✔
168
    assert X.shape[0] == y.shape[0]
3✔
169
    if weights is not None:
3✔
170
        assert weights.shape == y.shape
3✔
171
        assert X.shape[0] == weights.shape[0]
3✔
172
    if use_custom_variable_names:
3✔
173
        assert len(variable_names) == X.shape[1]
3✔
174
        # Check none of the variable names are function names:
175
        for var_name in variable_names:
3✔
176
            # Check if alphanumeric only:
177
            if not re.match(r"^[₀₁₂₃₄₅₆₇₈₉a-zA-Z0-9_]+$", var_name):
3✔
178
                raise ValueError(
3✔
179
                    f"Invalid variable name {var_name}. "
180
                    "Only alphanumeric characters, numbers, "
181
                    "and underscores are allowed."
182
                )
183
            assert_valid_sympy_symbol(var_name)
3✔
184
    if (
3✔
185
        isinstance(complexity_of_variables, list)
186
        and len(complexity_of_variables) != X.shape[1]
187
    ):
188
        raise ValueError(
3✔
189
            "The number of elements in `complexity_of_variables` must equal the number of features in `X`."
190
        )
191
    if X_units is not None and len(X_units) != X.shape[1]:
3✔
192
        raise ValueError(
3✔
193
            "The number of units in `X_units` must equal the number of features in `X`."
194
        )
195
    if y_units is not None:
3✔
196
        good_y_units = False
3✔
197
        if isinstance(y_units, list):
3✔
198
            if len(y.shape) == 1:
3✔
199
                good_y_units = len(y_units) == 1
3✔
200
            else:
201
                good_y_units = len(y_units) == y.shape[1]
3✔
202
        else:
203
            good_y_units = len(y.shape) == 1 or y.shape[1] == 1
3✔
204

205
        if not good_y_units:
3✔
206
            raise ValueError(
3✔
207
                "The number of units in `y_units` must equal the number of output features in `y`."
208
            )
209

210

211
def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings):
3✔
212
    # It is expected extra_jax/torch_mappings will be updated after fit.
213
    # Thus, validation is performed here instead of in _validate_init_params
214
    if extra_jax_mappings is not None:
3✔
215
        for value in extra_jax_mappings.values():
1✔
216
            if not isinstance(value, str):
1✔
217
                raise ValueError(
×
218
                    "extra_jax_mappings must have keys that are strings! "
219
                    "e.g., {sympy.sqrt: 'jnp.sqrt'}."
220
                )
221
    if extra_torch_mappings is not None:
3✔
222
        for value in extra_torch_mappings.values():
1✔
223
            if not callable(value):
1✔
224
                raise ValueError(
×
225
                    "extra_torch_mappings must be callable functions! "
226
                    "e.g., {sympy.sqrt: torch.sqrt}."
227
                )
228

229

230
# Class validation constants
231
VALID_OPTIMIZER_ALGORITHMS = ["BFGS", "NelderMead"]
3✔
232

233

234
@dataclass
3✔
235
class _DynamicallySetParams:
3✔
236
    """Defines some parameters that are set at runtime."""
237

238
    binary_operators: list[str]
3✔
239
    unary_operators: list[str]
3✔
240
    maxdepth: int
3✔
241
    constraints: dict[str, int | tuple[int, int]]
3✔
242
    batch_size: int
3✔
243
    update_verbosity: int
3✔
244
    progress: bool
3✔
245
    warmup_maxsize_by: float
3✔
246

247

248
class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
3✔
249
    """
250
    High-performance symbolic regression algorithm.
251

252
    This is the scikit-learn interface for SymbolicRegression.jl.
253
    This model will automatically search for equations which fit
254
    a given dataset subject to a particular loss and set of
255
    constraints.
256

257
    Most default parameters have been tuned over several example equations,
258
    but you should adjust `niterations`, `binary_operators`, `unary_operators`
259
    to your requirements. You can view more detailed explanations of the options
260
    on the [options page](https://ai.damtp.cam.ac.uk/pysr/options) of the
261
    documentation.
262

263
    Parameters
264
    ----------
265
    model_selection : str
266
        Model selection criterion when selecting a final expression from
267
        the list of best expression at each complexity.
268
        Can be `'accuracy'`, `'best'`, or `'score'`. Default is `'best'`.
269
        `'accuracy'` selects the candidate model with the lowest loss
270
        (highest accuracy).
271
        `'score'` selects the candidate model with the highest score.
272
        Score is defined as the negated derivative of the log-loss with
273
        respect to complexity - if an expression has a much better
274
        loss at a slightly higher complexity, it is preferred.
275
        `'best'` selects the candidate model with the highest score
276
        among expressions with a loss better than at least 1.5x the
277
        most accurate model.
278
    binary_operators : list[str]
279
        List of strings for binary operators used in the search.
280
        See the [operators page](https://ai.damtp.cam.ac.uk/pysr/operators/)
281
        for more details.
282
        Default is `["+", "-", "*", "/"]`.
283
    unary_operators : list[str]
284
        Operators which only take a single scalar as input.
285
        For example, `"cos"` or `"exp"`.
286
        Default is `None`.
287
    expression_spec : AbstractExpressionSpec
288
        The type of expression to search for. By default,
289
        this is just `ExpressionSpec()`. You can also use
290
        `TemplateExpressionSpec(...)` which allows you to specify
291
        a custom template for the expressions.
292
        Default is `ExpressionSpec()`.
293
    niterations : int
294
        Number of iterations of the algorithm to run. The best
295
        equations are printed and migrate between populations at the
296
        end of each iteration.
297
        Default is `100`.
298
    populations : int
299
        Number of populations running.
300
        Default is `31`.
301
    population_size : int
302
        Number of individuals in each population.
303
        Default is `27`.
304
    max_evals : int
305
        Limits the total number of evaluations of expressions to
306
        this number.  Default is `None`.
307
    maxsize : int
308
        Max complexity of an equation.  Default is `30`.
309
    maxdepth : int
310
        Max depth of an equation. You can use both `maxsize` and
311
        `maxdepth`. `maxdepth` is by default not used.
312
        Default is `None`.
313
    warmup_maxsize_by : float
314
        Whether to slowly increase max size from a small number up to
315
        the maxsize (if greater than 0).  If greater than 0, says the
316
        fraction of training time at which the current maxsize will
317
        reach the user-passed maxsize.
318
        Default is `0.0`.
319
    timeout_in_seconds : float
320
        Make the search return early once this many seconds have passed.
321
        Default is `None`.
322
    constraints : dict[str, int | tuple[int,int]]
323
        Dictionary of int (unary) or 2-tuples (binary), this enforces
324
        maxsize constraints on the individual arguments of operators.
325
        E.g., `'pow': (-1, 1)` says that power laws can have any
326
        complexity left argument, but only 1 complexity in the right
327
        argument. Use this to force more interpretable solutions.
328
        Default is `None`.
329
    nested_constraints : dict[str, dict]
330
        Specifies how many times a combination of operators can be
331
        nested. For example, `{"sin": {"cos": 0}}, "cos": {"cos": 2}}`
332
        specifies that `cos` may never appear within a `sin`, but `sin`
333
        can be nested with itself an unlimited number of times. The
334
        second term specifies that `cos` can be nested up to 2 times
335
        within a `cos`, so that `cos(cos(cos(x)))` is allowed
336
        (as well as any combination of `+` or `-` within it), but
337
        `cos(cos(cos(cos(x))))` is not allowed. When an operator is not
338
        specified, it is assumed that it can be nested an unlimited
339
        number of times. This requires that there is no operator which
340
        is used both in the unary operators and the binary operators
341
        (e.g., `-` could be both subtract, and negation). For binary
342
        operators, you only need to provide a single number: both
343
        arguments are treated the same way, and the max of each
344
        argument is constrained.
345
        Default is `None`.
346
    elementwise_loss : str
347
        String of Julia code specifying an elementwise loss function.
348
        Can either be a loss from LossFunctions.jl, or your own loss
349
        written as a function. Examples of custom written losses include:
350
        `myloss(x, y) = abs(x-y)` for non-weighted, or
351
        `myloss(x, y, w) = w*abs(x-y)` for weighted.
352
        The included losses include:
353
        Regression: `LPDistLoss{P}()`, `L1DistLoss()`,
354
        `L2DistLoss()` (mean square), `LogitDistLoss()`,
355
        `HuberLoss(d)`, `L1EpsilonInsLoss(ϵ)`, `L2EpsilonInsLoss(ϵ)`,
356
        `PeriodicLoss(c)`, `QuantileLoss(τ)`.
357
        Classification: `ZeroOneLoss()`, `PerceptronLoss()`,
358
        `L1HingeLoss()`, `SmoothedL1HingeLoss(γ)`,
359
        `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`,
360
        `SigmoidLoss()`, `DWDMarginLoss(q)`.
361
        Default is `"L2DistLoss()"`.
362
    loss_function : str
363
        Alternatively, you can specify the full objective function as
364
        a snippet of Julia code, including any sort of custom evaluation
365
        (including symbolic manipulations beforehand), and any sort
366
        of loss function or regularizations. The default `loss_function`
367
        used in SymbolicRegression.jl is roughly equal to:
368
        ```julia
369
        function eval_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
370
            prediction, flag = eval_tree_array(tree, dataset.X, options)
371
            if !flag
372
                return L(Inf)
373
            end
374
            return sum((prediction .- dataset.y) .^ 2) / dataset.n
375
        end
376
        ```
377
        where the example elementwise loss is mean-squared error.
378
        You may pass a function with the same arguments as this (note
379
        that the name of the function doesn't matter). Here,
380
        both `prediction` and `dataset.y` are 1D arrays of length `dataset.n`.
381
        Default is `None`.
382
    loss_function_expression : str
383
        Similar to `loss_function`, but takes as input the full
384
        expression object as the first argument, rather than
385
        the innermost `AbstractExpressionNode`. This is useful
386
        for specifying custom loss functions on `TemplateExpressionSpec`.
387
        Default is `None`.
388
    complexity_of_operators : dict[str, int | float]
389
        If you would like to use a complexity other than 1 for an
390
        operator, specify the complexity here. For example,
391
        `{"sin": 2, "+": 1}` would give a complexity of 2 for each use
392
        of the `sin` operator, and a complexity of 1 for each use of
393
        the `+` operator (which is the default). You may specify real
394
        numbers for a complexity, and the total complexity of a tree
395
        will be rounded to the nearest integer after computing.
396
        Default is `None`.
397
    complexity_of_constants : int | float
398
        Complexity of constants. Default is `1`.
399
    complexity_of_variables : int | float | list[int | float]
400
        Global complexity of variables. To set different complexities for
401
        different variables, pass a list of complexities to the `fit` method
402
        with keyword `complexity_of_variables`. You cannot use both.
403
        Default is `1`.
404
    complexity_mapping : str
405
        Alternatively, you can pass a function (a string of Julia code) that
406
        takes the expression as input and returns the complexity. Make sure that
407
        this operates on `AbstractExpression` (and unpacks to `AbstractExpressionNode`),
408
        and returns an integer.
409
        Default is `None`.
410
    parsimony : float
411
        Multiplicative factor for how much to punish complexity.
412
        Default is `0.0`.
413
    dimensional_constraint_penalty : float
414
        Additive penalty for if dimensional analysis of an expression fails.
415
        By default, this is `1000.0`.
416
    dimensionless_constants_only : bool
417
        Whether to only search for dimensionless constants, if using units.
418
        Default is `False`.
419
    use_frequency : bool
420
        Whether to measure the frequency of complexities, and use that
421
        instead of parsimony to explore equation space. Will naturally
422
        find equations of all complexities.
423
        Default is `True`.
424
    use_frequency_in_tournament : bool
425
        Whether to use the frequency mentioned above in the tournament,
426
        rather than just the simulated annealing.
427
        Default is `True`.
428
    adaptive_parsimony_scaling : float
429
        If the adaptive parsimony strategy (`use_frequency` and
430
        `use_frequency_in_tournament`), this is how much to (exponentially)
431
        weight the contribution. If you find that the search is only optimizing
432
        the most complex expressions while the simpler expressions remain stagnant,
433
        you should increase this value.
434
        Default is `1040.0`.
435
    alpha : float
436
        Initial temperature for simulated annealing
437
        (requires `annealing` to be `True`).
438
        Default is `3.17`.
439
    annealing : bool
440
        Whether to use annealing.  Default is `False`.
441
    early_stop_condition : float | str
442
        Stop the search early if this loss is reached. You may also
443
        pass a string containing a Julia function which
444
        takes a loss and complexity as input, for example:
445
        `"f(loss, complexity) = (loss < 0.1) && (complexity < 10)"`.
446
        Default is `None`.
447
    ncycles_per_iteration : int
448
        Number of total mutations to run, per 10 samples of the
449
        population, per iteration.
450
        Default is `380`.
451
    fraction_replaced : float
452
        How much of population to replace with migrating equations from
453
        other populations.
454
        Default is `0.00036`.
455
    fraction_replaced_hof : float
456
        How much of population to replace with migrating equations from
457
        hall of fame. Default is `0.0614`.
458
    weight_add_node : float
459
        Relative likelihood for mutation to add a node.
460
        Default is `2.47`.
461
    weight_insert_node : float
462
        Relative likelihood for mutation to insert a node.
463
        Default is `0.0112`.
464
    weight_delete_node : float
465
        Relative likelihood for mutation to delete a node.
466
        Default is `0.870`.
467
    weight_do_nothing : float
468
        Relative likelihood for mutation to leave the individual.
469
        Default is `0.273`.
470
    weight_mutate_constant : float
471
        Relative likelihood for mutation to change the constant slightly
472
        in a random direction.
473
        Default is `0.0346`.
474
    weight_mutate_operator : float
475
        Relative likelihood for mutation to swap an operator.
476
        Default is `0.293`.
477
    weight_swap_operands : float
478
        Relative likehood for swapping operands in binary operators.
479
        Default is `0.198`.
480
    weight_rotate_tree : float
481
        How often to perform a tree rotation at a random node.
482
        Default is `4.26`.
483
    weight_randomize : float
484
        Relative likelihood for mutation to completely delete and then
485
        randomly generate the equation
486
        Default is `0.000502`.
487
    weight_simplify : float
488
        Relative likelihood for mutation to simplify constant parts by evaluation
489
        Default is `0.00209`.
490
    weight_optimize: float
491
        Constant optimization can also be performed as a mutation, in addition to
492
        the normal strategy controlled by `optimize_probability` which happens
493
        every iteration. Using it as a mutation is useful if you want to use
494
        a large `ncycles_periteration`, and may not optimize very often.
495
        Default is `0.0`.
496
    crossover_probability : float
497
        Absolute probability of crossover-type genetic operation, instead of a mutation.
498
        Default is `0.0259`.
499
    skip_mutation_failures : bool
500
        Whether to skip mutation and crossover failures, rather than
501
        simply re-sampling the current member.
502
        Default is `True`.
503
    migration : bool
504
        Whether to migrate.  Default is `True`.
505
    hof_migration : bool
506
        Whether to have the hall of fame migrate.  Default is `True`.
507
    topn : int
508
        How many top individuals migrate from each population.
509
        Default is `12`.
510
    should_simplify : bool
511
        Whether to use algebraic simplification in the search. Note that only
512
        a few simple rules are implemented. Default is `True`.
513
    should_optimize_constants : bool
514
        Whether to numerically optimize constants (Nelder-Mead/Newton)
515
        at the end of each iteration. Default is `True`.
516
    optimizer_algorithm : str
517
        Optimization scheme to use for optimizing constants. Can currently
518
        be `NelderMead` or `BFGS`.
519
        Default is `"BFGS"`.
520
    optimizer_nrestarts : int
521
        Number of time to restart the constants optimization process with
522
        different initial conditions.
523
        Default is `2`.
524
    optimizer_f_calls_limit : int
525
        How many function calls to allow during optimization.
526
        Default is `10_000`.
527
    optimize_probability : float
528
        Probability of optimizing the constants during a single iteration of
529
        the evolutionary algorithm.
530
        Default is `0.14`.
531
    optimizer_iterations : int
532
        Number of iterations that the constants optimizer can take.
533
        Default is `8`.
534
    perturbation_factor : float
535
        Constants are perturbed by a max factor of
536
        (perturbation_factor*T + 1). Either multiplied by this or
537
        divided by this.
538
        Default is `0.129`.
539
    probability_negate_constant : float
540
        Probability of negating a constant in the equation when mutating it.
541
        Default is `0.00743`.
542
    tournament_selection_n : int
543
        Number of expressions to consider in each tournament.
544
        Default is `15`.
545
    tournament_selection_p : float
546
        Probability of selecting the best expression in each
547
        tournament. The probability will decay as p*(1-p)^n for other
548
        expressions, sorted by loss.
549
        Default is `0.982`.
550
    parallelism: Literal["serial", "multithreading", "multiprocessing"] | None
551
        Parallelism to use for the search. Can be `"serial"`, `"multithreading"`, or `"multiprocessing"`.
552
        Default is `"multithreading"`.
553
    procs: int | None
554
        Number of processes to use for parallelism. If `None`, defaults to `cpu_count()`.
555
        Default is `None`.
556
    cluster_manager : str
557
        For distributed computing, this sets the job queue system. Set
558
        to one of "slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld",
559
        or "htc". If set to one of these, PySR will run in distributed
560
        mode, and use `procs` to figure out how many processes to launch.
561
        Default is `None`.
562
    heap_size_hint_in_bytes : int
563
        For multiprocessing, this sets the `--heap-size-hint` parameter
564
        for new Julia processes. This can be configured when using
565
        multi-node distributed compute, to give a hint to each process
566
        about how much memory they can use before aggressive garbage
567
        collection.
568
    batching : bool
569
        Whether to compare population members on small batches during
570
        evolution. Still uses full dataset for comparing against hall
571
        of fame. Default is `False`.
572
    batch_size : int
573
        The amount of data to use if doing batching. Default is `50`.
574
    fast_cycle : bool
575
        Batch over population subsamples. This is a slightly different
576
        algorithm than regularized evolution, but does cycles 15%
577
        faster. May be algorithmically less efficient.
578
        Default is `False`.
579
    turbo: bool
580
        (Experimental) Whether to use LoopVectorization.jl to speed up the
581
        search evaluation. Certain operators may not be supported.
582
        Does not support 16-bit precision floats.
583
        Default is `False`.
584
    bumper: bool
585
        (Experimental) Whether to use Bumper.jl to speed up the search
586
        evaluation. Does not support 16-bit precision floats.
587
        Default is `False`.
588
    precision : int
589
        What precision to use for the data. By default this is `32`
590
        (float32), but you can select `64` or `16` as well, giving
591
        you 64 or 16 bits of floating point precision, respectively.
592
        If you pass complex data, the corresponding complex precision
593
        will be used (i.e., `64` for complex128, `32` for complex64).
594
        Default is `32`.
595
    autodiff_backend : Literal["Zygote"] | None
596
        Which backend to use for automatic differentiation during constant
597
        optimization. Currently only `"Zygote"` is supported. The default,
598
        `None`, uses forward-mode or finite difference.
599
        Default is `None`.
600
    random_state : int, Numpy RandomState instance or None
601
        Pass an int for reproducible results across multiple function calls.
602
        See :term:`Glossary <random_state>`.
603
        Default is `None`.
604
    deterministic : bool
605
        Make a PySR search give the same result every run.
606
        To use this, you must turn off parallelism
607
        (with `parallelism="serial"`),
608
        and set `random_state` to a fixed seed.
609
        Default is `False`.
610
    warm_start : bool
611
        Tells fit to continue from where the last call to fit finished.
612
        If false, each call to fit will be fresh, overwriting previous results.
613
        Default is `False`.
614
    verbosity : int
615
        What verbosity level to use. 0 means minimal print statements.
616
        Default is `1`.
617
    update_verbosity : int
618
        What verbosity level to use for package updates.
619
        Will take value of `verbosity` if not given.
620
        Default is `None`.
621
    print_precision : int
622
        How many significant digits to print for floats. Default is `5`.
623
    progress : bool
624
        Whether to use a progress bar instead of printing to stdout.
625
        Default is `True`.
626
    logger_spec: AbstractLoggerSpec | None
627
        Logger specification for the Julia backend. See, for example,
628
        `TensorBoardLoggerSpec`.
629
        Default is `None`.
630
    input_stream : str
631
        The stream to read user input from. By default, this is `"stdin"`.
632
        If you encounter issues with reading from `stdin`, like a hang,
633
        you can simply pass `"devnull"` to this argument. You can also
634
        reference an arbitrary Julia object in the `Main` namespace.
635
        Default is `"stdin"`.
636
    run_id : str
637
        A unique identifier for the run. Will be generated using the
638
        current date and time if not provided.
639
        Default is `None`.
640
    output_directory : str
641
        The base directory to save output files to. Files
642
        will be saved in a subdirectory according to the run ID.
643
        Will be set to `outputs/` if not provided.
644
        Default is `None`.
645
    temp_equation_file : bool
646
        Whether to put the hall of fame file in the temp directory.
647
        Deletion is then controlled with the `delete_tempfiles`
648
        parameter.
649
        Default is `False`.
650
    tempdir : str
651
        directory for the temporary files. Default is `None`.
652
    delete_tempfiles : bool
653
        Whether to delete the temporary files after finishing.
654
        Default is `True`.
655
    update: bool
656
        Whether to automatically update Julia packages when `fit` is called.
657
        You should make sure that PySR is up-to-date itself first, as
658
        the packaged Julia packages may not necessarily include all
659
        updated dependencies.
660
        Default is `False`.
661
    output_jax_format : bool
662
        Whether to create a 'jax_format' column in the output,
663
        containing jax-callable functions and the default parameters in
664
        a jax array.
665
        Default is `False`.
666
    output_torch_format : bool
667
        Whether to create a 'torch_format' column in the output,
668
        containing a torch module with trainable parameters.
669
        Default is `False`.
670
    extra_sympy_mappings : dict[str, Callable]
671
        Provides mappings between custom `binary_operators` or
672
        `unary_operators` defined in julia strings, to those same
673
        operators defined in sympy.
674
        E.G if `unary_operators=["inv(x)=1/x"]`, then for the fitted
675
        model to be export to sympy, `extra_sympy_mappings`
676
        would be `{"inv": lambda x: 1/x}`.
677
        Default is `None`.
678
    extra_jax_mappings : dict[Callable, str]
679
        Similar to `extra_sympy_mappings` but for model export
680
        to jax. The dictionary maps sympy functions to jax functions.
681
        For example: `extra_jax_mappings={sympy.sin: "jnp.sin"}` maps
682
        the `sympy.sin` function to the equivalent jax expression `jnp.sin`.
683
        Default is `None`.
684
    extra_torch_mappings : dict[Callable, Callable]
685
        The same as `extra_jax_mappings` but for model export
686
        to pytorch. Note that the dictionary keys should be callable
687
        pytorch expressions.
688
        For example: `extra_torch_mappings={sympy.sin: torch.sin}`.
689
        Default is `None`.
690
    denoise : bool
691
        Whether to use a Gaussian Process to denoise the data before
692
        inputting to PySR. Can help PySR fit noisy data.
693
        Default is `False`.
694
    select_k_features : int
695
        Whether to run feature selection in Python using random forests,
696
        before passing to the symbolic regression code. None means no
697
        feature selection; an int means select that many features.
698
        Default is `None`.
699
    **kwargs : dict
700
        Supports deprecated keyword arguments. Other arguments will
701
        result in an error.
702
    Attributes
703
    ----------
704
    equations_ : pandas.DataFrame | list[pandas.DataFrame]
705
        Processed DataFrame containing the results of model fitting.
706
    n_features_in_ : int
707
        Number of features seen during :term:`fit`.
708
    feature_names_in_ : ndarray of shape (`n_features_in_`,)
709
        Names of features seen during :term:`fit`. Defined only when `X`
710
        has feature names that are all strings.
711
    display_feature_names_in_ : ndarray of shape (`n_features_in_`,)
712
        Pretty names of features, used only during printing.
713
    X_units_ : list[str] of length n_features
714
        Units of each variable in the training dataset, `X`.
715
    y_units_ : str | list[str] of length n_out
716
        Units of each variable in the training dataset, `y`.
717
    nout_ : int
718
        Number of output dimensions.
719
    selection_mask_ : ndarray of shape (`n_features_in_`,)
720
        Mask of which features of `X` to use when `select_k_features` is set.
721
    tempdir_ : Path | None
722
        Path to the temporary equations directory.
723
    julia_state_stream_ : ndarray
724
        The serialized state for the julia SymbolicRegression.jl backend (after fitting),
725
        stored as an array of uint8, produced by Julia's Serialization.serialize function.
726
    julia_options_stream_ : ndarray
727
        The serialized julia options, stored as an array of uint8,
728
    logger_ : AnyValue | None
729
        The logger instance used for this fit, if any.
730
    expression_spec_ : AbstractExpressionSpec
731
        The expression specification used for this fit. This is equal to
732
        `self.expression_spec` if provided, or `ExpressionSpec()` otherwise.
733
    equation_file_contents_ : list[pandas.DataFrame]
734
        Contents of the equation file output by the Julia backend.
735
    show_pickle_warnings_ : bool
736
        Whether to show warnings about what attributes can be pickled.
737

738
    Examples
739
    --------
740
    ```python
741
    >>> import numpy as np
742
    >>> from pysr import PySRRegressor
743
    >>> randstate = np.random.RandomState(0)
744
    >>> X = 2 * randstate.randn(100, 5)
745
    >>> # y = 2.5382 * cos(x_3) + x_0 - 0.5
746
    >>> y = 2.5382 * np.cos(X[:, 3]) + X[:, 0] ** 2 - 0.5
747
    >>> model = PySRRegressor(
748
    ...     niterations=40,
749
    ...     binary_operators=["+", "*"],
750
    ...     unary_operators=[
751
    ...         "cos",
752
    ...         "exp",
753
    ...         "sin",
754
    ...         "inv(x) = 1/x",  # Custom operator (julia syntax)
755
    ...     ],
756
    ...     model_selection="best",
757
    ...     elementwise_loss="loss(x, y) = (x - y)^2",  # Custom loss function (julia syntax)
758
    ... )
759
    >>> model.fit(X, y)
760
    >>> model
761
    PySRRegressor.equations_ = [
762
    0         0.000000                                          3.8552167  3.360272e+01           1
763
    1         1.189847                                          (x0 * x0)  3.110905e+00           3
764
    2         0.010626                          ((x0 * x0) + -0.25573406)  3.045491e+00           5
765
    3         0.896632                              (cos(x3) + (x0 * x0))  1.242382e+00           6
766
    4         0.811362                ((x0 * x0) + (cos(x3) * 2.4384754))  2.451971e-01           8
767
    5  >>>>  13.733371          (((cos(x3) * 2.5382) + (x0 * x0)) + -0.5)  2.889755e-13          10
768
    6         0.194695  ((x0 * x0) + (((cos(x3) + -0.063180044) * 2.53...  1.957723e-13          12
769
    7         0.006988  ((x0 * x0) + (((cos(x3) + -0.32505524) * 1.538...  1.944089e-13          13
770
    8         0.000955  (((((x0 * x0) + cos(x3)) + -0.8251649) + (cos(...  1.940381e-13          15
771
    ]
772
    >>> model.score(X, y)
773
    1.0
774
    >>> model.predict(np.array([1,2,3,4,5]))
775
    array([-1.15907818, -1.15907818, -1.15907818, -1.15907818, -1.15907818])
776
    ```
777
    """
778

779
    equations_: pd.DataFrame | list[pd.DataFrame] | None
3✔
780
    n_features_in_: int
3✔
781
    feature_names_in_: ArrayLike[str]
3✔
782
    display_feature_names_in_: ArrayLike[str]
3✔
783
    complexity_of_variables_: int | float | list[int | float] | None
3✔
784
    X_units_: ArrayLike[str] | None
3✔
785
    y_units_: str | ArrayLike[str] | None
3✔
786
    nout_: int
3✔
787
    selection_mask_: NDArray[np.bool_] | None
3✔
788
    run_id_: str
3✔
789
    output_directory_: str
3✔
790
    julia_state_stream_: NDArray[np.uint8] | None
3✔
791
    julia_options_stream_: NDArray[np.uint8] | None
3✔
792
    logger_: AnyValue | None
3✔
793
    equation_file_contents_: list[pd.DataFrame] | None
3✔
794
    show_pickle_warnings_: bool
3✔
795

796
    def __init__(
3✔
797
        self,
798
        model_selection: Literal["best", "accuracy", "score"] = "best",
799
        *,
800
        binary_operators: list[str] | None = None,
801
        unary_operators: list[str] | None = None,
802
        expression_spec: AbstractExpressionSpec | None = None,
803
        niterations: int = 100,
804
        populations: int = 31,
805
        population_size: int = 27,
806
        max_evals: int | None = None,
807
        maxsize: int = 30,
808
        maxdepth: int | None = None,
809
        warmup_maxsize_by: float | None = None,
810
        timeout_in_seconds: float | None = None,
811
        constraints: dict[str, int | tuple[int, int]] | None = None,
812
        nested_constraints: dict[str, dict[str, int]] | None = None,
813
        elementwise_loss: str | None = None,
814
        loss_function: str | None = None,
815
        loss_function_expression: str | None = None,
816
        complexity_of_operators: dict[str, int | float] | None = None,
817
        complexity_of_constants: int | float | None = None,
818
        complexity_of_variables: int | float | list[int | float] | None = None,
819
        complexity_mapping: str | None = None,
820
        parsimony: float = 0.0,
821
        dimensional_constraint_penalty: float | None = None,
822
        dimensionless_constants_only: bool = False,
823
        use_frequency: bool = True,
824
        use_frequency_in_tournament: bool = True,
825
        adaptive_parsimony_scaling: float = 1040.0,
826
        alpha: float = 3.17,
827
        annealing: bool = False,
828
        early_stop_condition: float | str | None = None,
829
        ncycles_per_iteration: int = 380,
830
        fraction_replaced: float = 0.00036,
831
        fraction_replaced_hof: float = 0.0614,
832
        weight_add_node: float = 2.47,
833
        weight_insert_node: float = 0.0112,
834
        weight_delete_node: float = 0.870,
835
        weight_do_nothing: float = 0.273,
836
        weight_mutate_constant: float = 0.0346,
837
        weight_mutate_operator: float = 0.293,
838
        weight_swap_operands: float = 0.198,
839
        weight_rotate_tree: float = 4.26,
840
        weight_randomize: float = 0.000502,
841
        weight_simplify: float = 0.00209,
842
        weight_optimize: float = 0.0,
843
        crossover_probability: float = 0.0259,
844
        skip_mutation_failures: bool = True,
845
        migration: bool = True,
846
        hof_migration: bool = True,
847
        topn: int = 12,
848
        should_simplify: bool = True,
849
        should_optimize_constants: bool = True,
850
        optimizer_algorithm: Literal["BFGS", "NelderMead"] = "BFGS",
851
        optimizer_nrestarts: int = 2,
852
        optimizer_f_calls_limit: int | None = None,
853
        optimize_probability: float = 0.14,
854
        optimizer_iterations: int = 8,
855
        perturbation_factor: float = 0.129,
856
        probability_negate_constant: float = 0.00743,
857
        tournament_selection_n: int = 15,
858
        tournament_selection_p: float = 0.982,
859
        # fmt: off
860
        parallelism: Literal["serial", "multithreading", "multiprocessing"] | None = None,
861
        procs: int | None = None,
862
        cluster_manager: Literal["slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | str | None = None,
863
        # fmt: on
864
        heap_size_hint_in_bytes: int | None = None,
865
        batching: bool = False,
866
        batch_size: int = 50,
867
        fast_cycle: bool = False,
868
        turbo: bool = False,
869
        bumper: bool = False,
870
        precision: Literal[16, 32, 64] = 32,
871
        autodiff_backend: Literal["Zygote"] | None = None,
872
        random_state: int | np.random.RandomState | None = None,
873
        deterministic: bool = False,
874
        warm_start: bool = False,
875
        verbosity: int = 1,
876
        update_verbosity: int | None = None,
877
        print_precision: int = 5,
878
        progress: bool = True,
879
        logger_spec: AbstractLoggerSpec | None = None,
880
        input_stream: str = "stdin",
881
        run_id: str | None = None,
882
        output_directory: str | None = None,
883
        temp_equation_file: bool = False,
884
        tempdir: str | None = None,
885
        delete_tempfiles: bool = True,
886
        update: bool = False,
887
        output_jax_format: bool = False,
888
        output_torch_format: bool = False,
889
        extra_sympy_mappings: dict[str, Callable] | None = None,
890
        extra_torch_mappings: dict[Callable, Callable] | None = None,
891
        extra_jax_mappings: dict[Callable, str] | None = None,
892
        denoise: bool = False,
893
        select_k_features: int | None = None,
894
        **kwargs,
895
    ):
896
        # Hyperparameters
897
        # - Model search parameters
898
        self.model_selection = model_selection
3✔
899
        self.binary_operators = binary_operators
3✔
900
        self.unary_operators = unary_operators
3✔
901
        self.expression_spec = expression_spec
3✔
902
        self.niterations = niterations
3✔
903
        self.populations = populations
3✔
904
        self.population_size = population_size
3✔
905
        self.ncycles_per_iteration = ncycles_per_iteration
3✔
906
        # - Equation Constraints
907
        self.maxsize = maxsize
3✔
908
        self.maxdepth = maxdepth
3✔
909
        self.constraints = constraints
3✔
910
        self.nested_constraints = nested_constraints
3✔
911
        self.warmup_maxsize_by = warmup_maxsize_by
3✔
912
        self.should_simplify = should_simplify
3✔
913
        # - Early exit conditions:
914
        self.max_evals = max_evals
3✔
915
        self.timeout_in_seconds = timeout_in_seconds
3✔
916
        self.early_stop_condition = early_stop_condition
3✔
917
        # - Loss parameters
918
        self.elementwise_loss = elementwise_loss
3✔
919
        self.loss_function = loss_function
3✔
920
        self.loss_function_expression = loss_function_expression
3✔
921
        self.complexity_of_operators = complexity_of_operators
3✔
922
        self.complexity_of_constants = complexity_of_constants
3✔
923
        self.complexity_of_variables = complexity_of_variables
3✔
924
        self.complexity_mapping = complexity_mapping
3✔
925
        self.parsimony = parsimony
3✔
926
        self.dimensional_constraint_penalty = dimensional_constraint_penalty
3✔
927
        self.dimensionless_constants_only = dimensionless_constants_only
3✔
928
        self.use_frequency = use_frequency
3✔
929
        self.use_frequency_in_tournament = use_frequency_in_tournament
3✔
930
        self.adaptive_parsimony_scaling = adaptive_parsimony_scaling
3✔
931
        self.alpha = alpha
3✔
932
        self.annealing = annealing
3✔
933
        # - Evolutionary search parameters
934
        # -- Mutation parameters
935
        self.weight_add_node = weight_add_node
3✔
936
        self.weight_insert_node = weight_insert_node
3✔
937
        self.weight_delete_node = weight_delete_node
3✔
938
        self.weight_do_nothing = weight_do_nothing
3✔
939
        self.weight_mutate_constant = weight_mutate_constant
3✔
940
        self.weight_mutate_operator = weight_mutate_operator
3✔
941
        self.weight_swap_operands = weight_swap_operands
3✔
942
        self.weight_rotate_tree = weight_rotate_tree
3✔
943
        self.weight_randomize = weight_randomize
3✔
944
        self.weight_simplify = weight_simplify
3✔
945
        self.weight_optimize = weight_optimize
3✔
946
        self.crossover_probability = crossover_probability
3✔
947
        self.skip_mutation_failures = skip_mutation_failures
3✔
948
        # -- Migration parameters
949
        self.migration = migration
3✔
950
        self.hof_migration = hof_migration
3✔
951
        self.fraction_replaced = fraction_replaced
3✔
952
        self.fraction_replaced_hof = fraction_replaced_hof
3✔
953
        self.topn = topn
3✔
954
        # -- Constants parameters
955
        self.should_optimize_constants = should_optimize_constants
3✔
956
        self.optimizer_algorithm = optimizer_algorithm
3✔
957
        self.optimizer_nrestarts = optimizer_nrestarts
3✔
958
        self.optimizer_f_calls_limit = optimizer_f_calls_limit
3✔
959
        self.optimize_probability = optimize_probability
3✔
960
        self.optimizer_iterations = optimizer_iterations
3✔
961
        self.perturbation_factor = perturbation_factor
3✔
962
        self.probability_negate_constant = probability_negate_constant
3✔
963
        # -- Selection parameters
964
        self.tournament_selection_n = tournament_selection_n
3✔
965
        self.tournament_selection_p = tournament_selection_p
3✔
966
        # -- Performance parameters
967
        self.parallelism = parallelism
3✔
968
        self.procs = procs
3✔
969
        self.cluster_manager = cluster_manager
3✔
970
        self.heap_size_hint_in_bytes = heap_size_hint_in_bytes
3✔
971
        self.batching = batching
3✔
972
        self.batch_size = batch_size
3✔
973
        self.fast_cycle = fast_cycle
3✔
974
        self.turbo = turbo
3✔
975
        self.bumper = bumper
3✔
976
        self.precision = precision
3✔
977
        self.autodiff_backend = autodiff_backend
3✔
978
        self.random_state = random_state
3✔
979
        self.deterministic = deterministic
3✔
980
        self.warm_start = warm_start
3✔
981
        # Additional runtime parameters
982
        # - Runtime user interface
983
        self.verbosity = verbosity
3✔
984
        self.update_verbosity = update_verbosity
3✔
985
        self.print_precision = print_precision
3✔
986
        self.progress = progress
3✔
987
        self.logger_spec = logger_spec
3✔
988
        self.input_stream = input_stream
3✔
989
        # - Project management
990
        self.run_id = run_id
3✔
991
        self.output_directory = output_directory
3✔
992
        self.temp_equation_file = temp_equation_file
3✔
993
        self.tempdir = tempdir
3✔
994
        self.delete_tempfiles = delete_tempfiles
3✔
995
        self.update = update
3✔
996
        self.output_jax_format = output_jax_format
3✔
997
        self.output_torch_format = output_torch_format
3✔
998
        self.extra_sympy_mappings = extra_sympy_mappings
3✔
999
        self.extra_jax_mappings = extra_jax_mappings
3✔
1000
        self.extra_torch_mappings = extra_torch_mappings
3✔
1001
        # Pre-modelling transformation
1002
        self.denoise = denoise
3✔
1003
        self.select_k_features = select_k_features
3✔
1004

1005
        # Once all valid parameters have been assigned handle the
1006
        # deprecated kwargs
1007
        if len(kwargs) > 0:  # pragma: no cover
1008
            for k, v in kwargs.items():
1009
                # Handle renamed kwargs
1010
                if k in DEPRECATED_KWARGS:
1011
                    updated_kwarg_name = DEPRECATED_KWARGS[k]
1012
                    setattr(self, updated_kwarg_name, v)
1013
                    warnings.warn(
1014
                        f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
1015
                        "Please use that instead.",
1016
                        FutureWarning,
1017
                    )
1018
                elif k == "multithreading":
1019
                    # Specific advice given in `_map_parallelism_params`
1020
                    self.multithreading: bool | None = v
1021
                # Handle kwargs that have been moved to the fit method
1022
                elif k in ["weights", "variable_names", "Xresampled"]:
1023
                    warnings.warn(
1024
                        f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
1025
                        f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
1026
                        FutureWarning,
1027
                    )
1028
                elif k == "julia_project":
1029
                    warnings.warn(
1030
                        "The `julia_project` parameter has been deprecated. To use a custom "
1031
                        "julia project, please see `https://ai.damtp.cam.ac.uk/pysr/backend`.",
1032
                        FutureWarning,
1033
                    )
1034
                elif k == "julia_kwargs":
1035
                    warnings.warn(
1036
                        "The `julia_kwargs` parameter has been deprecated. To pass custom "
1037
                        "keyword arguments to the julia backend, you should use environment variables. "
1038
                        "See the Julia documentation for more information.",
1039
                        FutureWarning,
1040
                    )
1041
                else:
1042
                    suggested_keywords = _suggest_keywords(PySRRegressor, k)
1043
                    err_msg = (
1044
                        f"`{k}` is not a valid keyword argument for PySRRegressor."
1045
                    )
1046
                    if len(suggested_keywords) > 0:
1047
                        err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
1048
                    raise TypeError(err_msg)
1049

1050
    @classmethod
3✔
1051
    def from_file(
3✔
1052
        cls,
1053
        equation_file: None = None,  # Deprecated
1054
        *,
1055
        run_directory: PathLike,
1056
        binary_operators: list[str] | None = None,
1057
        unary_operators: list[str] | None = None,
1058
        n_features_in: int | None = None,
1059
        feature_names_in: ArrayLike[str] | None = None,
1060
        selection_mask: NDArray[np.bool_] | None = None,
1061
        nout: int = 1,
1062
        **pysr_kwargs,
1063
    ) -> "PySRRegressor":
1064
        """
1065
        Create a model from a saved model checkpoint or equation file.
1066

1067
        Parameters
1068
        ----------
1069
        run_directory : str
1070
            The directory containing outputs from a previous run.
1071
            This is of the form `[output_directory]/[run_id]`.
1072
            Default is `None`.
1073
        binary_operators : list[str]
1074
            The same binary operators used when creating the model.
1075
            Not needed if loading from a pickle file.
1076
        unary_operators : list[str]
1077
            The same unary operators used when creating the model.
1078
            Not needed if loading from a pickle file.
1079
        n_features_in : int
1080
            Number of features passed to the model.
1081
            Not needed if loading from a pickle file.
1082
        feature_names_in : list[str]
1083
            Names of the features passed to the model.
1084
            Not needed if loading from a pickle file.
1085
        selection_mask : NDArray[np.bool_]
1086
            If using `select_k_features`, you must pass `model.selection_mask_` here.
1087
            Not needed if loading from a pickle file.
1088
        nout : int
1089
            Number of outputs of the model.
1090
            Not needed if loading from a pickle file.
1091
            Default is `1`.
1092
        **pysr_kwargs : dict
1093
            Any other keyword arguments to initialize the PySRRegressor object.
1094
            These will overwrite those stored in the pickle file.
1095
            Not needed if loading from a pickle file.
1096

1097
        Returns
1098
        -------
1099
        model : PySRRegressor
1100
            The model with fitted equations.
1101
        """
1102
        if equation_file is not None:
3✔
1103
            raise ValueError(
3✔
1104
                "Passing `equation_file` is deprecated and no longer compatible with "
1105
                "the most recent versions of PySR's backend. Please pass `run_directory` "
1106
                "instead, which contains all checkpoint files."
1107
            )
1108

1109
        pkl_filename = Path(run_directory) / "checkpoint.pkl"
3✔
1110
        if pkl_filename.exists():
3✔
1111
            pysr_logger.info(f"Attempting to load model from {pkl_filename}...")
3✔
1112
            assert binary_operators is None
3✔
1113
            assert unary_operators is None
3✔
1114
            assert n_features_in is None
3✔
1115
            with open(pkl_filename, "rb") as f:
3✔
1116
                model = cast("PySRRegressor", pkl.load(f))
3✔
1117

1118
            # Update any parameters if necessary, such as
1119
            # extra_sympy_mappings:
1120
            model.set_params(**pysr_kwargs)
3✔
1121

1122
            if "equations_" not in model.__dict__ or model.equations_ is None:
3✔
1123
                model.refresh()
×
1124

1125
            if model.expression_spec is not None:
3✔
1126
                warnings.warn(
×
1127
                    "Loading model from checkpoint file with a non-default expression spec "
1128
                    "is not fully supported as it relies on dynamic objects. This may result in unexpected behavior.",
1129
                )
1130

1131
            return model
3✔
1132
        else:
1133
            pysr_logger.info(
3✔
1134
                f"Checkpoint file {pkl_filename} does not exist. "
1135
                "Attempting to recreate model from scratch..."
1136
            )
1137
            csv_filename = Path(run_directory) / "hall_of_fame.csv"
3✔
1138
            csv_filename_bak = Path(run_directory) / "hall_of_fame.csv.bak"
3✔
1139
            if not csv_filename.exists() and not csv_filename_bak.exists():
3✔
1140
                raise FileNotFoundError(
×
1141
                    f"Hall of fame file `{csv_filename}` or `{csv_filename_bak}` does not exist. "
1142
                    "Please pass a `run_directory` containing a valid checkpoint file."
1143
                )
1144
            assert binary_operators is not None or unary_operators is not None
3✔
1145
            assert n_features_in is not None
3✔
1146
            model = cls(
3✔
1147
                binary_operators=binary_operators,
1148
                unary_operators=unary_operators,
1149
                **pysr_kwargs,
1150
            )
1151
            model.nout_ = nout
3✔
1152
            model.n_features_in_ = n_features_in
3✔
1153

1154
            if feature_names_in is None:
3✔
1155
                model.feature_names_in_ = np.array(
3✔
1156
                    [f"x{i}" for i in range(n_features_in)]
1157
                )
1158
                model.display_feature_names_in_ = np.array(
3✔
1159
                    [f"x{_subscriptify(i)}" for i in range(n_features_in)]
1160
                )
1161
            else:
1162
                assert len(feature_names_in) == n_features_in
3✔
1163
                model.feature_names_in_ = feature_names_in
3✔
1164
                model.display_feature_names_in_ = feature_names_in
3✔
1165

1166
            if selection_mask is None:
3✔
1167
                model.selection_mask_ = np.ones(n_features_in, dtype=np.bool_)
3✔
1168
            else:
1169
                model.selection_mask_ = selection_mask
×
1170

1171
            model.refresh(run_directory=run_directory)
3✔
1172

1173
            return model
3✔
1174

1175
    def __repr__(self) -> str:
3✔
1176
        """
1177
        Print all current equations fitted by the model.
1178

1179
        The string `>>>>` denotes which equation is selected by the
1180
        `model_selection`.
1181
        """
1182
        if not hasattr(self, "equations_") or self.equations_ is None:
3✔
1183
            return "PySRRegressor.equations_ = None"
3✔
1184

1185
        output = "PySRRegressor.equations_ = [\n"
3✔
1186

1187
        equations = self.equations_
3✔
1188
        if not isinstance(equations, list):
3✔
1189
            all_equations = [equations]
3✔
1190
        else:
1191
            all_equations = equations
×
1192

1193
        for i, equations in enumerate(all_equations):
3✔
1194
            selected = pd.Series([""] * len(equations), index=equations.index)
3✔
1195
            chosen_row = idx_model_selection(equations, self.model_selection)
3✔
1196
            selected[chosen_row] = ">>>>"
3✔
1197
            repr_equations = pd.DataFrame(
3✔
1198
                dict(
1199
                    pick=selected,
1200
                    score=equations["score"],
1201
                    equation=equations["equation"],
1202
                    loss=equations["loss"],
1203
                    complexity=equations["complexity"],
1204
                )
1205
            )
1206

1207
            if len(all_equations) > 1:
3✔
1208
                output += "[\n"
×
1209

1210
            for line in repr_equations.__repr__().split("\n"):
3✔
1211
                output += "\t" + line + "\n"
3✔
1212

1213
            if len(all_equations) > 1:
3✔
1214
                output += "]"
×
1215

1216
            if i < len(all_equations) - 1:
3✔
1217
                output += ", "
×
1218

1219
        output += "]"
3✔
1220
        return output
3✔
1221

1222
    def __getstate__(self) -> dict[str, Any]:
3✔
1223
        """
1224
        Handle pickle serialization for PySRRegressor.
1225

1226
        The Scikit-learn standard requires estimators to be serializable via
1227
        `pickle.dumps()`. However, some attributes do not support pickling
1228
        and need to be hidden, such as the JAX and Torch representations.
1229
        """
1230
        state = self.__dict__
3✔
1231
        show_pickle_warning = not (
3✔
1232
            "show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
1233
        )
1234
        state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
3✔
1235
        for state_key in state_keys_containing_lambdas:
3✔
1236
            warn_msg = (
3✔
1237
                f"`{state_key}` cannot be pickled and will be removed from the "
1238
                "serialized instance. When loading the model, please redefine "
1239
                f"`{state_key}` at runtime."
1240
            )
1241
            if state[state_key] is not None:
3✔
1242
                if show_pickle_warning:
3✔
1243
                    warnings.warn(warn_msg)
×
1244
                else:
1245
                    pysr_logger.debug(warn_msg)
3✔
1246
        state_keys_to_clear = state_keys_containing_lambdas
3✔
1247
        state_keys_to_clear.append("logger_")
3✔
1248
        pickled_state = {
3✔
1249
            key: (None if key in state_keys_to_clear else value)
1250
            for key, value in state.items()
1251
        }
1252
        if ("equations_" in pickled_state) and (
3✔
1253
            pickled_state["equations_"] is not None
1254
        ):
1255
            pickled_state["output_torch_format"] = False
3✔
1256
            pickled_state["output_jax_format"] = False
3✔
1257
            if self.nout_ == 1:
3✔
1258
                pickled_columns = ~pickled_state["equations_"].columns.isin(
3✔
1259
                    ["jax_format", "torch_format"]
1260
                )
1261
                pickled_state["equations_"] = (
3✔
1262
                    pickled_state["equations_"].loc[:, pickled_columns].copy()
1263
                )
1264
            else:
1265
                pickled_columns = [
3✔
1266
                    ~dataframe.columns.isin(["jax_format", "torch_format"])
1267
                    for dataframe in pickled_state["equations_"]
1268
                ]
1269
                pickled_state["equations_"] = [
3✔
1270
                    dataframe.loc[:, signle_pickled_columns]
1271
                    for dataframe, signle_pickled_columns in zip(
1272
                        pickled_state["equations_"], pickled_columns
1273
                    )
1274
                ]
1275
        return pickled_state
3✔
1276

1277
    def _checkpoint(self):
3✔
1278
        """Save the model's current state to a checkpoint file.
1279

1280
        This should only be used internally by PySRRegressor.
1281
        """
1282
        # Save model state:
1283
        self.show_pickle_warnings_ = False
3✔
1284
        with open(self.get_pkl_filename(), "wb") as f:
3✔
1285
            try:
3✔
1286
                pkl.dump(self, f)
3✔
1287
            except Exception as e:
1✔
1288
                pysr_logger.debug(f"Error checkpointing model: {e}")
1✔
1289
        self.show_pickle_warnings_ = True
3✔
1290

1291
    def get_pkl_filename(self) -> Path:
3✔
1292
        path = Path(self.output_directory_) / self.run_id_ / "checkpoint.pkl"
3✔
1293
        path.parent.mkdir(parents=True, exist_ok=True)
3✔
1294
        return path
3✔
1295

1296
    @property
1297
    def equations(self):  # pragma: no cover
1298
        warnings.warn(
1299
            "PySRRegressor.equations is now deprecated. "
1300
            "Please use PySRRegressor.equations_ instead.",
1301
            FutureWarning,
1302
        )
1303
        return self.equations_
1304

1305
    @property
3✔
1306
    def julia_options_(self):
3✔
1307
        """The deserialized julia options."""
1308
        return jl_deserialize(self.julia_options_stream_)
3✔
1309

1310
    @property
3✔
1311
    def julia_state_(self):
3✔
1312
        """The deserialized state."""
1313
        return cast(
3✔
1314
            tuple[VectorValue, AnyValue] | None,
1315
            jl_deserialize(self.julia_state_stream_),
1316
        )
1317

1318
    @property
3✔
1319
    def raw_julia_state_(self):
3✔
1320
        warnings.warn(
3✔
1321
            "PySRRegressor.raw_julia_state_ is now deprecated. "
1322
            "Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
1323
            "for the raw stream of bytes.",
1324
            FutureWarning,
1325
        )
1326
        return self.julia_state_
3✔
1327

1328
    @property
3✔
1329
    def expression_spec_(self):
3✔
1330
        return self.expression_spec or ExpressionSpec()
3✔
1331

1332
    def get_best(
3✔
1333
        self, index: int | list[int] | None = None
1334
    ) -> pd.Series | list[pd.Series]:
1335
        """
1336
        Get best equation using `model_selection`.
1337

1338
        Parameters
1339
        ----------
1340
        index : int | list[int]
1341
            If you wish to select a particular equation from `self.equations_`,
1342
            give the row number here. This overrides the `model_selection`
1343
            parameter. If there are multiple output features, then pass
1344
            a list of indices with the order the same as the output feature.
1345

1346
        Returns
1347
        -------
1348
        best_equation : pandas.Series
1349
            Dictionary representing the best expression found.
1350

1351
        Raises
1352
        ------
1353
        NotImplementedError
1354
            Raised when an invalid model selection strategy is provided.
1355
        """
1356
        check_is_fitted(self, attributes=["equations_"])
3✔
1357

1358
        if index is not None:
3✔
1359
            if isinstance(self.equations_, list):
3✔
1360
                assert isinstance(
3✔
1361
                    index, list
1362
                ), "With multiple output features, index must be a list."
1363
                return [eq.iloc[i] for eq, i in zip(self.equations_, index)]
3✔
1364
            else:
1365
                equations_ = cast(pd.DataFrame, self.equations_)
3✔
1366
                return cast(pd.Series, equations_.iloc[index])
3✔
1367

1368
        if isinstance(self.equations_, list):
3✔
1369
            return [
3✔
1370
                cast(pd.Series, eq.loc[idx_model_selection(eq, self.model_selection)])
1371
                for eq in self.equations_
1372
            ]
1373
        else:
1374
            equations_ = cast(pd.DataFrame, self.equations_)
3✔
1375
            return cast(
3✔
1376
                pd.Series,
1377
                equations_.loc[idx_model_selection(equations_, self.model_selection)],
1378
            )
1379

1380
    @property
3✔
1381
    def equation_file_(self):
3✔
1382
        raise NotImplementedError(
3✔
1383
            "PySRRegressor.equation_file_ is now deprecated. "
1384
            "Please use PySRRegressor.output_directory_ and PySRRegressor.run_id_ "
1385
            "instead. For loading, you should pass `run_directory`."
1386
        )
1387

1388
    def _setup_equation_file(self):
3✔
1389
        """Set the pathname of the output directory."""
1390
        if self.warm_start and (
3✔
1391
            hasattr(self, "run_id_") or hasattr(self, "output_directory_")
1392
        ):
1393
            assert hasattr(self, "output_directory_")
3✔
1394
            assert hasattr(self, "run_id_")
3✔
1395
            if self.run_id is not None:
3✔
1396
                assert self.run_id_ == self.run_id
3✔
1397
            if self.output_directory is not None:
3✔
1398
                assert self.output_directory_ == self.output_directory
3✔
1399
        else:
1400
            self.output_directory_ = (
3✔
1401
                tempfile.mkdtemp()
1402
                if self.temp_equation_file
1403
                else (
1404
                    "outputs"
1405
                    if self.output_directory is None
1406
                    else self.output_directory
1407
                )
1408
            )
1409
            self.run_id_ = (
3✔
1410
                cast(str, SymbolicRegression.SearchUtilsModule.generate_run_id())
1411
                if self.run_id is None
1412
                else self.run_id
1413
            )
1414
            if self.temp_equation_file:
3✔
1415
                assert self.output_directory is None
3✔
1416

1417
    def _clear_equation_file_contents(self):
3✔
1418
        self.equation_file_contents_ = None
3✔
1419

1420
    def _validate_and_modify_params(self) -> _DynamicallySetParams:
3✔
1421
        """
1422
        Ensure parameters passed at initialization are valid.
1423

1424
        Also returns a dictionary of parameters to update from their
1425
        values given at initialization.
1426

1427
        Returns
1428
        -------
1429
        packed_modified_params : dict
1430
            Dictionary of parameters to modify from their initialized
1431
            values. For example, default parameters are set here
1432
            when a parameter is left set to `None`.
1433
        """
1434
        # Immutable parameter validation
1435
        # Ensure instance parameters are allowable values:
1436
        if self.tournament_selection_n > self.population_size:
3✔
1437
            raise ValueError(
3✔
1438
                "`tournament_selection_n` parameter must be smaller than `population_size`."
1439
            )
1440

1441
        if self.maxsize > 40:
3✔
1442
            warnings.warn(
×
1443
                "Note: Using a large maxsize for the equation search will be "
1444
                "exponentially slower and use significant memory."
1445
            )
1446
        elif self.maxsize < 7:
3✔
1447
            raise ValueError("PySR requires a maxsize of at least 7")
3✔
1448

1449
        # NotImplementedError - Values that could be supported at a later time
1450
        if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
3✔
1451
            raise NotImplementedError(
3✔
1452
                f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
1453
            )
1454

1455
        param_container = _DynamicallySetParams(
3✔
1456
            binary_operators=["+", "*", "-", "/"],
1457
            unary_operators=[],
1458
            maxdepth=self.maxsize,
1459
            constraints={},
1460
            batch_size=1,
1461
            update_verbosity=int(self.verbosity),
1462
            progress=self.progress,
1463
            warmup_maxsize_by=0.0,
1464
        )
1465

1466
        for param_name in map(lambda x: x.name, fields(_DynamicallySetParams)):
3✔
1467
            user_param_value = getattr(self, param_name)
3✔
1468
            if user_param_value is None:
3✔
1469
                # Leave as the default in DynamicallySetParams
1470
                ...
3✔
1471
            else:
1472
                # If user has specified it, we will override the default.
1473
                # However, there are some special cases to mutate it:
1474
                new_param_value = _mutate_parameter(param_name, user_param_value)
3✔
1475
                setattr(param_container, param_name, new_param_value)
3✔
1476
        # TODO: This should just be part of the __init__ of _DynamicallySetParams
1477

1478
        assert (
3✔
1479
            len(param_container.binary_operators) > 0
1480
            or len(param_container.unary_operators) > 0
1481
        ), "At least one operator must be provided."
1482

1483
        return param_container
3✔
1484

1485
    def _validate_and_set_fit_params(
3✔
1486
        self,
1487
        X,
1488
        y,
1489
        Xresampled,
1490
        weights,
1491
        variable_names,
1492
        complexity_of_variables,
1493
        X_units,
1494
        y_units,
1495
    ) -> tuple[
1496
        ndarray,
1497
        ndarray,
1498
        ndarray | None,
1499
        ndarray | None,
1500
        ArrayLike[str],
1501
        int | float | list[int | float] | None,
1502
        ArrayLike[str] | None,
1503
        str | ArrayLike[str] | None,
1504
    ]:
1505
        """
1506
        Validate the parameters passed to the :term`fit` method.
1507

1508
        This method also sets the `nout_` attribute.
1509

1510
        Parameters
1511
        ----------
1512
        X : ndarray | pandas.DataFrame
1513
            Training data of shape `(n_samples, n_features)`.
1514
        y : ndarray | pandas.DataFrame}
1515
            Target values of shape `(n_samples,)` or `(n_samples, n_targets)`.
1516
            Will be cast to `X`'s dtype if necessary.
1517
        Xresampled : ndarray | pandas.DataFrame
1518
            Resampled training data used for denoising,
1519
            of shape `(n_resampled, n_features)`.
1520
        weights : ndarray | pandas.DataFrame
1521
            Weight array of the same shape as `y`.
1522
            Each element is how to weight the mean-square-error loss
1523
            for that particular element of y.
1524
        variable_names : ndarray of length n_features
1525
            Names of each feature in the training dataset, `X`.
1526
        complexity_of_variables : int | float | list[int | float]
1527
            Complexity of each feature in the training dataset, `X`.
1528
        X_units : list[str] of length n_features
1529
            Units of each feature in the training dataset, `X`.
1530
        y_units : str | list[str] of length n_out
1531
            Units of each feature in the training dataset, `y`.
1532

1533
        Returns
1534
        -------
1535
        X_validated : ndarray of shape (n_samples, n_features)
1536
            Validated training data.
1537
        y_validated : ndarray of shape (n_samples,) or (n_samples, n_targets)
1538
            Validated target data.
1539
        Xresampled : ndarray of shape (n_resampled, n_features)
1540
            Validated resampled training data used for denoising.
1541
        variable_names_validated : list[str] of length n_features
1542
            Validated list of variable names for each feature in `X`.
1543
        X_units : list[str] of length n_features
1544
            Validated units for `X`.
1545
        y_units : str | list[str] of length n_out
1546
            Validated units for `y`.
1547

1548
        """
1549
        if isinstance(X, pd.DataFrame):
3✔
1550
            if variable_names:
3✔
1551
                variable_names = None
×
1552
                warnings.warn(
×
1553
                    "`variable_names` has been reset to `None` as `X` is a DataFrame. "
1554
                    "Using DataFrame column names instead."
1555
                )
1556

1557
            if (
3✔
1558
                pd.api.types.is_object_dtype(X.columns)
1559
                and X.columns.str.contains(" ").any()
1560
            ):
1561
                X.columns = X.columns.str.replace(" ", "_")
×
1562
                warnings.warn(
×
1563
                    "Spaces in DataFrame column names are not supported. "
1564
                    "Spaces have been replaced with underscores. \n"
1565
                    "Please rename the columns to valid names."
1566
                )
1567
        elif variable_names and any([" " in name for name in variable_names]):
3✔
1568
            variable_names = [name.replace(" ", "_") for name in variable_names]
×
1569
            warnings.warn(
×
1570
                "Spaces in `variable_names` are not supported. "
1571
                "Spaces have been replaced with underscores. \n"
1572
                "Please use valid names instead."
1573
            )
1574

1575
        if (
3✔
1576
            complexity_of_variables is not None
1577
            and self.complexity_of_variables is not None
1578
        ):
1579
            raise ValueError(
3✔
1580
                "You cannot set `complexity_of_variables` at both `fit` and `__init__`. "
1581
                "Pass it at `__init__` to set it to global default, OR use `fit` to set it for "
1582
                "each variable individually."
1583
            )
1584
        elif complexity_of_variables is not None:
3✔
1585
            complexity_of_variables = complexity_of_variables
3✔
1586
        elif self.complexity_of_variables is not None:
3✔
1587
            complexity_of_variables = self.complexity_of_variables
3✔
1588
        else:
1589
            complexity_of_variables = None
3✔
1590

1591
        # Data validation and feature name fetching via sklearn
1592
        # This method sets the n_features_in_ attribute
1593
        if Xresampled is not None:
3✔
1594
            Xresampled = check_array(Xresampled)
3✔
1595
        if weights is not None:
3✔
1596
            weights = check_array(weights, ensure_2d=False)
3✔
1597
            check_consistent_length(weights, y)
3✔
1598
        X, y = self._validate_data_X_y(X, y)
3✔
1599
        self.feature_names_in_ = _safe_check_feature_names_in(
3✔
1600
            self, variable_names, generate_names=False
1601
        )
1602

1603
        if self.feature_names_in_ is None:
3✔
1604
            self.feature_names_in_ = np.array([f"x{i}" for i in range(X.shape[1])])
3✔
1605
            self.display_feature_names_in_ = np.array(
3✔
1606
                [f"x{_subscriptify(i)}" for i in range(X.shape[1])]
1607
            )
1608
            variable_names = self.feature_names_in_
3✔
1609
        else:
1610
            self.display_feature_names_in_ = self.feature_names_in_
3✔
1611
            variable_names = self.feature_names_in_
3✔
1612

1613
        # Handle multioutput data
1614
        if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
3✔
1615
            y = y.reshape(-1)
3✔
1616
        elif len(y.shape) == 2:
3✔
1617
            self.nout_ = y.shape[1]
3✔
1618
        else:
1619
            raise NotImplementedError("y shape not supported!")
×
1620

1621
        self.complexity_of_variables_ = copy.deepcopy(complexity_of_variables)
3✔
1622
        self.X_units_ = copy.deepcopy(X_units)
3✔
1623
        self.y_units_ = copy.deepcopy(y_units)
3✔
1624

1625
        return (
3✔
1626
            X,
1627
            y,
1628
            Xresampled,
1629
            weights,
1630
            variable_names,
1631
            complexity_of_variables,
1632
            X_units,
1633
            y_units,
1634
        )
1635

1636
    def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
3✔
1637
        if OLD_SKLEARN:
3✔
1638
            raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True)  # type: ignore
×
1639
        else:
1640
            raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True)  # type: ignore
3✔
1641
        return cast(tuple[ndarray, ndarray], raw_out)
3✔
1642

1643
    def _validate_data_X(self, X: Any) -> ndarray:
3✔
1644
        if OLD_SKLEARN:
3✔
1645
            raw_out = self._validate_data(X=X, reset=False)  # type: ignore
×
1646
        else:
1647
            raw_out = validate_data(self, X=X, reset=False)  # type: ignore
3✔
1648
        return cast(ndarray, raw_out)
3✔
1649

1650
    def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:
3✔
1651
        is_complex = np.issubdtype(X.dtype, np.complexfloating)
3✔
1652
        is_real = not is_complex
3✔
1653
        if is_real:
3✔
1654
            return {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
3✔
1655
        else:
1656
            return {32: np.complex64, 64: np.complex128}[self.precision]
3✔
1657

1658
    def _pre_transform_training_data(
3✔
1659
        self,
1660
        X: ndarray,
1661
        y: ndarray,
1662
        Xresampled: ndarray | None,
1663
        variable_names: ArrayLike[str],
1664
        complexity_of_variables: int | float | list[int | float] | None,
1665
        X_units: ArrayLike[str] | None,
1666
        y_units: ArrayLike[str] | str | None,
1667
        random_state: np.random.RandomState,
1668
    ):
1669
        """
1670
        Transform the training data before fitting the symbolic regressor.
1671

1672
        This method also updates/sets the `selection_mask_` attribute.
1673

1674
        Parameters
1675
        ----------
1676
        X : ndarray
1677
            Training data of shape (n_samples, n_features).
1678
        y : ndarray
1679
            Target values of shape (n_samples,) or (n_samples, n_targets).
1680
            Will be cast to X's dtype if necessary.
1681
        Xresampled : ndarray | None
1682
            Resampled training data, of shape `(n_resampled, n_features)`,
1683
            used for denoising.
1684
        variable_names : list[str]
1685
            Names of each variable in the training dataset, `X`.
1686
            Of length `n_features`.
1687
        complexity_of_variables : int | float | list[int | float] | None
1688
            Complexity of each variable in the training dataset, `X`.
1689
        X_units : list[str]
1690
            Units of each variable in the training dataset, `X`.
1691
        y_units : str | list[str]
1692
            Units of each variable in the training dataset, `y`.
1693
        random_state : int | np.RandomState
1694
            Pass an int for reproducible results across multiple function calls.
1695
            See :term:`Glossary <random_state>`. Default is `None`.
1696

1697
        Returns
1698
        -------
1699
        X_transformed : ndarray of shape (n_samples, n_features)
1700
            Transformed training data. n_samples will be equal to
1701
            `Xresampled.shape[0]` if `self.denoise` is `True`,
1702
            and `Xresampled is not None`, otherwise it will be
1703
            equal to `X.shape[0]`. n_features will be equal to
1704
            `self.select_k_features` if `self.select_k_features is not None`,
1705
            otherwise it will be equal to `X.shape[1]`
1706
        y_transformed : ndarray of shape (n_samples,) or (n_samples, n_outputs)
1707
            Transformed target data. n_samples will be equal to
1708
            `Xresampled.shape[0]` if `self.denoise` is `True`,
1709
            and `Xresampled is not None`, otherwise it will be
1710
            equal to `X.shape[0]`.
1711
        variable_names_transformed : list[str] of length n_features
1712
            Names of each variable in the transformed dataset,
1713
            `X_transformed`.
1714
        X_units_transformed : list[str] of length n_features
1715
            Units of each variable in the transformed dataset.
1716
        y_units_transformed : str | list[str] of length n_out
1717
            Units of each variable in the transformed dataset.
1718
        """
1719
        # Feature selection transformation
1720
        if self.select_k_features:
3✔
1721
            selection_mask = run_feature_selection(
3✔
1722
                X, y, self.select_k_features, random_state=random_state
1723
            )
1724
            X = X[:, selection_mask]
3✔
1725

1726
            if Xresampled is not None:
3✔
1727
                Xresampled = Xresampled[:, selection_mask]
3✔
1728

1729
            # Reduce variable_names to selection
1730
            variable_names = cast(
3✔
1731
                ArrayLike[str],
1732
                [
1733
                    variable_names[i]
1734
                    for i in range(len(variable_names))
1735
                    if selection_mask[i]
1736
                ],
1737
            )
1738

1739
            if isinstance(complexity_of_variables, list):
3✔
1740
                complexity_of_variables = [
×
1741
                    complexity_of_variables[i]
1742
                    for i in range(len(complexity_of_variables))
1743
                    if selection_mask[i]
1744
                ]
1745
                self.complexity_of_variables_ = copy.deepcopy(complexity_of_variables)
×
1746

1747
            if X_units is not None:
3✔
1748
                X_units = cast(
3✔
1749
                    ArrayLike[str],
1750
                    [X_units[i] for i in range(len(X_units)) if selection_mask[i]],
1751
                )
1752
                self.X_units_ = copy.deepcopy(X_units)
3✔
1753

1754
            # Re-perform data validation and feature name updating
1755
            X, y = self._validate_data_X_y(X, y)
3✔
1756
            # Update feature names with selected variable names
1757
            self.selection_mask_ = selection_mask
3✔
1758
            self.feature_names_in_ = _check_feature_names_in(self, variable_names)
3✔
1759
            self.display_feature_names_in_ = self.feature_names_in_
3✔
1760
            pysr_logger.info(f"Using features {self.feature_names_in_}")
3✔
1761

1762
        # Denoising transformation
1763
        if self.denoise:
3✔
1764
            if self.nout_ > 1:
3✔
1765
                X, y = multi_denoise(
3✔
1766
                    X, y, Xresampled=Xresampled, random_state=random_state
1767
                )
1768
            else:
1769
                X, y = denoise(X, y, Xresampled=Xresampled, random_state=random_state)
3✔
1770

1771
        return X, y, variable_names, complexity_of_variables, X_units, y_units
3✔
1772

1773
    def _run(
3✔
1774
        self,
1775
        X: ndarray,
1776
        y: ndarray,
1777
        runtime_params: _DynamicallySetParams,
1778
        weights: ndarray | None,
1779
        category: ndarray | None,
1780
        seed: int,
1781
    ):
1782
        """
1783
        Run the symbolic regression fitting process on the julia backend.
1784

1785
        Parameters
1786
        ----------
1787
        X : ndarray
1788
            Training data of shape `(n_samples, n_features)`.
1789
        y : ndarray
1790
            Target values of shape `(n_samples,)` or `(n_samples, n_targets)`.
1791
            Will be cast to `X`'s dtype if necessary.
1792
        runtime_params : DynamicallySetParams
1793
            Dynamically set versions of some parameters passed in __init__.
1794
        weights : ndarray | None
1795
            Weight array of the same shape as `y`.
1796
            Each element is how to weight the mean-square-error loss
1797
            for that particular element of y.
1798
        category : ndarray | None
1799
            If `expression_spec` is a `ParametricExpressionSpec`, then this
1800
            argument should be a list of integers representing the category
1801
            of each sample in `X`.
1802
        seed : int
1803
            Random seed for julia backend process.
1804

1805
        Returns
1806
        -------
1807
        self : object
1808
            Reference to `self` with fitted attributes.
1809

1810
        Raises
1811
        ------
1812
        ImportError
1813
            Raised when the julia backend fails to import a package.
1814
        """
1815
        # Need to be global as we don't want to recreate/reinstate julia for
1816
        # every new instance of PySRRegressor
1817
        global ALREADY_RAN
1818

1819
        # These are the parameters which may be modified from the ones
1820
        # specified in init, so we define them here locally:
1821
        binary_operators = runtime_params.binary_operators
3✔
1822
        unary_operators = runtime_params.unary_operators
3✔
1823
        constraints = runtime_params.constraints
3✔
1824

1825
        nested_constraints = self.nested_constraints
3✔
1826
        complexity_of_operators = self.complexity_of_operators
3✔
1827
        complexity_of_variables = self.complexity_of_variables_
3✔
1828
        cluster_manager = self.cluster_manager
3✔
1829

1830
        # Start julia backend processes
1831
        if not ALREADY_RAN and runtime_params.update_verbosity != 0:
3✔
1832
            pysr_logger.info("Compiling Julia backend...")
3✔
1833

1834
        parallelism, numprocs = _map_parallelism_params(
3✔
1835
            self.parallelism, self.procs, getattr(self, "multithreading", None)
1836
        )
1837

1838
        if self.deterministic and parallelism != "serial":
3✔
1839
            raise ValueError(
3✔
1840
                "To ensure deterministic searches, you must set `parallelism='serial'`. "
1841
                "Additionally, make sure to set `random_state` to a seed."
1842
            )
1843
        if self.random_state is not None and (
3✔
1844
            parallelism != "serial" or not self.deterministic
1845
        ):
1846
            warnings.warn(
3✔
1847
                "Note: Setting `random_state` without also setting `deterministic=True` "
1848
                "and `parallelism='serial'` will result in non-deterministic searches."
1849
            )
1850

1851
        if cluster_manager is not None:
3✔
1852
            if parallelism != "multiprocessing":
×
1853
                raise ValueError(
×
1854
                    "To use cluster managers, you must set `parallelism='multiprocessing'`."
1855
                )
NEW
1856
            cluster_manager = load_cluster_manager(cluster_manager)
×
1857

1858
        # TODO(mcranmer): These functions should be part of this class.
1859
        binary_operators, unary_operators = _maybe_create_inline_operators(
3✔
1860
            binary_operators=binary_operators,
1861
            unary_operators=unary_operators,
1862
            extra_sympy_mappings=self.extra_sympy_mappings,
1863
            expression_spec=self.expression_spec_,
1864
        )
1865
        if constraints is not None:
3✔
1866
            _constraints = _process_constraints(
3✔
1867
                binary_operators=binary_operators,
1868
                unary_operators=unary_operators,
1869
                constraints=constraints,
1870
            )
1871
            una_constraints = [_constraints[op] for op in unary_operators]
3✔
1872
            bin_constraints = [_constraints[op] for op in binary_operators]
3✔
1873
        else:
1874
            una_constraints = None
×
1875
            bin_constraints = None
×
1876

1877
        # Parse dict into Julia Dict for nested constraints::
1878
        if nested_constraints is not None:
3✔
1879
            nested_constraints_str = "Dict("
3✔
1880
            for outer_k, outer_v in nested_constraints.items():
3✔
1881
                nested_constraints_str += f"({outer_k}) => Dict("
3✔
1882
                for inner_k, inner_v in outer_v.items():
3✔
1883
                    nested_constraints_str += f"({inner_k}) => {inner_v}, "
3✔
1884
                nested_constraints_str += "), "
3✔
1885
            nested_constraints_str += ")"
3✔
1886
            nested_constraints = jl.seval(nested_constraints_str)
3✔
1887

1888
        # Parse dict into Julia Dict for complexities:
1889
        if complexity_of_operators is not None:
3✔
1890
            complexity_of_operators_str = "Dict("
3✔
1891
            for k, v in complexity_of_operators.items():
3✔
1892
                complexity_of_operators_str += f"({k}) => {v}, "
3✔
1893
            complexity_of_operators_str += ")"
3✔
1894
            complexity_of_operators = jl.seval(complexity_of_operators_str)
3✔
1895
        # TODO: Refactor this into helper function
1896

1897
        if isinstance(complexity_of_variables, list):
3✔
1898
            complexity_of_variables = jl_array(complexity_of_variables)
3✔
1899

1900
        custom_loss = jl.seval(
3✔
1901
            str(self.elementwise_loss)
1902
            if self.elementwise_loss is not None
1903
            else "nothing"
1904
        )
1905
        custom_full_objective = jl.seval(
3✔
1906
            str(self.loss_function) if self.loss_function is not None else "nothing"
1907
        )
1908
        custom_loss_expression = jl.seval(
3✔
1909
            str(self.loss_function_expression)
1910
            if self.loss_function_expression is not None
1911
            else "nothing"
1912
        )
1913

1914
        early_stop_condition = jl.seval(
3✔
1915
            str(self.early_stop_condition)
1916
            if self.early_stop_condition is not None
1917
            else "nothing"
1918
        )
1919

1920
        input_stream = jl.seval(self.input_stream)
3✔
1921

1922
        load_required_packages(
3✔
1923
            turbo=self.turbo,
1924
            bumper=self.bumper,
1925
            autodiff_backend=self.autodiff_backend,
1926
            cluster_manager=cluster_manager,
1927
            logger_spec=self.logger_spec,
1928
        )
1929

1930
        if self.autodiff_backend is not None:
3✔
1931
            autodiff_backend = jl.Symbol(self.autodiff_backend)
×
1932
        else:
1933
            autodiff_backend = None
3✔
1934

1935
        mutation_weights = SymbolicRegression.MutationWeights(
3✔
1936
            mutate_constant=self.weight_mutate_constant,
1937
            mutate_operator=self.weight_mutate_operator,
1938
            swap_operands=self.weight_swap_operands,
1939
            rotate_tree=self.weight_rotate_tree,
1940
            add_node=self.weight_add_node,
1941
            insert_node=self.weight_insert_node,
1942
            delete_node=self.weight_delete_node,
1943
            simplify=self.weight_simplify,
1944
            randomize=self.weight_randomize,
1945
            do_nothing=self.weight_do_nothing,
1946
            optimize=self.weight_optimize,
1947
        )
1948

1949
        jl_binary_operators: list[Any] = []
3✔
1950
        jl_unary_operators: list[Any] = []
3✔
1951
        for input_list, output_list, name in [
3✔
1952
            (binary_operators, jl_binary_operators, "binary"),
1953
            (unary_operators, jl_unary_operators, "unary"),
1954
        ]:
1955
            for op in input_list:
3✔
1956
                jl_op = jl.seval(op)
3✔
1957
                if not jl_is_function(jl_op):
3✔
1958
                    raise ValueError(
3✔
1959
                        f"When building `{name}_operators`, `'{op}'` did not return a Julia function"
1960
                    )
1961
                output_list.append(jl_op)
3✔
1962

1963
        complexity_mapping = (
3✔
1964
            jl.seval(self.complexity_mapping) if self.complexity_mapping else None
1965
        )
1966

1967
        if hasattr(self, "logger_") and self.logger_ is not None and self.warm_start:
3✔
1968
            logger = self.logger_
3✔
1969
        else:
1970
            logger = self.logger_spec.create_logger() if self.logger_spec else None
3✔
1971

1972
        self.logger_ = logger
3✔
1973

1974
        # Call to Julia backend.
1975
        # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1976
        options = SymbolicRegression.Options(
3✔
1977
            binary_operators=jl_array(jl_binary_operators, dtype=jl.Function),
1978
            unary_operators=jl_array(jl_unary_operators, dtype=jl.Function),
1979
            bin_constraints=jl_array(bin_constraints),
1980
            una_constraints=jl_array(una_constraints),
1981
            complexity_of_operators=complexity_of_operators,
1982
            complexity_of_constants=self.complexity_of_constants,
1983
            complexity_of_variables=complexity_of_variables,
1984
            complexity_mapping=complexity_mapping,
1985
            expression_spec=self.expression_spec_.julia_expression_spec(),
1986
            nested_constraints=nested_constraints,
1987
            elementwise_loss=custom_loss,
1988
            loss_function=custom_full_objective,
1989
            loss_function_expression=custom_loss_expression,
1990
            maxsize=int(self.maxsize),
1991
            output_directory=_escape_filename(self.output_directory_),
1992
            npopulations=int(self.populations),
1993
            batching=self.batching,
1994
            batch_size=int(
1995
                min([runtime_params.batch_size, len(X)]) if self.batching else len(X)
1996
            ),
1997
            mutation_weights=mutation_weights,
1998
            tournament_selection_p=self.tournament_selection_p,
1999
            tournament_selection_n=self.tournament_selection_n,
2000
            # These have the same name:
2001
            parsimony=self.parsimony,
2002
            dimensional_constraint_penalty=self.dimensional_constraint_penalty,
2003
            dimensionless_constants_only=self.dimensionless_constants_only,
2004
            alpha=self.alpha,
2005
            maxdepth=runtime_params.maxdepth,
2006
            fast_cycle=self.fast_cycle,
2007
            turbo=self.turbo,
2008
            bumper=self.bumper,
2009
            autodiff_backend=autodiff_backend,
2010
            migration=self.migration,
2011
            hof_migration=self.hof_migration,
2012
            fraction_replaced_hof=self.fraction_replaced_hof,
2013
            should_simplify=self.should_simplify,
2014
            should_optimize_constants=self.should_optimize_constants,
2015
            warmup_maxsize_by=runtime_params.warmup_maxsize_by,
2016
            use_frequency=self.use_frequency,
2017
            use_frequency_in_tournament=self.use_frequency_in_tournament,
2018
            adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
2019
            npop=self.population_size,
2020
            ncycles_per_iteration=self.ncycles_per_iteration,
2021
            fraction_replaced=self.fraction_replaced,
2022
            topn=self.topn,
2023
            print_precision=self.print_precision,
2024
            optimizer_algorithm=self.optimizer_algorithm,
2025
            optimizer_nrestarts=self.optimizer_nrestarts,
2026
            optimizer_f_calls_limit=self.optimizer_f_calls_limit,
2027
            optimizer_probability=self.optimize_probability,
2028
            optimizer_iterations=self.optimizer_iterations,
2029
            perturbation_factor=self.perturbation_factor,
2030
            probability_negate_constant=self.probability_negate_constant,
2031
            annealing=self.annealing,
2032
            timeout_in_seconds=self.timeout_in_seconds,
2033
            crossover_probability=self.crossover_probability,
2034
            skip_mutation_failures=self.skip_mutation_failures,
2035
            max_evals=self.max_evals,
2036
            input_stream=input_stream,
2037
            early_stop_condition=early_stop_condition,
2038
            seed=seed,
2039
            deterministic=self.deterministic,
2040
            define_helper_functions=False,
2041
        )
2042

2043
        self.julia_options_stream_ = jl_serialize(options)
3✔
2044

2045
        # Convert data to desired precision
2046
        test_X = np.array(X)
3✔
2047
        np_dtype = self._get_precision_mapped_dtype(test_X)
3✔
2048

2049
        # This converts the data into a Julia array:
2050
        jl_X = jl_array(np.array(X, dtype=np_dtype).T)
3✔
2051
        if len(y.shape) == 1:
3✔
2052
            jl_y = jl_array(np.array(y, dtype=np_dtype))
3✔
2053
        else:
2054
            jl_y = jl_array(np.array(y, dtype=np_dtype).T)
3✔
2055
        if weights is not None:
3✔
2056
            if len(weights.shape) == 1:
3✔
2057
                jl_weights = jl_array(np.array(weights, dtype=np_dtype))
3✔
2058
            else:
2059
                jl_weights = jl_array(np.array(weights, dtype=np_dtype).T)
3✔
2060
        else:
2061
            jl_weights = None
3✔
2062

2063
        if category is not None:
3✔
2064
            offset_for_julia_indexing = 1
3✔
2065
            jl_category = jl_array(
3✔
2066
                (category + offset_for_julia_indexing).astype(np.int64)
2067
            )
2068
            jl_extra = jl.seval("NamedTuple{(:class,)}")((jl_category,))
3✔
2069
        else:
2070
            jl_extra = jl.NamedTuple()
3✔
2071

2072
        if len(y.shape) > 1:
3✔
2073
            # We set these manually so that they respect Python's 0 indexing
2074
            # (by default Julia will use y1, y2...)
2075
            jl_y_variable_names = jl_array(
3✔
2076
                [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
2077
            )
2078
        else:
2079
            jl_y_variable_names = None
3✔
2080

2081
        out = SymbolicRegression.equation_search(
3✔
2082
            jl_X,
2083
            jl_y,
2084
            weights=jl_weights,
2085
            extra=jl_extra,
2086
            niterations=int(self.niterations),
2087
            variable_names=jl_array([str(v) for v in self.feature_names_in_]),
2088
            display_variable_names=jl_array(
2089
                [str(v) for v in self.display_feature_names_in_]
2090
            ),
2091
            y_variable_names=jl_y_variable_names,
2092
            X_units=jl_array(self.X_units_),
2093
            y_units=(
2094
                jl_array(self.y_units_)
2095
                if isinstance(self.y_units_, list)
2096
                else self.y_units_
2097
            ),
2098
            options=options,
2099
            numprocs=numprocs,
2100
            parallelism=parallelism,
2101
            saved_state=self.julia_state_,
2102
            return_state=True,
2103
            run_id=self.run_id_,
2104
            addprocs_function=cluster_manager,
2105
            heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
2106
            progress=runtime_params.progress
2107
            and self.verbosity > 0
2108
            and len(y.shape) == 1,
2109
            verbosity=int(self.verbosity),
2110
            logger=logger,
2111
        )
2112
        if self.logger_spec is not None:
3✔
2113
            self.logger_spec.write_hparams(logger, self.get_params())
3✔
2114
            if not self.warm_start:
3✔
2115
                self.logger_spec.close(logger)
3✔
2116

2117
        self.julia_state_stream_ = jl_serialize(out)
3✔
2118

2119
        # Set attributes
2120
        self.equations_ = self.get_hof(out)
3✔
2121

2122
        ALREADY_RAN = True
3✔
2123

2124
        return self
3✔
2125

2126
    def fit(
3✔
2127
        self,
2128
        X,
2129
        y,
2130
        *,
2131
        Xresampled=None,
2132
        weights=None,
2133
        variable_names: ArrayLike[str] | None = None,
2134
        complexity_of_variables: int | float | list[int | float] | None = None,
2135
        X_units: ArrayLike[str] | None = None,
2136
        y_units: str | ArrayLike[str] | None = None,
2137
        category: ndarray | None = None,
2138
    ) -> "PySRRegressor":
2139
        """
2140
        Search for equations to fit the dataset and store them in `self.equations_`.
2141

2142
        Parameters
2143
        ----------
2144
        X : ndarray | pandas.DataFrame
2145
            Training data of shape (n_samples, n_features).
2146
        y : ndarray | pandas.DataFrame
2147
            Target values of shape (n_samples,) or (n_samples, n_targets).
2148
            Will be cast to X's dtype if necessary.
2149
        Xresampled : ndarray | pandas.DataFrame
2150
            Resampled training data, of shape (n_resampled, n_features),
2151
            to generate a denoised data on. This
2152
            will be used as the training data, rather than `X`.
2153
        weights : ndarray | pandas.DataFrame
2154
            Weight array of the same shape as `y`.
2155
            Each element is how to weight the mean-square-error loss
2156
            for that particular element of `y`. Alternatively,
2157
            if a custom `loss` was set, it will can be used
2158
            in arbitrary ways.
2159
        variable_names : list[str]
2160
            A list of names for the variables, rather than "x0", "x1", etc.
2161
            If `X` is a pandas dataframe, the column names will be used
2162
            instead of `variable_names`. Cannot contain spaces or special
2163
            characters. Avoid variable names which are also
2164
            function names in `sympy`, such as "N".
2165
        X_units : list[str]
2166
            A list of units for each variable in `X`. Each unit should be
2167
            a string representing a Julia expression. See DynamicQuantities.jl
2168
            https://symbolicml.org/DynamicQuantities.jl/dev/units/ for more
2169
            information.
2170
        y_units : str | list[str]
2171
            Similar to `X_units`, but as a unit for the target variable, `y`.
2172
            If `y` is a matrix, a list of units should be passed. If `X_units`
2173
            is given but `y_units` is not, then `y_units` will be arbitrary.
2174
        category : list[int]
2175
            If `expression_spec` is a `ParametricExpressionSpec`, then this
2176
            argument should be a list of integers representing the category
2177
            of each sample.
2178

2179
        Returns
2180
        -------
2181
        self : object
2182
            Fitted estimator.
2183
        """
2184
        # Init attributes that are not specified in BaseEstimator
2185
        if self.warm_start and hasattr(self, "julia_state_stream_"):
3✔
2186
            pass
3✔
2187
        else:
2188
            if hasattr(self, "julia_state_stream_"):
3✔
2189
                warnings.warn(
3✔
2190
                    "The discovered expressions are being reset. "
2191
                    "Please set `warm_start=True` if you wish to continue "
2192
                    "to start a search where you left off.",
2193
                )
2194

2195
            self.equations_ = None
3✔
2196
            self.nout_ = 1
3✔
2197
            self.selection_mask_ = None
3✔
2198
            self.julia_state_stream_ = None
3✔
2199
            self.julia_options_stream_ = None
3✔
2200
            self.complexity_of_variables_ = None
3✔
2201
            self.X_units_ = None
3✔
2202
            self.y_units_ = None
3✔
2203

2204
        self._setup_equation_file()
3✔
2205
        self._clear_equation_file_contents()
3✔
2206

2207
        runtime_params = self._validate_and_modify_params()
3✔
2208

2209
        if category is not None:
3✔
2210
            assert Xresampled is None
3✔
2211

2212
        if isinstance(self.expression_spec, ParametricExpressionSpec):
3✔
2213
            assert category is not None
3✔
2214

2215
        # TODO: Put `category` here
2216
        (
3✔
2217
            X,
2218
            y,
2219
            Xresampled,
2220
            weights,
2221
            variable_names,
2222
            complexity_of_variables,
2223
            X_units,
2224
            y_units,
2225
        ) = self._validate_and_set_fit_params(
2226
            X,
2227
            y,
2228
            Xresampled,
2229
            weights,
2230
            variable_names,
2231
            complexity_of_variables,
2232
            X_units,
2233
            y_units,
2234
        )
2235

2236
        if X.shape[0] > 10000 and not self.batching:
3✔
2237
            warnings.warn(
3✔
2238
                "Note: you are running with more than 10,000 datapoints. "
2239
                "You should consider turning on batching (https://ai.damtp.cam.ac.uk/pysr/options/#batching). "
2240
                "You should also reconsider if you need that many datapoints. "
2241
                "Unless you have a large amount of noise (in which case you "
2242
                "should smooth your dataset first), generally < 10,000 datapoints "
2243
                "is enough to find a functional form with symbolic regression. "
2244
                "More datapoints will lower the search speed."
2245
            )
2246

2247
        random_state = check_random_state(self.random_state)  # For np random
3✔
2248
        seed = cast(int, random_state.randint(0, 2**31 - 1))  # For julia random
3✔
2249

2250
        # Pre transformations (feature selection and denoising)
2251
        X, y, variable_names, complexity_of_variables, X_units, y_units = (
3✔
2252
            self._pre_transform_training_data(
2253
                X,
2254
                y,
2255
                Xresampled,
2256
                variable_names,
2257
                complexity_of_variables,
2258
                X_units,
2259
                y_units,
2260
                random_state,
2261
            )
2262
        )
2263

2264
        # Warn about large feature counts (still warn if feature count is large
2265
        # after running feature selection)
2266
        if self.n_features_in_ >= 10:
3✔
2267
            warnings.warn(
3✔
2268
                "Note: you are running with 10 features or more. "
2269
                "Genetic algorithms like used in PySR scale poorly with large numbers of features. "
2270
                "You should run PySR for more `niterations` to ensure it can find "
2271
                "the correct variables, and consider using a larger `maxsize`."
2272
            )
2273

2274
        # Assertion checks
2275
        use_custom_variable_names = variable_names is not None
3✔
2276
        # TODO: this is always true.
2277

2278
        _check_assertions(
3✔
2279
            X,
2280
            use_custom_variable_names,
2281
            variable_names,
2282
            complexity_of_variables,
2283
            weights,
2284
            y,
2285
            X_units,
2286
            y_units,
2287
        )
2288

2289
        # Initially, just save model parameters, so that
2290
        # it can be loaded from an early exit:
2291
        if not self.temp_equation_file:
3✔
2292
            self._checkpoint()
3✔
2293

2294
        # Perform the search:
2295
        self._run(X, y, runtime_params, weights=weights, seed=seed, category=category)
3✔
2296

2297
        # Then, after fit, we save again, so the pickle file contains
2298
        # the equations:
2299
        if not self.temp_equation_file:
3✔
2300
            self._checkpoint()
3✔
2301

2302
        return self
3✔
2303

2304
    def refresh(self, run_directory: PathLike | None = None) -> None:
3✔
2305
        """
2306
        Update self.equations_ with any new options passed.
2307

2308
        For example, updating `extra_sympy_mappings`
2309
        will require a `.refresh()` to update the equations.
2310

2311
        Parameters
2312
        ----------
2313
        checkpoint_file : str or Path
2314
            Path to checkpoint hall of fame file to be loaded.
2315
            The default will use the set `equation_file_`.
2316
        """
2317
        if run_directory is not None:
3✔
2318
            self.output_directory_ = str(Path(run_directory).parent)
3✔
2319
            self.run_id_ = Path(run_directory).name
3✔
2320
            self._clear_equation_file_contents()
3✔
2321
        check_is_fitted(self, attributes=["run_id_", "output_directory_"])
3✔
2322
        self.equations_ = self.get_hof()
3✔
2323

2324
    def predict(
3✔
2325
        self,
2326
        X,
2327
        index: int | list[int] | None = None,
2328
        *,
2329
        category: ndarray | None = None,
2330
    ) -> ndarray:
2331
        """
2332
        Predict y from input X using the equation chosen by `model_selection`.
2333

2334
        You may see what equation is used by printing this object. X should
2335
        have the same columns as the training data.
2336

2337
        Parameters
2338
        ----------
2339
        X : ndarray | pandas.DataFrame
2340
            Training data of shape `(n_samples, n_features)`.
2341
        index : int | list[int]
2342
            If you want to compute the output of an expression using a
2343
            particular row of `self.equations_`, you may specify the index here.
2344
            For multiple output equations, you must pass a list of indices
2345
            in the same order.
2346
        category : ndarray | None
2347
            If `expression_spec` is a `ParametricExpressionSpec`, then this
2348
            argument should be a list of integers representing the category
2349
            of each sample in `X`.
2350

2351
        Returns
2352
        -------
2353
        y_predicted : ndarray of shape (n_samples, nout_)
2354
            Values predicted by substituting `X` into the fitted symbolic
2355
            regression model.
2356

2357
        Raises
2358
        ------
2359
        ValueError
2360
            Raises if the `best_equation` cannot be evaluated.
2361
        """
2362
        check_is_fitted(
3✔
2363
            self, attributes=["selection_mask_", "feature_names_in_", "nout_"]
2364
        )
2365
        best_equation = self.get_best(index=index)
3✔
2366

2367
        # When X is an numpy array or a pandas dataframe with a RangeIndex,
2368
        # the self.feature_names_in_ generated during fit, for the same X,
2369
        # will cause a warning to be thrown during _validate_data.
2370
        # To avoid this, convert X to a dataframe, apply the selection mask,
2371
        # and then set the column/feature_names of X to be equal to those
2372
        # generated during fit.
2373
        if not isinstance(X, pd.DataFrame):
3✔
2374
            X = check_array(X)
3✔
2375
            X = pd.DataFrame(X)
3✔
2376
        if isinstance(X.columns, pd.RangeIndex):
3✔
2377
            if self.selection_mask_ is not None:
3✔
2378
                # RangeIndex enforces column order allowing columns to
2379
                # be correctly filtered with self.selection_mask_
2380
                X = X[X.columns[self.selection_mask_]]
3✔
2381
            X.columns = self.feature_names_in_
3✔
2382
        # Without feature information, CallableEquation/lambda_format equations
2383
        # require that the column order of X matches that of the X used during
2384
        # the fitting process. _validate_data removes this feature information
2385
        # when it converts the dataframe to an np array. Thus, to ensure feature
2386
        # order is preserved after conversion, the dataframe columns must be
2387
        # reordered/reindexed to match those of the transformed (denoised and
2388
        # feature selected) X in fit.
2389
        X = X.reindex(columns=self.feature_names_in_)
3✔
2390
        X = self._validate_data_X(X)
3✔
2391
        if self.expression_spec_.evaluates_in_julia:
3✔
2392
            # Julia wants the right dtype
2393
            X = X.astype(self._get_precision_mapped_dtype(X))
3✔
2394

2395
        if category is not None:
3✔
2396
            offset_for_julia_indexing = 1
3✔
2397
            args: tuple = (
3✔
2398
                jl_array((category + offset_for_julia_indexing).astype(np.int64)),
2399
            )
2400
        else:
2401
            args = ()
3✔
2402

2403
        try:
3✔
2404
            if isinstance(best_equation, list):
3✔
2405
                assert self.nout_ > 1
3✔
2406
                return np.stack(
3✔
2407
                    [
2408
                        cast(ndarray, eq["lambda_format"](X, *args))
2409
                        for eq in best_equation
2410
                    ],
2411
                    axis=1,
2412
                )
2413
            else:
2414
                return cast(ndarray, best_equation["lambda_format"](X, *args))
3✔
2415
        except Exception as error:
×
2416
            raise ValueError(
×
2417
                "Failed to evaluate the expression. "
2418
                "If you are using a custom operator, make sure to define it in `extra_sympy_mappings`, "
2419
                "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
2420
                "`lambda x: 1/x` is a valid SymPy function defining the operator. "
2421
                "You can then run `model.refresh()` to re-load the expressions."
2422
            ) from error
2423

2424
    def sympy(self, index: int | list[int] | None = None):
3✔
2425
        """
2426
        Return sympy representation of the equation(s) chosen by `model_selection`.
2427

2428
        Parameters
2429
        ----------
2430
        index : int | list[int]
2431
            If you wish to select a particular equation from
2432
            `self.equations_`, give the index number here. This overrides
2433
            the `model_selection` parameter. If there are multiple output
2434
            features, then pass a list of indices with the order the same
2435
            as the output feature.
2436

2437
        Returns
2438
        -------
2439
        best_equation : str, list[str] of length nout_
2440
            SymPy representation of the best equation.
2441
        """
2442
        if not self.expression_spec_.supports_sympy:
3✔
2443
            raise ValueError(
3✔
2444
                f"`expression_spec={self.expression_spec_}` does not support sympy export."
2445
            )
2446
        self.refresh()
3✔
2447
        best_equation = self.get_best(index=index)
3✔
2448
        if isinstance(best_equation, list):
3✔
2449
            assert self.nout_ > 1
3✔
2450
            return [eq["sympy_format"] for eq in best_equation]
3✔
2451
        else:
2452
            return best_equation["sympy_format"]
3✔
2453

2454
    def latex(
3✔
2455
        self, index: int | list[int] | None = None, precision: int = 3
2456
    ) -> str | list[str]:
2457
        """
2458
        Return latex representation of the equation(s) chosen by `model_selection`.
2459

2460
        Parameters
2461
        ----------
2462
        index : int | list[int]
2463
            If you wish to select a particular equation from
2464
            `self.equations_`, give the index number here. This overrides
2465
            the `model_selection` parameter. If there are multiple output
2466
            features, then pass a list of indices with the order the same
2467
            as the output feature.
2468
        precision : int
2469
            The number of significant figures shown in the LaTeX
2470
            representation.
2471
            Default is `3`.
2472

2473
        Returns
2474
        -------
2475
        best_equation : str or list[str] of length nout_
2476
            LaTeX expression of the best equation.
2477
        """
2478
        if not self.expression_spec_.supports_latex:
3✔
2479
            raise ValueError(
3✔
2480
                f"`expression_spec={self.expression_spec_}` does not support latex export."
2481
            )
2482
        self.refresh()
3✔
2483
        sympy_representation = self.sympy(index=index)
3✔
2484
        if self.nout_ > 1:
3✔
2485
            output = []
3✔
2486
            for s in sympy_representation:
3✔
2487
                latex = sympy2latex(s, prec=precision)
3✔
2488
                output.append(latex)
3✔
2489
            return output
3✔
2490
        return sympy2latex(sympy_representation, prec=precision)
3✔
2491

2492
    def jax(self, index=None):
3✔
2493
        """
2494
        Return jax representation of the equation(s) chosen by `model_selection`.
2495

2496
        Each equation (multiple given if there are multiple outputs) is a dictionary
2497
        containing {"callable": func, "parameters": params}. To call `func`, pass
2498
        func(X, params). This function is differentiable using `jax.grad`.
2499

2500
        Parameters
2501
        ----------
2502
        index : int | list[int]
2503
            If you wish to select a particular equation from
2504
            `self.equations_`, give the index number here. This overrides
2505
            the `model_selection` parameter. If there are multiple output
2506
            features, then pass a list of indices with the order the same
2507
            as the output feature.
2508

2509
        Returns
2510
        -------
2511
        best_equation : dict[str, Any]
2512
            Dictionary of callable jax function in "callable" key,
2513
            and jax array of parameters as "parameters" key.
2514
        """
2515
        if not self.expression_spec_.supports_jax:
3✔
2516
            raise ValueError(
3✔
2517
                f"`expression_spec={self.expression_spec_}` does not support jax export."
2518
            )
2519
        self.set_params(output_jax_format=True)
1✔
2520
        self.refresh()
1✔
2521
        best_equation = self.get_best(index=index)
1✔
2522
        if isinstance(best_equation, list):
1✔
2523
            assert self.nout_ > 1
×
2524
            return [eq["jax_format"] for eq in best_equation]
×
2525
        else:
2526
            return best_equation["jax_format"]
1✔
2527

2528
    def pytorch(self, index=None):
3✔
2529
        """
2530
        Return pytorch representation of the equation(s) chosen by `model_selection`.
2531

2532
        Each equation (multiple given if there are multiple outputs) is a PyTorch module
2533
        containing the parameters as trainable attributes. You can use the module like
2534
        any other PyTorch module: `module(X)`, where `X` is a tensor with the same
2535
        column ordering as trained with.
2536

2537
        Parameters
2538
        ----------
2539
        index : int | list[int]
2540
            If you wish to select a particular equation from
2541
            `self.equations_`, give the index number here. This overrides
2542
            the `model_selection` parameter. If there are multiple output
2543
            features, then pass a list of indices with the order the same
2544
            as the output feature.
2545

2546
        Returns
2547
        -------
2548
        best_equation : torch.nn.Module
2549
            PyTorch module representing the expression.
2550
        """
2551
        if not self.expression_spec_.supports_torch:
3✔
2552
            raise ValueError(
3✔
2553
                f"`expression_spec={self.expression_spec_}` does not support torch export."
2554
            )
2555
        self.set_params(output_torch_format=True)
1✔
2556
        self.refresh()
1✔
2557
        best_equation = self.get_best(index=index)
1✔
2558
        if isinstance(best_equation, list):
1✔
2559
            return [eq["torch_format"] for eq in best_equation]
×
2560
        else:
2561
            return best_equation["torch_format"]
1✔
2562

2563
    def get_equation_file(self, i: int | None = None) -> Path:
3✔
2564
        if i is not None:
3✔
2565
            return (
3✔
2566
                Path(self.output_directory_)
2567
                / self.run_id_
2568
                / f"hall_of_fame_output{i}.csv"
2569
            )
2570
        else:
2571
            return Path(self.output_directory_) / self.run_id_ / "hall_of_fame.csv"
3✔
2572

2573
    def _read_equation_file(self) -> list[pd.DataFrame]:
3✔
2574
        """Read the hall of fame file created by `SymbolicRegression.jl`."""
2575

2576
        try:
3✔
2577
            if self.nout_ > 1:
3✔
2578
                all_outputs = []
3✔
2579
                for i in range(1, self.nout_ + 1):
3✔
2580
                    cur_filename = str(self.get_equation_file(i)) + ".bak"
3✔
2581
                    if not os.path.exists(cur_filename):
3✔
2582
                        cur_filename = str(self.get_equation_file(i))
×
2583
                    with open(cur_filename, "r", encoding="utf-8") as f:
3✔
2584
                        buf = f.read()
3✔
2585
                    buf = _preprocess_julia_floats(buf)
3✔
2586
                    df = self._postprocess_dataframe(pd.read_csv(StringIO(buf)))
3✔
2587
                    all_outputs.append(df)
3✔
2588
            else:
2589
                filename = str(self.get_equation_file()) + ".bak"
3✔
2590
                if not os.path.exists(filename):
3✔
2591
                    filename = str(self.get_equation_file())
3✔
2592
                with open(filename, "r", encoding="utf-8") as f:
3✔
2593
                    buf = f.read()
3✔
2594
                buf = _preprocess_julia_floats(buf)
3✔
2595
                all_outputs = [self._postprocess_dataframe(pd.read_csv(StringIO(buf)))]
3✔
2596

2597
        except FileNotFoundError:
×
2598
            raise RuntimeError(
×
2599
                "Couldn't find equation file! The equation search likely exited "
2600
                "before a single iteration completed."
2601
            )
2602
        return all_outputs
3✔
2603

2604
    def _postprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
3✔
2605
        df = df.rename(
3✔
2606
            columns={
2607
                "Complexity": "complexity",
2608
                "Loss": "loss",
2609
                "Equation": "equation",
2610
            },
2611
        )
2612

2613
        return df
3✔
2614

2615
    def get_hof(self, search_output=None) -> pd.DataFrame | list[pd.DataFrame]:
3✔
2616
        """Get the equations from a hall of fame file or search output.
2617

2618
        If no arguments entered, the ones used
2619
        previously from a call to PySR will be used.
2620
        """
2621
        check_is_fitted(
3✔
2622
            self,
2623
            attributes=[
2624
                "nout_",
2625
                "run_id_",
2626
                "output_directory_",
2627
                "selection_mask_",
2628
                "feature_names_in_",
2629
            ],
2630
        )
2631
        should_read_from_file = (
3✔
2632
            not hasattr(self, "equation_file_contents_")
2633
            or self.equation_file_contents_ is None
2634
        )
2635
        if should_read_from_file:
3✔
2636
            self.equation_file_contents_ = self._read_equation_file()
3✔
2637

2638
        _validate_export_mappings(self.extra_jax_mappings, self.extra_torch_mappings)
3✔
2639

2640
        equation_file_contents = cast(list[pd.DataFrame], self.equation_file_contents_)
3✔
2641

2642
        ret_outputs = [
3✔
2643
            pd.concat(
2644
                [
2645
                    output,
2646
                    calculate_scores(output),
2647
                    self.expression_spec_.create_exports(self, output, search_output),
2648
                ],
2649
                axis=1,
2650
            )
2651
            for output in equation_file_contents
2652
        ]
2653

2654
        if self.nout_ > 1:
3✔
2655
            return ret_outputs
3✔
2656
        return ret_outputs[0]
3✔
2657

2658
    def latex_table(
3✔
2659
        self,
2660
        indices: list[int] | None = None,
2661
        precision: int = 3,
2662
        columns: list[str] = ["equation", "complexity", "loss", "score"],
2663
    ) -> str:
2664
        """Create a LaTeX/booktabs table for all, or some, of the equations.
2665

2666
        Parameters
2667
        ----------
2668
        indices : list[int] | list[list[int]]
2669
            If you wish to select a particular subset of equations from
2670
            `self.equations_`, give the row numbers here. By default,
2671
            all equations will be used. If there are multiple output
2672
            features, then pass a list of lists.
2673
        precision : int
2674
            The number of significant figures shown in the LaTeX
2675
            representations.
2676
            Default is `3`.
2677
        columns : list[str]
2678
            Which columns to include in the table.
2679
            Default is `["equation", "complexity", "loss", "score"]`.
2680

2681
        Returns
2682
        -------
2683
        latex_table_str : str
2684
            A string that will render a table in LaTeX of the equations.
2685
        """
2686
        if not self.expression_spec_.supports_latex:
3✔
2687
            raise ValueError(
3✔
2688
                f"`expression_spec={self.expression_spec_}` does not support latex export."
2689
            )
2690
        self.refresh()
3✔
2691

2692
        if isinstance(self.equations_, list):
3✔
2693
            if indices is not None:
3✔
2694
                assert isinstance(indices, list)
×
2695
                assert isinstance(indices[0], list)
×
2696
                assert len(indices) == self.nout_
×
2697

2698
            table_string = sympy2multilatextable(
3✔
2699
                self.equations_, indices=indices, precision=precision, columns=columns
2700
            )
2701
        elif isinstance(self.equations_, pd.DataFrame):
3✔
2702
            if indices is not None:
3✔
2703
                assert isinstance(indices, list)
3✔
2704
                assert isinstance(indices[0], int)
3✔
2705

2706
            table_string = sympy2latextable(
3✔
2707
                self.equations_, indices=indices, precision=precision, columns=columns
2708
            )
2709
        else:
2710
            raise ValueError(
×
2711
                "Invalid type for equations_ to pass to `latex_table`. "
2712
                "Expected a DataFrame or a list of DataFrames."
2713
            )
2714

2715
        return with_preamble(table_string)
3✔
2716

2717

2718
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
3✔
2719
    """Select an expression and return its index."""
2720
    if model_selection == "accuracy":
3✔
2721
        chosen_idx = equations["loss"].idxmin()
3✔
2722
    elif model_selection == "best":
3✔
2723
        threshold = 1.5 * equations["loss"].min()
3✔
2724
        filtered_equations = equations.query(f"loss <= {threshold}")
3✔
2725
        chosen_idx = filtered_equations["score"].idxmax()
3✔
2726
    elif model_selection == "score":
3✔
2727
        chosen_idx = equations["score"].idxmax()
3✔
2728
    else:
2729
        raise NotImplementedError(
3✔
2730
            f"{model_selection} is not a valid model selection strategy."
2731
        )
2732
    return chosen_idx
3✔
2733

2734

2735
def calculate_scores(df: pd.DataFrame) -> pd.DataFrame:
3✔
2736
    """Calculate scores for each equation based on loss and complexity.
2737

2738
    Score is defined as the negated derivative of the log-loss with respect to complexity.
2739
    A higher score means the equation achieved a much better loss at a slightly higher complexity.
2740
    """
2741
    scores = []
3✔
2742
    lastMSE = None
3✔
2743
    lastComplexity = 0
3✔
2744

2745
    for _, row in df.iterrows():
3✔
2746
        curMSE = row["loss"]
3✔
2747
        curComplexity = row["complexity"]
3✔
2748

2749
        if lastMSE is None:
3✔
2750
            cur_score = 0.0
3✔
2751
        else:
2752
            if curMSE > 0.0:
3✔
2753
                cur_score = -np.log(curMSE / lastMSE) / (curComplexity - lastComplexity)
3✔
2754
            else:
2755
                cur_score = np.inf
3✔
2756

2757
        scores.append(cur_score)
3✔
2758
        lastMSE = curMSE
3✔
2759
        lastComplexity = curComplexity
3✔
2760

2761
    return pd.DataFrame(
3✔
2762
        {
2763
            "score": np.array(scores),
2764
        },
2765
        index=df.index,
2766
    )
2767

2768

2769
def _mutate_parameter(param_name: str, param_value):
3✔
2770
    if param_name == "batch_size" and param_value < 1:
3✔
2771
        warnings.warn(
×
2772
            "Given `batch_size` must be greater than or equal to one. "
2773
            "`batch_size` has been increased to equal one."
2774
        )
2775
        return 1
×
2776

2777
    if (
3✔
2778
        param_name == "progress"
2779
        and param_value == True
2780
        and "buffer" not in sys.stdout.__dir__()
2781
    ):
2782
        warnings.warn(
×
2783
            "Note: it looks like you are running in Jupyter. "
2784
            "The progress bar will be turned off."
2785
        )
2786
        return False
×
2787

2788
    return param_value
3✔
2789

2790

2791
def _map_parallelism_params(
3✔
2792
    parallelism: Literal["serial", "multithreading", "multiprocessing"] | None,
2793
    procs: int | None,
2794
    multithreading: bool | None,
2795
) -> tuple[Literal["serial", "multithreading", "multiprocessing"], int | None]:
2796
    """Map old and new parallelism parameters to the new format.
2797

2798
    Parameters
2799
    ----------
2800
    parallelism : str or None
2801
        New parallelism parameter. Can be "serial", "multithreading", or "multiprocessing".
2802
    procs : int or None
2803
        Number of processes parameter.
2804
    multithreading : bool or None
2805
        Old multithreading parameter.
2806

2807
    Returns
2808
    -------
2809
    parallelism : str
2810
        Mapped parallelism mode.
2811
    procs : int or None
2812
        Mapped number of processes.
2813

2814
    Raises
2815
    ------
2816
    ValueError
2817
        If both old and new parameters are specified, or if invalid combinations are given.
2818
    """
2819
    # Check for mixing old and new parameters
2820
    using_new = parallelism is not None
3✔
2821
    using_old = multithreading is not None
3✔
2822

2823
    if using_new and using_old:
3✔
2824
        raise ValueError(
×
2825
            "Cannot mix old and new parallelism parameters. "
2826
            "Use either `parallelism` and `numprocs`, or `procs` and `multithreading`."
2827
        )
2828
    elif using_old:
3✔
2829
        warnings.warn(
3✔
2830
            "The `multithreading: bool` parameter has been deprecated in favor "
2831
            "of `parallelism: Literal['multithreading', 'serial', 'multiprocessing']`.\n"
2832
            "Previous usage of `multithreading=True` (default) is now `parallelism='multithreading'`; "
2833
            "`multithreading=False, procs=0` is now `parallelism='serial'`; and "
2834
            "`multithreading=True, procs={int}` is now `parallelism='multiprocessing', procs={int}`."
2835
        )
2836
        if multithreading:
3✔
2837
            _parallelism: Literal["multithreading", "multiprocessing", "serial"] = (
×
2838
                "multithreading"
2839
            )
2840
            _procs = None
×
2841
        elif procs is not None and procs > 0:
3✔
2842
            _parallelism = "multiprocessing"
×
2843
            _procs = procs
×
2844
        else:
2845
            _parallelism = "serial"
3✔
2846
            _procs = None
3✔
2847
    elif using_new:
3✔
2848
        _parallelism = cast(
3✔
2849
            Literal["serial", "multithreading", "multiprocessing"], parallelism
2850
        )
2851
        _procs = procs
3✔
2852
    else:
2853
        _parallelism = "multithreading"
3✔
2854
        _procs = None
3✔
2855

2856
    if _parallelism not in {"serial", "multithreading", "multiprocessing"}:
3✔
2857
        raise ValueError(
×
2858
            "`parallelism` must be one of 'serial', 'multithreading', or 'multiprocessing'"
2859
        )
2860
    elif _parallelism == "serial" and _procs is not None:
3✔
2861
        warnings.warn(
×
2862
            "`numprocs` is specified but will be ignored since `parallelism='serial'`"
2863
        )
2864
        _procs = None
×
2865
    elif parallelism == "multithreading" and _procs is not None:
3✔
2866
        warnings.warn(
×
2867
            "`numprocs` is specified but will be ignored since `parallelism='multithreading'`"
2868
        )
2869
        _procs = None
×
2870
    elif parallelism == "multiprocessing" and _procs is None:
3✔
2871
        _procs = cpu_count()
×
2872

2873
    return _parallelism, _procs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc