• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.96
/src/MLJInterface.jl
1
module MLJInterfaceModule
162✔
2

3
using Optim: Optim
4
using LineSearches: LineSearches
5
using MLJModelInterface: MLJModelInterface as MMI
6
using ADTypes: AbstractADType
7
using DynamicExpressions:
8
    eval_tree_array,
9
    string_tree,
10
    AbstractExpressionNode,
11
    AbstractExpression,
12
    Node,
13
    Expression,
14
    default_node_type,
15
    get_tree
16
using DynamicQuantities:
17
    QuantityArray,
18
    UnionAbstractQuantity,
19
    AbstractDimensions,
20
    SymbolicDimensions,
21
    Quantity,
22
    DEFAULT_DIM_BASE_TYPE,
23
    ustrip,
24
    dimension
25
using LossFunctions: SupervisedLoss
26
using Compat: allequal, stack
27
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
28
using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
29
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
30
using ..ComplexityModule: compute_complexity
31
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
32
using ..UtilsModule: subscriptify, @ignore
33

34
import ..equation_search
35

36
abstract type AbstractSRRegressor <: MMI.Deterministic end
37

38
# For static analysis tools:
39
@ignore mutable struct SRRegressor <: AbstractSRRegressor
40
    selection_method::Function
41
end
42
@ignore mutable struct MultitargetSRRegressor <: AbstractSRRegressor
43
    selection_method::Function
44
end
45

46
# TODO: To reduce code re-use, we could forward these defaults from
47
#       `equation_search`, similar to what we do for `Options`.
48

49
"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
50
function modelexpr(model_name::Symbol)
42✔
51
    struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <:
162✔
52
                                 AbstractSRRegressor
53
        niterations::Int = 10
54
        parallelism::Symbol = :multithreading
55
        numprocs::Union{Int,Nothing} = nothing
56
        procs::Union{Vector{Int},Nothing} = nothing
57
        addprocs_function::Union{Function,Nothing} = nothing
58
        heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
59
        runtests::Bool = true
60
        loss_type::L = Nothing
61
        selection_method::Function = choose_best
62
        dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
63
    end)
64
    # TODO: store `procs` from initial run if parallelism is `:multiprocessing`
65
    fields = last(last(struct_def.args).args).args
42✔
66

67
    # Add everything from `Options` constructor directly to struct:
68
    for (i, option) in enumerate(DEFAULT_OPTIONS)
54✔
69
        insert!(fields, i, Expr(:(=), option.args...))
3,354✔
70
    end
5,004✔
71

72
    # We also need to create the `get_options` function, based on this:
73
    constructor = :(Options(;))
42✔
74
    constructor_fields = last(constructor.args).args
42✔
75
    for option in DEFAULT_OPTIONS
42✔
76
        symb = getsymb(first(option.args))
3,924✔
77
        push!(constructor_fields, Expr(:kw, symb, Expr(:(.), :m, Core.QuoteNode(symb))))
5,466✔
78
    end
2,526✔
79

80
    return quote
42✔
81
        $struct_def
82
        function get_options(m::$(model_name))
148✔
83
            return $constructor
148✔
84
        end
85
    end
86
end
87
function getsymb(ex::Symbol)
240✔
88
    return ex
240✔
89
end
90
function getsymb(ex::Expr)
2,604✔
91
    for arg in ex.args
2,604✔
92
        isa(arg, Symbol) && return arg
2,604✔
93
        s = getsymb(arg)
×
94
        isa(s, Symbol) && return s
×
95
    end
96
    return nothing
×
97
end
98

99
"""Get an equivalent `Options()` object for a particular regressor."""
100
function get_options(::AbstractSRRegressor) end
×
101

102
eval(modelexpr(:SRRegressor))
103
eval(modelexpr(:MultitargetSRRegressor))
104

105
# Cleaning already taken care of by `Options` and `equation_search`
106
function full_report(
501✔
107
    m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true)
108
) where {with_strings}
109
    _, hof = fitresult.state
299✔
110
    # TODO: Adjust baseline loss
111
    formatted = format_hall_of_fame(hof, fitresult.options)
293✔
112
    equation_strings = if with_strings
257✔
113
        get_equation_strings_for(
141✔
114
            m, formatted.trees, fitresult.options, fitresult.variable_names
115
        )
116
    else
117
        nothing
169✔
118
    end
119
    best_idx = dispatch_selection_for(
257✔
120
        m, formatted.trees, formatted.losses, formatted.scores, formatted.complexities
121
    )
122
    return (;
257✔
123
        best_idx=best_idx,
124
        equations=formatted.trees,
125
        equation_strings=equation_strings,
126
        losses=formatted.losses,
127
        complexities=formatted.complexities,
128
        scores=formatted.scores,
129
    )
130
end
131

132
MMI.clean!(::AbstractSRRegressor) = ""
113✔
133

134
# TODO: Enable `verbosity` being passed to `equation_search`
135
function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing)
210✔
136
    return MMI.update(m, verbosity, nothing, nothing, X, y, w)
224✔
137
end
138
function MMI.update(
146✔
139
    m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing
140
)
141
    options = old_fitresult === nothing ? get_options(m) : old_fitresult.options
148✔
142
    return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing)
136✔
143
end
144
function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes)
147✔
145
    if isnothing(classes) && MMI.istable(X) && haskey(X, :classes)
184✔
146
        if !(X isa NamedTuple)
12✔
NEW
147
            error("Classes can only be specified with named tuples.")
×
148
        end
149
        new_X = Base.structdiff(X, (; X.classes))
12✔
150
        new_classes = X.classes
12✔
151
        return _update(
12✔
152
            m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes
153
        )
154
    end
155
    if !isnothing(old_fitresult)
136✔
156
        @assert(
12✔
157
            old_fitresult.has_classes == !isnothing(classes),
158
            "If the first fit used classes, the second fit must also use classes."
159
        )
160
    end
161
    # To speed up iterative fits, we cache the types:
162
    types = if isnothing(old_fitresult)
136✔
163
        (;
124✔
164
            T=Any,
165
            X_t=Any,
166
            y_t=Any,
167
            w_t=Any,
168
            state=Any,
169
            X_units=Any,
170
            y_units=Any,
171
            X_units_clean=Any,
172
            y_units_clean=Any,
173
        )
174
    else
175
        old_fitresult.types
52✔
176
    end
177
    X_t::types.X_t, variable_names, X_units::types.X_units = get_matrix_and_info(
144✔
178
        X, m.dimensions_type
179
    )
180
    y_t::types.y_t, y_variable_names, y_units::types.y_units = format_input_for(
139✔
181
        m, y, m.dimensions_type
182
    )
183
    X_units_clean::types.X_units_clean = clean_units(X_units)
120✔
184
    y_units_clean::types.y_units_clean = clean_units(y_units)
129✔
185
    w_t::types.w_t = if w !== nothing && isa(m, MultitargetSRRegressor)
114✔
186
        @assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.")
24✔
187
        repeat(w', size(y_t, 1))
24✔
188
    else
189
        w
98✔
190
    end
191
    search_state::types.state = equation_search(
114✔
192
        X_t,
193
        y_t;
194
        niterations=m.niterations,
195
        weights=w_t,
196
        variable_names=variable_names,
197
        options=options,
198
        parallelism=m.parallelism,
199
        numprocs=m.numprocs,
200
        procs=m.procs,
201
        addprocs_function=m.addprocs_function,
202
        heap_size_hint_in_bytes=m.heap_size_hint_in_bytes,
203
        runtests=m.runtests,
204
        saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state),
205
        return_state=true,
206
        loss_type=m.loss_type,
207
        X_units=X_units_clean,
208
        y_units=y_units_clean,
209
        verbosity=verbosity,
210
        extra=isnothing(classes) ? (;) : (; classes),
211
        # Help out with inference:
212
        v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2),
213
    )
214
    fitresult = (;
108✔
215
        state=search_state,
216
        num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1),
217
        options=options,
218
        variable_names=variable_names,
219
        y_variable_names=y_variable_names,
220
        y_is_table=MMI.istable(y),
221
        has_classes=!isnothing(classes),
222
        X_units=X_units_clean,
223
        y_units=y_units_clean,
224
        types=(
225
            T=hof_eltype(search_state[2]),
226
            X_t=typeof(X_t),
227
            y_t=typeof(y_t),
228
            w_t=typeof(w_t),
229
            state=typeof(search_state),
230
            X_units=typeof(X_units),
231
            y_units=typeof(y_units),
232
            X_units_clean=typeof(X_units_clean),
233
            y_units_clean=typeof(y_units_clean),
234
        ),
235
    )::(old_fitresult === nothing ? Any : typeof(old_fitresult))
236
    return (fitresult, nothing, full_report(m, fitresult))
108✔
237
end
238
hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T
97✔
239
hof_eltype(::Type{V}) where {V<:Vector} = hof_eltype(eltype(V))
48✔
240
hof_eltype(h) = hof_eltype(typeof(h))
108✔
241

242
function clean_units(units)
51✔
243
    !isa(units, AbstractDimensions) && error("Unexpected units.")
60✔
244
    iszero(units) && return nothing
60✔
245
    return units
12✔
246
end
247
function clean_units(units::Vector)
233✔
248
    !all(Base.Fix2(isa, AbstractDimensions), units) && error("Unexpected units.")
358✔
249
    all(iszero, units) && return nothing
287✔
250
    return units
36✔
251
end
252

253
function get_matrix_and_info(X, ::Type{D}) where {D}
301✔
254
    sch = MMI.istable(X) ? MMI.schema(X) : nothing
315✔
255
    Xm_t = MMI.matrix(X; transpose=true)
331✔
256
    colnames = if sch === nothing
315✔
257
        [map(i -> "x$(subscriptify(i))", axes(Xm_t, 1))...]
376✔
258
    else
259
        [string.(sch.names)...]
506✔
260
    end
261
    D_promoted = get_dimensions_type(Xm_t, D)
315✔
262
    Xm_t_strip, X_units = unwrap_units_single(Xm_t, D_promoted)
315✔
263
    return Xm_t_strip, colnames, X_units
315✔
264
end
265

266
function format_input_for(::SRRegressor, y, ::Type{D}) where {D}
70✔
267
    @assert(
82✔
268
        !(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)),
269
        "For multi-output regression, please use `MultitargetSRRegressor`."
270
    )
271
    y_t = vec(y)
60✔
272
    colnames = nothing
60✔
273
    D_promoted = get_dimensions_type(y_t, D)
60✔
274
    y_t_strip, y_units = unwrap_units_single(y_t, D_promoted)
60✔
275
    return y_t_strip, colnames, y_units
60✔
276
end
277
function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D}
56✔
278
    @assert(
65✔
279
        MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1),
280
        "For single-output regression, please use `SRRegressor`."
281
    )
282
    return get_matrix_and_info(y, D)
57✔
283
end
284
function validate_variable_names(variable_names, fitresult)
85✔
285
    @assert(
112✔
286
        variable_names == fitresult.variable_names,
287
        "Variable names do not match fitted regressor."
288
    )
289
    return nothing
90✔
290
end
291
function validate_units(X_units, old_X_units)
77✔
292
    @assert(
90✔
293
        all(X_units .== old_X_units),
294
        "Units of new data do not match units of fitted regressor."
295
    )
296
    return nothing
90✔
297
end
298

299
# TODO: Test whether this conversion poses any issues in data normalization...
300
function dimension_with_fallback(q::UnionAbstractQuantity{T}, ::Type{D}) where {T,D}
6,692✔
301
    return dimension(convert(Quantity{T,D}, q))::D
8,028✔
302
end
303
function dimension_with_fallback(_, ::Type{D}) where {D}
49,086✔
304
    return D()
58,638✔
305
end
306
function prediction_warn()
24✔
307
    @warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction."
24✔
308
end
309

310
wrap_units(v, ::Nothing, ::Integer) = v
138✔
311
wrap_units(v, ::Nothing, ::Nothing) = v
48✔
312
wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v)
48✔
313
wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v)
48✔
314

315
function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T}
12✔
316
    prediction_warn()
12✔
317
    out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T))
384✔
318
    return wrap_units(out, fitresult.y_units, nothing)
12✔
319
end
320
function prediction_fallback(
12✔
321
    ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype
322
) where {T}
323
    prediction_warn()
12✔
324
    out_cols = [
12✔
325
        wrap_units(
326
            fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i
327
        ) for i in 1:(fitresult.num_targets)
328
    ]
329
    out_matrix = hcat(out_cols...)
12✔
330
    if !fitresult.y_is_table
12✔
331
        return out_matrix
12✔
332
    else
333
        return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
×
334
    end
335
end
336

337
compat_ustrip(A::QuantityArray) = ustrip(A)
6✔
338
compat_ustrip(A) = ustrip.(A)
940✔
339

340
"""
341
    unwrap_units_single(::AbstractArray, ::Type{<:AbstractDimensions})
342

343
Remove units from some features in a matrix, and return, as a tuple,
344
(1) the matrix with stripped units, and (2) the dimensions for those features.
345
"""
346
function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D}
333✔
347
    dims = D[dimension_with_fallback(first(row), D) for row in eachrow(A)]
886✔
348
    @inbounds for (i, row) in enumerate(eachrow(A))
666✔
349
        all(xi -> dimension_with_fallback(xi, D) == dims[i], row) ||
56,632✔
350
            error("Inconsistent units in feature $i of matrix.")
351
    end
1,210✔
352
    return stack(compat_ustrip, eachrow(A); dims=1)::AbstractMatrix, dims
333✔
353
end
354
function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D}
51✔
355
    dims = dimension_with_fallback(first(v), D)
60✔
356
    all(xi -> dimension_with_fallback(xi, D) == dims, v) ||
10,122✔
357
        error("Inconsistent units in vector.")
358
    return compat_ustrip(v)::AbstractVector, dims
60✔
359
end
360

361
function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
12✔
362
    report = full_report(m, fitresult)
13✔
363
    return (;
12✔
364
        best_idx=report.best_idx,
365
        equations=report.equations,
366
        equation_strings=report.equation_strings,
367
    )
368
end
369

370
function eval_tree_mlj(
162✔
371
    tree::AbstractExpression,
372
    X_t,
373
    classes,
374
    m::AbstractSRRegressor,
375
    ::Type{T},
376
    fitresult,
377
    i,
378
    prototype,
379
) where {T}
380
    out, completed = if isnothing(classes)
162✔
381
        eval_tree_array(tree, X_t, fitresult.options)
225✔
382
    else
383
        eval_tree_array(tree, X_t, classes, fitresult.options)
168✔
384
    end
385
    if completed
162✔
386
        return wrap_units(out, fitresult.y_units, i)
162✔
387
    else
388
        return prediction_fallback(T, m, X_t, fitresult, prototype)
×
389
    end
390
end
391

392
function MMI.predict(
262✔
393
    m::M, fitresult, Xnew; idx=nothing, classes=nothing
394
) where {M<:AbstractSRRegressor}
395
    return _predict(m, fitresult, Xnew, idx, classes)
162✔
396
end
397
function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor}
188✔
398
    if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data))
195✔
399
        @assert(
58✔
400
            haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2,
401
            "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`."
402
        )
403
        return _predict(m, fitresult, Xnew.data, Xnew.idx, classes)
42✔
404
    end
405
    if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes)
137✔
406
        if !(Xnew isa NamedTuple)
12✔
NEW
407
            error("Classes can only be specified with named tuples.")
×
408
        end
409
        Xnew2 = Base.structdiff(Xnew, (; Xnew.classes))
12✔
410
        return _predict(m, fitresult, Xnew2, idx, Xnew.classes)
12✔
411
    end
412

413
    if fitresult.has_classes
125✔
414
        @assert(
12✔
415
            !isnothing(classes),
416
            "Classes must be specified if the model was fit with classes."
417
        )
418
    end
419

420
    params = full_report(m, fitresult; v_with_strings=Val(false))
155✔
421
    prototype = MMI.istable(Xnew) ? Xnew : nothing
125✔
422
    Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type)
130✔
423
    T = promote_type(eltype(Xnew_t), fitresult.types.T)
125✔
424

425
    if isempty(params.equations) || any(isempty, params.equations)
390✔
426
        @warn "Equations not found. Returning 0s for prediction."
24✔
427
        return prediction_fallback(T, m, Xnew_t, fitresult, prototype)
24✔
428
    end
429

430
    X_units_clean = clean_units(X_units)
103✔
431
    validate_variable_names(variable_names, fitresult)
106✔
432
    validate_units(X_units_clean, fitresult.X_units)
102✔
433

434
    idx = idx === nothing ? params.best_idx : idx
90✔
435

436
    if M <: SRRegressor
90✔
437
        return eval_tree_mlj(
48✔
438
            params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype
439
        )
440
    elseif M <: MultitargetSRRegressor
42✔
441
        outs = [
42✔
442
            eval_tree_mlj(
443
                params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype
444
            ) for i in eachindex(idx, params.equations)
445
        ]
446
        out_matrix = reduce(hcat, outs)
42✔
447
        if !fitresult.y_is_table
42✔
448
            return out_matrix
30✔
449
        else
450
            return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
12✔
451
        end
452
    end
453
end
454

455
function get_equation_strings_for(::SRRegressor, trees, options, variable_names)
66✔
456
    return (t -> string_tree(t, options; variable_names=variable_names)).(trees)
519✔
457
end
458
function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names)
55✔
459
    return [
66✔
460
        (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees
901✔
461
    ]
462
end
463

464
function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L<:LOSS_TYPE}
766✔
465
    # Same as in PySR:
466
    # https://github.com/MilesCranmer/PySR/blob/e74b8ad46b163c799908b3aa4d851cf8457c79ef/pysr/sr.py#L2318-L2332
467
    # threshold = 1.5 * minimum_loss
468
    # Then, we get max score of those below the threshold.
469
    threshold = 1.5 * minimum(losses)
445✔
470
    return argmax([
383✔
471
        (losses[i] <= threshold) ? scores[i] : typemin(L) for i in eachindex(losses)
472
    ])
473
end
474

475
function dispatch_selection_for(m::SRRegressor, trees, losses, scores, complexities)::Int
126✔
476
    length(trees) == 0 && return 0
126✔
477
    return m.selection_method(;
103✔
478
        trees=trees, losses=losses, scores=scores, complexities=complexities
479
    )
480
end
481
function dispatch_selection_for(
131✔
482
    m::MultitargetSRRegressor, trees, losses, scores, complexities
483
)
484
    any(t -> length(t) == 0, trees) && return fill(0, length(trees))
715✔
485
    return [
107✔
486
        m.selection_method(;
487
            trees=trees[i], losses=losses[i], scores=scores[i], complexities=complexities[i]
488
        ) for i in eachindex(trees)
489
    ]
490
end
491

492
MMI.metadata_pkg(
493
    AbstractSRRegressor;
494
    name="SymbolicRegression",
495
    uuid="8254be44-1295-4e6a-a16d-46603ac705cb",
496
    url="https://github.com/MilesCranmer/SymbolicRegression.jl",
497
    julia=true,
498
    license="Apache-2.0",
499
    is_wrapper=false,
500
)
501

502
const input_scitype = Union{
503
    MMI.Table(MMI.Continuous),
504
    AbstractMatrix{<:MMI.Continuous},
505
    MMI.Table(MMI.Continuous, MMI.Count),
506
}
507

508
# TODO: Allow for Count data, and coerce it into Continuous as needed.
509
MMI.metadata_model(
510
    SRRegressor;
511
    input_scitype,
512
    target_scitype=AbstractVector{<:MMI.Continuous},
513
    supports_weights=true,
514
    reports_feature_importances=false,
515
    load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
516
    human_name="Symbolic Regression via Evolutionary Search",
517
)
518
MMI.metadata_model(
519
    MultitargetSRRegressor;
520
    input_scitype,
521
    target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
522
    supports_weights=true,
523
    reports_feature_importances=false,
524
    load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
525
    human_name="Multi-Target Symbolic Regression via Evolutionary Search",
526
)
527

528
function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)
42✔
529
    docstring = """$(MMI.doc_header(eval(model_name)))
42✔
530

531
    $(description)
532

533
    # Hyper-parameters
534
    """
535

536
    # TODO: These ones are copied (or written) manually:
537
    append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search.
30✔
538
        More iterations will improve the results.
539
    - `parallelism=:multithreading`: What parallelism mode to use.
540
        The options are `:multithreading`, `:multiprocessing`, and `:serial`.
541
        By default, multithreading will be used. Multithreading uses less memory,
542
        but multiprocessing can handle multi-node compute. If using `:multithreading`
543
        mode, the number of threads available to julia are used. If using
544
        `:multiprocessing`, `numprocs` processes will be created dynamically if
545
        `procs` is unset. If you have already allocated processes, pass them
546
        to the `procs` argument and they will be used.
547
        You may also pass a string instead of a symbol, like `"multithreading"`.
548
    - `numprocs::Union{Int, Nothing}=nothing`:  The number of processes to use,
549
        if you want `equation_search` to set this up automatically. By default
550
        this will be `4`, but can be any number (you should pick a number <=
551
        the number of cores available).
552
    - `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
553
        a distributed run manually with `procs = addprocs()` and `@everywhere`,
554
        pass the `procs` to this keyword argument.
555
    - `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
556
        (`parallelism=:multithreading`), and are not passing `procs` manually,
557
        then they will be allocated dynamically using `addprocs`. However,
558
        you may also pass a custom function to use instead of `addprocs`.
559
        This function should take a single positional argument,
560
        which is the number of processes to use, as well as the `lazy` keyword argument.
561
        For example, if set up on a slurm cluster, you could pass
562
        `addprocs_function = addprocs_slurm`, which will set up slurm processes.
563
    - `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
564
        flag on Julia processes, recommending garbage collection once a process
565
        is close to the recommended size. This is important for long-running distributed
566
        jobs where each process has an independent memory, and can help avoid
567
        out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
568
    - `runtests::Bool=true`: Whether to run (quick) tests before starting the
569
        search, to see if there will be any problems during the equation search
570
        related to the host environment.
571
    - `loss_type::Type=Nothing`: If you would like to use a different type
572
        for the loss than for the data you passed, specify the type here.
573
        Note that if you pass complex data `::Complex{L}`, then the loss
574
        type will automatically be set to `L`.
575
    - `selection_method::Function`: Function to selection expression from
576
        the Pareto frontier for use in `predict`.
577
        See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
578
        This function should return a single integer specifying
579
        the index of the expression to use. By default, this maximizes
580
        the score (a pound-for-pound rating) of expressions reaching the threshold
581
        of 1.5x the minimum loss. To override this at prediction time, you can pass
582
        a named tuple with keys `data` and `idx` to `predict`. See the Operations
583
        section for details.
584
    - `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
585
        the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
586
    """
587

588
    bottom = """
42✔
589
    # Operations
590

591
    - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
592
        should have same scitype as `X` above. The expression used for prediction is defined
593
        by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
594
    - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
595
        `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
596
        `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
597

598
    $(bottom_matter)
599
    """
600

601
    # Remove common indentation:
602
    docstring = replace(docstring, r"^    " => "")
42✔
603
    extra_arguments = replace(append_arguments, r"^    " => "")
42✔
604
    bottom = replace(bottom, r"^    " => "")
42✔
605

606
    # Add parameter descriptions:
607
    docstring = docstring * OPTION_DESCRIPTIONS
42✔
608
    docstring = docstring * extra_arguments
42✔
609
    docstring = docstring * bottom
42✔
610
    return quote
42✔
611
        @doc $docstring $model_name
612
    end
613
end
614

615
#https://arxiv.org/abs/2305.01582
616
eval(
617
    tag_with_docstring(
618
        :SRRegressor,
619
        replace(
620
            """
621
    Single-target Symbolic Regression regressor (`SRRegressor`) searches
622
    for symbolic expressions that predict a single target variable from
623
    a set of input variables. All data is assumed to be `Continuous`.
624
    The search is performed using an evolutionary algorithm.
625
    This algorithm is described in the paper
626
    https://arxiv.org/abs/2305.01582.
627

628
    # Training data
629

630
    In MLJ or MLJBase, bind an instance `model` to data with
631

632
        mach = machine(model, X, y)
633

634
    OR
635

636
        mach = machine(model, X, y, w)
637

638
    Here:
639

640
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
641
      `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
642
      expressions will be taken from the column names of `X`, if available. Units in columns
643
      of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
644

645
    - `y` is the target, which can be any `AbstractVector` whose element scitype is
646
        `Continuous`; check the scitype with `scitype(y)`. Units in `y` (use `DynamicQuantities`
647
        for units) will trigger dimensional analysis to be used.
648

649
    - `w` is the observation weights which can either be `nothing` (default) or an
650
      `AbstractVector` whose element scitype is `Count` or `Continuous`.
651

652
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
653
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
654
    Note that unlike other regressors, symbolic regression stores a list of
655
    trained models. The model chosen from this list is defined by the function
656
    `selection_method` keyword argument, which by default balances accuracy
657
    and complexity. You can override this at prediction time by passing a named
658
    tuple with keys `data` and `idx`.
659

660
    """,
661
            r"^    " => "",
662
        ),
663
        replace(
664
            """
665
    # Fitted parameters
666

667
    The fields of `fitted_params(mach)` are:
668

669
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
670
       as determined by the `selection_method` function. Override in `predict` by passing
671
        a named tuple with keys `data` and `idx`.
672
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
673
      in a dominating Pareto frontier (i.e., the best expressions found for
674
      each complexity). `T` is equal to the element type
675
      of the passed data.
676
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
677
      represented as strings for easy inspection.
678

679
    # Report
680

681
    The fields of `report(mach)` are:
682

683
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
684
       as determined by the `selection_method` function. Override in `predict` by passing
685
       a named tuple with keys `data` and `idx`.
686
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
687
      in a dominating Pareto frontier (i.e., the best expressions found for
688
      each complexity).
689
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
690
      represented as strings for easy inspection.
691
    - `complexities::Vector{Int}`: The complexity of each expression in the Pareto frontier.
692
    - `losses::Vector{L}`: The loss of each expression in the Pareto frontier, according
693
      to the loss function specified in the model. The type `L` is the loss type, which
694
      is usually the same as the element type of data passed (i.e., `T`), but can differ
695
      if complex data types are passed.
696
    - `scores::Vector{L}`: A metric which considers both the complexity and loss of an expression,
697
      equal to the change in the log-loss divided by the change in complexity, relative to
698
      the previous expression along the Pareto frontier. A larger score aims to indicate
699
      an expression is more likely to be the true expression generating the data, but
700
      this is very problem-dependent and generally several other factors should be considered.
701

702
    # Examples
703

704
    ```julia
705
    using MLJ
706
    SRRegressor = @load SRRegressor pkg=SymbolicRegression
707
    X, y = @load_boston
708
    model = SRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
709
    mach = machine(model, X, y)
710
    fit!(mach)
711
    y_hat = predict(mach, X)
712
    # View the equation used:
713
    r = report(mach)
714
    println("Equation used:", r.equation_strings[r.best_idx])
715
    ```
716

717
    With units and variable names:
718

719
    ```julia
720
    using MLJ
721
    using DynamicQuantities
722
    SRegressor = @load SRRegressor pkg=SymbolicRegression
723

724
    X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km")
725
    y = @. X.x2 / X.x1 + 0.5us"h"
726
    model = SRRegressor(binary_operators=[+, -, *, /])
727
    mach = machine(model, X, y)
728
    fit!(mach)
729
    y_hat = predict(mach, X)
730
    # View the equation used:
731
    r = report(mach)
732
    println("Equation used:", r.equation_strings[r.best_idx])
733
    ```
734

735
    See also [`MultitargetSRRegressor`](@ref).
736
    """,
737
            r"^    " => "",
738
        ),
739
    ),
740
)
741
eval(
742
    tag_with_docstring(
743
        :MultitargetSRRegressor,
744
        replace(
745
            """
746
    Multi-target Symbolic Regression regressor (`MultitargetSRRegressor`)
747
    conducts several searches for expressions that predict each target variable
748
    from a set of input variables. All data is assumed to be `Continuous`.
749
    The search is performed using an evolutionary algorithm.
750
    This algorithm is described in the paper
751
    https://arxiv.org/abs/2305.01582.
752

753
    # Training data
754
    In MLJ or MLJBase, bind an instance `model` to data with
755

756
        mach = machine(model, X, y)
757

758
    OR
759

760
        mach = machine(model, X, y, w)
761

762
    Here:
763

764
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
765
    `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
766
    expressions will be taken from the column names of `X`, if available. Units in columns
767
    of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
768

769
    - `y` is the target, which can be any table of target variables whose element
770
      scitype is `Continuous`; check the scitype with `schema(y)`. Units in columns of
771
      `y` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
772

773
    - `w` is the observation weights which can either be `nothing` (default) or an
774
      `AbstractVector` whose element scitype is `Count` or `Continuous`. The same
775
      weights are used for all targets.
776

777
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
778
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
779
    Note that unlike other regressors, symbolic regression stores a list of lists of
780
    trained models. The models chosen from each of these lists is defined by the function
781
    `selection_method` keyword argument, which by default balances accuracy
782
    and complexity. You can override this at prediction time by passing a named
783
    tuple with keys `data` and `idx`.
784

785
    """,
786
            r"^    " => "",
787
        ),
788
        replace(
789
            """
790
    # Fitted parameters
791

792
    The fields of `fitted_params(mach)` are:
793

794
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
795
      as determined by the `selection_method` function. Override in `predict` by passing
796
      a named tuple with keys `data` and `idx`.
797
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
798
      in a dominating Pareto frontier (i.e., the best expressions found for
799
      each complexity). The outer vector is indexed by target variable, and the inner
800
      vector is ordered by increasing complexity. `T` is equal to the element type
801
      of the passed data.
802
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
803
      represented as strings for easy inspection.
804

805
    # Report
806

807
    The fields of `report(mach)` are:
808

809
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
810
       as determined by the `selection_method` function. Override in `predict` by passing
811
       a named tuple with keys `data` and `idx`.
812
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
813
      in a dominating Pareto frontier (i.e., the best expressions found for
814
      each complexity). The outer vector is indexed by target variable, and the inner
815
      vector is ordered by increasing complexity.
816
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
817
      represented as strings for easy inspection.
818
    - `complexities::Vector{Vector{Int}}`: The complexity of each expression in each Pareto frontier.
819
    - `losses::Vector{Vector{L}}`: The loss of each expression in each Pareto frontier, according
820
      to the loss function specified in the model. The type `L` is the loss type, which
821
      is usually the same as the element type of data passed (i.e., `T`), but can differ
822
      if complex data types are passed.
823
    - `scores::Vector{Vector{L}}`: A metric which considers both the complexity and loss of an expression,
824
      equal to the change in the log-loss divided by the change in complexity, relative to
825
      the previous expression along the Pareto frontier. A larger score aims to indicate
826
      an expression is more likely to be the true expression generating the data, but
827
      this is very problem-dependent and generally several other factors should be considered.
828

829
    # Examples
830

831
    ```julia
832
    using MLJ
833
    MultitargetSRRegressor = @load MultitargetSRRegressor pkg=SymbolicRegression
834
    X = (a=rand(100), b=rand(100), c=rand(100))
835
    Y = (y1=(@. cos(X.c) * 2.1 - 0.9), y2=(@. X.a * X.b + X.c))
836
    model = MultitargetSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
837
    mach = machine(model, X, Y)
838
    fit!(mach)
839
    y_hat = predict(mach, X)
840
    # View the equations used:
841
    r = report(mach)
842
    for (output_index, (eq, i)) in enumerate(zip(r.equation_strings, r.best_idx))
843
        println("Equation used for ", output_index, ": ", eq[i])
844
    end
845
    ```
846

847
    See also [`SRRegressor`](@ref).
848
    """,
849
            r"^    " => "",
850
        ),
851
    ),
852
)
853

854
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc