• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9686354911

26 Jun 2024 08:31PM UTC coverage: 93.22% (-1.4%) from 94.617%
9686354911

Pull #326

github

web-flow
Merge 6f8229c9f into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

275 of 296 new or added lines in 17 files covered. (92.91%)

34 existing lines in 5 files now uncovered.

2530 of 2714 relevant lines covered (93.22%)

32081968.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.83
/src/SymbolicRegression.jl
1
module SymbolicRegression
484✔
2

3
# Types
4
export Population,
5
    PopMember,
6
    HallOfFame,
7
    Options,
8
    Dataset,
9
    MutationWeights,
10
    Node,
11
    GraphNode,
12
    ParametricNode,
13
    Expression,
14
    ParametricExpression,
15
    NodeSampler,
16
    AbstractExpression,
17
    AbstractExpressionNode,
18
    SRRegressor,
19
    MultitargetSRRegressor,
20
    LOSS_TYPE,
21
    DATA_TYPE,
22

23
    #Functions:
24
    equation_search,
25
    s_r_cycle,
26
    calculate_pareto_frontier,
27
    count_nodes,
28
    compute_complexity,
29
    @parse_expression,
30
    parse_expression,
31
    print_tree,
32
    string_tree,
33
    eval_tree_array,
34
    eval_diff_tree_array,
35
    eval_grad_tree_array,
36
    differentiable_eval_tree_array,
37
    set_node!,
38
    copy_node,
39
    node_to_symbolic,
40
    node_type,
41
    symbolic_to_node,
42
    simplify_tree!,
43
    tree_mapreduce,
44
    combine_operators,
45
    gen_random_tree,
46
    gen_random_tree_fixed_size,
47
    @extend_operators,
48

49
    #Operators
50
    plus,
51
    sub,
52
    mult,
53
    square,
54
    cube,
55
    pow,
56
    safe_pow,
57
    safe_log,
58
    safe_log2,
59
    safe_log10,
60
    safe_log1p,
61
    safe_acosh,
62
    safe_sqrt,
63
    neg,
64
    greater,
65
    cond,
66
    relu,
67
    logical_or,
68
    logical_and,
69

70
    # special operators
71
    gamma,
72
    erf,
73
    erfc,
74
    atanh_clip
75

76
using Distributed
77
using Printf: @printf, @sprintf
78
using PackageExtensionCompat: @require_extensions
79
using Pkg: Pkg
80
using TOML: parsefile
81
using Random: seed!, shuffle!
82
using Reexport
83
using DynamicExpressions:
84
    Node,
85
    GraphNode,
86
    ParametricNode,
87
    Expression,
88
    ParametricExpression,
89
    NodeSampler,
90
    AbstractExpression,
91
    AbstractExpressionNode,
92
    @parse_expression,
93
    parse_expression,
94
    copy_node,
95
    set_node!,
96
    string_tree,
97
    print_tree,
98
    count_nodes,
99
    get_constants,
100
    set_constants!,
101
    index_constants,
102
    NodeIndex,
103
    eval_tree_array,
104
    differentiable_eval_tree_array,
105
    eval_diff_tree_array,
106
    eval_grad_tree_array,
107
    node_to_symbolic,
108
    symbolic_to_node,
109
    combine_operators,
110
    simplify_tree!,
111
    tree_mapreduce,
112
    set_default_variable_names!,
113
    node_type
114
using DynamicExpressions: with_type_parameters
115
@reexport using LossFunctions:
116
    MarginLoss,
117
    DistanceLoss,
118
    SupervisedLoss,
119
    ZeroOneLoss,
120
    LogitMarginLoss,
121
    PerceptronLoss,
122
    HingeLoss,
123
    L1HingeLoss,
124
    L2HingeLoss,
125
    SmoothedL1HingeLoss,
126
    ModifiedHuberLoss,
127
    L2MarginLoss,
128
    ExpLoss,
129
    SigmoidLoss,
130
    DWDMarginLoss,
131
    LPDistLoss,
132
    L1DistLoss,
133
    L2DistLoss,
134
    PeriodicLoss,
135
    HuberLoss,
136
    EpsilonInsLoss,
137
    L1EpsilonInsLoss,
138
    L2EpsilonInsLoss,
139
    LogitDistLoss,
140
    QuantileLoss,
141
    LogCoshLoss
142

143
# https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15
144
const PACKAGE_VERSION = try
145
    root = pkgdir(@__MODULE__)
146
    if root == String
147
        let project = parsefile(joinpath(root, "Project.toml"))
148
            VersionNumber(project["version"])
149
        end
150
    else
151
        VersionNumber(0, 0, 0)
152
    end
153
catch
154
    VersionNumber(0, 0, 0)
155
end
156

157
function deprecate_varmap(variable_names, varMap, func_name)
292,684✔
158
    if varMap !== nothing
368,369✔
159
        Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name)
×
160
        @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`"
×
161
        variable_names = varMap
×
162
    end
163
    return variable_names
368,356✔
164
end
165

166
using DispatchDoctor: @stable
167

168
@stable default_mode = "disable" begin
169
    include("Utils.jl")
170
    include("InterfaceDynamicQuantities.jl")
171
    include("Core.jl")
172
    include("InterfaceDynamicExpressions.jl")
173
    include("Recorder.jl")
174
    include("Complexity.jl")
175
    include("DimensionalAnalysis.jl")
176
    include("CheckConstraints.jl")
177
    include("AdaptiveParsimony.jl")
178
    include("MutationFunctions.jl")
179
    include("LossFunctions.jl")
180
    include("PopMember.jl")
181
    include("ConstantOptimization.jl")
182
    include("Population.jl")
183
    include("HallOfFame.jl")
184
    include("Mutate.jl")
185
    include("RegularizedEvolution.jl")
186
    include("SingleIteration.jl")
187
    include("ProgressBars.jl")
188
    include("Migration.jl")
189
    include("SearchUtils.jl")
190
    include("ExpressionBuilder.jl")
191
end
192

193
using .CoreModule:
194
    MAX_DEGREE,
195
    BATCH_DIM,
196
    FEATURE_DIM,
197
    DATA_TYPE,
198
    LOSS_TYPE,
199
    RecordType,
200
    Dataset,
201
    Options,
202
    MutationWeights,
203
    plus,
204
    sub,
205
    mult,
206
    square,
207
    cube,
208
    pow,
209
    safe_pow,
210
    safe_log,
211
    safe_log2,
212
    safe_log10,
213
    safe_log1p,
214
    safe_sqrt,
215
    safe_acosh,
216
    neg,
217
    greater,
218
    cond,
219
    relu,
220
    logical_or,
221
    logical_and,
222
    gamma,
223
    erf,
224
    erfc,
225
    atanh_clip,
226
    create_expression
227
using .UtilsModule: is_anonymous_function, recursive_merge, json3_write
228
using .ComplexityModule: compute_complexity
229
using .CheckConstraintsModule: check_constraints
230
using .AdaptiveParsimonyModule:
231
    RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies!
232
using .MutationFunctionsModule:
233
    gen_random_tree,
234
    gen_random_tree_fixed_size,
235
    random_node,
236
    random_node_and_parent,
237
    crossover_trees
238
using .InterfaceDynamicExpressionsModule: @extend_operators
239
using .LossFunctionsModule: eval_loss, score_func, update_baseline_loss!
240
using .PopMemberModule: PopMember, reset_birth!
241
using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample
242
using .HallOfFameModule:
243
    HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve
244
using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population
245
using .ProgressBarsModule: WrappedProgressBar
246
using .RecorderModule: @recorder, find_iteration_from_record
247
using .MigrationModule: migrate!
248
using .SearchUtilsModule:
249
    SearchState,
250
    RuntimeOptions,
251
    WorkerAssignments,
252
    DefaultWorkerOutputType,
253
    assign_next_worker!,
254
    get_worker_output_type,
255
    extract_from_worker,
256
    @sr_spawner,
257
    StdinReader,
258
    watch_stream,
259
    close_reader!,
260
    check_for_user_quit,
261
    check_for_loss_threshold,
262
    check_for_timeout,
263
    check_max_evals,
264
    ResourceMonitor,
265
    start_work_monitor!,
266
    stop_work_monitor!,
267
    estimate_work_fraction,
268
    update_progress_bar!,
269
    print_search_state,
270
    init_dummy_pops,
271
    load_saved_hall_of_fame,
272
    load_saved_population,
273
    construct_datasets,
274
    save_to_file,
275
    get_cur_maxsize,
276
    update_hall_of_fame!
277
using .ExpressionBuilderModule: embed_metadata, strip_metadata
278

279
@stable default_mode = "disable" begin
280
    include("deprecates.jl")
281
    include("Configure.jl")
282
end
283

284
"""
285
    equation_search(X, y[; kws...])
286

287
Perform a distributed equation search for functions `f_i` which
288
describe the mapping `f_i(X[:, j]) ≈ y[i, j]`. Options are
289
configured using SymbolicRegression.Options(...),
290
which should be passed as a keyword argument to options.
291
One can turn off parallelism with `numprocs=0`,
292
which is useful for debugging and profiling.
293

294
# Arguments
295
- `X::AbstractMatrix{T}`:  The input dataset to predict `y` from.
296
    The first dimension is features, the second dimension is rows.
297
- `y::Union{AbstractMatrix{T}, AbstractVector{T}}`: The values to predict. The first dimension
298
    is the output feature to predict with each equation, and the
299
    second dimension is rows.
300
- `niterations::Int=10`: The number of iterations to perform the search.
301
    More iterations will improve the results.
302
- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally
303
    weight the loss for each `y` by this value (same shape as `y`).
304
- `options::Options=Options()`: The options for the search, such as
305
    which operators to use, evolution hyperparameters, etc.
306
- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names
307
    of each feature in `X`, which will be used during printing of equations.
308
- `display_variable_names::Union{Vector{String}, Nothing}=variable_names`: Names
309
    to use when printing expressions during the search, but not when saving
310
    to an equation file.
311
- `y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing`: The
312
    names of each output feature in `y`, which will be used during printing
313
    of equations.
314
- `parallelism=:multithreading`: What parallelism mode to use.
315
    The options are `:multithreading`, `:multiprocessing`, and `:serial`.
316
    By default, multithreading will be used. Multithreading uses less memory,
317
    but multiprocessing can handle multi-node compute. If using `:multithreading`
318
    mode, the number of threads available to julia are used. If using
319
    `:multiprocessing`, `numprocs` processes will be created dynamically if
320
    `procs` is unset. If you have already allocated processes, pass them
321
    to the `procs` argument and they will be used.
322
    You may also pass a string instead of a symbol, like `"multithreading"`.
323
- `numprocs::Union{Int, Nothing}=nothing`:  The number of processes to use,
324
    if you want `equation_search` to set this up automatically. By default
325
    this will be `4`, but can be any number (you should pick a number <=
326
    the number of cores available).
327
- `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
328
    a distributed run manually with `procs = addprocs()` and `@everywhere`,
329
    pass the `procs` to this keyword argument.
330
- `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
331
    (`parallelism=:multithreading`), and are not passing `procs` manually,
332
    then they will be allocated dynamically using `addprocs`. However,
333
    you may also pass a custom function to use instead of `addprocs`.
334
    This function should take a single positional argument,
335
    which is the number of processes to use, as well as the `lazy` keyword argument.
336
    For example, if set up on a slurm cluster, you could pass
337
    `addprocs_function = addprocs_slurm`, which will set up slurm processes.
338
- `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
339
    flag on Julia processes, recommending garbage collection once a process
340
    is close to the recommended size. This is important for long-running distributed
341
    jobs where each process has an independent memory, and can help avoid
342
    out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
343
- `runtests::Bool=true`: Whether to run (quick) tests before starting the
344
    search, to see if there will be any problems during the equation search
345
    related to the host environment.
346
- `saved_state=nothing`: If you have already
347
    run `equation_search` and want to resume it, pass the state here.
348
    To get this to work, you need to have set return_state=true,
349
    which will cause `equation_search` to return the state. The second
350
    element of the state is the regular return value with the hall of fame.
351
    Note that you cannot change the operators or dataset, but most other options
352
    should be changeable.
353
- `return_state::Union{Bool, Nothing}=nothing`: Whether to return the
354
    state of the search for warm starts. By default this is false.
355
- `loss_type::Type=Nothing`: If you would like to use a different type
356
    for the loss than for the data you passed, specify the type here.
357
    Note that if you pass complex data `::Complex{L}`, then the loss
358
    type will automatically be set to `L`.
359
- `verbosity`: Whether to print debugging statements or not.
360
- `progress`: Whether to use a progress bar output. Only available for
361
    single target output.
362
- `X_units::Union{AbstractVector,Nothing}=nothing`: The units of the dataset,
363
    to be used for dimensional constraints. For example, if `X_units=["kg", "m"]`,
364
    then the first feature will have units of kilograms, and the second will
365
    have units of meters.
366
- `y_units=nothing`: The units of the output, to be used for dimensional constraints.
367
    If `y` is a matrix, then this can be a vector of units, in which case
368
    each element corresponds to each output feature.
369

370
# Returns
371
- `hallOfFame::HallOfFame`: The best equations seen during the search.
372
    hallOfFame.members gives an array of `PopMember` objects, which
373
    have their tree (equation) stored in `.tree`. Their score (loss)
374
    is given in `.score`. The array of `PopMember` objects
375
    is enumerated by size from `1` to `options.maxsize`.
376
"""
377
function equation_search(
609✔
378
    X::AbstractMatrix{T},
379
    y::AbstractMatrix{T};
380
    niterations::Int=10,
381
    weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
382
    options::Options=Options(),
383
    variable_names::Union{AbstractVector{String},Nothing}=nothing,
384
    display_variable_names::Union{AbstractVector{String},Nothing}=variable_names,
385
    y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing,
386
    parallelism=:multithreading,
387
    numprocs::Union{Int,Nothing}=nothing,
388
    procs::Union{Vector{Int},Nothing}=nothing,
389
    addprocs_function::Union{Function,Nothing}=nothing,
390
    heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
391
    runtests::Bool=true,
392
    saved_state=nothing,
393
    return_state::Union{Bool,Nothing,Val}=nothing,
394
    loss_type::Type{L}=Nothing,
395
    verbosity::Union{Integer,Nothing}=nothing,
396
    progress::Union{Bool,Nothing}=nothing,
397
    X_units::Union{AbstractVector,Nothing}=nothing,
398
    y_units=nothing,
399
    extra::NamedTuple=NamedTuple(),
400
    v_dim_out::Val{DIM_OUT}=Val(nothing),
401
    # Deprecated:
402
    multithreaded=nothing,
403
    varMap=nothing,
404
) where {T<:DATA_TYPE,L,DIM_OUT}
405
    if multithreaded !== nothing
364✔
406
        error(
×
407
            "`multithreaded` is deprecated. Use the `parallelism` argument instead. " *
408
            "Choose one of :multithreaded, :multiprocessing, or :serial.",
409
        )
410
    end
411
    variable_names = deprecate_varmap(variable_names, varMap, :equation_search)
304✔
412

413
    if weights !== nothing
304✔
414
        @assert length(weights) == length(y)
18✔
415
        weights = reshape(weights, size(y))
18✔
416
    end
417

418
    datasets = construct_datasets(
360✔
419
        X,
420
        y,
421
        weights,
422
        variable_names,
423
        display_variable_names,
424
        y_variable_names,
425
        X_units,
426
        y_units,
427
        extra,
428
        L,
429
    )
430

431
    return equation_search(
360✔
432
        datasets;
433
        niterations=niterations,
434
        options=options,
435
        parallelism=parallelism,
436
        numprocs=numprocs,
437
        procs=procs,
438
        addprocs_function=addprocs_function,
439
        heap_size_hint_in_bytes=heap_size_hint_in_bytes,
440
        runtests=runtests,
441
        saved_state=saved_state,
442
        return_state=return_state,
443
        verbosity=verbosity,
444
        progress=progress,
445
        v_dim_out=Val(DIM_OUT),
446
    )
447
end
448

449
function equation_search(
×
450
    X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw...
451
) where {T1<:DATA_TYPE,T2<:DATA_TYPE}
452
    U = promote_type(T1, T2)
×
453
    return equation_search(
×
454
        convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw...
455
    )
456
end
457

458
function equation_search(
437✔
459
    X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw...
460
) where {T1<:DATA_TYPE,T2<:DATA_TYPE}
461
    return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1))
287✔
462
end
463

464
function equation_search(dataset::Dataset; kws...)
10✔
465
    return equation_search([dataset]; kws..., v_dim_out=Val(1))
7✔
466
end
467

468
function equation_search(
683✔
469
    datasets::Vector{D};
470
    niterations::Int=10,
471
    options::Options=Options(),
472
    parallelism=:multithreading,
473
    numprocs::Union{Int,Nothing}=nothing,
474
    procs::Union{Vector{Int},Nothing}=nothing,
475
    addprocs_function::Union{Function,Nothing}=nothing,
476
    heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
477
    runtests::Bool=true,
478
    saved_state=nothing,
479
    return_state::Union{Bool,Nothing,Val}=nothing,
480
    verbosity::Union{Int,Nothing}=nothing,
481
    progress::Union{Bool,Nothing}=nothing,
482
    v_dim_out::Val{DIM_OUT}=Val(nothing),
483
) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
484
    concurrency = if parallelism in (:multithreading, "multithreading")
518✔
485
        :multithreading
330✔
486
    elseif parallelism in (:multiprocessing, "multiprocessing")
78✔
487
        :multiprocessing
24✔
488
    elseif parallelism in (:serial, "serial")
18✔
489
        :serial
18✔
490
    else
491
        error(
×
492
            "Invalid parallelism mode: $parallelism. " *
493
            "You must choose one of :multithreading, :multiprocessing, or :serial.",
494
        )
495
        :serial
423✔
496
    end
497
    not_distributed = concurrency in (:multithreading, :serial)
394✔
498
    not_distributed &&
372✔
499
        procs !== nothing &&
500
        error(
501
            "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
502
        )
503
    not_distributed &&
372✔
504
        numprocs !== nothing &&
505
        error(
506
            "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
507
        )
508

509
    _return_state = if return_state isa Val
324✔
510
        first(typeof(return_state).parameters)
12✔
511
    else
512
        if options.return_state === Val(nothing)
312✔
513
            return_state === nothing ? false : return_state
312✔
514
        else
515
            @assert(
×
516
                return_state === nothing,
517
                "You cannot set `return_state` in both the `Options` and in the passed arguments."
518
            )
519
            first(typeof(options.return_state).parameters)
76✔
520
        end
521
    end
522

523
    dim_out = if DIM_OUT === nothing
324✔
524
        length(datasets) > 1 ? 2 : 1
108✔
525
    else
526
        DIM_OUT
252✔
527
    end
528
    _numprocs::Int = if numprocs === nothing
324✔
529
        if procs === nothing
300✔
530
            4
300✔
531
        else
532
            length(procs)
68✔
533
        end
534
    else
535
        if procs === nothing
24✔
536
            numprocs
24✔
537
        else
538
            @assert length(procs) == numprocs
×
539
            numprocs
76✔
540
        end
541
    end
542

543
    _verbosity = if verbosity === nothing && options.verbosity === nothing
324✔
544
        1
72✔
545
    elseif verbosity === nothing && options.verbosity !== nothing
252✔
546
        options.verbosity
78✔
547
    elseif verbosity !== nothing && options.verbosity === nothing
222✔
548
        verbosity
174✔
549
    else
UNCOV
550
        error(
×
551
            "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.",
552
        )
553
        1
215✔
554
    end
555
    _progress::Bool = if progress === nothing && options.progress === nothing
372✔
556
        (_verbosity > 0) && length(datasets) == 1
366✔
557
    elseif progress === nothing && options.progress !== nothing
78✔
558
        options.progress
78✔
559
    elseif progress !== nothing && options.progress === nothing
×
560
        progress
×
561
    else
562
        error(
×
563
            "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.",
564
        )
565
        false
398✔
566
    end
567

568
    _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function
372✔
569

570
    exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing
324✔
571
        heap_size_hint_in_megabytes = floor(
16✔
572
            Int, (
573
                if heap_size_hint_in_bytes === nothing
574
                    (Sys.free_memory() / _numprocs)
16✔
575
                else
576
                    heap_size_hint_in_bytes
16✔
577
                end
578
            ) / 1024^2
579
        )
580
        _verbosity > 0 &&
16✔
581
            heap_size_hint_in_bytes === nothing &&
582
            @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter."
583

584
        `--heap-size=$(heap_size_hint_in_megabytes)M`
16✔
585
    else
586
        ``
604✔
587
    end
588

589
    # Underscores here mean that we have mutated the variable
590
    return _equation_search(
372✔
591
        datasets,
592
        RuntimeOptions(;
593
            niterations=niterations,
594
            total_cycles=options.populations * niterations,
595
            numprocs=_numprocs,
596
            init_procs=procs,
597
            addprocs_function=_addprocs_function,
598
            exeflags=exeflags,
599
            runtests=runtests,
600
            verbosity=_verbosity,
601
            progress=_progress,
602
            parallelism=Val(concurrency),
603
            dim_out=Val(dim_out),
604
            return_state=Val(_return_state),
605
        ),
606
        options,
607
        saved_state,
608
    )
609
end
610

611
@noinline function _equation_search(
372✔
612
    datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state
613
) where {D<:Dataset}
614
    _validate_options(datasets, ropt, options)
372✔
615
    state = _create_workers(datasets, ropt, options)
558✔
616
    _initialize_search!(state, datasets, ropt, options, saved_state)
372✔
617
    _warmup_search!(state, datasets, ropt, options)
372✔
618
    _main_search_loop!(state, datasets, ropt, options)
372✔
619
    _tear_down!(state, ropt, options)
372✔
620
    return _format_output(state, datasets, ropt, options)
372✔
621
end
622

623
function _validate_options(
372✔
624
    datasets::Vector{D}, ropt::RuntimeOptions, options::Options
625
) where {T,L,D<:Dataset{T,L}}
626
    example_dataset = first(datasets)
372✔
627
    nout = length(datasets)
372✔
628
    @assert nout >= 1
372✔
629
    @assert (nout == 1 || ropt.dim_out == 2)
462✔
630
    @assert options.populations >= 1
372✔
631
    if ropt.progress
372✔
632
        @assert(nout == 1, "You cannot display a progress bar for multi-output searches.")
96✔
633
        @assert(ropt.verbosity > 0, "You cannot display a progress bar with `verbosity=0`.")
96✔
634
    end
635
    if options.node_type <: GraphNode && ropt.verbosity > 0
324✔
636
        @warn "The `GraphNode` interface and mutation operators are experimental and will change in future versions."
×
637
    end
638
    if ropt.runtests
372✔
639
        test_option_configuration(ropt.parallelism, datasets, options, ropt.verbosity)
558✔
640
        test_dataset_configuration(example_dataset, options, ropt.verbosity)
558✔
641
    end
642
    for dataset in datasets
372✔
643
        update_baseline_loss!(dataset, options)
548✔
644
    end
467✔
645
    if options.define_helper_functions
372✔
646
        set_default_variable_names!(first(datasets).variable_names)
180✔
647
    end
648
    if options.seed !== nothing
372✔
649
        seed!(options.seed)
77✔
650
    end
651
    return nothing
372✔
652
end
653
@stable default_mode = "disable" function _create_workers(
372✔
654
    datasets::Vector{D}, ropt::RuntimeOptions, options::Options
655
) where {T,L,D<:Dataset{T,L}}
656
    stdin_reader = watch_stream(stdin)
372✔
657

658
    record = RecordType()
372✔
659
    @recorder record["options"] = "$(options)"
372✔
660

661
    nout = length(datasets)
372✔
662
    example_dataset = first(datasets)
372✔
663
    example_ex = create_expression(zero(T), options, example_dataset)
348✔
664
    NT = typeof(example_ex)
324✔
665
    PopType = Population{T,L,NT}
324✔
666
    HallOfFameType = HallOfFame{T,L,NT}
324✔
667
    WorkerOutputType = get_worker_output_type(
510✔
668
        Val(ropt.parallelism), PopType, HallOfFameType
669
    )
670
    ChannelType = ropt.parallelism == :multiprocessing ? RemoteChannel : Channel
372✔
671

672
    # Pointers to populations on each worker:
673
    worker_output = Vector{WorkerOutputType}[WorkerOutputType[] for j in 1:nout]
486✔
674
    # Initialize storage for workers
675
    tasks = [Task[] for j in 1:nout]
372✔
676
    # Set up a channel to send finished populations back to head node
677
    channels = [[ChannelType(1) for i in 1:(options.populations)] for j in 1:nout]
372✔
678
    (procs, we_created_procs) = if ropt.parallelism == :multiprocessing
324✔
679
        configure_workers(;
24✔
680
            procs=ropt.init_procs,
681
            ropt.numprocs,
682
            ropt.addprocs_function,
683
            options,
684
            project_path=splitdir(Pkg.project().path)[1],
685
            file=@__FILE__,
686
            ropt.exeflags,
687
            ropt.verbosity,
688
            example_dataset,
689
            ropt.runtests,
690
        )
691
    else
692
        Int[], false
372✔
693
    end
694
    # Get the next worker process to give a job:
695
    worker_assignment = WorkerAssignments()
372✔
696
    # Randomly order which order to check populations:
697
    # This is done so that we do work on all nout equally.
698
    task_order = [(j, i) for j in 1:nout for i in 1:(options.populations)]
410✔
699
    shuffle!(task_order)
372✔
700

701
    # Persistent storage of last-saved population for final return:
702
    last_pops = init_dummy_pops(options.populations, datasets, options)
558✔
703
    # Best 10 members from each population for migration:
704
    best_sub_pops = init_dummy_pops(options.populations, datasets, options)
558✔
705
    # TODO: Should really be one per population too.
706
    all_running_search_statistics = [
372✔
707
        RunningSearchStatistics(; options=options) for j in 1:nout
708
    ]
709
    # Records the number of evaluations:
710
    # Real numbers indicate use of batching.
711
    num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout]
372✔
712

713
    halls_of_fame = Vector{HallOfFameType}(undef, nout)
410✔
714

715
    cycles_remaining = [ropt.total_cycles for j in 1:nout]
644✔
716
    cur_maxsizes = [
372✔
717
        get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j])
718
        for j in 1:nout
719
    ]
720

721
    return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(;
372✔
722
        procs=procs,
723
        we_created_procs=we_created_procs,
724
        worker_output=worker_output,
725
        tasks=tasks,
726
        channels=channels,
727
        worker_assignment=worker_assignment,
728
        task_order=task_order,
729
        halls_of_fame=halls_of_fame,
730
        last_pops=last_pops,
731
        best_sub_pops=best_sub_pops,
732
        all_running_search_statistics=all_running_search_statistics,
733
        num_evals=num_evals,
734
        cycles_remaining=cycles_remaining,
735
        cur_maxsizes=cur_maxsizes,
736
        stdin_reader=stdin_reader,
737
        record=Ref(record),
738
    )
739
end
740
function _initialize_search!(
372✔
741
    state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options, saved_state
742
) where {T,L,N}
743
    nout = length(datasets)
372✔
744

745
    init_hall_of_fame = load_saved_hall_of_fame(saved_state)
344✔
746
    if init_hall_of_fame === nothing
320✔
747
        for j in 1:nout
344✔
748
            state.halls_of_fame[j] = HallOfFame(options, datasets[j])
487✔
749
        end
315✔
750
    else
751
        # Recompute losses for the hall of fame, in
752
        # case the dataset changed:
753
        for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame)
152✔
754
            hof = strip_metadata(init_hall_of_fame[j], options, datasets[j])
252✔
755
            for member in hof.members[hof.exists]
170✔
756
                score, result_loss = score_func(datasets[j], member, options)
3,113✔
757
                member.score = score
2,975✔
758
                member.loss = result_loss
3,002✔
759
            end
1,675✔
760
            state.halls_of_fame[j] = hof
177✔
761
        end
185✔
762
    end
763

764
    for j in 1:nout, i in 1:(options.populations)
534✔
765
        worker_idx = assign_next_worker!(
3,186✔
766
            state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs
767
        )
768
        saved_pop = load_saved_population(saved_state; out=j, pop=i)
4,568✔
769
        new_pop =
3,186✔
770
            if saved_pop !== nothing && length(saved_pop.members) == options.population_size
771
                _saved_pop = strip_metadata(saved_pop, options, datasets[j])
844✔
772
                ## Update losses:
773
                for member in _saved_pop.members
557✔
774
                    score, result_loss = score_func(datasets[j], member, options)
103,004✔
775
                    member.score = score
94,386✔
776
                    member.loss = result_loss
94,485✔
777
                end
52,192✔
778
                copy_pop = copy(_saved_pop)
587✔
779
                @sr_spawner(
557✔
780
                    begin
781
                        (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0)
557✔
782
                    end,
783
                    parallelism = ropt.parallelism,
784
                    worker_idx = worker_idx
785
                )
786
            else
787
                if saved_pop !== nothing && ropt.verbosity > 0
2,629✔
788
                    @warn "Recreating population (output=$(j), population=$(i)), as the saved one doesn't have the correct number of members."
30✔
789
                end
790
                @sr_spawner(
3,517✔
791
                    begin
792
                        (
2,843✔
793
                            Population(
794
                                datasets[j];
795
                                population_size=options.population_size,
796
                                nlength=3,
797
                                options=options,
798
                                nfeatures=datasets[j].nfeatures,
799
                            ),
800
                            HallOfFame(options, datasets[j]),
801
                            RecordType(),
802
                            Float64(options.population_size),
803
                        )
804
                    end,
805
                    parallelism = ropt.parallelism,
806
                    worker_idx = worker_idx
807
                )
808
                # This involves population_size evaluations, on the full dataset:
809
            end
810
        push!(state.worker_output[j], new_pop)
3,331✔
811
    end
2,855✔
812
    return nothing
372✔
813
end
814
function _warmup_search!(
372✔
815
    state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options
816
) where {T,L,N}
817
    nout = length(datasets)
372✔
818
    for j in 1:nout, i in 1:(options.populations)
534✔
819
        dataset = datasets[j]
3,312✔
820
        running_search_statistics = state.all_running_search_statistics[j]
3,312✔
821
        cur_maxsize = state.cur_maxsizes[j]
3,312✔
822
        @recorder state.record[]["out$(j)_pop$(i)"] = RecordType()
3,312✔
823
        worker_idx = assign_next_worker!(
3,168✔
824
            state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs
825
        )
826

827
        # TODO - why is this needed??
828
        # Multi-threaded doesn't like to fetch within a new task:
829
        c_rss = deepcopy(running_search_statistics)
3,312✔
830
        last_pop = state.worker_output[j][i]
3,312✔
831
        updated_pop = @sr_spawner(
3,312✔
832
            begin
833
                in_pop = first(
3,406✔
834
                    extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N})
835
                )
836
                _dispatch_s_r_cycle(
3,253✔
837
                    in_pop,
838
                    dataset,
839
                    options;
840
                    pop=i,
841
                    out=j,
842
                    iteration=0,
843
                    ropt.verbosity,
844
                    cur_maxsize,
845
                    running_search_statistics=c_rss,
846
                )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}}
847
            end,
848
            parallelism = ropt.parallelism,
849
            worker_idx = worker_idx
850
        )
851
        state.worker_output[j][i] = updated_pop
3,331✔
852
    end
2,855✔
853
    return nothing
372✔
854
end
855
function _main_search_loop!(
372✔
856
    state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options
857
) where {T,L,N}
858
    ropt.verbosity > 0 && @info "Started!"
372✔
859
    nout = length(datasets)
372✔
860
    start_time = time()
372✔
861
    if ropt.progress
372✔
862
        #TODO: need to iterate this on the max cycles remaining!
863
        sum_cycle_remaining = sum(state.cycles_remaining)
192✔
864
        progress_bar = WrappedProgressBar(
96✔
865
            1:sum_cycle_remaining; width=options.terminal_width
866
        )
867
    end
868
    last_print_time = time()
372✔
869
    last_speed_recording_time = time()
372✔
870
    num_evals_last = sum(sum, state.num_evals)
372✔
871
    num_evals_since_last = sum(sum, state.num_evals) - num_evals_last  # i.e., start at 0
372✔
872
    print_every_n_seconds = 5
324✔
873
    equation_speed = Float32[]
372✔
874

875
    if ropt.parallelism in (:multiprocessing, :multithreading)
372✔
876
        for j in 1:nout, i in 1:(options.populations)
510✔
877
            # Start listening for each population to finish:
878
            t = @async put!(state.channels[j][i], fetch(state.worker_output[j][i]))
5,180✔
879
            push!(state.tasks[j], t)
3,127✔
880
        end
2,685✔
881
    end
882
    kappa = 0
324✔
883
    resource_monitor = ResourceMonitor(;
372✔
884
        absolute_start_time=time(),
885
        # Storing n times as many monitoring intervals as populations seems like it will
886
        # help get accurate resource estimates:
887
        num_intervals_to_store=options.populations * 100 * nout,
888
    )
889
    while sum(state.cycles_remaining) > 0
6,109,880✔
890
        kappa += 1
3,054,820✔
891
        if kappa > options.populations * nout
3,054,820✔
892
            kappa = 1
385,962✔
893
        end
894
        # nout, populations:
895
        j, i = state.task_order[kappa]
3,054,820✔
896

897
        # Check if error on population:
898
        if ropt.parallelism in (:multiprocessing, :multithreading)
3,054,820✔
899
            if istaskfailed(state.tasks[j][i])
3,054,700✔
900
                fetch(state.tasks[j][i])
×
901
                error("Task failed for population")
×
902
            end
903
        end
904
        # Non-blocking check if a population is ready:
905
        population_ready = if ropt.parallelism in (:multiprocessing, :multithreading)
3,054,619✔
906
            # TODO: Implement type assertions based on parallelism.
907
            isready(state.channels[j][i])
3,054,700✔
908
        else
909
            true
95,467✔
910
        end
911
        # Don't start more if this output has finished its cycles:
912
        # TODO - this might skip extra cycles?
913
        population_ready &= (state.cycles_remaining[j] > 0)
3,054,820✔
914
        if population_ready
3,054,820✔
915
            start_work_monitor!(resource_monitor)
12,853✔
916
            # Take the fetch operation from the channel since its ready
917
            (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in
12,692✔
918
                (
919
                :multiprocessing, :multithreading
920
            )
921
                take!(
12,733✔
922
                    state.channels[j][i]
923
                )
924
            else
925
                state.worker_output[j][i]
12,853✔
926
            end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}}
927
            state.last_pops[j][i] = copy(cur_pop)
12,853✔
928
            state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn)
19,232✔
929
            @recorder state.record[] = recursive_merge(state.record[], cur_record)
12,853✔
930
            state.num_evals[j][i] += cur_num_evals
12,853✔
931
            dataset = datasets[j]
12,853✔
932
            cur_maxsize = state.cur_maxsizes[j]
12,853✔
933

934
            for member in cur_pop.members
12,853✔
935
                size = compute_complexity(member, options)
1,974,948✔
936
                update_frequencies!(state.all_running_search_statistics[j]; size)
2,318,293✔
937
            end
1,277,931✔
938
            #! format: off
939
            update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options)
19,232✔
940
            update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options)
19,232✔
941
            #! format: on
942

943
            # Dominating pareto curve - must be better than all simpler equations
944
            dominating = calculate_pareto_frontier(state.halls_of_fame[j])
12,853✔
945

946
            if options.save_to_file
12,853✔
947
                save_to_file(dominating, nout, j, dataset, options)
18,262✔
948
            end
949
            ###################################################################
950
            # Migration #######################################################
951
            if options.migration
12,853✔
952
                best_of_each = Population([
12,853✔
953
                    member for pop in state.best_sub_pops[j] for member in pop.members
954
                ])
955
                migrate!(
19,232✔
956
                    best_of_each.members => cur_pop, options; frac=options.fraction_replaced
957
                )
958
            end
959
            if options.hof_migration && length(dominating) > 0
12,853✔
960
                migrate!(dominating => cur_pop, options; frac=options.fraction_replaced_hof)
19,232✔
961
            end
962
            ###################################################################
963

964
            state.cycles_remaining[j] -= 1
12,853✔
965
            if state.cycles_remaining[j] == 0
12,853✔
966
                break
174✔
967
            end
968
            worker_idx = assign_next_worker!(
12,612✔
969
                state.worker_assignment;
970
                out=j,
971
                pop=i,
972
                parallelism=ropt.parallelism,
973
                state.procs,
974
            )
975
            iteration = if options.use_recorder
12,679✔
976
                key = "out$(j)_pop$(i)"
54✔
977
                find_iteration_from_record(key, state.record[]) + 1
81✔
978
            else
979
                0
25,167✔
980
            end
981

982
            c_rss = deepcopy(state.all_running_search_statistics[j])
12,679✔
983
            in_pop = copy(cur_pop::Population{T,L,N})
14,639✔
984
            state.worker_output[j][i] = @sr_spawner(
12,736✔
985
                begin
986
                    _dispatch_s_r_cycle(
12,583✔
987
                        in_pop,
988
                        dataset,
989
                        options;
990
                        pop=i,
991
                        out=j,
992
                        iteration,
993
                        ropt.verbosity,
994
                        cur_maxsize,
995
                        running_search_statistics=c_rss,
996
                    )
997
                end,
998
                parallelism = ropt.parallelism,
999
                worker_idx = worker_idx
1000
            )
1001
            if ropt.parallelism in (:multiprocessing, :multithreading)
12,679✔
1002
                state.tasks[j][i] = @async put!(
20,810✔
1003
                    state.channels[j][i], fetch(state.worker_output[j][i])
1004
                )
1005
            end
1006

1007
            state.cur_maxsizes[j] = get_cur_maxsize(;
12,679✔
1008
                options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j]
1009
            )
1010
            stop_work_monitor!(resource_monitor)
12,679✔
1011
            move_window!(state.all_running_search_statistics[j])
12,679✔
1012
            if ropt.progress
12,679✔
1013
                head_node_occupation = estimate_work_fraction(resource_monitor)
10,681✔
1014
                update_progress_bar!(
19,611✔
1015
                    progress_bar,
1016
                    only(state.halls_of_fame),
1017
                    only(datasets),
1018
                    options,
1019
                    equation_speed,
1020
                    head_node_occupation,
1021
                    ropt.parallelism,
1022
                )
1023
            end
1024
        end
1025
        sleep(1e-6)
3,054,646✔
1026

1027
        ################################################################
1028
        ## Search statistics
1029
        elapsed_since_speed_recording = time() - last_speed_recording_time
3,054,646✔
1030
        if elapsed_since_speed_recording > 1.0
3,054,646✔
1031
            num_evals_since_last, num_evals_last = let s = sum(sum, state.num_evals)
12,329✔
1032
                s - num_evals_last, s
12,329✔
1033
            end
1034
            current_speed = num_evals_since_last / elapsed_since_speed_recording
12,329✔
1035
            push!(equation_speed, current_speed)
12,329✔
1036
            average_over_m_measurements = 20 # 20 second running average
12,317✔
1037
            if length(equation_speed) > average_over_m_measurements
12,329✔
1038
                deleteat!(equation_speed, 1)
10,733✔
1039
            end
1040
            last_speed_recording_time = time()
12,329✔
1041
        end
1042
        ################################################################
1043

1044
        ################################################################
1045
        ## Printing code
1046
        elapsed = time() - last_print_time
3,054,646✔
1047
        # Update if time has passed
1048
        if elapsed > print_every_n_seconds
3,054,646✔
1049
            if ropt.verbosity > 0 && !ropt.progress && length(equation_speed) > 0
2,753✔
1050

1051
                # Dominating pareto curve - must be better than all simpler equations
1052
                head_node_occupation = estimate_work_fraction(resource_monitor)
1,912✔
1053
                print_search_state(
1,912✔
1054
                    state.halls_of_fame,
1055
                    datasets;
1056
                    options,
1057
                    equation_speed,
1058
                    ropt.total_cycles,
1059
                    state.cycles_remaining,
1060
                    head_node_occupation,
1061
                    parallelism=ropt.parallelism,
1062
                    width=options.terminal_width,
1063
                )
1064
            end
1065
            last_print_time = time()
2,753✔
1066
        end
1067
        ################################################################
1068

1069
        ################################################################
1070
        ## Early stopping code
1071
        if any((
15,272,978✔
1072
            check_for_loss_threshold(state.halls_of_fame, options),
1073
            check_for_user_quit(state.stdin_reader),
1074
            check_for_timeout(start_time, options),
1075
            check_max_evals(state.num_evals, options),
1076
        ))
1077
            break
78✔
1078
        end
1079
        ################################################################
1080
    end
3,000,243✔
1081
    return nothing
372✔
1082
end
1083
function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options)
372✔
1084
    close_reader!(state.stdin_reader)
372✔
1085
    # Safely close all processes or threads
1086
    if ropt.parallelism == :multiprocessing
316✔
1087
        state.we_created_procs && rmprocs(state.procs)
24✔
1088
    elseif ropt.parallelism == :multithreading
348✔
1089
        nout = length(state.worker_output)
330✔
1090
        for j in 1:nout, i in eachindex(state.worker_output[j])
476✔
1091
            wait(state.worker_output[j][i])
2,778✔
1092
        end
2,390✔
1093
    end
1094
    @recorder json3_write(state.record[], options.recorder_file)
372✔
1095
    return nothing
372✔
1096
end
1097
function _format_output(
315✔
1098
    state::SearchState, datasets, ropt::RuntimeOptions, options::Options
1099
)
1100
    nout = length(datasets)
372✔
1101
    out_hof = if ropt.dim_out == 1
316✔
1102
        embed_metadata(only(state.halls_of_fame), options, only(datasets))
344✔
1103
    else
1104
        map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout)
402✔
1105
    end
1106
    if ropt.return_state
316✔
1107
        return (
144✔
1108
            map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout),
204✔
1109
            out_hof,
1110
        )
1111
    else
1112
        return out_hof
220✔
1113
    end
1114
end
1115

1116
@stable default_mode = "disable" function _dispatch_s_r_cycle(
29,208✔
1117
    in_pop::Population{T,L,N},
1118
    dataset::Dataset,
1119
    options::Options;
1120
    pop::Int,
1121
    out::Int,
1122
    iteration::Int,
1123
    verbosity,
1124
    cur_maxsize::Int,
1125
    running_search_statistics,
1126
) where {T,L,N}
1127
    record = RecordType()
18,662✔
1128
    @recorder record["out$(out)_pop$(pop)"] = RecordType(
15,779✔
1129
        "iteration$(iteration)" => record_population(in_pop, options)
1130
    )
1131
    num_evals = 0.0
15,425✔
1132
    normalize_frequencies!(running_search_statistics)
15,778✔
1133
    out_pop, best_seen, evals_from_cycle = s_r_cycle(
23,551✔
1134
        dataset,
1135
        in_pop,
1136
        options.ncycles_per_iteration,
1137
        cur_maxsize,
1138
        running_search_statistics;
1139
        verbosity=verbosity,
1140
        options=options,
1141
        record=record,
1142
    )
1143
    num_evals += evals_from_cycle
15,775✔
1144
    out_pop, evals_from_optimize = optimize_and_simplify_population(
23,546✔
1145
        dataset, out_pop, options, cur_maxsize, record
1146
    )
1147
    num_evals += evals_from_optimize
15,717✔
1148
    if options.batching
15,717✔
1149
        for i_member in 1:(options.maxsize + MAX_DEGREE)
1,191✔
1150
            score, result_loss = score_func(dataset, best_seen.members[i_member], options)
33,452✔
1151
            best_seen.members[i_member].score = score
26,466✔
1152
            best_seen.members[i_member].loss = result_loss
26,466✔
1153
            num_evals += 1
22,206✔
1154
        end
29,283✔
1155
    end
1156
    return (out_pop, best_seen, record, num_evals)
15,717✔
1157
end
1158

1159
include("MLJInterface.jl")
1160
using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor
1161

1162
function __init__()
80✔
1163
    @require_extensions
22✔
1164
end
1165

1166
macro ignore(args...) end
21✔
1167
# Hack to get static analysis to work from within tests:
1168
@ignore include("../test/runtests.jl")
1169

1170
include("precompile.jl")
1171
redirect_stdout(devnull) do
1172
    redirect_stderr(devnull) do
21✔
1173
        do_precompilation(Val(:precompile))
21✔
1174
    end
1175
end
1176

1177
end #module SR
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc