• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.99
/src/ExpressionBuilder.jl
1
module ExpressionBuilderModule
2

3
using DispatchDoctor: @unstable
4
using DynamicExpressions:
5
    AbstractExpressionNode,
6
    AbstractExpression,
7
    Expression,
8
    ParametricExpression,
9
    ParametricNode,
10
    constructorof,
11
    get_tree,
12
    get_contents,
13
    get_metadata,
14
    with_contents,
15
    with_metadata,
16
    count_scalar_constants,
17
    eval_tree_array
18
using Random: default_rng, AbstractRNG
19
using StatsBase: StatsBase
20
using ..CoreModule: Options, Dataset, DATA_TYPE
21
using ..HallOfFameModule: HallOfFame
22
using ..LossFunctionsModule: maybe_getindex
23
using ..InterfaceDynamicExpressionsModule: expected_array_type
24
using ..PopulationModule: Population
25
using ..PopMemberModule: PopMember
26

27
import DynamicExpressions: get_operators
28
import ..CoreModule: create_expression
29
import ..MutationFunctionsModule:
30
    make_random_leaf, crossover_trees, mutate_constant, mutate_factor
31
import ..LossFunctionsModule: eval_tree_dispatch
32
import ..ConstantOptimizationModule: count_constants_for_optimization
33

34
@unstable function create_expression(
169,233✔
35
    t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false)
36
) where {T,L,embed}
37
    return create_expression(
196,196✔
38
        constructorof(options.node_type)(; val=t), options, dataset, Val(embed)
39
    )
40
end
41
@unstable function create_expression(
905,729✔
42
    t::AbstractExpressionNode{T},
43
    options::Options,
44
    dataset::Dataset{T,L},
45
    ::Val{embed}=Val(false),
46
) where {T,L,embed}
47
    return constructorof(options.expression_type)(
1,093,039✔
48
        t; init_params(options, dataset, nothing, Val(embed))...
49
    )
50
end
51
function create_expression(
12✔
52
    ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false)
53
) where {T,L,embed}
54
    return ex
16✔
55
end
56
@unstable function init_params(
657,326✔
57
    options::Options,
58
    dataset::Dataset{T,L},
59
    prototype::Union{Nothing,AbstractExpression},
60
    ::Val{embed},
61
) where {T,L,embed}
62
    consistency_checks(options, prototype)
1,177,040✔
63
    return (;
785,067✔
64
        operators=embed ? options.operators : nothing,
65
        variable_names=embed ? dataset.variable_names : nothing,
66
        extra_init_params(
67
            options.expression_type, prototype, options, dataset, Val(embed)
68
        )...,
69
    )
70
end
71
function extra_init_params(
602,896✔
72
    ::Type{E},
73
    prototype::Union{Nothing,AbstractExpression},
74
    options::Options,
75
    dataset::Dataset{T},
76
    ::Val{embed},
77
) where {T,embed,E<:AbstractExpression}
78
    return (; options.expression_options...)
719,966✔
79
end
80
function extra_init_params(
57,571✔
81
    ::Type{E},
82
    prototype::Union{Nothing,ParametricExpression},
83
    options::Options,
84
    dataset::Dataset{T},
85
    ::Val{embed},
86
) where {T,embed,E<:ParametricExpression}
87
    num_params = options.expression_options.max_parameters
59,650✔
88
    num_classes = length(unique(dataset.extra.classes))
59,650✔
89
    parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing
59,649✔
90
    _parameters = if prototype === nothing
59,649✔
91
        randn(T, (num_params, num_classes))
150,696✔
92
    else
93
        copy(get_metadata(prototype).parameters)
32,763✔
94
    end
95
    return (; parameters=_parameters, parameter_names)
59,649✔
96
end
97

98
consistency_checks(::Options, prototype::Nothing) = nothing
589,868✔
99
function consistency_checks(options::Options, prototype)
163,364✔
100
    if prototype === nothing
189,798✔
NEW
101
        return nothing
×
102
    end
103
    @assert(
189,798✔
104
        prototype isa options.expression_type,
105
        "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))"
106
    )
107
    if prototype isa ParametricExpression
189,792✔
108
        if prototype.metadata.parameter_names !== nothing
16,320✔
109
            @assert(
6,096✔
110
                length(prototype.metadata.parameter_names) ==
111
                    options.expression_options.max_parameters,
112
                "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)"
113
            )
114
        end
115
        @assert size(prototype.metadata.parameters, 1) ==
16,314✔
116
            options.expression_options.max_parameters
117
    end
118
    return nothing
189,786✔
119
end
120

121
@unstable begin
122
    function embed_metadata(
103,327✔
123
        ex::AbstractExpression, options::Options, dataset::Dataset{T,L}
124
    ) where {T,L}
125
        return with_metadata(ex; init_params(options, dataset, ex, Val(true))...)
126,081✔
126
    end
127
    function embed_metadata(
104,385✔
128
        member::PopMember, options::Options, dataset::Dataset{T,L}
129
    ) where {T,L}
130
        return PopMember(
126,075✔
131
            embed_metadata(member.tree, options, dataset),
132
            member.score,
133
            member.loss,
134
            nothing;
135
            member.ref,
136
            member.parent,
137
            deterministic=options.deterministic,
138
        )
139
    end
140
    function embed_metadata(
2,991✔
141
        pop::Population, options::Options, dataset::Dataset{T,L}
142
    ) where {T,L}
143
        return Population(
3,654✔
144
            map(member -> embed_metadata(member, options, dataset), pop.members)
112,017✔
145
        )
146
    end
147
    function embed_metadata(
495✔
148
        hof::HallOfFame, options::Options, dataset::Dataset{T,L}
149
    ) where {T,L}
150
        return HallOfFame(
648✔
151
            map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists
14,058✔
152
        )
153
    end
154
    function embed_metadata(
282✔
155
        vec::Vector{H}, options::Options, dataset::Dataset{T,L}
156
    ) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
157
        return map(elem -> embed_metadata(elem, options, dataset), vec)
4,014✔
158
    end
159
end
160

161
"""Strips all metadata except for top-level information"""
162
function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L}
53,943✔
163
    return with_metadata(ex; init_params(options, dataset, ex, Val(false))...)
67,341✔
164
end
165
function strip_metadata(
5,066✔
166
    ex::ParametricExpression, options::Options, dataset::Dataset{T,L}
167
) where {T,L}
168
    return with_metadata(ex; init_params(options, dataset, ex, Val(false))...)
6,078✔
169
end
170
function strip_metadata(
58,997✔
171
    member::PopMember, options::Options, dataset::Dataset{T,L}
172
) where {T,L}
173
    return PopMember(
73,407✔
174
        strip_metadata(member.tree, options, dataset),
175
        member.score,
176
        member.loss,
177
        nothing;
178
        member.ref,
179
        member.parent,
180
        deterministic=options.deterministic,
181
    )
182
end
183
function strip_metadata(
595✔
184
    pop::Population, options::Options, dataset::Dataset{T,L}
185
) where {T,L}
186
    return Population(map(member -> strip_metadata(member, options, dataset), pop.members))
70,348✔
187
end
188
function strip_metadata(
127✔
189
    hof::HallOfFame, options::Options, dataset::Dataset{T,L}
190
) where {T,L}
191
    return HallOfFame(
174✔
192
        map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists
3,828✔
193
    )
194
end
195

196
function eval_tree_dispatch(
31,139,582✔
197
    tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx
198
) where {T<:DATA_TYPE}
199
    A = expected_array_type(dataset.X)
45,270,245✔
200
    return eval_tree_array(
47,778,605✔
201
        tree,
202
        maybe_getindex(dataset.X, :, idx),
203
        maybe_getindex(dataset.extra.classes, idx),
204
        options.operators,
205
    )::Tuple{A,Bool}
206
end
207

208
function make_random_leaf(
19,577,281✔
209
    nfeatures::Int,
210
    ::Type{T},
211
    ::Type{N},
212
    rng::AbstractRNG=default_rng(),
213
    options::Union{Options,Nothing}=nothing,
214
) where {T<:DATA_TYPE,N<:ParametricNode}
215
    choice = rand(rng, 1:3)
19,576,104✔
216
    if choice == 1
19,575,619✔
217
        return ParametricNode(; val=randn(rng, T))
6,526,954✔
218
    elseif choice == 2
13,051,336✔
219
        return ParametricNode(T; feature=rand(rng, 1:nfeatures))
6,524,261✔
220
    else
221
        tree = ParametricNode{T}()
6,528,510✔
222
        tree.val = zero(T)
6,528,433✔
223
        tree.degree = 0
6,528,460✔
224
        tree.feature = 0
6,528,456✔
225
        tree.constant = false
6,528,405✔
226
        tree.is_parameter = true
6,528,442✔
227
        tree.parameter = rand(
6,528,419✔
228
            rng, UInt16(1):UInt16(options.expression_options.max_parameters)
229
        )
230
        return tree
6,528,255✔
231
    end
232
end
233

234
function crossover_trees(
3,909,535✔
235
    ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng()
236
) where {T}
237
    tree1 = get_contents(ex1)
4,074,741✔
238
    tree2 = get_contents(ex2)
2,930,799✔
239
    out1, out2 = crossover_trees(tree1, tree2, rng)
4,453,470✔
240
    ex1 = with_contents(ex1, out1)
4,453,511✔
241
    ex2 = with_contents(ex2, out2)
4,453,510✔
242

243
    # We also randomly share parameters
244
    nparams1 = size(ex1.metadata.parameters, 1)
2,930,832✔
245
    nparams2 = size(ex2.metadata.parameters, 1)
2,930,815✔
246
    num_params_switch = min(nparams1, nparams2)
2,930,822✔
247
    idx_to_switch = StatsBase.sample(
3,194,965✔
248
        rng, 1:num_params_switch, num_params_switch; replace=false
249
    )
250
    for param_idx in idx_to_switch
2,930,781✔
251
        ex2_params = ex2.metadata.parameters[param_idx, :]
16,393,739✔
252
        ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :]
11,192,875✔
253
        ex1.metadata.parameters[param_idx, :] .= ex2_params
11,029,187✔
254
    end
6,509,634✔
255

256
    return ex1, ex2
2,930,830✔
257
end
258

259
function count_constants_for_optimization(ex::ParametricExpression)
157,746✔
260
    return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters)
166,984✔
261
end
262

263
function mutate_constant(
263,377✔
264
    ex::ParametricExpression{T},
265
    temperature,
266
    options::Options,
267
    rng::AbstractRNG=default_rng(),
268
) where {T<:DATA_TYPE}
269
    if rand(rng, Bool)
274,196✔
270
        # Normal mutation of inner constant
271
        tree = get_contents(ex)
98,882✔
272
        return with_contents(ex, mutate_constant(tree, temperature, options, rng))
98,882✔
273
    else
274
        # Mutate parameters
275
        parameter_index = rand(rng, 1:(options.expression_options.max_parameters))
98,465✔
276
        # We mutate all the parameters at once
277
        factor = mutate_factor(T, temperature, options, rng)
98,465✔
278
        ex.metadata.parameters[parameter_index, :] .*= factor
136,829✔
279
        return ex
98,465✔
280
    end
281
end
282

283
@unstable function get_operators(ex::AbstractExpression, options::Options)
73,464,696✔
284
    return get_operators(ex, options.operators)
104,967,265✔
285
end
286
@unstable function get_operators(ex::AbstractExpressionNode, options::Options)
1,484,040✔
287
    return get_operators(ex, options.operators)
2,058,254✔
288
end
289

290
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc