• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9704727222

27 Jun 2024 11:01PM UTC coverage: 95.922% (+1.3%) from 94.617%
9704727222

Pull #326

github

web-flow
Merge 1f104aaf8 into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

301 of 307 new or added lines in 17 files covered. (98.05%)

1 existing line in 1 file now uncovered.

2611 of 2722 relevant lines covered (95.92%)

35611300.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.96
/src/MLJInterface.jl
1
module MLJInterfaceModule
162✔
2

3
using Optim: Optim
4
using LineSearches: LineSearches
5
using MLJModelInterface: MLJModelInterface as MMI
6
using ADTypes: AbstractADType
7
using DynamicExpressions:
8
    eval_tree_array,
9
    string_tree,
10
    AbstractExpressionNode,
11
    AbstractExpression,
12
    Node,
13
    Expression,
14
    default_node_type,
15
    get_tree
16
using DynamicQuantities:
17
    QuantityArray,
18
    UnionAbstractQuantity,
19
    AbstractDimensions,
20
    SymbolicDimensions,
21
    Quantity,
22
    DEFAULT_DIM_BASE_TYPE,
23
    ustrip,
24
    dimension
25
using LossFunctions: SupervisedLoss
26
using Compat: allequal, stack
27
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
28
using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
29
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
30
using ..ComplexityModule: compute_complexity
31
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
32
using ..UtilsModule: subscriptify
33

34
import ..equation_search
35

36
abstract type AbstractSRRegressor <: MMI.Deterministic end
37

38
# TODO: To reduce code re-use, we could forward these defaults from
39
#       `equation_search`, similar to what we do for `Options`.
40

41
"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
42
function modelexpr(model_name::Symbol)
42✔
43
    struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <:
162✔
44
                                 AbstractSRRegressor
45
        niterations::Int = 10
46
        parallelism::Symbol = :multithreading
47
        numprocs::Union{Int,Nothing} = nothing
48
        procs::Union{Vector{Int},Nothing} = nothing
49
        addprocs_function::Union{Function,Nothing} = nothing
50
        heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
51
        runtests::Bool = true
52
        loss_type::L = Nothing
53
        selection_method::Function = choose_best
54
        dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
55
    end)
56
    # TODO: store `procs` from initial run if parallelism is `:multiprocessing`
57
    fields = last(last(struct_def.args).args).args
42✔
58

59
    # Add everything from `Options` constructor directly to struct:
60
    for (i, option) in enumerate(DEFAULT_OPTIONS)
54✔
61
        insert!(fields, i, Expr(:(=), option.args...))
3,354✔
62
    end
5,004✔
63

64
    # We also need to create the `get_options` function, based on this:
65
    constructor = :(Options(;))
42✔
66
    constructor_fields = last(constructor.args).args
42✔
67
    for option in DEFAULT_OPTIONS
42✔
68
        symb = getsymb(first(option.args))
3,924✔
69
        push!(constructor_fields, Expr(:kw, symb, Expr(:(.), :m, Core.QuoteNode(symb))))
5,466✔
70
    end
2,526✔
71

72
    return quote
42✔
73
        $struct_def
74
        function get_options(m::$(model_name))
148✔
75
            return $constructor
148✔
76
        end
77
    end
78
end
79
function getsymb(ex::Symbol)
240✔
80
    return ex
240✔
81
end
82
function getsymb(ex::Expr)
2,604✔
83
    for arg in ex.args
2,604✔
84
        isa(arg, Symbol) && return arg
2,604✔
85
        s = getsymb(arg)
×
86
        isa(s, Symbol) && return s
×
87
    end
88
    return nothing
×
89
end
90

91
"""Get an equivalent `Options()` object for a particular regressor."""
92
function get_options(::AbstractSRRegressor) end
×
93

94
eval(modelexpr(:SRRegressor))
95
eval(modelexpr(:MultitargetSRRegressor))
96

97
# Cleaning already taken care of by `Options` and `equation_search`
98
function full_report(
499✔
99
    m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true)
100
) where {with_strings}
101
    _, hof = fitresult.state
299✔
102
    # TODO: Adjust baseline loss
103
    formatted = format_hall_of_fame(hof, fitresult.options)
320✔
104
    equation_strings = if with_strings
257✔
105
        get_equation_strings_for(
132✔
106
            m, formatted.trees, fitresult.options, fitresult.variable_names
107
        )
108
    else
109
        nothing
169✔
110
    end
111
    best_idx = dispatch_selection_for(
257✔
112
        m, formatted.trees, formatted.losses, formatted.scores, formatted.complexities
113
    )
114
    return (;
257✔
115
        best_idx=best_idx,
116
        equations=formatted.trees,
117
        equation_strings=equation_strings,
118
        losses=formatted.losses,
119
        complexities=formatted.complexities,
120
        scores=formatted.scores,
121
    )
122
end
123

124
MMI.clean!(::AbstractSRRegressor) = ""
113✔
125

126
# TODO: Enable `verbosity` being passed to `equation_search`
127
function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing)
210✔
128
    return MMI.update(m, verbosity, nothing, nothing, X, y, w)
224✔
129
end
130
function MMI.update(
146✔
131
    m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing
132
)
133
    options = old_fitresult === nothing ? get_options(m) : old_fitresult.options
148✔
134
    return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing)
136✔
135
end
136
function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes)
147✔
137
    if isnothing(classes) && MMI.istable(X) && haskey(X, :classes)
184✔
138
        if !(X isa NamedTuple)
12✔
NEW
139
            error("Classes can only be specified with named tuples.")
×
140
        end
141
        new_X = Base.structdiff(X, (; X.classes))
12✔
142
        new_classes = X.classes
12✔
143
        return _update(
12✔
144
            m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes
145
        )
146
    end
147
    if !isnothing(old_fitresult)
136✔
148
        @assert(
12✔
149
            old_fitresult.has_classes == !isnothing(classes),
150
            "If the first fit used classes, the second fit must also use classes."
151
        )
152
    end
153
    # To speed up iterative fits, we cache the types:
154
    types = if isnothing(old_fitresult)
136✔
155
        (;
124✔
156
            T=Any,
157
            X_t=Any,
158
            y_t=Any,
159
            w_t=Any,
160
            state=Any,
161
            X_units=Any,
162
            y_units=Any,
163
            X_units_clean=Any,
164
            y_units_clean=Any,
165
        )
166
    else
167
        old_fitresult.types
52✔
168
    end
169
    X_t::types.X_t, variable_names, X_units::types.X_units = get_matrix_and_info(
144✔
170
        X, m.dimensions_type
171
    )
172
    y_t::types.y_t, y_variable_names, y_units::types.y_units = format_input_for(
139✔
173
        m, y, m.dimensions_type
174
    )
175
    X_units_clean::types.X_units_clean = clean_units(X_units)
120✔
176
    y_units_clean::types.y_units_clean = clean_units(y_units)
129✔
177
    w_t::types.w_t = if w !== nothing && isa(m, MultitargetSRRegressor)
114✔
178
        @assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.")
24✔
179
        repeat(w', size(y_t, 1))
24✔
180
    else
181
        w
98✔
182
    end
183
    search_state::types.state = equation_search(
114✔
184
        X_t,
185
        y_t;
186
        niterations=m.niterations,
187
        weights=w_t,
188
        variable_names=variable_names,
189
        options=options,
190
        parallelism=m.parallelism,
191
        numprocs=m.numprocs,
192
        procs=m.procs,
193
        addprocs_function=m.addprocs_function,
194
        heap_size_hint_in_bytes=m.heap_size_hint_in_bytes,
195
        runtests=m.runtests,
196
        saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state),
197
        return_state=true,
198
        loss_type=m.loss_type,
199
        X_units=X_units_clean,
200
        y_units=y_units_clean,
201
        verbosity=verbosity,
202
        extra=isnothing(classes) ? (;) : (; classes),
203
        # Help out with inference:
204
        v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2),
205
    )
206
    fitresult = (;
108✔
207
        state=search_state,
208
        num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1),
209
        options=options,
210
        variable_names=variable_names,
211
        y_variable_names=y_variable_names,
212
        y_is_table=MMI.istable(y),
213
        has_classes=!isnothing(classes),
214
        X_units=X_units_clean,
215
        y_units=y_units_clean,
216
        types=(
217
            T=hof_eltype(search_state[2]),
218
            X_t=typeof(X_t),
219
            y_t=typeof(y_t),
220
            w_t=typeof(w_t),
221
            state=typeof(search_state),
222
            X_units=typeof(X_units),
223
            y_units=typeof(y_units),
224
            X_units_clean=typeof(X_units_clean),
225
            y_units_clean=typeof(y_units_clean),
226
        ),
227
    )::(old_fitresult === nothing ? Any : typeof(old_fitresult))
228
    return (fitresult, nothing, full_report(m, fitresult))
108✔
229
end
230
hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T
97✔
231
hof_eltype(::Type{V}) where {V<:Vector} = hof_eltype(eltype(V))
48✔
232
hof_eltype(h) = hof_eltype(typeof(h))
108✔
233

234
function clean_units(units)
51✔
235
    !isa(units, AbstractDimensions) && error("Unexpected units.")
60✔
236
    iszero(units) && return nothing
60✔
237
    return units
12✔
238
end
239
function clean_units(units::Vector)
233✔
240
    !all(Base.Fix2(isa, AbstractDimensions), units) && error("Unexpected units.")
358✔
241
    all(iszero, units) && return nothing
287✔
242
    return units
36✔
243
end
244

245
function get_matrix_and_info(X, ::Type{D}) where {D}
301✔
246
    sch = MMI.istable(X) ? MMI.schema(X) : nothing
315✔
247
    Xm_t = MMI.matrix(X; transpose=true)
331✔
248
    colnames = if sch === nothing
315✔
249
        [map(i -> "x$(subscriptify(i))", axes(Xm_t, 1))...]
376✔
250
    else
251
        [string.(sch.names)...]
506✔
252
    end
253
    D_promoted = get_dimensions_type(Xm_t, D)
315✔
254
    Xm_t_strip, X_units = unwrap_units_single(Xm_t, D_promoted)
315✔
255
    return Xm_t_strip, colnames, X_units
315✔
256
end
257

258
function format_input_for(::SRRegressor, y, ::Type{D}) where {D}
70✔
259
    @assert(
82✔
260
        !(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)),
261
        "For multi-output regression, please use `MultitargetSRRegressor`."
262
    )
263
    y_t = vec(y)
60✔
264
    colnames = nothing
60✔
265
    D_promoted = get_dimensions_type(y_t, D)
60✔
266
    y_t_strip, y_units = unwrap_units_single(y_t, D_promoted)
60✔
267
    return y_t_strip, colnames, y_units
60✔
268
end
269
function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D}
56✔
270
    @assert(
65✔
271
        MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1),
272
        "For single-output regression, please use `SRRegressor`."
273
    )
274
    return get_matrix_and_info(y, D)
57✔
275
end
276
function validate_variable_names(variable_names, fitresult)
85✔
277
    @assert(
112✔
278
        variable_names == fitresult.variable_names,
279
        "Variable names do not match fitted regressor."
280
    )
281
    return nothing
90✔
282
end
283
function validate_units(X_units, old_X_units)
77✔
284
    @assert(
90✔
285
        all(X_units .== old_X_units),
286
        "Units of new data do not match units of fitted regressor."
287
    )
288
    return nothing
90✔
289
end
290

291
# TODO: Test whether this conversion poses any issues in data normalization...
292
function dimension_with_fallback(q::UnionAbstractQuantity{T}, ::Type{D}) where {T,D}
6,692✔
293
    return dimension(convert(Quantity{T,D}, q))::D
8,028✔
294
end
295
function dimension_with_fallback(_, ::Type{D}) where {D}
49,086✔
296
    return D()
58,638✔
297
end
298
function prediction_warn()
24✔
299
    @warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction."
24✔
300
end
301

302
wrap_units(v, ::Nothing, ::Integer) = v
138✔
303
wrap_units(v, ::Nothing, ::Nothing) = v
48✔
304
wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v)
48✔
305
wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v)
48✔
306

307
function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T}
12✔
308
    prediction_warn()
12✔
309
    out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T))
384✔
310
    return wrap_units(out, fitresult.y_units, nothing)
12✔
311
end
312
function prediction_fallback(
12✔
313
    ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype
314
) where {T}
315
    prediction_warn()
12✔
316
    out_cols = [
12✔
317
        wrap_units(
318
            fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i
319
        ) for i in 1:(fitresult.num_targets)
320
    ]
321
    out_matrix = hcat(out_cols...)
12✔
322
    if !fitresult.y_is_table
12✔
323
        return out_matrix
12✔
324
    else
325
        return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
×
326
    end
327
end
328

329
compat_ustrip(A::QuantityArray) = ustrip(A)
6✔
330
compat_ustrip(A) = ustrip.(A)
940✔
331

332
"""
333
    unwrap_units_single(::AbstractArray, ::Type{<:AbstractDimensions})
334

335
Remove units from some features in a matrix, and return, as a tuple,
336
(1) the matrix with stripped units, and (2) the dimensions for those features.
337
"""
338
function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D}
333✔
339
    dims = D[dimension_with_fallback(first(row), D) for row in eachrow(A)]
886✔
340
    @inbounds for (i, row) in enumerate(eachrow(A))
666✔
341
        all(xi -> dimension_with_fallback(xi, D) == dims[i], row) ||
56,632✔
342
            error("Inconsistent units in feature $i of matrix.")
343
    end
1,210✔
344
    return stack(compat_ustrip, eachrow(A); dims=1)::AbstractMatrix, dims
333✔
345
end
346
function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D}
51✔
347
    dims = dimension_with_fallback(first(v), D)
60✔
348
    all(xi -> dimension_with_fallback(xi, D) == dims, v) ||
10,122✔
349
        error("Inconsistent units in vector.")
350
    return compat_ustrip(v)::AbstractVector, dims
60✔
351
end
352

353
function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
12✔
354
    report = full_report(m, fitresult)
12✔
355
    return (;
12✔
356
        best_idx=report.best_idx,
357
        equations=report.equations,
358
        equation_strings=report.equation_strings,
359
    )
360
end
361

362
function eval_tree_mlj(
162✔
363
    tree::AbstractExpression,
364
    X_t,
365
    classes,
366
    m::AbstractSRRegressor,
367
    ::Type{T},
368
    fitresult,
369
    i,
370
    prototype,
371
) where {T}
372
    out, completed = if isnothing(classes)
162✔
373
        eval_tree_array(tree, X_t, fitresult.options)
225✔
374
    else
375
        eval_tree_array(tree, X_t, classes, fitresult.options)
168✔
376
    end
377
    if completed
162✔
378
        return wrap_units(out, fitresult.y_units, i)
162✔
379
    else
380
        return prediction_fallback(T, m, X_t, fitresult, prototype)
×
381
    end
382
end
383

384
function MMI.predict(
262✔
385
    m::M, fitresult, Xnew; idx=nothing, classes=nothing
386
) where {M<:AbstractSRRegressor}
387
    return _predict(m, fitresult, Xnew, idx, classes)
162✔
388
end
389
function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor}
188✔
390
    if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data))
195✔
391
        @assert(
58✔
392
            haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2,
393
            "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`."
394
        )
395
        return _predict(m, fitresult, Xnew.data, Xnew.idx, classes)
42✔
396
    end
397
    if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes)
137✔
398
        if !(Xnew isa NamedTuple)
12✔
NEW
399
            error("Classes can only be specified with named tuples.")
×
400
        end
401
        Xnew2 = Base.structdiff(Xnew, (; Xnew.classes))
12✔
402
        return _predict(m, fitresult, Xnew2, idx, Xnew.classes)
12✔
403
    end
404

405
    if fitresult.has_classes
125✔
406
        @assert(
12✔
407
            !isnothing(classes),
408
            "Classes must be specified if the model was fit with classes."
409
        )
410
    end
411

412
    params = full_report(m, fitresult; v_with_strings=Val(false))
155✔
413
    prototype = MMI.istable(Xnew) ? Xnew : nothing
125✔
414
    Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type)
130✔
415
    T = promote_type(eltype(Xnew_t), fitresult.types.T)
125✔
416

417
    if isempty(params.equations) || any(isempty, params.equations)
390✔
418
        @warn "Equations not found. Returning 0s for prediction."
24✔
419
        return prediction_fallback(T, m, Xnew_t, fitresult, prototype)
24✔
420
    end
421

422
    X_units_clean = clean_units(X_units)
103✔
423
    validate_variable_names(variable_names, fitresult)
106✔
424
    validate_units(X_units_clean, fitresult.X_units)
102✔
425

426
    idx = idx === nothing ? params.best_idx : idx
90✔
427

428
    if M <: SRRegressor
90✔
429
        return eval_tree_mlj(
48✔
430
            params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype
431
        )
432
    elseif M <: MultitargetSRRegressor
42✔
433
        outs = [
42✔
434
            eval_tree_mlj(
435
                params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype
436
            ) for i in eachindex(idx, params.equations)
437
        ]
438
        out_matrix = reduce(hcat, outs)
42✔
439
        if !fitresult.y_is_table
42✔
440
            return out_matrix
30✔
441
        else
442
            return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
12✔
443
        end
444
    end
445
end
446

447
function get_equation_strings_for(::SRRegressor, trees, options, variable_names)
66✔
448
    return (t -> string_tree(t, options; variable_names=variable_names)).(trees)
527✔
449
end
450
function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names)
55✔
451
    return [
66✔
452
        (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees
989✔
453
    ]
454
end
455

456
function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L<:LOSS_TYPE}
764✔
457
    # Same as in PySR:
458
    # https://github.com/MilesCranmer/PySR/blob/e74b8ad46b163c799908b3aa4d851cf8457c79ef/pysr/sr.py#L2318-L2332
459
    # threshold = 1.5 * minimum_loss
460
    # Then, we get max score of those below the threshold.
461
    threshold = 1.5 * minimum(losses)
444✔
462
    return argmax([
382✔
463
        (losses[i] <= threshold) ? scores[i] : typemin(L) for i in eachindex(losses)
464
    ])
465
end
466

467
function dispatch_selection_for(m::SRRegressor, trees, losses, scores, complexities)::Int
126✔
468
    length(trees) == 0 && return 0
126✔
469
    return m.selection_method(;
102✔
470
        trees=trees, losses=losses, scores=scores, complexities=complexities
471
    )
472
end
473
function dispatch_selection_for(
131✔
474
    m::MultitargetSRRegressor, trees, losses, scores, complexities
475
)
476
    any(t -> length(t) == 0, trees) && return fill(0, length(trees))
715✔
477
    return [
107✔
478
        m.selection_method(;
479
            trees=trees[i], losses=losses[i], scores=scores[i], complexities=complexities[i]
480
        ) for i in eachindex(trees)
481
    ]
482
end
483

484
MMI.metadata_pkg(
485
    AbstractSRRegressor;
486
    name="SymbolicRegression",
487
    uuid="8254be44-1295-4e6a-a16d-46603ac705cb",
488
    url="https://github.com/MilesCranmer/SymbolicRegression.jl",
489
    julia=true,
490
    license="Apache-2.0",
491
    is_wrapper=false,
492
)
493

494
const input_scitype = Union{
495
    MMI.Table(MMI.Continuous),
496
    AbstractMatrix{<:MMI.Continuous},
497
    MMI.Table(MMI.Continuous, MMI.Count),
498
}
499

500
# TODO: Allow for Count data, and coerce it into Continuous as needed.
501
MMI.metadata_model(
502
    SRRegressor;
503
    input_scitype,
504
    target_scitype=AbstractVector{<:MMI.Continuous},
505
    supports_weights=true,
506
    reports_feature_importances=false,
507
    load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
508
    human_name="Symbolic Regression via Evolutionary Search",
509
)
510
MMI.metadata_model(
511
    MultitargetSRRegressor;
512
    input_scitype,
513
    target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
514
    supports_weights=true,
515
    reports_feature_importances=false,
516
    load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
517
    human_name="Multi-Target Symbolic Regression via Evolutionary Search",
518
)
519

520
function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)
42✔
521
    docstring = """$(MMI.doc_header(eval(model_name)))
42✔
522

523
    $(description)
524

525
    # Hyper-parameters
526
    """
527

528
    # TODO: These ones are copied (or written) manually:
529
    append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search.
30✔
530
        More iterations will improve the results.
531
    - `parallelism=:multithreading`: What parallelism mode to use.
532
        The options are `:multithreading`, `:multiprocessing`, and `:serial`.
533
        By default, multithreading will be used. Multithreading uses less memory,
534
        but multiprocessing can handle multi-node compute. If using `:multithreading`
535
        mode, the number of threads available to julia are used. If using
536
        `:multiprocessing`, `numprocs` processes will be created dynamically if
537
        `procs` is unset. If you have already allocated processes, pass them
538
        to the `procs` argument and they will be used.
539
        You may also pass a string instead of a symbol, like `"multithreading"`.
540
    - `numprocs::Union{Int, Nothing}=nothing`:  The number of processes to use,
541
        if you want `equation_search` to set this up automatically. By default
542
        this will be `4`, but can be any number (you should pick a number <=
543
        the number of cores available).
544
    - `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
545
        a distributed run manually with `procs = addprocs()` and `@everywhere`,
546
        pass the `procs` to this keyword argument.
547
    - `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
548
        (`parallelism=:multithreading`), and are not passing `procs` manually,
549
        then they will be allocated dynamically using `addprocs`. However,
550
        you may also pass a custom function to use instead of `addprocs`.
551
        This function should take a single positional argument,
552
        which is the number of processes to use, as well as the `lazy` keyword argument.
553
        For example, if set up on a slurm cluster, you could pass
554
        `addprocs_function = addprocs_slurm`, which will set up slurm processes.
555
    - `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
556
        flag on Julia processes, recommending garbage collection once a process
557
        is close to the recommended size. This is important for long-running distributed
558
        jobs where each process has an independent memory, and can help avoid
559
        out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
560
    - `runtests::Bool=true`: Whether to run (quick) tests before starting the
561
        search, to see if there will be any problems during the equation search
562
        related to the host environment.
563
    - `loss_type::Type=Nothing`: If you would like to use a different type
564
        for the loss than for the data you passed, specify the type here.
565
        Note that if you pass complex data `::Complex{L}`, then the loss
566
        type will automatically be set to `L`.
567
    - `selection_method::Function`: Function to selection expression from
568
        the Pareto frontier for use in `predict`.
569
        See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
570
        This function should return a single integer specifying
571
        the index of the expression to use. By default, this maximizes
572
        the score (a pound-for-pound rating) of expressions reaching the threshold
573
        of 1.5x the minimum loss. To override this at prediction time, you can pass
574
        a named tuple with keys `data` and `idx` to `predict`. See the Operations
575
        section for details.
576
    - `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
577
        the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
578
    """
579

580
    bottom = """
42✔
581
    # Operations
582

583
    - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
584
      should have same scitype as `X` above. The expression used for prediction is defined
585
      by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
586
    - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
587
      `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
588
      `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
589

590
    $(bottom_matter)
591
    """
592

593
    # Remove common indentation:
594
    docstring = replace(docstring, r"^    " => "")
42✔
595
    extra_arguments = replace(append_arguments, r"^    " => "")
42✔
596
    bottom = replace(bottom, r"^    " => "")
42✔
597

598
    # Add parameter descriptions:
599
    docstring = docstring * OPTION_DESCRIPTIONS
42✔
600
    docstring = docstring * extra_arguments
42✔
601
    docstring = docstring * bottom
42✔
602
    return quote
42✔
603
        @doc $docstring $model_name
604
    end
605
end
606

607
#https://arxiv.org/abs/2305.01582
608
eval(
609
    tag_with_docstring(
610
        :SRRegressor,
611
        replace(
612
            """
613
    Single-target Symbolic Regression regressor (`SRRegressor`) searches
614
    for symbolic expressions that predict a single target variable from
615
    a set of input variables. All data is assumed to be `Continuous`.
616
    The search is performed using an evolutionary algorithm.
617
    This algorithm is described in the paper
618
    https://arxiv.org/abs/2305.01582.
619

620
    # Training data
621

622
    In MLJ or MLJBase, bind an instance `model` to data with
623

624
        mach = machine(model, X, y)
625

626
    OR
627

628
        mach = machine(model, X, y, w)
629

630
    Here:
631

632
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
633
      `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
634
      expressions will be taken from the column names of `X`, if available. Units in columns
635
      of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
636

637
    - `y` is the target, which can be any `AbstractVector` whose element scitype is
638
        `Continuous`; check the scitype with `scitype(y)`. Units in `y` (use `DynamicQuantities`
639
        for units) will trigger dimensional analysis to be used.
640

641
    - `w` is the observation weights which can either be `nothing` (default) or an
642
      `AbstractVector` whoose element scitype is `Count` or `Continuous`.
643

644
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
645
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
646
    Note that unlike other regressors, symbolic regression stores a list of
647
    trained models. The model chosen from this list is defined by the function
648
    `selection_method` keyword argument, which by default balances accuracy
649
    and complexity. You can override this at prediction time by passing a named
650
    tuple with keys `data` and `idx`.
651

652
    """,
653
            r"^    " => "",
654
        ),
655
        replace(
656
            """
657
    # Fitted parameters
658

659
    The fields of `fitted_params(mach)` are:
660

661
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
662
       as determined by the `selection_method` function. Override in `predict` by passing
663
        a named tuple with keys `data` and `idx`.
664
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
665
      in a dominating Pareto frontier (i.e., the best expressions found for
666
      each complexity). `T` is equal to the element type
667
      of the passed data.
668
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
669
      represented as strings for easy inspection.
670

671
    # Report
672

673
    The fields of `report(mach)` are:
674

675
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
676
       as determined by the `selection_method` function. Override in `predict` by passing
677
       a named tuple with keys `data` and `idx`.
678
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
679
      in a dominating Pareto frontier (i.e., the best expressions found for
680
      each complexity).
681
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
682
      represented as strings for easy inspection.
683
    - `complexities::Vector{Int}`: The complexity of each expression in the Pareto frontier.
684
    - `losses::Vector{L}`: The loss of each expression in the Pareto frontier, according
685
      to the loss function specified in the model. The type `L` is the loss type, which
686
      is usually the same as the element type of data passed (i.e., `T`), but can differ
687
      if complex data types are passed.
688
    - `scores::Vector{L}`: A metric which considers both the complexity and loss of an expression,
689
      equal to the change in the log-loss divided by the change in complexity, relative to
690
      the previous expression along the Pareto frontier. A larger score aims to indicate
691
      an expression is more likely to be the true expression generating the data, but
692
      this is very problem-dependent and generally several other factors should be considered.
693

694
    # Examples
695

696
    ```julia
697
    using MLJ
698
    SRRegressor = @load SRRegressor pkg=SymbolicRegression
699
    X, y = @load_boston
700
    model = SRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
701
    mach = machine(model, X, y)
702
    fit!(mach)
703
    y_hat = predict(mach, X)
704
    # View the equation used:
705
    r = report(mach)
706
    println("Equation used:", r.equation_strings[r.best_idx])
707
    ```
708

709
    With units and variable names:
710

711
    ```julia
712
    using MLJ
713
    using DynamicQuantities
714
    SRegressor = @load SRRegressor pkg=SymbolicRegression
715

716
    X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km")
717
    y = @. X.x2 / X.x1 + 0.5us"h"
718
    model = SRRegressor(binary_operators=[+, -, *, /])
719
    mach = machine(model, X, y)
720
    fit!(mach)
721
    y_hat = predict(mach, X)
722
    # View the equation used:
723
    r = report(mach)
724
    println("Equation used:", r.equation_strings[r.best_idx])
725
    ```
726

727
    See also [`MultitargetSRRegressor`](@ref).
728
    """,
729
            r"^    " => "",
730
        ),
731
    ),
732
)
733
eval(
734
    tag_with_docstring(
735
        :MultitargetSRRegressor,
736
        replace(
737
            """
738
    Multi-target Symbolic Regression regressor (`MultitargetSRRegressor`)
739
    conducts several searches for expressions that predict each target variable
740
    from a set of input variables. All data is assumed to be `Continuous`.
741
    The search is performed using an evolutionary algorithm.
742
    This algorithm is described in the paper
743
    https://arxiv.org/abs/2305.01582.
744

745
    # Training data
746
    In MLJ or MLJBase, bind an instance `model` to data with
747

748
        mach = machine(model, X, y)
749

750
    OR
751

752
        mach = machine(model, X, y, w)
753

754
    Here:
755

756
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
757
    `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
758
    expressions will be taken from the column names of `X`, if available. Units in columns
759
    of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
760

761
    - `y` is the target, which can be any table of target variables whose element
762
      scitype is `Continuous`; check the scitype with `schema(y)`. Units in columns of
763
      `y` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
764

765
    - `w` is the observation weights which can either be `nothing` (default) or an
766
      `AbstractVector` whoose element scitype is `Count` or `Continuous`. The same
767
      weights are used for all targets.
768

769
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
770
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
771
    Note that unlike other regressors, symbolic regression stores a list of lists of
772
    trained models. The models chosen from each of these lists is defined by the function
773
    `selection_method` keyword argument, which by default balances accuracy
774
    and complexity. You can override this at prediction time by passing a named
775
    tuple with keys `data` and `idx`.
776

777
    """,
778
            r"^    " => "",
779
        ),
780
        replace(
781
            """
782
    # Fitted parameters
783

784
    The fields of `fitted_params(mach)` are:
785

786
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
787
      as determined by the `selection_method` function. Override in `predict` by passing
788
      a named tuple with keys `data` and `idx`.
789
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
790
      in a dominating Pareto frontier (i.e., the best expressions found for
791
      each complexity). The outer vector is indexed by target variable, and the inner
792
      vector is ordered by increasing complexity. `T` is equal to the element type
793
      of the passed data.
794
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
795
      represented as strings for easy inspection.
796

797
    # Report
798

799
    The fields of `report(mach)` are:
800

801
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
802
       as determined by the `selection_method` function. Override in `predict` by passing
803
       a named tuple with keys `data` and `idx`.
804
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
805
      in a dominating Pareto frontier (i.e., the best expressions found for
806
      each complexity). The outer vector is indexed by target variable, and the inner
807
      vector is ordered by increasing complexity.
808
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
809
      represented as strings for easy inspection.
810
    - `complexities::Vector{Vector{Int}}`: The complexity of each expression in each Pareto frontier.
811
    - `losses::Vector{Vector{L}}`: The loss of each expression in each Pareto frontier, according
812
      to the loss function specified in the model. The type `L` is the loss type, which
813
      is usually the same as the element type of data passed (i.e., `T`), but can differ
814
      if complex data types are passed.
815
    - `scores::Vector{Vector{L}}`: A metric which considers both the complexity and loss of an expression,
816
      equal to the change in the log-loss divided by the change in complexity, relative to
817
      the previous expression along the Pareto frontier. A larger score aims to indicate
818
      an expression is more likely to be the true expression generating the data, but
819
      this is very problem-dependent and generally several other factors should be considered.
820

821
    # Examples
822

823
    ```julia
824
    using MLJ
825
    MultitargetSRRegressor = @load MultitargetSRRegressor pkg=SymbolicRegression
826
    X = (a=rand(100), b=rand(100), c=rand(100))
827
    Y = (y1=(@. cos(X.c) * 2.1 - 0.9), y2=(@. X.a * X.b + X.c))
828
    model = MultitargetSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
829
    mach = machine(model, X, Y)
830
    fit!(mach)
831
    y_hat = predict(mach, X)
832
    # View the equations used:
833
    r = report(mach)
834
    for (output_index, (eq, i)) in enumerate(zip(r.equation_strings, r.best_idx))
835
        println("Equation used for ", output_index, ": ", eq[i])
836
    end
837
    ```
838

839
    See also [`SRRegressor`](@ref).
840
    """,
841
            r"^    " => "",
842
        ),
843
    ),
844
)
845

846
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc