• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11394658450

17 Oct 2024 11:32PM UTC coverage: 95.332% (+0.6%) from 94.757%
11394658450

Pull #355

github

web-flow
Merge a9e5332c7 into 3892a6659
Pull Request #355: Create `TemplateExpression` for providing a pre-defined functional structure and constraints

253 of 257 new or added lines in 15 files covered. (98.44%)

3 existing lines in 2 files now uncovered.

2818 of 2956 relevant lines covered (95.33%)

38419396.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.8
/src/OptionsStruct.jl
1
module OptionsStructModule
2

3
using DispatchDoctor: @unstable
4
using Optim: Optim
5
using DynamicExpressions:
6
    AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum
7
using LossFunctions: SupervisedLoss
8

9
import ..MutationWeightsModule: AbstractMutationWeights
10

11
"""
12
This struct defines how complexity is calculated.
13

14
# Fields
15
- `use`: Shortcut indicating whether we use custom complexities,
16
    or just use 1 for everything.
17
- `binop_complexities`: Complexity of each binary operator.
18
- `unaop_complexities`: Complexity of each unary operator.
19
- `variable_complexity`: Complexity of using a variable.
20
- `constant_complexity`: Complexity of using a constant.
21
"""
22
struct ComplexityMapping{T<:Real,VC<:Union{T,AbstractVector{T}}}
23
    use::Bool
6,600✔
24
    binop_complexities::Vector{T}
25
    unaop_complexities::Vector{T}
26
    variable_complexity::VC
27
    constant_complexity::T
28
end
29

30
Base.eltype(::ComplexityMapping{T}) where {T} = T
×
31

32
"""Promote type when defining complexity mapping."""
33
function ComplexityMapping(;
88✔
34
    binop_complexities::Vector{T1},
35
    unaop_complexities::Vector{T2},
36
    variable_complexity::Union{T3,AbstractVector{T3}},
37
    constant_complexity::T4,
38
) where {T1<:Real,T2<:Real,T3<:Real,T4<:Real}
39
    T = promote_type(T1, T2, T3, T4)
44✔
40
    vc = map(T, variable_complexity)
44✔
41
    return ComplexityMapping{T,typeof(vc)}(
44✔
42
        true,
43
        map(T, binop_complexities),
44
        map(T, unaop_complexities),
45
        vc,
46
        T(constant_complexity),
47
    )
48
end
49

50
function ComplexityMapping(
6,556✔
51
    ::Nothing, ::Nothing, ::Nothing, binary_operators, unary_operators
52
)
53
    # If no customization provided, then we simply
54
    # turn off the complexity mapping
55
    use = false
6,556✔
56
    return ComplexityMapping{Int,Int}(use, zeros(Int, 0), zeros(Int, 0), 0, 0)
6,556✔
57
end
58
function ComplexityMapping(
44✔
59
    complexity_of_operators,
60
    complexity_of_variables,
61
    complexity_of_constants,
62
    binary_operators,
63
    unary_operators,
64
)
65
    _complexity_of_operators = if complexity_of_operators === nothing
44✔
66
        Dict{Function,Int64}()
×
67
    else
68
        # Convert to dict:
69
        Dict(complexity_of_operators)
44✔
70
    end
71

72
    VAR_T = if (complexity_of_variables !== nothing)
44✔
73
        if complexity_of_variables isa AbstractVector
20✔
74
            eltype(complexity_of_variables)
4✔
75
        else
76
            typeof(complexity_of_variables)
16✔
77
        end
78
    else
79
        Int
24✔
80
    end
81
    CONST_T = if (complexity_of_constants !== nothing)
44✔
82
        typeof(complexity_of_constants)
8✔
83
    else
84
        Int
36✔
85
    end
86
    OP_T = eltype(_complexity_of_operators).parameters[2]
44✔
87

88
    T = promote_type(VAR_T, CONST_T, OP_T)
44✔
89

90
    # If not in dict, then just set it to 1.
91
    binop_complexities = T[
88✔
92
        (haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) #
93
        for op in binary_operators
94
    ]
95
    unaop_complexities = T[
77✔
96
        (haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) #
97
        for op in unary_operators
98
    ]
99

100
    variable_complexity = if complexity_of_variables !== nothing
44✔
101
        map(T, complexity_of_variables)
20✔
102
    else
103
        one(T)
24✔
104
    end
105
    constant_complexity = if complexity_of_constants !== nothing
44✔
106
        map(T, complexity_of_constants)
8✔
107
    else
108
        one(T)
36✔
109
    end
110

111
    return ComplexityMapping(;
44✔
112
        binop_complexities, unaop_complexities, variable_complexity, constant_complexity
113
    )
114
end
115

116
"""
117
Controls level of specialization we compile into `Options`.
118

119
Overload if needed for custom expression types.
120
"""
NEW
121
operator_specialization(
×
122
    ::Type{O}, ::Type{<:AbstractExpression}
123
) where {O<:AbstractOperatorEnum} = O
124
@unstable operator_specialization(::Type{<:OperatorEnum}, ::Type{<:AbstractExpression}) =
1,667✔
125
    OperatorEnum
126

127
"""
128
    AbstractOptions
129

130
An abstract type that stores all search hyperparameters for SymbolicRegression.jl.
131
The standard implementation is [`Options`](@ref).
132

133
You may wish to create a new subtypes of `AbstractOptions` to override certain functions
134
or create new behavior. Ensure that this new type has all properties of [`Options`](@ref).
135

136
For example, if we have new options that we want to add to `Options`:
137

138
```julia
139
Base.@kwdef struct MyNewOptions
140
    a::Float64 = 1.0
141
    b::Int = 3
142
end
143
```
144

145
we can create a combined options type that forwards properties to each corresponding type:
146

147
```julia
148
struct MyOptions{O<:SymbolicRegression.Options} <: SymbolicRegression.AbstractOptions
149
    new_options::MyNewOptions
150
    sr_options::O
151
end
152
const NEW_OPTIONS_KEYS = fieldnames(MyNewOptions)
153

154
# Constructor with both sets of parameters:
155
function MyOptions(; kws...)
156
    new_options_keys = filter(k -> k in NEW_OPTIONS_KEYS, keys(kws))
157
    new_options = MyNewOptions(; NamedTuple(new_options_keys .=> Tuple(kws[k] for k in new_options_keys))...)
158
    sr_options_keys = filter(k -> !(k in NEW_OPTIONS_KEYS), keys(kws))
159
    sr_options = SymbolicRegression.Options(; NamedTuple(sr_options_keys .=> Tuple(kws[k] for k in sr_options_keys))...)
160
    return MyOptions(new_options, sr_options)
161
end
162

163
# Make all `Options` available while also making `new_options` accessible
164
function Base.getproperty(options::MyOptions, k::Symbol)
165
    if k in NEW_OPTIONS_KEYS
166
        return getproperty(getfield(options, :new_options), k)
167
    else
168
        return getproperty(getfield(options, :sr_options), k)
169
    end
170
end
171

172
Base.propertynames(options::MyOptions) = (NEW_OPTIONS_KEYS..., fieldnames(SymbolicRegression.Options)...)
173
```
174

175
which would let you access `a` and `b` from `MyOptions` objects, as well as making
176
all properties of `Options` available for internal methods in SymbolicRegression.jl
177
"""
178
abstract type AbstractOptions end
179

180
struct Options{
181
    CM<:ComplexityMapping,
182
    OP<:AbstractOperatorEnum,
183
    N<:AbstractExpressionNode,
184
    E<:AbstractExpression,
185
    EO<:NamedTuple,
186
    MW<:AbstractMutationWeights,
187
    _turbo,
188
    _bumper,
189
    _return_state,
190
    AD,
191
} <: AbstractOptions
192
    operators::OP
187,664✔
193
    bin_constraints::Vector{Tuple{Int,Int}}
194
    una_constraints::Vector{Int}
195
    complexity_mapping::CM
196
    tournament_selection_n::Int
197
    tournament_selection_p::Float32
198
    parsimony::Float32
199
    dimensional_constraint_penalty::Union{Float32,Nothing}
200
    dimensionless_constants_only::Bool
201
    alpha::Float32
202
    maxsize::Int
203
    maxdepth::Int
204
    turbo::Val{_turbo}
205
    bumper::Val{_bumper}
206
    migration::Bool
207
    hof_migration::Bool
208
    should_simplify::Bool
209
    should_optimize_constants::Bool
210
    output_file::String
211
    populations::Int
212
    perturbation_factor::Float32
213
    annealing::Bool
214
    batching::Bool
215
    batch_size::Int
216
    mutation_weights::MW
217
    crossover_probability::Float32
218
    warmup_maxsize_by::Float32
219
    use_frequency::Bool
220
    use_frequency_in_tournament::Bool
221
    adaptive_parsimony_scaling::Float64
222
    population_size::Int
223
    ncycles_per_iteration::Int
224
    fraction_replaced::Float32
225
    fraction_replaced_hof::Float32
226
    topn::Int
227
    verbosity::Union{Int,Nothing}
228
    print_precision::Int
229
    save_to_file::Bool
230
    probability_negate_constant::Float32
231
    nuna::Int
232
    nbin::Int
233
    seed::Union{Int,Nothing}
234
    elementwise_loss::Union{SupervisedLoss,Function}
235
    loss_function::Union{Nothing,Function}
236
    node_type::Type{N}
237
    expression_type::Type{E}
238
    expression_options::EO
239
    progress::Union{Bool,Nothing}
240
    terminal_width::Union{Int,Nothing}
241
    optimizer_algorithm::Optim.AbstractOptimizer
242
    optimizer_probability::Float32
243
    optimizer_nrestarts::Int
244
    optimizer_options::Optim.Options
245
    autodiff_backend::AD
246
    recorder_file::String
247
    prob_pick_first::Float32
248
    early_stop_condition::Union{Function,Nothing}
249
    return_state::Val{_return_state}
250
    timeout_in_seconds::Union{Float64,Nothing}
251
    max_evals::Union{Int,Nothing}
252
    skip_mutation_failures::Bool
253
    nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
254
    deterministic::Bool
255
    define_helper_functions::Bool
256
    use_recorder::Bool
257
end
258

259
function Base.print(io::IO, options::Options)
4✔
260
    return print(
4✔
261
        io,
262
        "Options(" *
263
        "binops=$(options.operators.binops), " *
264
        "unaops=$(options.operators.unaops), "
265
        # Fill in remaining fields automatically:
266
        *
267
        join(
268
            [
269
                if fieldname in (:optimizer_options, :mutation_weights)
270
                    "$(fieldname)=..."
8✔
271
                else
272
                    "$(fieldname)=$(getfield(options, fieldname))"
240✔
273
                end for
274
                fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin]
275
            ],
276
            ", ",
277
        ) *
278
        ")",
279
    )
280
end
281
Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options)
×
282

283
specialized_options(options::AbstractOptions) = options
×
284
@unstable function specialized_options(options::Options)
181,098✔
285
    return _specialized_options(options)
181,097✔
286
end
287
@generated function _specialized_options(options::O) where {O<:Options}
180,994✔
288
    # Return an options struct with concrete operators
289
    type_parameters = O.parameters
23✔
290
    fields = Any[:(getfield(options, $(QuoteNode(k)))) for k in fieldnames(O)]
1,518✔
291
    quote
23✔
292
        operators = getfield(options, :operators)
181,114✔
293
        Options{$(type_parameters[1]),typeof(operators),$(type_parameters[3:end]...)}(
181,104✔
294
            $(fields...)
295
        )
296
    end
297
end
298

299
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc