• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.37
/src/Options.jl
1
module OptionsModule
2

3
using DispatchDoctor: @unstable
4
using Optim: Optim
5
using Dates: Dates
6
using StatsBase: StatsBase
7
using DynamicExpressions: OperatorEnum, Expression, default_node_type
8
using ADTypes: AbstractADType, ADTypes
9
using LossFunctions: L2DistLoss, SupervisedLoss
10
using Optim: Optim
11
using LineSearches: LineSearches
12
#TODO - eventually move some of these
13
# into the SR call itself, rather than
14
# passing huge options at once.
15
using ..OperatorsModule:
16
    plus,
17
    pow,
18
    safe_pow,
19
    mult,
20
    sub,
21
    safe_log,
22
    safe_log10,
23
    safe_log2,
24
    safe_log1p,
25
    safe_sqrt,
26
    safe_acosh,
27
    atanh_clip
28
using ..MutationWeightsModule: MutationWeights, mutations
29
import ..OptionsStructModule: Options
30
using ..OptionsStructModule: ComplexityMapping, operator_specialization
31
using ..UtilsModule: max_ops, @save_kwargs, @ignore
32

33
"""Build constraints on operator-level complexity from a user-passed dict."""
34
@unstable function build_constraints(;
18,137✔
35
    una_constraints,
36
    bin_constraints,
37
    @nospecialize(unary_operators),
38
    @nospecialize(binary_operators)
39
)::Tuple{Vector{Int},Vector{Tuple{Int,Int}}}
40
    # Expect format ((*)=>(-1, 3)), etc.
41
    # TODO: Need to disable simplification if (*, -, +, /) are constrained?
42
    #  Or, just quit simplification is constraints violated.
43

44
    is_una_constraints_already_done = una_constraints isa Vector{Int}
9,882✔
45
    _una_constraints1 = if una_constraints isa Array && !is_una_constraints_already_done
8,235✔
46
        Dict(una_constraints)
24✔
47
    else
48
        una_constraints
8,219✔
49
    end
50
    _una_constraints2 = if _una_constraints1 === nothing
8,235✔
51
        fill(-1, length(unary_operators))
9,868✔
52
    elseif !is_una_constraints_already_done
24✔
53
        [
24✔
54
            haskey(_una_constraints1, op) ? _una_constraints1[op]::Int : -1 for
55
            op in unary_operators
56
        ]
57
    else
58
        _una_constraints1
1,639✔
59
    end
60

61
    is_bin_constraints_already_done = bin_constraints isa Vector{Tuple{Int,Int}}
8,235✔
62
    _bin_constraints1 = if bin_constraints isa Array && !is_bin_constraints_already_done
8,235✔
63
        Dict(bin_constraints)
24✔
64
    else
65
        bin_constraints
8,219✔
66
    end
67
    _bin_constraints2 = if _bin_constraints1 === nothing
8,235✔
68
        fill((-1, -1), length(binary_operators))
9,868✔
69
    elseif !is_bin_constraints_already_done
24✔
70
        [
24✔
71
            if haskey(_bin_constraints1, op)
72
                _bin_constraints1[op]::Tuple{Int,Int}
12✔
73
            else
74
                (-1, -1)
72✔
75
            end for op in binary_operators
76
        ]
77
    else
78
        _bin_constraints1
1,639✔
79
    end
80

81
    return _una_constraints2, _bin_constraints2
9,892✔
82
end
83

84
@unstable function build_nested_constraints(;
13,412✔
85
    @nospecialize(binary_operators), @nospecialize(unary_operators), nested_constraints
86
)
87
    nested_constraints === nothing && return nested_constraints
9,883✔
88
    # Check that intersection of binary operators and unary operators is empty:
89
    for op in binary_operators
36✔
90
        if op ∈ unary_operators
132✔
NEW
91
            error(
×
92
                "Operator $(op) is both a binary and unary operator. " *
93
                "You can't use nested constraints.",
94
            )
95
        end
96
    end
110✔
97

98
    # Convert to dict:
99
    _nested_constraints = if nested_constraints isa Dict
36✔
NEW
100
        nested_constraints
×
101
    else
102
        # Convert to dict:
103
        nested_constraints = Dict(
36✔
104
            [cons[1] => Dict(cons[2]...) for cons in nested_constraints]...
105
        )
106
    end
107
    for (op, nested_constraint) in _nested_constraints
123✔
108
        if !(op ∈ binary_operators || op ∈ unary_operators)
78✔
NEW
109
            error("Operator $(op) is not in the operator set.")
×
110
        end
111
        for (nested_op, max_nesting) in nested_constraint
42✔
112
            if !(nested_op ∈ binary_operators || nested_op ∈ unary_operators)
78✔
NEW
113
                error("Operator $(nested_op) is not in the operator set.")
×
114
            end
115
            @assert nested_op ∈ binary_operators || nested_op ∈ unary_operators
78✔
116
            @assert max_nesting >= -1 && typeof(max_nesting) <: Int
42✔
117
        end
35✔
118
    end
199✔
119

120
    # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting.
121
    return [
36✔
122
        let (degree, idx) = if op ∈ binary_operators
123
                2, findfirst(isequal(op), binary_operators)::Int
6✔
124
            else
125
                1, findfirst(isequal(op), unary_operators)::Int
58✔
126
            end,
127
            new_max_nesting_dict = [
128
                let (nested_degree, nested_idx) = if nested_op ∈ binary_operators
129
                        2, findfirst(isequal(nested_op), binary_operators)::Int
6✔
130
                    else
131
                        1, findfirst(isequal(nested_op), unary_operators)::Int
42✔
132
                    end
133
                    (nested_degree, nested_idx, max_nesting)
42✔
134
                end for (nested_op, max_nesting) in nested_constraint
135
            ]
136

137
            (degree, idx, new_max_nesting_dict)
42✔
138
        end for (op, nested_constraint) in _nested_constraints
139
    ]
140
end
141

142
function binopmap(op::F) where {F}
32,450✔
143
    if op == plus
33,966✔
144
        return +
18✔
145
    elseif op == mult
33,948✔
146
        return *
18✔
147
    elseif op == sub
33,930✔
148
        return -
7,529✔
149
    elseif op == div
26,401✔
150
        return /
×
151
    elseif op == ^
26,401✔
152
        return safe_pow
6,164✔
153
    elseif op == pow
20,237✔
154
        return safe_pow
×
155
    end
156
    return op
20,237✔
157
end
158
function inverse_binopmap(op::F) where {F}
602✔
159
    if op == safe_pow
734✔
160
        return ^
104✔
161
    end
162
    return op
630✔
163
end
164

165
function unaopmap(op::F) where {F}
13,270✔
166
    if op == log
14,416✔
167
        return safe_log
76✔
168
    elseif op == log10
14,340✔
169
        return safe_log10
18✔
170
    elseif op == log2
14,322✔
171
        return safe_log2
18✔
172
    elseif op == log1p
14,304✔
173
        return safe_log1p
18✔
174
    elseif op == sqrt
14,286✔
175
        return safe_sqrt
120✔
176
    elseif op == acosh
14,166✔
177
        return safe_acosh
18✔
178
    elseif op == atanh
14,148✔
179
        return atanh_clip
12✔
180
    end
181
    return op
14,136✔
182
end
183
function inverse_unaopmap(op::F) where {F}
353✔
184
    if op == safe_log
419✔
185
        return log
8✔
186
    elseif op == safe_log10
411✔
187
        return log10
8✔
188
    elseif op == safe_log2
403✔
189
        return log2
8✔
190
    elseif op == safe_log1p
395✔
191
        return log1p
×
192
    elseif op == safe_sqrt
395✔
193
        return sqrt
41✔
194
    elseif op == safe_acosh
354✔
195
        return acosh
8✔
196
    elseif op == atanh_clip
346✔
197
        return atanh
×
198
    end
199
    return op
346✔
200
end
201

202
create_mutation_weights(w::MutationWeights) = w
8,235✔
203
create_mutation_weights(w::NamedTuple) = MutationWeights(; w...)
6✔
204

205
const deprecated_options_mapping = Base.ImmutableDict(
206
    :mutationWeights => :mutation_weights,
207
    :hofMigration => :hof_migration,
208
    :shouldOptimizeConstants => :should_optimize_constants,
209
    :hofFile => :output_file,
210
    :perturbationFactor => :perturbation_factor,
211
    :batchSize => :batch_size,
212
    :crossoverProbability => :crossover_probability,
213
    :warmupMaxsizeBy => :warmup_maxsize_by,
214
    :useFrequency => :use_frequency,
215
    :useFrequencyInTournament => :use_frequency_in_tournament,
216
    :ncyclesperiteration => :ncycles_per_iteration,
217
    :fractionReplaced => :fraction_replaced,
218
    :fractionReplacedHof => :fraction_replaced_hof,
219
    :probNegate => :probability_negate_constant,
220
    :optimize_probability => :optimizer_probability,
221
    :probPickFirst => :tournament_selection_p,
222
    :earlyStopCondition => :early_stop_condition,
223
    :stateReturn => :deprecated_return_state,
224
    :return_state => :deprecated_return_state,
225
    :enable_autodiff => :deprecated_enable_autodiff,
226
    :ns => :tournament_selection_n,
227
    :loss => :elementwise_loss,
228
)
229

230
# For static analysis tools:
231
@ignore const DEFAULT_OPTIONS = ()
232

233
const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators (functions) to use.
234
    Each operator should be defined for two input scalars,
235
    and one output scalar. All operators
236
    need to be defined over the entire real line (excluding infinity - these
237
    are stopped before they are input), or return `NaN` where not defined.
238
    For speed, define it so it takes two reals
239
    of the same type as input, and outputs the same type. For the SymbolicUtils
240
    simplification backend, you will need to define a generic method of the
241
    operator so it takes arbitrary types.
242
- `unary_operators`: Same, but for
243
    unary operators (one input scalar, gives an output scalar).
244
- `constraints`: Array of pairs specifying size constraints
245
    for each operator. The constraints for a binary operator should be a 2-tuple
246
    (e.g., `(-1, -1)`) and the constraints for a unary operator should be an `Int`.
247
    A size constraint is a limit to the size of the subtree
248
    in each argument of an operator. e.g., `[(^)=>(-1, 3)]` means that the
249
    `^` operator can have arbitrary size (`-1`) in its left argument,
250
    but a maximum size of `3` in its right argument. Default is
251
    no constraints.
252
- `batching`: Whether to evolve based on small mini-batches of data,
253
    rather than the entire dataset.
254
- `batch_size`: What batch size to use if using batching.
255
- `elementwise_loss`: What elementwise loss function to use. Can be one of
256
    the following losses, or any other loss of type
257
    `SupervisedLoss`. You can also pass a function that takes
258
    a scalar target (left argument), and scalar predicted (right
259
    argument), and returns a scalar. This will be averaged
260
    over the predicted data. If weights are supplied, your
261
    function should take a third argument for the weight scalar.
262
    Included losses:
263
        Regression:
264
            - `LPDistLoss{P}()`,
265
            - `L1DistLoss()`,
266
            - `L2DistLoss()` (mean square),
267
            - `LogitDistLoss()`,
268
            - `HuberLoss(d)`,
269
            - `L1EpsilonInsLoss(ϵ)`,
270
            - `L2EpsilonInsLoss(ϵ)`,
271
            - `PeriodicLoss(c)`,
272
            - `QuantileLoss(Ï„)`,
273
        Classification:
274
            - `ZeroOneLoss()`,
275
            - `PerceptronLoss()`,
276
            - `L1HingeLoss()`,
277
            - `SmoothedL1HingeLoss(γ)`,
278
            - `ModifiedHuberLoss()`,
279
            - `L2MarginLoss()`,
280
            - `ExpLoss()`,
281
            - `SigmoidLoss()`,
282
            - `DWDMarginLoss(q)`.
283
- `loss_function`: Alternatively, you may redefine the loss used
284
    as any function of `tree::AbstractExpressionNode{T}`, `dataset::Dataset{T}`,
285
    and `options::Options`, so long as you output a non-negative
286
    scalar of type `T`. This is useful if you want to use a loss
287
    that takes into account derivatives, or correlations across
288
    the dataset. This also means you could use a custom evaluation
289
    for a particular expression. If you are using
290
    `batching=true`, then your function should
291
    accept a fourth argument `idx`, which is either `nothing`
292
    (indicating that the full dataset should be used), or a vector
293
    of indices to use for the batch.
294
    For example,
295

296
        function my_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
297
            prediction, flag = eval_tree_array(tree, dataset.X, options)
298
            if !flag
299
                return L(Inf)
300
            end
301
            return sum((prediction .- dataset.y) .^ 2) / dataset.n
302
        end
303

304
- `expression_type::Type{E}=Expression`: The type of expression to use.
305
    For example, `Expression`.
306
- `node_type::Type{N}=default_node_type(Expression)`: The type of node to use for the search.
307
    For example, `Node` or `GraphNode`. The default is computed by `default_node_type(expression_type)`.
308
- `populations`: How many populations of equations to use.
309
- `population_size`: How many equations in each population.
310
- `ncycles_per_iteration`: How many generations to consider per iteration.
311
- `tournament_selection_n`: Number of expressions considered in each tournament.
312
- `tournament_selection_p`: The fittest expression in a tournament is to be
313
    selected with probability `p`, the next fittest with probability `p*(1-p)`,
314
    and so forth.
315
- `topn`: Number of equations to return to the host process, and to
316
    consider for the hall of fame.
317
- `complexity_of_operators`: What complexity should be assigned to each operator,
318
    and the occurrence of a constant or variable. By default, this is 1
319
    for all operators. Can be a real number as well, in which case
320
    the complexity of an expression will be rounded to the nearest integer.
321
    Input this in the form of, e.g., [(^) => 3, sin => 2].
322
- `complexity_of_constants`: What complexity should be assigned to use of a constant.
323
    By default, this is 1.
324
- `complexity_of_variables`: What complexity should be assigned to use of a variable,
325
    which can also be a vector indicating different per-variable complexity.
326
    By default, this is 1.
327
- `alpha`: The probability of accepting an equation mutation
328
    during regularized evolution is given by exp(-delta_loss/(alpha * T)),
329
    where T goes from 1 to 0. Thus, alpha=infinite is the same as no annealing.
330
- `maxsize`: Maximum size of equations during the search.
331
- `maxdepth`: Maximum depth of equations during the search, by default
332
    this is set equal to the maxsize.
333
- `parsimony`: A multiplicative factor for how much complexity is
334
    punished.
335
- `dimensional_constraint_penalty`: An additive factor if the dimensional
336
    constraint is violated.
337
- `dimensionless_constants_only`: Whether to only allow dimensionless
338
    constants.
339
- `use_frequency`: Whether to use a parsimony that adapts to the
340
    relative proportion of equations at each complexity; this will
341
    ensure that there are a balanced number of equations considered
342
    for every complexity.
343
- `use_frequency_in_tournament`: Whether to use the adaptive parsimony described
344
    above inside the score, rather than just at the mutation accept/reject stage.
345
- `adaptive_parsimony_scaling`: How much to scale the adaptive parsimony term
346
    in the loss. Increase this if the search is spending too much time
347
    optimizing the most complex equations.
348
- `turbo`: Whether to use `LoopVectorization.@turbo` to evaluate expressions.
349
    This can be significantly faster, but is only compatible with certain
350
    operators. *Experimental!*
351
- `bumper`: Whether to use Bumper.jl for faster evaluation. *Experimental!*
352
- `migration`: Whether to migrate equations between processes.
353
- `hof_migration`: Whether to migrate equations from the hall of fame
354
    to processes.
355
- `fraction_replaced`: What fraction of each population to replace with
356
    migrated equations at the end of each cycle.
357
- `fraction_replaced_hof`: What fraction to replace with hall of fame
358
    equations at the end of each cycle.
359
- `should_simplify`: Whether to simplify equations. If you
360
    pass a custom objective, this will be set to `false`.
361
- `should_optimize_constants`: Whether to use an optimization algorithm
362
    to periodically optimize constants in equations.
363
- `optimizer_algorithm`: Select algorithm to use for optimizing constants. Default
364
    is `Optim.BFGS(linesearch=LineSearches.BackTracking())`.
365
- `optimizer_nrestarts`: How many different random starting positions to consider
366
    for optimization of constants.
367
- `optimizer_probability`: Probability of performing optimization of constants at
368
    the end of a given iteration.
369
- `optimizer_iterations`: How many optimization iterations to perform. This gets
370
    passed to `Optim.Options` as `iterations`. The default is 8.
371
- `optimizer_f_calls_limit`: How many function calls to allow during optimization.
372
    This gets passed to `Optim.Options` as `f_calls_limit`. The default is
373
    `10_000`.
374
- `optimizer_options`: General options for the constant optimization. For details
375
    we refer to the documentation on `Optim.Options` from the `Optim.jl` package.
376
    Options can be provided here as `NamedTuple`, e.g. `(iterations=16,)`, as a
377
    `Dict`, e.g. Dict(:x_tol => 1.0e-32,), or as an `Optim.Options` instance.
378
- `autodiff_backend`: The backend to use for differentiation, which should be
379
    an instance of `AbstractADType` (see `DifferentiationInterface.jl`).
380
    Default is `nothing`, which means `Optim.jl` will estimate gradients (likely
381
    with finite differences). You can also pass a symbolic version of the backend
382
    type, such as `:Zygote` for Zygote, `:Enzyme`, etc. Most backends will not
383
    work, and many will never work due to incompatibilities, though support for some
384
    is gradually being added.
385
- `output_file`: What file to store equations to, as a backup.
386
- `perturbation_factor`: When mutating a constant, either
387
    multiply or divide by (1+perturbation_factor)^(rand()+1).
388
- `probability_negate_constant`: Probability of negating a constant in the equation
389
    when mutating it.
390
- `mutation_weights`: Relative probabilities of the mutations. The struct
391
    `MutationWeights` should be passed to these options.
392
    See its documentation on `MutationWeights` for the different weights.
393
- `crossover_probability`: Probability of performing crossover.
394
- `annealing`: Whether to use simulated annealing.
395
- `warmup_maxsize_by`: Whether to slowly increase the max size from 5 up to
396
    `maxsize`. If nonzero, specifies the fraction through the search
397
    at which the maxsize should be reached.
398
- `verbosity`: Whether to print debugging statements or
399
    not.
400
- `print_precision`: How many digits to print when printing
401
    equations. By default, this is 5.
402
- `save_to_file`: Whether to save equations to a file during the search.
403
- `bin_constraints`: See `constraints`. This is the same, but specified for binary
404
    operators only (for example, if you have an operator that is both a binary
405
    and unary operator).
406
- `una_constraints`: Likewise, for unary operators.
407
- `seed`: What random seed to use. `nothing` uses no seed.
408
- `progress`: Whether to use a progress bar output (`verbosity` will
409
    have no effect).
410
- `early_stop_condition`: Float - whether to stop early if the mean loss gets below this value.
411
    Function - a function taking (loss, complexity) as arguments and returning true or false.
412
- `timeout_in_seconds`: Float64 - the time in seconds after which to exit (as an alternative to the number of iterations).
413
- `max_evals`: Int (or Nothing) - the maximum number of evaluations of expressions to perform.
414
- `skip_mutation_failures`: Whether to simply skip over mutations that fail or are rejected, rather than to replace the mutated
415
    expression with the original expression and proceed normally.
416
- `nested_constraints`: Specifies how many times a combination of operators can be nested. For example,
417
    `[sin => [cos => 0], cos => [cos => 2]]` specifies that `cos` may never appear within a `sin`,
418
    but `sin` can be nested with itself an unlimited number of times. The second term specifies that `cos`
419
    can be nested up to 2 times within a `cos`, so that `cos(cos(cos(x)))` is allowed (as well as any combination
420
    of `+` or `-` within it), but `cos(cos(cos(cos(x))))` is not allowed. When an operator is not specified,
421
    it is assumed that it can be nested an unlimited number of times. This requires that there is no operator
422
    which is used both in the unary operators and the binary operators (e.g., `-` could be both subtract, and negation).
423
    For binary operators, both arguments are treated the same way, and the max of each argument is constrained.
424
- `deterministic`: Use a global counter for the birth time, rather than calls to `time()`. This gives
425
    perfect resolution, and is therefore deterministic. However, it is not thread safe, and must be used
426
    in serial mode.
427
- `define_helper_functions`: Whether to define helper functions
428
    for constructing and evaluating trees.
429
"""
430

431
"""
432
    Options(;kws...)
433

434
Construct options for `equation_search` and other functions.
435
The current arguments have been tuned using the median values from
436
https://github.com/MilesCranmer/PySR/discussions/115.
437

438
# Arguments
439
$(OPTION_DESCRIPTIONS)
440
"""
441
@unstable @save_kwargs DEFAULT_OPTIONS function Options(;
18,244✔
442
    binary_operators=Function[+, -, /, *],
443
    unary_operators=Function[],
444
    constraints=nothing,
445
    elementwise_loss::Union{Function,SupervisedLoss,Nothing}=nothing,
446
    loss_function::Union{Function,Nothing}=nothing,
447
    tournament_selection_n::Integer=12, #1 sampled from every tournament_selection_n per mutation
448
    tournament_selection_p::Real=0.86,
449
    topn::Integer=12, #samples to return per population
450
    complexity_of_operators=nothing,
451
    complexity_of_constants::Union{Nothing,Real}=nothing,
452
    complexity_of_variables::Union{Nothing,Real,AbstractVector}=nothing,
453
    parsimony::Real=0.0032,
454
    dimensional_constraint_penalty::Union{Nothing,Real}=nothing,
455
    dimensionless_constants_only::Bool=false,
456
    alpha::Real=0.100000,
457
    maxsize::Integer=20,
458
    maxdepth::Union{Nothing,Integer}=nothing,
459
    turbo::Bool=false,
460
    bumper::Bool=false,
461
    migration::Bool=true,
462
    hof_migration::Bool=true,
463
    should_simplify::Union{Nothing,Bool}=nothing,
464
    should_optimize_constants::Bool=true,
465
    output_file::Union{Nothing,AbstractString}=nothing,
466
    expression_type::Type=Expression,
467
    node_type::Type=default_node_type(expression_type),
468
    expression_options::NamedTuple=NamedTuple(),
469
    populations::Integer=15,
470
    perturbation_factor::Real=0.076,
471
    annealing::Bool=false,
472
    batching::Bool=false,
473
    batch_size::Integer=50,
474
    mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(),
475
    crossover_probability::Real=0.066,
476
    warmup_maxsize_by::Real=0.0,
477
    use_frequency::Bool=true,
478
    use_frequency_in_tournament::Bool=true,
479
    adaptive_parsimony_scaling::Real=20.0,
480
    population_size::Integer=33,
481
    ncycles_per_iteration::Integer=550,
482
    fraction_replaced::Real=0.00036,
483
    fraction_replaced_hof::Real=0.035,
484
    verbosity::Union{Integer,Nothing}=nothing,
485
    print_precision::Integer=5,
486
    save_to_file::Bool=true,
487
    probability_negate_constant::Real=0.01,
488
    seed=nothing,
489
    bin_constraints=nothing,
490
    una_constraints=nothing,
491
    progress::Union{Bool,Nothing}=nothing,
492
    terminal_width::Union{Nothing,Integer}=nothing,
493
    optimizer_algorithm::Union{AbstractString,Optim.AbstractOptimizer}=Optim.BFGS(;
494
        linesearch=LineSearches.BackTracking()
495
    ),
496
    optimizer_nrestarts::Integer=2,
497
    optimizer_probability::Real=0.14,
498
    optimizer_iterations::Union{Nothing,Integer}=nothing,
499
    optimizer_f_calls_limit::Union{Nothing,Integer}=nothing,
500
    optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing,
501
    autodiff_backend::Union{AbstractADType,Symbol,Nothing}=nothing,
502
    use_recorder::Bool=false,
503
    recorder_file::AbstractString="pysr_recorder.json",
504
    early_stop_condition::Union{Function,Real,Nothing}=nothing,
505
    timeout_in_seconds::Union{Nothing,Real}=nothing,
506
    max_evals::Union{Nothing,Integer}=nothing,
507
    skip_mutation_failures::Bool=true,
508
    nested_constraints=nothing,
509
    deterministic::Bool=false,
510
    # Not search options; just construction options:
511
    define_helper_functions::Bool=true,
512
    deprecated_return_state=nothing,
513
    # Deprecated args:
514
    fast_cycle::Bool=false,
515
    npopulations::Union{Nothing,Integer}=nothing,
516
    npop::Union{Nothing,Integer}=nothing,
517
    kws...,
518
)
519
    for k in keys(kws)
11,527✔
520
        !haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k")
450✔
521
        new_key = deprecated_options_mapping[k]
450✔
522
        if startswith(string(new_key), "deprecated_")
55✔
523
            Base.depwarn("The keyword argument `$(k)` is deprecated.", :Options)
×
524
            if string(new_key) != "deprecated_return_state"
×
525
                # This one we actually want to use
526
                continue
×
527
            end
528
        else
529
            Base.depwarn(
30✔
530
                "The keyword argument `$(k)` is deprecated. Use `$(new_key)` instead.",
531
                :Options,
532
            )
533
        end
534
        # Now, set the new key to the old value:
535
        #! format: off
536
        k == :hofMigration && (hof_migration = kws[k]; true) && continue
30✔
537
        k == :shouldOptimizeConstants && (should_optimize_constants = kws[k]; true) && continue
36✔
538
        k == :hofFile && (output_file = kws[k]; true) && continue
24✔
539
        k == :perturbationFactor && (perturbation_factor = kws[k]; true) && continue
24✔
540
        k == :batchSize && (batch_size = kws[k]; true) && continue
24✔
541
        k == :crossoverProbability && (crossover_probability = kws[k]; true) && continue
24✔
542
        k == :warmupMaxsizeBy && (warmup_maxsize_by = kws[k]; true) && continue
24✔
543
        k == :useFrequency && (use_frequency = kws[k]; true) && continue
24✔
544
        k == :useFrequencyInTournament && (use_frequency_in_tournament = kws[k]; true) && continue
24✔
545
        k == :ncyclesperiteration && (ncycles_per_iteration = kws[k]; true) && continue
24✔
546
        k == :fractionReplaced && (fraction_replaced = kws[k]; true) && continue
24✔
547
        k == :fractionReplacedHof && (fraction_replaced_hof = kws[k]; true) && continue
30✔
548
        k == :probNegate && (probability_negate_constant = kws[k]; true) && continue
18✔
549
        k == :optimize_probability && (optimizer_probability = kws[k]; true) && continue
18✔
550
        k == :probPickFirst && (tournament_selection_p = kws[k]; true) && continue
18✔
551
        k == :earlyStopCondition && (early_stop_condition = kws[k]; true) && continue
18✔
552
        k == :return_state && (deprecated_return_state = kws[k]; true) && continue
18✔
553
        k == :stateReturn && (deprecated_return_state = kws[k]; true) && continue
18✔
554
        k == :enable_autodiff && continue
18✔
555
        k == :ns && (tournament_selection_n = kws[k]; true) && continue
18✔
556
        k == :loss && (elementwise_loss = kws[k]; true) && continue
24✔
557
        if k == :mutationWeights
12✔
558
            if typeof(kws[k]) <: AbstractVector
12✔
559
                _mutation_weights = kws[k]
6✔
560
                if length(_mutation_weights) < length(mutations)
6✔
561
                    # Pad with zeros:
562
                    _mutation_weights = vcat(
24✔
563
                        _mutation_weights,
564
                        zeros(length(mutations) - length(_mutation_weights))
565
                    )
566
                end
567
                mutation_weights = MutationWeights(_mutation_weights...)
6✔
568
            else
569
                mutation_weights = kws[k]
6✔
570
            end
571
            continue
12✔
572
        end
573
        #! format: on
574
        error(
6✔
575
            "Unknown deprecated keyword argument: $k. Please update `Options(;)` to transfer this key.",
576
        )
577
    end
30✔
578
    fast_cycle && Base.depwarn("`fast_cycle` is deprecated and has no effect.", :Options)
9,892✔
579
    if npop !== nothing
9,868✔
580
        Base.depwarn("`npop` is deprecated. Use `population_size` instead.", :Options)
30✔
581
        population_size = npop
30✔
582
    end
583
    if npopulations !== nothing
9,868✔
584
        Base.depwarn("`npopulations` is deprecated. Use `populations` instead.", :Options)
36✔
585
        populations = npopulations
36✔
586
    end
587
    if optimizer_algorithm isa AbstractString
9,868✔
588
        Base.depwarn(
30✔
589
            "The `optimizer_algorithm` argument should be an `AbstractOptimizer`, not a string.",
590
            :Options,
591
        )
592
        optimizer_algorithm = if optimizer_algorithm == "NelderMead"
30✔
593
            Optim.NelderMead(; linesearch=LineSearches.BackTracking())
24✔
594
        else
595
            Optim.BFGS(; linesearch=LineSearches.BackTracking())
36✔
596
        end
597
    end
598

599
    if elementwise_loss === nothing
9,868✔
600
        elementwise_loss = L2DistLoss()
676✔
601
    else
602
        if loss_function !== nothing
9,192✔
603
            error("You cannot specify both `elementwise_loss` and `loss_function`.")
×
604
        end
605
    end
606

607
    if should_simplify === nothing
9,892✔
608
        should_simplify = (
9,868✔
609
            loss_function === nothing &&
610
            nested_constraints === nothing &&
611
            constraints === nothing &&
612
            bin_constraints === nothing &&
613
            una_constraints === nothing
614
        )
615
    end
616

617
    is_testing = parse(Bool, get(ENV, "SYMBOLIC_REGRESSION_IS_TESTING", "false"))
18,062✔
618

619
    if output_file === nothing
9,892✔
620
        # "%Y-%m-%d_%H%M%S.%f"
621
        date_time_str = Dates.format(Dates.now(), "yyyy-mm-dd_HHMMSS.sss")
9,892✔
622
        output_file = "hall_of_fame_" * date_time_str * ".csv"
11,529✔
623
        if is_testing
9,892✔
624
            tmpdir = mktempdir()
9,794✔
625
            output_file = joinpath(tmpdir, output_file)
9,794✔
626
        end
627
    end
628

629
    @assert maxsize > 3
9,892✔
630
    @assert warmup_maxsize_by >= 0.0f0
9,892✔
631
    @assert length(unary_operators) <= max_ops
9,892✔
632
    @assert length(binary_operators) <= max_ops
9,892✔
633

634
    # Make sure nested_constraints contains functions within our operator set:
635
    _nested_constraints = build_nested_constraints(;
9,880✔
636
        binary_operators, unary_operators, nested_constraints
637
    )
638

639
    if typeof(constraints) <: Tuple
9,868✔
640
        constraints = collect(constraints)
26✔
641
    end
642
    if constraints !== nothing
9,868✔
643
        @assert bin_constraints === nothing
24✔
644
        @assert una_constraints === nothing
24✔
645
        # TODO: This is redundant with the checks in equation_search
646
        for op in binary_operators
24✔
647
            @assert !(op in unary_operators)
90✔
648
        end
90✔
649
        for op in unary_operators
24✔
650
            @assert !(op in binary_operators)
34✔
651
        end
20✔
652
        bin_constraints = constraints
24✔
653
        una_constraints = constraints
24✔
654
    end
655

656
    _una_constraints, _bin_constraints = build_constraints(;
9,892✔
657
        una_constraints, bin_constraints, unary_operators, binary_operators
658
    )
659

660
    complexity_mapping = ComplexityMapping(
14,839✔
661
        complexity_of_operators,
662
        complexity_of_variables,
663
        complexity_of_constants,
664
        binary_operators,
665
        unary_operators,
666
    )
667

668
    if maxdepth === nothing
9,892✔
669
        maxdepth = maxsize
9,862✔
670
    end
671

672
    if define_helper_functions
9,892✔
673
        # We call here so that mapped operators, like ^
674
        # are correctly overloaded, rather than overloading
675
        # operators like "safe_pow", etc.
676
        OperatorEnum(;
9,796✔
677
            binary_operators=binary_operators,
678
            unary_operators=unary_operators,
679
            define_helper_functions=true,
680
            empty_old_operators=true,
681
        )
682
    end
683

684
    binary_operators = map(binopmap, binary_operators)
9,893✔
685
    unary_operators = map(unaopmap, unary_operators)
9,903✔
686

687
    operators = OperatorEnum(;
9,892✔
688
        binary_operators=binary_operators,
689
        unary_operators=unary_operators,
690
        define_helper_functions=define_helper_functions,
691
        empty_old_operators=false,
692
    )
693

694
    early_stop_condition = if typeof(early_stop_condition) <: Real
9,868✔
695
        # Need to make explicit copy here for this to work:
696
        stopping_point = Float64(early_stop_condition)
42✔
697
        (loss, complexity) -> loss < stopping_point
1,181,742,228✔
698
    else
699
        early_stop_condition
9,846✔
700
    end
701

702
    # Parse optimizer options
703
    if !isa(optimizer_options, Optim.Options)
9,868✔
704
        optimizer_iterations = isnothing(optimizer_iterations) ? 8 : optimizer_iterations
9,856✔
705
        optimizer_f_calls_limit = if isnothing(optimizer_f_calls_limit)
9,856✔
706
            10_000
9,856✔
707
        else
708
            optimizer_f_calls_limit
3,268✔
709
        end
710
        extra_kws = hasfield(Optim.Options, :show_warnings) ? (; show_warnings=false) : ()
9,856✔
711
        optimizer_options = Optim.Options(;
9,880✔
712
            iterations=optimizer_iterations,
713
            f_calls_limit=optimizer_f_calls_limit,
714
            extra_kws...,
715
            (isnothing(optimizer_options) ? () : optimizer_options)...,
716
        )
717
    else
718
        @assert optimizer_iterations === nothing && optimizer_f_calls_limit === nothing
12✔
719
    end
720
    if hasfield(Optim.Options, :show_warnings) && optimizer_options.show_warnings
9,892✔
721
        @warn "Optimizer warnings are turned on. This might result in a lot of warnings being printed from NaNs, as these are common during symbolic regression"
12✔
722
    end
723

724
    set_mutation_weights = create_mutation_weights(mutation_weights)
9,868✔
725

726
    @assert print_precision > 0
9,892✔
727

728
    _autodiff_backend = if autodiff_backend isa Union{Nothing,AbstractADType}
9,868✔
729
        autodiff_backend
9,850✔
730
    else
731
        ADTypes.Auto(autodiff_backend)
3,286✔
732
    end
733

734
    options = Options{
9,892✔
735
        typeof(complexity_mapping),
736
        operator_specialization(typeof(operators)),
737
        node_type,
738
        expression_type,
739
        typeof(expression_options),
740
        turbo,
741
        bumper,
742
        deprecated_return_state,
743
        typeof(_autodiff_backend),
744
    }(
745
        operators,
746
        _bin_constraints,
747
        _una_constraints,
748
        complexity_mapping,
749
        tournament_selection_n,
750
        tournament_selection_p,
751
        parsimony,
752
        dimensional_constraint_penalty,
753
        dimensionless_constants_only,
754
        alpha,
755
        maxsize,
756
        maxdepth,
757
        Val(turbo),
758
        Val(bumper),
759
        migration,
760
        hof_migration,
761
        should_simplify,
762
        should_optimize_constants,
763
        output_file,
764
        populations,
765
        perturbation_factor,
766
        annealing,
767
        batching,
768
        batch_size,
769
        set_mutation_weights,
770
        crossover_probability,
771
        warmup_maxsize_by,
772
        use_frequency,
773
        use_frequency_in_tournament,
774
        adaptive_parsimony_scaling,
775
        population_size,
776
        ncycles_per_iteration,
777
        fraction_replaced,
778
        fraction_replaced_hof,
779
        topn,
780
        verbosity,
781
        print_precision,
782
        save_to_file,
783
        probability_negate_constant,
784
        length(unary_operators),
785
        length(binary_operators),
786
        seed,
787
        elementwise_loss,
788
        loss_function,
789
        node_type,
790
        expression_type,
791
        expression_options,
792
        progress,
793
        terminal_width,
794
        optimizer_algorithm,
795
        optimizer_probability,
796
        optimizer_nrestarts,
797
        optimizer_options,
798
        _autodiff_backend,
799
        recorder_file,
800
        tournament_selection_p,
801
        early_stop_condition,
802
        Val(deprecated_return_state),
803
        timeout_in_seconds,
804
        max_evals,
805
        skip_mutation_failures,
806
        _nested_constraints,
807
        deterministic,
808
        define_helper_functions,
809
        use_recorder,
810
    )
811

812
    return options
9,892✔
813
end
814

815
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc