• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9686354911

26 Jun 2024 08:31PM UTC coverage: 93.22% (-1.4%) from 94.617%
9686354911

Pull #326

github

web-flow
Merge 6f8229c9f into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

275 of 296 new or added lines in 17 files covered. (92.91%)

34 existing lines in 5 files now uncovered.

2530 of 2714 relevant lines covered (93.22%)

32081968.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/src/MLJInterface.jl
1
module MLJInterfaceModule
162✔
2

3
using Optim: Optim
4
using LineSearches: LineSearches
5
using MLJModelInterface: MLJModelInterface as MMI
6
using ADTypes: AbstractADType
7
using DynamicExpressions:
8
    eval_tree_array,
9
    string_tree,
10
    AbstractExpressionNode,
11
    AbstractExpression,
12
    Node,
13
    Expression,
14
    default_node_type,
15
    get_tree
16
using DynamicQuantities:
17
    QuantityArray,
18
    UnionAbstractQuantity,
19
    AbstractDimensions,
20
    SymbolicDimensions,
21
    Quantity,
22
    DEFAULT_DIM_BASE_TYPE,
23
    ustrip,
24
    dimension
25
using LossFunctions: SupervisedLoss
26
using Compat: allequal, stack
27
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
28
using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
29
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
30
using ..ComplexityModule: compute_complexity
31
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
32
using ..UtilsModule: subscriptify
33

34
import ..equation_search
35

36
abstract type AbstractSRRegressor <: MMI.Deterministic end
37

38
# TODO: To reduce code re-use, we could forward these defaults from
39
#       `equation_search`, similar to what we do for `Options`.
40

41
"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
42
function modelexpr(model_name::Symbol)
42✔
43
    struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <:
74✔
44
                                 AbstractSRRegressor
45
        niterations::Int = 10
46
        parallelism::Symbol = :multithreading
47
        numprocs::Union{Int,Nothing} = nothing
48
        procs::Union{Vector{Int},Nothing} = nothing
49
        addprocs_function::Union{Function,Nothing} = nothing
50
        heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
51
        runtests::Bool = true
52
        loss_type::L = Nothing
53
        selection_method::Function = choose_best
54
        dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
55
    end)
56
    # TODO: store `procs` from initial run if parallelism is `:multiprocessing`
57
    fields = last(last(struct_def.args).args).args
42✔
58

59
    # Add everything from `Options` constructor directly to struct:
60
    for (i, option) in enumerate(DEFAULT_OPTIONS)
54✔
61
        insert!(fields, i, Expr(:(=), option.args...))
3,354✔
62
    end
5,004✔
63

64
    # We also need to create the `get_options` function, based on this:
65
    constructor = :(Options(;))
42✔
66
    constructor_fields = last(constructor.args).args
42✔
67
    for option in DEFAULT_OPTIONS
42✔
68
        symb = getsymb(first(option.args))
3,924✔
69
        push!(constructor_fields, Expr(:kw, symb, Expr(:(.), :m, Core.QuoteNode(symb))))
5,466✔
70
    end
2,526✔
71

72
    return quote
42✔
73
        $struct_def
74
        function get_options(m::$(model_name))
24✔
75
            return $constructor
24✔
76
        end
77
    end
78
end
79
function getsymb(ex::Symbol)
240✔
80
    return ex
240✔
81
end
82
function getsymb(ex::Expr)
2,604✔
83
    for arg in ex.args
2,604✔
84
        isa(arg, Symbol) && return arg
2,604✔
85
        s = getsymb(arg)
×
86
        isa(s, Symbol) && return s
×
87
    end
88
    return nothing
×
89
end
90

91
"""Get an equivalent `Options()` object for a particular regressor."""
92
function get_options(::AbstractSRRegressor) end
×
93

94
eval(modelexpr(:SRRegressor))
95
eval(modelexpr(:MultitargetSRRegressor))
96

97
# Cleaning already taken care of by `Options` and `equation_search`
98
function full_report(
106✔
99
    m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true)
100
) where {with_strings}
101
    _, hof = fitresult.state
63✔
102
    # TODO: Adjust baseline loss
103
    formatted = format_hall_of_fame(hof, fitresult.options)
72✔
104
    equation_strings = if with_strings
54✔
105
        get_equation_strings_for(
30✔
106
            m, formatted.trees, fitresult.options, fitresult.variable_names
107
        )
108
    else
109
        nothing
34✔
110
    end
111
    best_idx = dispatch_selection_for(
54✔
112
        m, formatted.trees, formatted.losses, formatted.scores, formatted.complexities
113
    )
114
    return (;
54✔
115
        best_idx=best_idx,
116
        equations=formatted.trees,
117
        equation_strings=equation_strings,
118
        losses=formatted.losses,
119
        complexities=formatted.complexities,
120
        scores=formatted.scores,
121
    )
122
end
123

124
MMI.clean!(::AbstractSRRegressor) = ""
25✔
125

126
# TODO: Enable `verbosity` being passed to `equation_search`
127
function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing)
44✔
128
    return MMI.update(m, verbosity, nothing, nothing, X, y, w)
48✔
129
end
130
function MMI.update(
35✔
131
    m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing
132
)
133
    options = old_fitresult === nothing ? get_options(m) : old_fitresult.options
36✔
134
    return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing)
30✔
135
end
136
function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes)
36✔
137
    if isnothing(classes) && MMI.istable(X) && haskey(X, :classes)
36✔
138
        if !(X isa NamedTuple)
6✔
NEW
139
            error("Classes can only be specified with named tuples.")
×
140
        end
141
        new_X = Base.structdiff(X, (; X.classes))
6✔
142
        new_classes = X.classes
6✔
143
        return _update(
6✔
144
            m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes
145
        )
146
    end
147
    # To speed up iterative fits, we cache the types:
148
    types = if old_fitresult === nothing
30✔
149
        (;
24✔
150
            T=Any,
151
            X_t=Any,
152
            y_t=Any,
153
            w_t=Any,
154
            state=Any,
155
            X_units=Any,
156
            y_units=Any,
157
            X_units_clean=Any,
158
            y_units_clean=Any,
159
        )
160
    else
161
        old_fitresult.types
14✔
162
    end
163
    X_t::types.X_t, variable_names, X_units::types.X_units = get_matrix_and_info(
30✔
164
        X, m.dimensions_type
165
    )
166
    y_t::types.y_t, y_variable_names, y_units::types.y_units = format_input_for(
30✔
167
        m, y, m.dimensions_type
168
    )
169
    X_units_clean::types.X_units_clean = clean_units(X_units)
36✔
170
    y_units_clean::types.y_units_clean = clean_units(y_units)
45✔
171
    w_t::types.w_t = if w !== nothing && isa(m, MultitargetSRRegressor)
30✔
UNCOV
172
        @assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.")
×
UNCOV
173
        repeat(w', size(y_t, 1))
×
174
    else
175
        w
30✔
176
    end
177
    search_state::types.state = equation_search(
30✔
178
        X_t,
179
        y_t;
180
        niterations=m.niterations,
181
        weights=w_t,
182
        variable_names=variable_names,
183
        options=options,
184
        parallelism=m.parallelism,
185
        numprocs=m.numprocs,
186
        procs=m.procs,
187
        addprocs_function=m.addprocs_function,
188
        heap_size_hint_in_bytes=m.heap_size_hint_in_bytes,
189
        runtests=m.runtests,
190
        saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state),
191
        return_state=true,
192
        loss_type=m.loss_type,
193
        X_units=X_units_clean,
194
        y_units=y_units_clean,
195
        verbosity=verbosity,
196
        extra=isnothing(classes) ? (;) : (; classes),
197
        # Help out with inference:
198
        v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2),
199
    )
200
    fitresult = (;
30✔
201
        state=search_state,
202
        num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1),
203
        options=options,
204
        variable_names=variable_names,
205
        y_variable_names=y_variable_names,
206
        y_is_table=MMI.istable(y),
207
        X_units=X_units_clean,
208
        y_units=y_units_clean,
209
        types=(
210
            T=hof_eltype(search_state[2]),
211
            X_t=typeof(X_t),
212
            y_t=typeof(y_t),
213
            w_t=typeof(w_t),
214
            state=typeof(search_state),
215
            X_units=typeof(X_units),
216
            y_units=typeof(y_units),
217
            X_units_clean=typeof(X_units_clean),
218
            y_units_clean=typeof(y_units_clean),
219
        ),
220
    )::(old_fitresult === nothing ? Any : typeof(old_fitresult))
221
    return (fitresult, nothing, full_report(m, fitresult))
30✔
222
end
223
hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T
28✔
224
hof_eltype(::Type{V}) where {V<:Vector} = hof_eltype(eltype(V))
11✔
225
hof_eltype(h) = hof_eltype(typeof(h))
30✔
226

227
function clean_units(units)
15✔
228
    !isa(units, AbstractDimensions) && error("Unexpected units.")
18✔
229
    iszero(units) && return nothing
18✔
230
    return units
12✔
231
end
232
function clean_units(units::Vector)
62✔
233
    !all(Base.Fix2(isa, AbstractDimensions), units) && error("Unexpected units.")
88✔
234
    all(iszero, units) && return nothing
84✔
235
    return units
36✔
236
end
237

238
function get_matrix_and_info(X, ::Type{D}) where {D}
66✔
239
    sch = MMI.istable(X) ? MMI.schema(X) : nothing
66✔
240
    Xm_t = MMI.matrix(X; transpose=true)
66✔
241
    colnames = if sch === nothing
66✔
UNCOV
242
        [map(i -> "x$(subscriptify(i))", axes(Xm_t, 1))...]
×
243
    else
244
        [string.(sch.names)...]
150✔
245
    end
246
    D_promoted = get_dimensions_type(Xm_t, D)
66✔
247
    Xm_t_strip, X_units = unwrap_units_single(Xm_t, D_promoted)
66✔
248
    return Xm_t_strip, colnames, X_units
66✔
249
end
250

251
function format_input_for(::SRRegressor, y, ::Type{D}) where {D}
17✔
252
    @assert(
18✔
253
        !(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)),
254
        "For multi-output regression, please use `MultitargetSRRegressor`."
255
    )
256
    y_t = vec(y)
18✔
257
    colnames = nothing
18✔
258
    D_promoted = get_dimensions_type(y_t, D)
18✔
259
    y_t_strip, y_units = unwrap_units_single(y_t, D_promoted)
18✔
260
    return y_t_strip, colnames, y_units
18✔
261
end
262
function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D}
10✔
263
    @assert(
12✔
264
        MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1),
265
        "For single-output regression, please use `SRRegressor`."
266
    )
267
    return get_matrix_and_info(y, D)
12✔
268
end
269
function validate_variable_names(variable_names, fitresult)
20✔
270
    @assert(
24✔
271
        variable_names == fitresult.variable_names,
272
        "Variable names do not match fitted regressor."
273
    )
274
    return nothing
24✔
275
end
276
function validate_units(X_units, old_X_units)
22✔
277
    @assert(
24✔
278
        all(X_units .== old_X_units),
279
        "Units of new data do not match units of fitted regressor."
280
    )
281
    return nothing
24✔
282
end
283

284
# TODO: Test whether this conversion poses any issues in data normalization...
285
function dimension_with_fallback(q::UnionAbstractQuantity{T}, ::Type{D}) where {T,D}
6,692✔
286
    return dimension(convert(Quantity{T,D}, q))::D
8,028✔
287
end
288
function dimension_with_fallback(_, ::Type{D}) where {D}
4,346✔
289
    return D()
5,214✔
290
end
UNCOV
291
function prediction_warn()
×
UNCOV
292
    @warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction."
×
293
end
294

UNCOV
295
wrap_units(v, ::Nothing, ::Integer) = v
×
296
wrap_units(v, ::Nothing, ::Nothing) = v
6✔
297
wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v)
48✔
298
wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v)
48✔
299

UNCOV
300
function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T}
×
UNCOV
301
    prediction_warn()
×
UNCOV
302
    out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T))
×
UNCOV
303
    return wrap_units(out, fitresult.y_units, nothing)
×
304
end
UNCOV
305
function prediction_fallback(
×
306
    ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype
307
) where {T}
UNCOV
308
    prediction_warn()
×
UNCOV
309
    out_cols = [
×
310
        wrap_units(
311
            fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i
312
        ) for i in 1:(fitresult.num_targets)
313
    ]
UNCOV
314
    out_matrix = hcat(out_cols...)
×
UNCOV
315
    if !fitresult.y_is_table
×
UNCOV
316
        return out_matrix
×
317
    else
318
        return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
×
319
    end
320
end
321

322
compat_ustrip(A::QuantityArray) = ustrip(A)
6✔
323
compat_ustrip(A) = ustrip.(A)
234✔
324

325
"""
326
    unwrap_units_single(::AbstractArray, ::Type{<:AbstractDimensions})
327

328
Remove units from some features in a matrix, and return, as a tuple,
329
(1) the matrix with stripped units, and (2) the dimensions for those features.
330
"""
331
function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D}
84✔
332
    dims = D[dimension_with_fallback(first(row), D) for row in eachrow(A)]
222✔
333
    @inbounds for (i, row) in enumerate(eachrow(A))
168✔
334
        all(xi -> dimension_with_fallback(xi, D) == dims[i], row) ||
11,531✔
335
            error("Inconsistent units in feature $i of matrix.")
336
    end
300✔
337
    return stack(compat_ustrip, eachrow(A); dims=1)::AbstractMatrix, dims
84✔
338
end
339
function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D}
16✔
340
    dims = dimension_with_fallback(first(v), D)
18✔
341
    all(xi -> dimension_with_fallback(xi, D) == dims, v) ||
1,734✔
342
        error("Inconsistent units in vector.")
343
    return compat_ustrip(v)::AbstractVector, dims
18✔
344
end
345

UNCOV
346
function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
×
UNCOV
347
    report = full_report(m, fitresult)
×
UNCOV
348
    return (;
×
349
        best_idx=report.best_idx,
350
        equations=report.equations,
351
        equation_strings=report.equation_strings,
352
    )
353
end
354

355
function eval_tree_mlj(
30✔
356
    tree::AbstractExpression,
357
    X_t,
358
    classes,
359
    m::AbstractSRRegressor,
360
    ::Type{T},
361
    fitresult,
362
    i,
363
    prototype,
364
) where {T}
365
    out, completed = if isnothing(classes)
30✔
366
        eval_tree_array(tree, X_t, fitresult.options)
36✔
367
    else
368
        eval_tree_array(tree, X_t, classes, fitresult.options)
33✔
369
    end
370
    if completed
30✔
371
        return wrap_units(out, fitresult.y_units, i)
30✔
372
    else
373
        return prediction_fallback(T, m, X_t, fitresult, prototype)
×
374
    end
375
end
376

377
function MMI.predict(
58✔
378
    m::M, fitresult, Xnew; idx=nothing, classes=nothing
379
) where {M<:AbstractSRRegressor}
380
    if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data))
35✔
UNCOV
381
        @assert(
×
382
            haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2,
383
            "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`."
384
        )
NEW
385
        return MMI.predict(m, fitresult, Xnew.data; idx=Xnew.idx, classes)
×
386
    end
387
    if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes)
30✔
388
        if !(Xnew isa NamedTuple)
6✔
NEW
389
            error("Classes can only be specified with named tuples.")
×
390
        end
391
        Xnew2 = Base.structdiff(Xnew, (; Xnew.classes))
6✔
392
        return MMI.predict(m, fitresult, Xnew2; idx, Xnew.classes)
6✔
393
    end
394

395
    params = full_report(m, fitresult; v_with_strings=Val(false))
33✔
396
    prototype = MMI.istable(Xnew) ? Xnew : nothing
24✔
397
    Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type)
24✔
398
    T = promote_type(eltype(Xnew_t), fitresult.types.T)
24✔
399

400
    if isempty(params.equations) || any(isempty, params.equations)
66✔
UNCOV
401
        @warn "Equations not found. Returning 0s for prediction."
×
UNCOV
402
        return prediction_fallback(T, m, Xnew_t, fitresult, prototype)
×
403
    end
404

405
    X_units_clean = clean_units(X_units)
26✔
406
    validate_variable_names(variable_names, fitresult)
24✔
407
    validate_units(X_units_clean, fitresult.X_units)
36✔
408

409
    idx = idx === nothing ? params.best_idx : idx
24✔
410

411
    if M <: SRRegressor
24✔
412
        return eval_tree_mlj(
18✔
413
            params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype
414
        )
415
    elseif M <: MultitargetSRRegressor
6✔
416
        outs = [
6✔
417
            eval_tree_mlj(
418
                params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype
419
            ) for i in eachindex(idx, params.equations)
420
        ]
421
        out_matrix = reduce(hcat, outs)
6✔
422
        if !fitresult.y_is_table
6✔
UNCOV
423
            return out_matrix
×
424
        else
425
            return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
6✔
426
        end
427
    end
428
end
429

430
function get_equation_strings_for(::SRRegressor, trees, options, variable_names)
18✔
431
    return (t -> string_tree(t, options; variable_names=variable_names)).(trees)
139✔
432
end
433
function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names)
10✔
434
    return [
12✔
435
        (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees
134✔
436
    ]
437
end
438

439
function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L<:LOSS_TYPE}
144✔
440
    # Same as in PySR:
441
    # https://github.com/MilesCranmer/PySR/blob/e74b8ad46b163c799908b3aa4d851cf8457c79ef/pysr/sr.py#L2318-L2332
442
    # threshold = 1.5 * minimum_loss
443
    # Then, we get max score of those below the threshold.
444
    threshold = 1.5 * minimum(losses)
84✔
445
    return argmax([
72✔
446
        (losses[i] <= threshold) ? scores[i] : typemin(L) for i in eachindex(losses)
447
    ])
448
end
449

450
function dispatch_selection_for(m::SRRegressor, trees, losses, scores, complexities)::Int
36✔
451
    length(trees) == 0 && return 0
36✔
452
    return m.selection_method(;
36✔
453
        trees=trees, losses=losses, scores=scores, complexities=complexities
454
    )
455
end
456
function dispatch_selection_for(
18✔
457
    m::MultitargetSRRegressor, trees, losses, scores, complexities
458
)
459
    any(t -> length(t) == 0, trees) && return fill(0, length(trees))
90✔
460
    return [
18✔
461
        m.selection_method(;
462
            trees=trees[i], losses=losses[i], scores=scores[i], complexities=complexities[i]
463
        ) for i in eachindex(trees)
464
    ]
465
end
466

467
MMI.metadata_pkg(
468
    AbstractSRRegressor;
469
    name="SymbolicRegression",
470
    uuid="8254be44-1295-4e6a-a16d-46603ac705cb",
471
    url="https://github.com/MilesCranmer/SymbolicRegression.jl",
472
    julia=true,
473
    license="Apache-2.0",
474
    is_wrapper=false,
475
)
476

477
const input_scitype = Union{
478
    MMI.Table(MMI.Continuous),
479
    AbstractMatrix{<:MMI.Continuous},
480
    MMI.Table(MMI.Continuous, MMI.Count),
481
}
482

483
# TODO: Allow for Count data, and coerce it into Continuous as needed.
484
MMI.metadata_model(
485
    SRRegressor;
486
    input_scitype,
487
    target_scitype=AbstractVector{<:MMI.Continuous},
488
    supports_weights=true,
489
    reports_feature_importances=false,
490
    load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
491
    human_name="Symbolic Regression via Evolutionary Search",
492
)
493
MMI.metadata_model(
494
    MultitargetSRRegressor;
495
    input_scitype,
496
    target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
497
    supports_weights=true,
498
    reports_feature_importances=false,
499
    load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
500
    human_name="Multi-Target Symbolic Regression via Evolutionary Search",
501
)
502

503
function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)
42✔
504
    docstring = """$(MMI.doc_header(eval(model_name)))
42✔
505

506
    $(description)
507

508
    # Hyper-parameters
509
    """
510

511
    # TODO: These ones are copied (or written) manually:
512
    append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search.
30✔
513
        More iterations will improve the results.
514
    - `parallelism=:multithreading`: What parallelism mode to use.
515
        The options are `:multithreading`, `:multiprocessing`, and `:serial`.
516
        By default, multithreading will be used. Multithreading uses less memory,
517
        but multiprocessing can handle multi-node compute. If using `:multithreading`
518
        mode, the number of threads available to julia are used. If using
519
        `:multiprocessing`, `numprocs` processes will be created dynamically if
520
        `procs` is unset. If you have already allocated processes, pass them
521
        to the `procs` argument and they will be used.
522
        You may also pass a string instead of a symbol, like `"multithreading"`.
523
    - `numprocs::Union{Int, Nothing}=nothing`:  The number of processes to use,
524
        if you want `equation_search` to set this up automatically. By default
525
        this will be `4`, but can be any number (you should pick a number <=
526
        the number of cores available).
527
    - `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up
528
        a distributed run manually with `procs = addprocs()` and `@everywhere`,
529
        pass the `procs` to this keyword argument.
530
    - `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing
531
        (`parallelism=:multithreading`), and are not passing `procs` manually,
532
        then they will be allocated dynamically using `addprocs`. However,
533
        you may also pass a custom function to use instead of `addprocs`.
534
        This function should take a single positional argument,
535
        which is the number of processes to use, as well as the `lazy` keyword argument.
536
        For example, if set up on a slurm cluster, you could pass
537
        `addprocs_function = addprocs_slurm`, which will set up slurm processes.
538
    - `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint`
539
        flag on Julia processes, recommending garbage collection once a process
540
        is close to the recommended size. This is important for long-running distributed
541
        jobs where each process has an independent memory, and can help avoid
542
        out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
543
    - `runtests::Bool=true`: Whether to run (quick) tests before starting the
544
        search, to see if there will be any problems during the equation search
545
        related to the host environment.
546
    - `loss_type::Type=Nothing`: If you would like to use a different type
547
        for the loss than for the data you passed, specify the type here.
548
        Note that if you pass complex data `::Complex{L}`, then the loss
549
        type will automatically be set to `L`.
550
    - `selection_method::Function`: Function to selection expression from
551
        the Pareto frontier for use in `predict`.
552
        See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
553
        This function should return a single integer specifying
554
        the index of the expression to use. By default, this maximizes
555
        the score (a pound-for-pound rating) of expressions reaching the threshold
556
        of 1.5x the minimum loss. To override this at prediction time, you can pass
557
        a named tuple with keys `data` and `idx` to `predict`. See the Operations
558
        section for details.
559
    - `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
560
        the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
561
    """
562

563
    bottom = """
42✔
564
    # Operations
565

566
    - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
567
      should have same scitype as `X` above. The expression used for prediction is defined
568
      by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
569
    - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
570
      `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
571
      `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
572

573
    $(bottom_matter)
574
    """
575

576
    # Remove common indentation:
577
    docstring = replace(docstring, r"^    " => "")
42✔
578
    extra_arguments = replace(append_arguments, r"^    " => "")
42✔
579
    bottom = replace(bottom, r"^    " => "")
42✔
580

581
    # Add parameter descriptions:
582
    docstring = docstring * OPTION_DESCRIPTIONS
42✔
583
    docstring = docstring * extra_arguments
42✔
584
    docstring = docstring * bottom
42✔
585
    return quote
42✔
586
        @doc $docstring $model_name
587
    end
588
end
589

590
#https://arxiv.org/abs/2305.01582
591
eval(
592
    tag_with_docstring(
593
        :SRRegressor,
594
        replace(
595
            """
596
    Single-target Symbolic Regression regressor (`SRRegressor`) searches
597
    for symbolic expressions that predict a single target variable from
598
    a set of input variables. All data is assumed to be `Continuous`.
599
    The search is performed using an evolutionary algorithm.
600
    This algorithm is described in the paper
601
    https://arxiv.org/abs/2305.01582.
602

603
    # Training data
604

605
    In MLJ or MLJBase, bind an instance `model` to data with
606

607
        mach = machine(model, X, y)
608

609
    OR
610

611
        mach = machine(model, X, y, w)
612

613
    Here:
614

615
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
616
      `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
617
      expressions will be taken from the column names of `X`, if available. Units in columns
618
      of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
619

620
    - `y` is the target, which can be any `AbstractVector` whose element scitype is
621
        `Continuous`; check the scitype with `scitype(y)`. Units in `y` (use `DynamicQuantities`
622
        for units) will trigger dimensional analysis to be used.
623

624
    - `w` is the observation weights which can either be `nothing` (default) or an
625
      `AbstractVector` whoose element scitype is `Count` or `Continuous`.
626

627
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
628
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
629
    Note that unlike other regressors, symbolic regression stores a list of
630
    trained models. The model chosen from this list is defined by the function
631
    `selection_method` keyword argument, which by default balances accuracy
632
    and complexity. You can override this at prediction time by passing a named
633
    tuple with keys `data` and `idx`.
634

635
    """,
636
            r"^    " => "",
637
        ),
638
        replace(
639
            """
640
    # Fitted parameters
641

642
    The fields of `fitted_params(mach)` are:
643

644
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
645
       as determined by the `selection_method` function. Override in `predict` by passing
646
        a named tuple with keys `data` and `idx`.
647
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
648
      in a dominating Pareto frontier (i.e., the best expressions found for
649
      each complexity). `T` is equal to the element type
650
      of the passed data.
651
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
652
      represented as strings for easy inspection.
653

654
    # Report
655

656
    The fields of `report(mach)` are:
657

658
    - `best_idx::Int`: The index of the best expression in the Pareto frontier,
659
       as determined by the `selection_method` function. Override in `predict` by passing
660
       a named tuple with keys `data` and `idx`.
661
    - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
662
      in a dominating Pareto frontier (i.e., the best expressions found for
663
      each complexity).
664
    - `equation_strings::Vector{String}`: The expressions discovered by the search,
665
      represented as strings for easy inspection.
666
    - `complexities::Vector{Int}`: The complexity of each expression in the Pareto frontier.
667
    - `losses::Vector{L}`: The loss of each expression in the Pareto frontier, according
668
      to the loss function specified in the model. The type `L` is the loss type, which
669
      is usually the same as the element type of data passed (i.e., `T`), but can differ
670
      if complex data types are passed.
671
    - `scores::Vector{L}`: A metric which considers both the complexity and loss of an expression,
672
      equal to the change in the log-loss divided by the change in complexity, relative to
673
      the previous expression along the Pareto frontier. A larger score aims to indicate
674
      an expression is more likely to be the true expression generating the data, but
675
      this is very problem-dependent and generally several other factors should be considered.
676

677
    # Examples
678

679
    ```julia
680
    using MLJ
681
    SRRegressor = @load SRRegressor pkg=SymbolicRegression
682
    X, y = @load_boston
683
    model = SRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
684
    mach = machine(model, X, y)
685
    fit!(mach)
686
    y_hat = predict(mach, X)
687
    # View the equation used:
688
    r = report(mach)
689
    println("Equation used:", r.equation_strings[r.best_idx])
690
    ```
691

692
    With units and variable names:
693

694
    ```julia
695
    using MLJ
696
    using DynamicQuantities
697
    SRegressor = @load SRRegressor pkg=SymbolicRegression
698

699
    X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km")
700
    y = @. X.x2 / X.x1 + 0.5us"h"
701
    model = SRRegressor(binary_operators=[+, -, *, /])
702
    mach = machine(model, X, y)
703
    fit!(mach)
704
    y_hat = predict(mach, X)
705
    # View the equation used:
706
    r = report(mach)
707
    println("Equation used:", r.equation_strings[r.best_idx])
708
    ```
709

710
    See also [`MultitargetSRRegressor`](@ref).
711
    """,
712
            r"^    " => "",
713
        ),
714
    ),
715
)
716
eval(
717
    tag_with_docstring(
718
        :MultitargetSRRegressor,
719
        replace(
720
            """
721
    Multi-target Symbolic Regression regressor (`MultitargetSRRegressor`)
722
    conducts several searches for expressions that predict each target variable
723
    from a set of input variables. All data is assumed to be `Continuous`.
724
    The search is performed using an evolutionary algorithm.
725
    This algorithm is described in the paper
726
    https://arxiv.org/abs/2305.01582.
727

728
    # Training data
729
    In MLJ or MLJBase, bind an instance `model` to data with
730

731
        mach = machine(model, X, y)
732

733
    OR
734

735
        mach = machine(model, X, y, w)
736

737
    Here:
738

739
    - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype
740
    `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered
741
    expressions will be taken from the column names of `X`, if available. Units in columns
742
    of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
743

744
    - `y` is the target, which can be any table of target variables whose element
745
      scitype is `Continuous`; check the scitype with `schema(y)`. Units in columns of
746
      `y` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used.
747

748
    - `w` is the observation weights which can either be `nothing` (default) or an
749
      `AbstractVector` whoose element scitype is `Count` or `Continuous`. The same
750
      weights are used for all targets.
751

752
    Train the machine using `fit!(mach)`, inspect the discovered expressions with
753
    `report(mach)`, and predict on new data with `predict(mach, Xnew)`.
754
    Note that unlike other regressors, symbolic regression stores a list of lists of
755
    trained models. The models chosen from each of these lists is defined by the function
756
    `selection_method` keyword argument, which by default balances accuracy
757
    and complexity. You can override this at prediction time by passing a named
758
    tuple with keys `data` and `idx`.
759

760
    """,
761
            r"^    " => "",
762
        ),
763
        replace(
764
            """
765
    # Fitted parameters
766

767
    The fields of `fitted_params(mach)` are:
768

769
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
770
      as determined by the `selection_method` function. Override in `predict` by passing
771
      a named tuple with keys `data` and `idx`.
772
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
773
      in a dominating Pareto frontier (i.e., the best expressions found for
774
      each complexity). The outer vector is indexed by target variable, and the inner
775
      vector is ordered by increasing complexity. `T` is equal to the element type
776
      of the passed data.
777
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
778
      represented as strings for easy inspection.
779

780
    # Report
781

782
    The fields of `report(mach)` are:
783

784
    - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
785
       as determined by the `selection_method` function. Override in `predict` by passing
786
       a named tuple with keys `data` and `idx`.
787
    - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
788
      in a dominating Pareto frontier (i.e., the best expressions found for
789
      each complexity). The outer vector is indexed by target variable, and the inner
790
      vector is ordered by increasing complexity.
791
    - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search,
792
      represented as strings for easy inspection.
793
    - `complexities::Vector{Vector{Int}}`: The complexity of each expression in each Pareto frontier.
794
    - `losses::Vector{Vector{L}}`: The loss of each expression in each Pareto frontier, according
795
      to the loss function specified in the model. The type `L` is the loss type, which
796
      is usually the same as the element type of data passed (i.e., `T`), but can differ
797
      if complex data types are passed.
798
    - `scores::Vector{Vector{L}}`: A metric which considers both the complexity and loss of an expression,
799
      equal to the change in the log-loss divided by the change in complexity, relative to
800
      the previous expression along the Pareto frontier. A larger score aims to indicate
801
      an expression is more likely to be the true expression generating the data, but
802
      this is very problem-dependent and generally several other factors should be considered.
803

804
    # Examples
805

806
    ```julia
807
    using MLJ
808
    MultitargetSRRegressor = @load MultitargetSRRegressor pkg=SymbolicRegression
809
    X = (a=rand(100), b=rand(100), c=rand(100))
810
    Y = (y1=(@. cos(X.c) * 2.1 - 0.9), y2=(@. X.a * X.b + X.c))
811
    model = MultitargetSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100)
812
    mach = machine(model, X, Y)
813
    fit!(mach)
814
    y_hat = predict(mach, X)
815
    # View the equations used:
816
    r = report(mach)
817
    for (output_index, (eq, i)) in enumerate(zip(r.equation_strings, r.best_idx))
818
        println("Equation used for ", output_index, ": ", eq[i])
819
    end
820
    ```
821

822
    See also [`SRRegressor`](@ref).
823
    """,
824
            r"^    " => "",
825
        ),
826
    ),
827
)
828

829
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc