• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9686354911

26 Jun 2024 08:31PM UTC coverage: 93.22% (-1.4%) from 94.617%
9686354911

Pull #326

github

web-flow
Merge 6f8229c9f into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

275 of 296 new or added lines in 17 files covered. (92.91%)

34 existing lines in 5 files now uncovered.

2530 of 2714 relevant lines covered (93.22%)

32081968.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.89
/src/Options.jl
1
module OptionsModule
2

3
using DispatchDoctor: @unstable
4
using Optim: Optim
5
using Dates: Dates
6
using StatsBase: StatsBase
7
using DynamicExpressions: OperatorEnum, Node, Expression, default_node_type
8
using ADTypes: AbstractADType, ADTypes
9
using Distributed: nworkers
10
using LossFunctions: L2DistLoss, SupervisedLoss
11
using Optim: Optim
12
using LineSearches: LineSearches
13
#TODO - eventually move some of these
14
# into the SR call itself, rather than
15
# passing huge options at once.
16
using ..OperatorsModule:
17
    plus,
18
    pow,
19
    safe_pow,
20
    mult,
21
    sub,
22
    safe_log,
23
    safe_log10,
24
    safe_log2,
25
    safe_log1p,
26
    safe_sqrt,
27
    safe_acosh,
28
    atanh_clip
29
using ..MutationWeightsModule: MutationWeights, mutations
30
import ..OptionsStructModule: Options
31
using ..OptionsStructModule: ComplexityMapping, operator_specialization
32
using ..UtilsModule: max_ops, @save_kwargs
33

34
"""
35
         build_constraints(una_constraints, bin_constraints,
36
                           unary_operators, binary_operators)
37

38
Build constraints on operator-level complexity from a user-passed dict.
39
"""
40
function build_constraints(
8,079✔
41
    una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin
42
)::Tuple{Array{Int,1},Array{Tuple{Int,Int},1}}
43
    # Expect format ((*)=>(-1, 3)), etc.
44
    # TODO: Need to disable simplification if (*, -, +, /) are constrained?
45
    #  Or, just quit simplification is constraints violated.
46

47
    is_bin_constraints_already_done = typeof(bin_constraints) <: Array{Tuple{Int,Int},1}
9,606✔
48
    is_una_constraints_already_done = typeof(una_constraints) <: Array{Int,1}
9,606✔
49

50
    if typeof(bin_constraints) <: Array && !is_bin_constraints_already_done
9,606✔
51
        bin_constraints = Dict(bin_constraints)
24✔
52
    end
53
    if typeof(una_constraints) <: Array && !is_una_constraints_already_done
9,606✔
54
        una_constraints = Dict(una_constraints)
24✔
55
    end
56

57
    if una_constraints === nothing
9,606✔
58
        una_constraints = [-1 for i in 1:nuna]
16,039✔
59
    elseif !is_una_constraints_already_done
24✔
60
        una_constraints::Dict
24✔
61
        _una_constraints = Int[]
24✔
62
        for (i, op) in enumerate(unary_operators)
24✔
63
            did_user_declare_constraints = haskey(una_constraints, op)
32✔
64
            if did_user_declare_constraints
24✔
65
                constraint::Int = una_constraints[op]
32✔
66
                push!(_una_constraints, constraint)
24✔
67
            else
68
                push!(_una_constraints, -1)
4✔
69
            end
70
        end
20✔
71
        una_constraints = _una_constraints
24✔
72
    end
73
    if bin_constraints === nothing
9,606✔
74
        bin_constraints = [(-1, -1) for i in 1:nbin]
16,146✔
75
    elseif !is_bin_constraints_already_done
24✔
76
        bin_constraints::Dict
24✔
77
        _bin_constraints = Tuple{Int,Int}[]
24✔
78
        for (i, op) in enumerate(binary_operators)
36✔
79
            did_user_declare_constraints = haskey(bin_constraints, op)
84✔
80
            if did_user_declare_constraints
84✔
81
                constraint::Tuple{Int,Int} = bin_constraints[op]
12✔
82
                push!(_bin_constraints, constraint)
12✔
83
            else
84
                push!(_bin_constraints, (-1, -1))
86✔
85
            end
86
        end
120✔
87
        bin_constraints = _bin_constraints
24✔
88
    end
89

90
    return una_constraints, bin_constraints
9,648✔
91
end
92

93
function binopmap(op::F) where {F}
31,926✔
94
    if op == plus
33,442✔
95
        return +
18✔
96
    elseif op == mult
33,424✔
97
        return *
18✔
98
    elseif op == sub
33,406✔
99
        return -
7,529✔
100
    elseif op == div
25,877✔
101
        return /
×
102
    elseif op == ^
25,877✔
103
        return safe_pow
6,164✔
104
    elseif op == pow
19,713✔
105
        return safe_pow
×
106
    end
107
    return op
19,713✔
108
end
109
function inverse_binopmap(op::F) where {F}
596✔
110
    if op == safe_pow
728✔
111
        return ^
104✔
112
    end
113
    return op
624✔
114
end
115

116
function unaopmap(op::F) where {F}
13,210✔
117
    if op == log
14,356✔
118
        return safe_log
76✔
119
    elseif op == log10
14,280✔
120
        return safe_log10
18✔
121
    elseif op == log2
14,262✔
122
        return safe_log2
18✔
123
    elseif op == log1p
14,244✔
124
        return safe_log1p
18✔
125
    elseif op == sqrt
14,226✔
126
        return safe_sqrt
120✔
127
    elseif op == acosh
14,106✔
128
        return safe_acosh
18✔
129
    elseif op == atanh
14,088✔
130
        return atanh_clip
12✔
131
    end
132
    return op
14,076✔
133
end
134
function inverse_unaopmap(op::F) where {F}
354✔
135
    if op == safe_log
420✔
136
        return log
8✔
137
    elseif op == safe_log10
412✔
138
        return log10
8✔
139
    elseif op == safe_log2
404✔
140
        return log2
8✔
141
    elseif op == safe_log1p
396✔
142
        return log1p
×
143
    elseif op == safe_sqrt
396✔
144
        return sqrt
42✔
145
    elseif op == safe_acosh
354✔
146
        return acosh
8✔
147
    elseif op == atanh_clip
346✔
148
        return atanh
×
149
    end
150
    return op
346✔
151
end
152

153
create_mutation_weights(w::MutationWeights) = w
8,067✔
154
create_mutation_weights(w::NamedTuple) = MutationWeights(; w...)
6✔
155

156
const deprecated_options_mapping = Base.ImmutableDict(
157
    :mutationWeights => :mutation_weights,
158
    :hofMigration => :hof_migration,
159
    :shouldOptimizeConstants => :should_optimize_constants,
160
    :hofFile => :output_file,
161
    :perturbationFactor => :perturbation_factor,
162
    :batchSize => :batch_size,
163
    :crossoverProbability => :crossover_probability,
164
    :warmupMaxsizeBy => :warmup_maxsize_by,
165
    :useFrequency => :use_frequency,
166
    :useFrequencyInTournament => :use_frequency_in_tournament,
167
    :ncyclesperiteration => :ncycles_per_iteration,
168
    :fractionReplaced => :fraction_replaced,
169
    :fractionReplacedHof => :fraction_replaced_hof,
170
    :probNegate => :probability_negate_constant,
171
    :optimize_probability => :optimizer_probability,
172
    :probPickFirst => :tournament_selection_p,
173
    :earlyStopCondition => :early_stop_condition,
174
    :stateReturn => :deprecated_return_state,
175
    :return_state => :deprecated_return_state,
176
    :enable_autodiff => :deprecated_enable_autodiff,
177
    :ns => :tournament_selection_n,
178
    :loss => :elementwise_loss,
179
)
180

181
const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators (functions) to use.
182
    Each operator should be defined for two input scalars,
183
    and one output scalar. All operators
184
    need to be defined over the entire real line (excluding infinity - these
185
    are stopped before they are input), or return `NaN` where not defined.
186
    For speed, define it so it takes two reals
187
    of the same type as input, and outputs the same type. For the SymbolicUtils
188
    simplification backend, you will need to define a generic method of the
189
    operator so it takes arbitrary types.
190
- `unary_operators`: Same, but for
191
    unary operators (one input scalar, gives an output scalar).
192
- `constraints`: Array of pairs specifying size constraints
193
    for each operator. The constraints for a binary operator should be a 2-tuple
194
    (e.g., `(-1, -1)`) and the constraints for a unary operator should be an `Int`.
195
    A size constraint is a limit to the size of the subtree
196
    in each argument of an operator. e.g., `[(^)=>(-1, 3)]` means that the
197
    `^` operator can have arbitrary size (`-1`) in its left argument,
198
    but a maximum size of `3` in its right argument. Default is
199
    no constraints.
200
- `batching`: Whether to evolve based on small mini-batches of data,
201
    rather than the entire dataset.
202
- `batch_size`: What batch size to use if using batching.
203
- `elementwise_loss`: What elementwise loss function to use. Can be one of
204
    the following losses, or any other loss of type
205
    `SupervisedLoss`. You can also pass a function that takes
206
    a scalar target (left argument), and scalar predicted (right
207
    argument), and returns a scalar. This will be averaged
208
    over the predicted data. If weights are supplied, your
209
    function should take a third argument for the weight scalar.
210
    Included losses:
211
        Regression:
212
            - `LPDistLoss{P}()`,
213
            - `L1DistLoss()`,
214
            - `L2DistLoss()` (mean square),
215
            - `LogitDistLoss()`,
216
            - `HuberLoss(d)`,
217
            - `L1EpsilonInsLoss(ϵ)`,
218
            - `L2EpsilonInsLoss(ϵ)`,
219
            - `PeriodicLoss(c)`,
220
            - `QuantileLoss(τ)`,
221
        Classification:
222
            - `ZeroOneLoss()`,
223
            - `PerceptronLoss()`,
224
            - `L1HingeLoss()`,
225
            - `SmoothedL1HingeLoss(γ)`,
226
            - `ModifiedHuberLoss()`,
227
            - `L2MarginLoss()`,
228
            - `ExpLoss()`,
229
            - `SigmoidLoss()`,
230
            - `DWDMarginLoss(q)`.
231
- `loss_function`: Alternatively, you may redefine the loss used
232
    as any function of `tree::AbstractExpressionNode{T}`, `dataset::Dataset{T}`,
233
    and `options::Options`, so long as you output a non-negative
234
    scalar of type `T`. This is useful if you want to use a loss
235
    that takes into account derivatives, or correlations across
236
    the dataset. This also means you could use a custom evaluation
237
    for a particular expression. If you are using
238
    `batching=true`, then your function should
239
    accept a fourth argument `idx`, which is either `nothing`
240
    (indicating that the full dataset should be used), or a vector
241
    of indices to use for the batch.
242
    For example,
243

244
        function my_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
245
            prediction, flag = eval_tree_array(tree, dataset.X, options)
246
            if !flag
247
                return L(Inf)
248
            end
249
            return sum((prediction .- dataset.y) .^ 2) / dataset.n
250
        end
251

252
- `expression_type::Type{E}=Expression`: The type of expression to use.
253
    For example, `Expression`.
254
- `node_type::Type{N}=default_node_type(Expression)`: The type of node to use for the search.
255
    For example, `Node` or `GraphNode`. The default is computed by `default_node_type(expression_type)`.
256
- `populations`: How many populations of equations to use.
257
- `population_size`: How many equations in each population.
258
- `ncycles_per_iteration`: How many generations to consider per iteration.
259
- `tournament_selection_n`: Number of expressions considered in each tournament.
260
- `tournament_selection_p`: The fittest expression in a tournament is to be
261
    selected with probability `p`, the next fittest with probability `p*(1-p)`,
262
    and so forth.
263
- `topn`: Number of equations to return to the host process, and to
264
    consider for the hall of fame.
265
- `complexity_of_operators`: What complexity should be assigned to each operator,
266
    and the occurrence of a constant or variable. By default, this is 1
267
    for all operators. Can be a real number as well, in which case
268
    the complexity of an expression will be rounded to the nearest integer.
269
    Input this in the form of, e.g., [(^) => 3, sin => 2].
270
- `complexity_of_constants`: What complexity should be assigned to use of a constant.
271
    By default, this is 1.
272
- `complexity_of_variables`: What complexity should be assigned to use of a variable,
273
    which can also be a vector indicating different per-variable complexity.
274
    By default, this is 1.
275
- `alpha`: The probability of accepting an equation mutation
276
    during regularized evolution is given by exp(-delta_loss/(alpha * T)),
277
    where T goes from 1 to 0. Thus, alpha=infinite is the same as no annealing.
278
- `maxsize`: Maximum size of equations during the search.
279
- `maxdepth`: Maximum depth of equations during the search, by default
280
    this is set equal to the maxsize.
281
- `parsimony`: A multiplicative factor for how much complexity is
282
    punished.
283
- `dimensional_constraint_penalty`: An additive factor if the dimensional
284
    constraint is violated.
285
- `dimensionless_constants_only`: Whether to only allow dimensionless
286
    constants.
287
- `use_frequency`: Whether to use a parsimony that adapts to the
288
    relative proportion of equations at each complexity; this will
289
    ensure that there are a balanced number of equations considered
290
    for every complexity.
291
- `use_frequency_in_tournament`: Whether to use the adaptive parsimony described
292
    above inside the score, rather than just at the mutation accept/reject stage.
293
- `adaptive_parsimony_scaling`: How much to scale the adaptive parsimony term
294
    in the loss. Increase this if the search is spending too much time
295
    optimizing the most complex equations.
296
- `turbo`: Whether to use `LoopVectorization.@turbo` to evaluate expressions.
297
    This can be significantly faster, but is only compatible with certain
298
    operators. *Experimental!*
299
- `bumper`: Whether to use Bumper.jl for faster evaluation. *Experimental!*
300
- `migration`: Whether to migrate equations between processes.
301
- `hof_migration`: Whether to migrate equations from the hall of fame
302
    to processes.
303
- `fraction_replaced`: What fraction of each population to replace with
304
    migrated equations at the end of each cycle.
305
- `fraction_replaced_hof`: What fraction to replace with hall of fame
306
    equations at the end of each cycle.
307
- `should_simplify`: Whether to simplify equations. If you
308
    pass a custom objective, this will be set to `false`.
309
- `should_optimize_constants`: Whether to use an optimization algorithm
310
    to periodically optimize constants in equations.
311
- `optimizer_algorithm`: Select algorithm to use for optimizing constants. Default
312
    is `Optim.BFGS(linesearch=LineSearches.BackTracking())`.
313
- `optimizer_nrestarts`: How many different random starting positions to consider
314
    for optimization of constants.
315
- `optimizer_probability`: Probability of performing optimization of constants at
316
    the end of a given iteration.
317
- `optimizer_iterations`: How many optimization iterations to perform. This gets
318
    passed to `Optim.Options` as `iterations`. The default is 8.
319
- `optimizer_f_calls_limit`: How many function calls to allow during optimization.
320
    This gets passed to `Optim.Options` as `f_calls_limit`. The default is
321
    `10_000`.
322
- `optimizer_options`: General options for the constant optimization. For details
323
    we refer to the documentation on `Optim.Options` from the `Optim.jl` package.
324
    Options can be provided here as `NamedTuple`, e.g. `(iterations=16,)`, as a
325
    `Dict`, e.g. Dict(:x_tol => 1.0e-32,), or as an `Optim.Options` instance.
326
- `autodiff_backend`: The backend to use for differentiation, which should be
327
    an instance of `AbstractADType` (see `DifferentiationInterface.jl`).
328
    Default is `nothing`, which means `Optim.jl` will estimate gradients (likely
329
    with finite differences). You can also pass a symbolic version of the backend
330
    type, such as `:Zygote` for Zygote, `:Enzyme`, etc. Most backends will not
331
    work, and many will never work due to incompatibilities, though support for some
332
    is gradually being added.
333
- `output_file`: What file to store equations to, as a backup.
334
- `perturbation_factor`: When mutating a constant, either
335
    multiply or divide by (1+perturbation_factor)^(rand()+1).
336
- `probability_negate_constant`: Probability of negating a constant in the equation
337
    when mutating it.
338
- `mutation_weights`: Relative probabilities of the mutations. The struct
339
    `MutationWeights` should be passed to these options.
340
    See its documentation on `MutationWeights` for the different weights.
341
- `crossover_probability`: Probability of performing crossover.
342
- `annealing`: Whether to use simulated annealing.
343
- `warmup_maxsize_by`: Whether to slowly increase the max size from 5 up to
344
    `maxsize`. If nonzero, specifies the fraction through the search
345
    at which the maxsize should be reached.
346
- `verbosity`: Whether to print debugging statements or
347
    not.
348
- `print_precision`: How many digits to print when printing
349
    equations. By default, this is 5.
350
- `save_to_file`: Whether to save equations to a file during the search.
351
- `bin_constraints`: See `constraints`. This is the same, but specified for binary
352
    operators only (for example, if you have an operator that is both a binary
353
    and unary operator).
354
- `una_constraints`: Likewise, for unary operators.
355
- `seed`: What random seed to use. `nothing` uses no seed.
356
- `progress`: Whether to use a progress bar output (`verbosity` will
357
    have no effect).
358
- `early_stop_condition`: Float - whether to stop early if the mean loss gets below this value.
359
    Function - a function taking (loss, complexity) as arguments and returning true or false.
360
- `timeout_in_seconds`: Float64 - the time in seconds after which to exit (as an alternative to the number of iterations).
361
- `max_evals`: Int (or Nothing) - the maximum number of evaluations of expressions to perform.
362
- `skip_mutation_failures`: Whether to simply skip over mutations that fail or are rejected, rather than to replace the mutated
363
    expression with the original expression and proceed normally.
364
- `nested_constraints`: Specifies how many times a combination of operators can be nested. For example,
365
    `[sin => [cos => 0], cos => [cos => 2]]` specifies that `cos` may never appear within a `sin`,
366
    but `sin` can be nested with itself an unlimited number of times. The second term specifies that `cos`
367
    can be nested up to 2 times within a `cos`, so that `cos(cos(cos(x)))` is allowed (as well as any combination
368
    of `+` or `-` within it), but `cos(cos(cos(cos(x))))` is not allowed. When an operator is not specified,
369
    it is assumed that it can be nested an unlimited number of times. This requires that there is no operator
370
    which is used both in the unary operators and the binary operators (e.g., `-` could be both subtract, and negation).
371
    For binary operators, both arguments are treated the same way, and the max of each argument is constrained.
372
- `deterministic`: Use a global counter for the birth time, rather than calls to `time()`. This gives
373
    perfect resolution, and is therefore deterministic. However, it is not thread safe, and must be used
374
    in serial mode.
375
- `define_helper_functions`: Whether to define helper functions
376
    for constructing and evaluating trees.
377
"""
378

379
"""
380
    Options(;kws...)
381

382
Construct options for `equation_search` and other functions.
383
The current arguments have been tuned using the median values from
384
https://github.com/MilesCranmer/PySR/discussions/115.
385

386
# Arguments
387
$(OPTION_DESCRIPTIONS)
388
"""
389
@unstable @save_kwargs DEFAULT_OPTIONS function Options(;
17,840✔
390
    binary_operators=Function[+, -, /, *],
391
    unary_operators=Function[],
392
    constraints=nothing,
393
    elementwise_loss::Union{Function,SupervisedLoss,Nothing}=nothing,
394
    loss_function::Union{Function,Nothing}=nothing,
395
    tournament_selection_n::Integer=12, #1 sampled from every tournament_selection_n per mutation
396
    tournament_selection_p::Real=0.86,
397
    topn::Integer=12, #samples to return per population
398
    complexity_of_operators=nothing,
399
    complexity_of_constants::Union{Nothing,Real}=nothing,
400
    complexity_of_variables::Union{Nothing,Real,AbstractVector}=nothing,
401
    parsimony::Real=0.0032,
402
    dimensional_constraint_penalty::Union{Nothing,Real}=nothing,
403
    dimensionless_constants_only::Bool=false,
404
    alpha::Real=0.100000,
405
    maxsize::Integer=20,
406
    maxdepth::Union{Nothing,Integer}=nothing,
407
    turbo::Bool=false,
408
    bumper::Bool=false,
409
    migration::Bool=true,
410
    hof_migration::Bool=true,
411
    should_simplify::Union{Nothing,Bool}=nothing,
412
    should_optimize_constants::Bool=true,
413
    output_file::Union{Nothing,AbstractString}=nothing,
414
    expression_type::Type=Expression,
415
    node_type::Type=default_node_type(expression_type),
416
    expression_options::NamedTuple=NamedTuple(),
417
    populations::Integer=15,
418
    perturbation_factor::Real=0.076,
419
    annealing::Bool=false,
420
    batching::Bool=false,
421
    batch_size::Integer=50,
422
    mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(),
423
    crossover_probability::Real=0.066,
424
    warmup_maxsize_by::Real=0.0,
425
    use_frequency::Bool=true,
426
    use_frequency_in_tournament::Bool=true,
427
    adaptive_parsimony_scaling::Real=20.0,
428
    population_size::Integer=33,
429
    ncycles_per_iteration::Integer=550,
430
    fraction_replaced::Real=0.00036,
431
    fraction_replaced_hof::Real=0.035,
432
    verbosity::Union{Integer,Nothing}=nothing,
433
    print_precision::Integer=5,
434
    save_to_file::Bool=true,
435
    probability_negate_constant::Real=0.01,
436
    seed=nothing,
437
    bin_constraints=nothing,
438
    una_constraints=nothing,
439
    progress::Union{Bool,Nothing}=nothing,
440
    terminal_width::Union{Nothing,Integer}=nothing,
441
    optimizer_algorithm::Union{AbstractString,Optim.AbstractOptimizer}=Optim.BFGS(;
442
        linesearch=LineSearches.BackTracking()
443
    ),
444
    optimizer_nrestarts::Integer=2,
445
    optimizer_probability::Real=0.14,
446
    optimizer_iterations::Union{Nothing,Integer}=nothing,
447
    optimizer_f_calls_limit::Union{Nothing,Integer}=nothing,
448
    optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing,
449
    autodiff_backend::Union{AbstractADType,Symbol,Nothing}=nothing,
450
    use_recorder::Bool=false,
451
    recorder_file::AbstractString="pysr_recorder.json",
452
    early_stop_condition::Union{Function,Real,Nothing}=nothing,
453
    timeout_in_seconds::Union{Nothing,Real}=nothing,
454
    max_evals::Union{Nothing,Integer}=nothing,
455
    skip_mutation_failures::Bool=true,
456
    nested_constraints=nothing,
457
    deterministic::Bool=false,
458
    # Not search options; just construction options:
459
    define_helper_functions::Bool=true,
460
    deprecated_return_state=nothing,
461
    # Deprecated args:
462
    fast_cycle::Bool=false,
463
    npopulations::Union{Nothing,Integer}=nothing,
464
    npop::Union{Nothing,Integer}=nothing,
465
    kws...,
466
)
467
    for k in keys(kws)
11,290✔
468
        !haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k")
450✔
469
        new_key = deprecated_options_mapping[k]
450✔
470
        if startswith(string(new_key), "deprecated_")
55✔
471
            Base.depwarn("The keyword argument `$(k)` is deprecated.", :Options)
×
472
            if string(new_key) != "deprecated_return_state"
×
473
                # This one we actually want to use
474
                continue
×
475
            end
476
        else
477
            Base.depwarn(
30✔
478
                "The keyword argument `$(k)` is deprecated. Use `$(new_key)` instead.",
479
                :Options,
480
            )
481
        end
482
        # Now, set the new key to the old value:
483
        #! format: off
484
        k == :hofMigration && (hof_migration = kws[k]; true) && continue
30✔
485
        k == :shouldOptimizeConstants && (should_optimize_constants = kws[k]; true) && continue
36✔
486
        k == :hofFile && (output_file = kws[k]; true) && continue
24✔
487
        k == :perturbationFactor && (perturbation_factor = kws[k]; true) && continue
24✔
488
        k == :batchSize && (batch_size = kws[k]; true) && continue
24✔
489
        k == :crossoverProbability && (crossover_probability = kws[k]; true) && continue
24✔
490
        k == :warmupMaxsizeBy && (warmup_maxsize_by = kws[k]; true) && continue
24✔
491
        k == :useFrequency && (use_frequency = kws[k]; true) && continue
24✔
492
        k == :useFrequencyInTournament && (use_frequency_in_tournament = kws[k]; true) && continue
24✔
493
        k == :ncyclesperiteration && (ncycles_per_iteration = kws[k]; true) && continue
24✔
494
        k == :fractionReplaced && (fraction_replaced = kws[k]; true) && continue
24✔
495
        k == :fractionReplacedHof && (fraction_replaced_hof = kws[k]; true) && continue
30✔
496
        k == :probNegate && (probability_negate_constant = kws[k]; true) && continue
18✔
497
        k == :optimize_probability && (optimizer_probability = kws[k]; true) && continue
18✔
498
        k == :probPickFirst && (tournament_selection_p = kws[k]; true) && continue
18✔
499
        k == :earlyStopCondition && (early_stop_condition = kws[k]; true) && continue
18✔
500
        k == :return_state && (deprecated_return_state = kws[k]; true) && continue
18✔
501
        k == :stateReturn && (deprecated_return_state = kws[k]; true) && continue
18✔
502
        k == :enable_autodiff && continue
18✔
503
        k == :ns && (tournament_selection_n = kws[k]; true) && continue
18✔
504
        k == :loss && (elementwise_loss = kws[k]; true) && continue
24✔
505
        if k == :mutationWeights
12✔
506
            if typeof(kws[k]) <: AbstractVector
12✔
507
                _mutation_weights = kws[k]
6✔
508
                if length(_mutation_weights) < length(mutations)
6✔
509
                    # Pad with zeros:
510
                    _mutation_weights = vcat(
24✔
511
                        _mutation_weights,
512
                        zeros(length(mutations) - length(_mutation_weights))
513
                    )
514
                end
515
                mutation_weights = MutationWeights(_mutation_weights...)
6✔
516
            else
517
                mutation_weights = kws[k]
6✔
518
            end
519
            continue
12✔
520
        end
521
        #! format: on
522
        error(
6✔
523
            "Unknown deprecated keyword argument: $k. Please update `Options(;)` to transfer this key.",
524
        )
525
    end
30✔
526
    fast_cycle && Base.depwarn("`fast_cycle` is deprecated and has no effect.", :Options)
9,690✔
527
    if npop !== nothing
9,663✔
UNCOV
528
        Base.depwarn("`npop` is deprecated. Use `population_size` instead.", :Options)
×
UNCOV
529
        population_size = npop
×
530
    end
531
    if npopulations !== nothing
9,663✔
532
        Base.depwarn("`npopulations` is deprecated. Use `populations` instead.", :Options)
6✔
533
        populations = npopulations
6✔
534
    end
535
    if optimizer_algorithm isa AbstractString
9,663✔
536
        Base.depwarn(
30✔
537
            "The `optimizer_algorithm` argument should be an `AbstractOptimizer`, not a string.",
538
            :Options,
539
        )
540
        optimizer_algorithm = if optimizer_algorithm == "NelderMead"
30✔
541
            Optim.NelderMead(; linesearch=LineSearches.BackTracking())
24✔
542
        else
543
            Optim.BFGS(; linesearch=LineSearches.BackTracking())
36✔
544
        end
545
    end
546

547
    if elementwise_loss === nothing
9,663✔
548
        elementwise_loss = L2DistLoss()
471✔
549
    else
550
        if loss_function !== nothing
9,192✔
551
            error("You cannot specify both `elementwise_loss` and `loss_function`.")
×
552
        end
553
    end
554

555
    if should_simplify === nothing
9,690✔
556
        should_simplify = (
9,663✔
557
            loss_function === nothing &&
558
            nested_constraints === nothing &&
559
            constraints === nothing &&
560
            bin_constraints === nothing &&
561
            una_constraints === nothing
562
        )
563
    end
564

565
    is_testing = parse(Bool, get(ENV, "SYMBOLIC_REGRESSION_IS_TESTING", "false"))
17,713✔
566

567
    if output_file === nothing
9,690✔
568
        # "%Y-%m-%d_%H%M%S.%f"
569
        date_time_str = Dates.format(Dates.now(), "yyyy-mm-dd_HHMMSS.sss")
9,690✔
570
        output_file = "hall_of_fame_" * date_time_str * ".csv"
11,293✔
571
        if is_testing
9,690✔
572
            tmpdir = mktempdir()
9,630✔
573
            output_file = joinpath(tmpdir, output_file)
9,630✔
574
        end
575
    end
576

577
    nuna = length(unary_operators)
9,690✔
578
    nbin = length(binary_operators)
9,690✔
579
    @assert maxsize > 3
9,690✔
580
    @assert warmup_maxsize_by >= 0.0f0
9,690✔
581
    @assert nuna <= max_ops && nbin <= max_ops
9,690✔
582

583
    # Make sure nested_constraints contains functions within our operator set:
584
    if nested_constraints !== nothing
9,663✔
585
        # Check that intersection of binary operators and unary operators is empty:
586
        for op in binary_operators
36✔
587
            if op ∈ unary_operators
137✔
588
                error(
6✔
589
                    "Operator $(op) is both a binary and unary operator. " *
590
                    "You can't use nested constraints.",
591
                )
592
            end
593
        end
140✔
594

595
        # Convert to dict:
596
        if !(typeof(nested_constraints) <: Dict)
36✔
597
            # Convert to dict:
598
            nested_constraints = Dict(
36✔
599
                [cons[1] => Dict(cons[2]...) for cons in nested_constraints]...
600
            )
601
        end
602
        for (op, nested_constraint) in nested_constraints
100✔
603
            if !(op ∈ binary_operators || op ∈ unary_operators)
84✔
604
                error("Operator $(op) is not in the operator set.")
×
605
            end
606
            for (nested_op, max_nesting) in nested_constraint
42✔
607
                if !(nested_op ∈ binary_operators || nested_op ∈ unary_operators)
84✔
608
                    error("Operator $(nested_op) is not in the operator set.")
×
609
                end
610
                @assert nested_op ∈ binary_operators || nested_op ∈ unary_operators
78✔
611
                @assert max_nesting >= -1 && typeof(max_nesting) <: Int
42✔
612
            end
35✔
613
        end
202✔
614

615
        # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting.
616
        new_nested_constraints = []
36✔
617
        # Dict()
618
        for (op, nested_constraint) in nested_constraints
100✔
619
            (degree, idx) = if op ∈ binary_operators
42✔
620
                2, findfirst(isequal(op), binary_operators)
6✔
621
            else
622
                1, findfirst(isequal(op), unary_operators)
78✔
623
            end
624
            new_max_nesting_dict = []
42✔
625
            # Dict()
626
            for (nested_op, max_nesting) in nested_constraint
42✔
627
                (nested_degree, nested_idx) = if nested_op ∈ binary_operators
42✔
628
                    2, findfirst(isequal(nested_op), binary_operators)
6✔
629
                else
630
                    1, findfirst(isequal(nested_op), unary_operators)
78✔
631
                end
632
                # new_max_nesting_dict[(nested_degree, nested_idx)] = max_nesting
633
                push!(new_max_nesting_dict, (nested_degree, nested_idx, max_nesting))
42✔
634
            end
35✔
635
            # new_nested_constraints[(degree, idx)] = new_max_nesting_dict
636
            push!(new_nested_constraints, (degree, idx, new_max_nesting_dict))
42✔
637
        end
202✔
638
        nested_constraints = new_nested_constraints
36✔
639
    end
640

641
    if typeof(constraints) <: Tuple
9,663✔
642
        constraints = collect(constraints)
26✔
643
    end
644
    if constraints !== nothing
9,663✔
645
        @assert bin_constraints === nothing
24✔
646
        @assert una_constraints === nothing
24✔
647
        # TODO: This is redundant with the checks in equation_search
648
        for op in binary_operators
24✔
649
            @assert !(op in unary_operators)
90✔
650
        end
90✔
651
        for op in unary_operators
24✔
652
            @assert !(op in binary_operators)
34✔
653
        end
20✔
654
        bin_constraints = constraints
24✔
655
        una_constraints = constraints
24✔
656
    end
657

658
    una_constraints, bin_constraints = build_constraints(
14,525✔
659
        una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin
660
    )
661

662
    complexity_mapping = ComplexityMapping(
14,535✔
663
        complexity_of_operators,
664
        complexity_of_variables,
665
        complexity_of_constants,
666
        binary_operators,
667
        unary_operators,
668
    )
669

670
    if maxdepth === nothing
9,690✔
671
        maxdepth = maxsize
9,657✔
672
    end
673

674
    if define_helper_functions
9,690✔
675
        # We call here so that mapped operators, like ^
676
        # are correctly overloaded, rather than overloading
677
        # operators like "safe_pow", etc.
678
        OperatorEnum(;
9,594✔
679
            binary_operators=binary_operators,
680
            unary_operators=unary_operators,
681
            define_helper_functions=true,
682
            empty_old_operators=true,
683
        )
684
    end
685

686
    binary_operators = map(binopmap, binary_operators)
9,691✔
687
    unary_operators = map(unaopmap, unary_operators)
9,696✔
688

689
    operators = OperatorEnum(;
9,690✔
690
        binary_operators=binary_operators,
691
        unary_operators=unary_operators,
692
        define_helper_functions=define_helper_functions,
693
        empty_old_operators=false,
694
    )
695

696
    early_stop_condition = if typeof(early_stop_condition) <: Real
9,663✔
697
        # Need to make explicit copy here for this to work:
698
        stopping_point = Float64(early_stop_condition)
6✔
699
        (loss, complexity) -> loss < stopping_point
522,207✔
700
    else
701
        early_stop_condition
9,665✔
702
    end
703

704
    # Parse optimizer options
705
    if !isa(optimizer_options, Optim.Options)
9,663✔
706
        optimizer_iterations = isnothing(optimizer_iterations) ? 8 : optimizer_iterations
9,651✔
707
        optimizer_f_calls_limit = if isnothing(optimizer_f_calls_limit)
9,651✔
708
            10_000
9,651✔
709
        else
710
            optimizer_f_calls_limit
3,199✔
711
        end
712
        extra_kws = hasfield(Optim.Options, :show_warnings) ? (; show_warnings=false) : ()
9,651✔
713
        optimizer_options = Optim.Options(;
9,678✔
714
            iterations=optimizer_iterations,
715
            f_calls_limit=optimizer_f_calls_limit,
716
            extra_kws...,
717
            (isnothing(optimizer_options) ? () : optimizer_options)...,
718
        )
719
    else
720
        @assert optimizer_iterations === nothing && optimizer_f_calls_limit === nothing
12✔
721
    end
722
    if hasfield(Optim.Options, :show_warnings) && optimizer_options.show_warnings
9,690✔
723
        @warn "Optimizer warnings are turned on. This might result in a lot of warnings being printed from NaNs, as these are common during symbolic regression"
12✔
724
    end
725

726
    ## Create tournament weights:
727
    tournament_selection_weights =
9,663✔
728
        let n = tournament_selection_n, p = tournament_selection_p
729
            k = collect(0:(n - 1))
83,190✔
730
            prob_each = p * ((1 - p) .^ k)
12,917✔
731

732
            StatsBase.Weights(prob_each, sum(prob_each))
9,690✔
733
        end
734

735
    set_mutation_weights = create_mutation_weights(mutation_weights)
9,663✔
736

737
    @assert print_precision > 0
9,690✔
738

739
    _autodiff_backend = if autodiff_backend isa Union{Nothing,AbstractADType}
9,663✔
740
        autodiff_backend
9,657✔
741
    else
742
        ADTypes.Auto(autodiff_backend)
3,209✔
743
    end
744

745
    options = Options{
9,690✔
746
        typeof(complexity_mapping),
747
        operator_specialization(typeof(operators)),
748
        node_type,
749
        expression_type,
750
        typeof(expression_options),
751
        turbo,
752
        bumper,
753
        deprecated_return_state,
754
        typeof(tournament_selection_weights),
755
        typeof(_autodiff_backend),
756
    }(
757
        operators,
758
        bin_constraints,
759
        una_constraints,
760
        complexity_mapping,
761
        tournament_selection_n,
762
        tournament_selection_p,
763
        tournament_selection_weights,
764
        parsimony,
765
        dimensional_constraint_penalty,
766
        dimensionless_constants_only,
767
        alpha,
768
        maxsize,
769
        maxdepth,
770
        Val(turbo),
771
        Val(bumper),
772
        migration,
773
        hof_migration,
774
        should_simplify,
775
        should_optimize_constants,
776
        output_file,
777
        populations,
778
        perturbation_factor,
779
        annealing,
780
        batching,
781
        batch_size,
782
        set_mutation_weights,
783
        crossover_probability,
784
        warmup_maxsize_by,
785
        use_frequency,
786
        use_frequency_in_tournament,
787
        adaptive_parsimony_scaling,
788
        population_size,
789
        ncycles_per_iteration,
790
        fraction_replaced,
791
        fraction_replaced_hof,
792
        topn,
793
        verbosity,
794
        print_precision,
795
        save_to_file,
796
        probability_negate_constant,
797
        nuna,
798
        nbin,
799
        seed,
800
        elementwise_loss,
801
        loss_function,
802
        node_type,
803
        expression_type,
804
        expression_options,
805
        progress,
806
        terminal_width,
807
        optimizer_algorithm,
808
        optimizer_probability,
809
        optimizer_nrestarts,
810
        optimizer_options,
811
        _autodiff_backend,
812
        recorder_file,
813
        tournament_selection_p,
814
        early_stop_condition,
815
        Val(deprecated_return_state),
816
        timeout_in_seconds,
817
        max_evals,
818
        skip_mutation_failures,
819
        nested_constraints,
820
        deterministic,
821
        define_helper_functions,
822
        use_recorder,
823
    )
824

825
    return options
9,690✔
826
end
827

828
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc