• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9639805727

24 Jun 2024 05:00AM UTC coverage: 94.475% (-0.1%) from 94.617%
9639805727

Pull #326

github

web-flow
Merge 3ba1556f8 into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

239 of 250 new or added lines in 15 files covered. (95.6%)

4 existing lines in 3 files now uncovered.

2548 of 2697 relevant lines covered (94.48%)

46539295.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.62
/src/ExpressionBuilder.jl
1
module ExpressionBuilderModule
2

3
using DispatchDoctor: @unstable
4
using DynamicExpressions:
5
    AbstractExpressionNode,
6
    AbstractExpression,
7
    Expression,
8
    ParametricExpression,
9
    ParametricNode,
10
    constructorof,
11
    parse_expression,
12
    get_tree,
13
    get_contents,
14
    with_contents,
15
    with_metadata,
16
    count_constants
17
using Random: default_rng, AbstractRNG
18
using StatsBase: StatsBase
19
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
20
using ..HallOfFameModule: HallOfFame
21
using ..LossFunctionsModule: maybe_getindex
22
using ..InterfaceDynamicExpressionsModule: eval_tree_array, expected_array_type
23
using ..PopulationModule: Population
24
using ..PopMemberModule: PopMember
25

26
import DynamicExpressions: get_operators, string_tree
27
import ..CoreModule: create_expression
28
import ..MutationFunctionsModule:
29
    make_random_leaf, crossover_trees, mutate_constant, mutate_factor
30
import ..LossFunctionsModule: eval_tree_dispatch
31
import ..ConstantOptimizationModule: count_constants_for_optimization
32

33
@unstable function create_expression(
127,246✔
34
    t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false)
35
) where {T,L,embed}
36
    return create_expression(
155,229✔
37
        constructorof(options.node_type)(; val=t), options, dataset, Val(embed)
38
    )
39
end
40
@unstable function create_expression(
1,030,191✔
41
    t::AbstractExpressionNode{T},
42
    options::Options,
43
    dataset::Dataset{T,L},
44
    ::Val{embed}=Val(false),
45
) where {T,L,embed}
46
    return constructorof(options.expression_type)(
1,866,578✔
47
        t; init_params(options, dataset, Val(embed))...
48
    )
49
end
50
function create_expression(
12✔
51
    ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false)
52
) where {T,L,embed}
53
    return ex
16✔
54
end
55
@unstable function init_params(
1,372,755✔
56
    options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false); kws...
57
) where {T,L,embed}
58
    return (;
961,003✔
59
        operators=embed ? options.operators : nothing,
60
        variable_names=embed ? dataset.variable_names : nothing,
61
        extra_init_params(options.expression_type, options, dataset, Val(embed); kws...)...,
62
    )
63
end
64
function extra_init_params(args...; kws...)
1,357,945✔
65
    return (;)
944,394✔
66
end
67
function extra_init_params(
13,225✔
68
    ::Type{<:ParametricExpression},
69
    options,
70
    dataset::Dataset{T,L},
71
    ::Val{embed};
72
    parameters=nothing,
73
) where {T,L,embed}
74
    num_params = options.expression_options.max_parameters
8,417✔
75
    num_classes = length(unique(dataset.extra.classes))
7,214✔
76
    parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing
7,216✔
77
    let parameters =
21,634✔
78
            parameters === nothing ? randn(T, (num_params, num_classes)) : parameters
79
        return (; parameters, parameter_names)
7,215✔
80
    end
81
end
82

83
@unstable begin
84
    function embed_metadata(
85,149✔
85
        ex::AbstractExpression, options::Options, dataset::Dataset{T,L}
86
    ) where {T,L}
87
        return with_metadata(ex; init_params(options, dataset, Val(true))...)
105,577✔
88
    end
89
    function embed_metadata(
85,149✔
90
        member::PopMember, options::Options, dataset::Dataset{T,L}
91
    ) where {T,L}
92
        return PopMember(
105,577✔
93
            embed_metadata(member.tree, options, dataset),
94
            member.score,
95
            member.loss,
96
            nothing;
97
            member.ref,
98
            member.parent,
99
            deterministic=options.deterministic,
100
        )
101
    end
102
    function embed_metadata(
2,673✔
103
        pop::Population, options::Options, dataset::Dataset{T,L}
104
    ) where {T,L}
105
        return Population(
3,294✔
106
            map(member -> embed_metadata(member, options, dataset), pop.members)
91,591✔
107
        )
108
    end
109
    function embed_metadata(
490✔
110
        hof::HallOfFame, options::Options, dataset::Dataset{T,L}
111
    ) where {T,L}
112
        return HallOfFame(
642✔
113
            map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists
13,986✔
114
        )
115
    end
116
    function embed_metadata(
272✔
117
        vec::Vector{H}, options::Options, dataset::Dataset{T,L}
118
    ) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
119
        return map(elem -> embed_metadata(elem, options, dataset), vec)
3,642✔
120
    end
121
end
122

123
"""Strips all metadata except for top-level information"""
124
function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L}
53,184✔
125
    return with_metadata(ex; init_params(options, dataset, Val(false))...)
68,166✔
126
end
NEW
127
function strip_metadata(
×
128
    ex::ParametricExpression, options::Options, dataset::Dataset{T,L}
129
) where {T,L}
NEW
130
    return with_metadata(
×
131
        ex; init_params(options, dataset, Val(false); ex.metadata.parameters)...
132
    )
133
end
134
function strip_metadata(
53,178✔
135
    member::PopMember, options::Options, dataset::Dataset{T,L}
136
) where {T,L}
137
    return PopMember(
66,576✔
138
        strip_metadata(member.tree, options, dataset),
139
        member.score,
140
        member.loss,
141
        nothing;
142
        member.ref,
143
        member.parent,
144
        deterministic=options.deterministic,
145
    )
146
end
147
function strip_metadata(
416✔
148
    pop::Population, options::Options, dataset::Dataset{T,L}
149
) where {T,L}
150
    return Population(map(member -> strip_metadata(member, options, dataset), pop.members))
63,439✔
151
end
152
function strip_metadata(
122✔
153
    hof::HallOfFame, options::Options, dataset::Dataset{T,L}
154
) where {T,L}
155
    return HallOfFame(
168✔
156
        map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists
3,696✔
157
    )
158
end
159

160
function eval_tree_dispatch(
8,573,793✔
161
    tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx
162
) where {T<:DATA_TYPE}
163
    A = expected_array_type(dataset.X)
10,135,985✔
164
    return eval_tree_array(
8,650,642✔
165
        tree,
166
        maybe_getindex(dataset.X, :, idx),
167
        maybe_getindex(dataset.extra.classes, idx),
168
        options.operators,
169
    )::Tuple{A,Bool}
170
end
171

172
function make_random_leaf(
2,546,319✔
173
    nfeatures::Int,
174
    ::Type{T},
175
    ::Type{N},
176
    rng::AbstractRNG=default_rng(),
177
    options::Union{Options,Nothing}=nothing,
178
) where {T<:DATA_TYPE,N<:ParametricNode}
179
    choice = rand(rng, 1:3)
2,546,151✔
180
    if choice == 1
2,546,082✔
181
        return ParametricNode(; val=randn(rng, T))
849,125✔
182
    elseif choice == 2
1,697,206✔
183
        return ParametricNode(T; feature=rand(rng, 1:nfeatures))
848,663✔
184
    else
185
        tree = ParametricNode{T}()
848,668✔
186
        tree.val = zero(T)
848,668✔
187
        tree.degree = 0
848,671✔
188
        tree.feature = 0
848,673✔
189
        tree.constant = false
848,676✔
190
        tree.is_parameter = true
848,674✔
191
        tree.parameter = rand(
848,672✔
192
            rng, UInt16(1):UInt16(options.expression_options.max_parameters)
193
        )
194
        return tree
848,673✔
195
    end
196
end
197

198
function crossover_trees(
449,761✔
199
    ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng()
200
) where {T}
201
    tree1 = get_contents(ex1)
511,187✔
202
    tree2 = get_contents(ex2)
386,739✔
203
    out1, out2 = crossover_trees(tree1, tree2, rng)
584,314✔
204
    ex1 = with_contents(ex1, out1)
584,307✔
205
    ex2 = with_contents(ex2, out2)
584,307✔
206

207
    # We also randomly share parameters
208
    nparams1 = size(ex1.metadata.parameters, 1)
386,743✔
209
    nparams2 = size(ex2.metadata.parameters, 1)
386,738✔
210
    num_params_switch = min(nparams1, nparams2)
386,739✔
211
    idx_to_switch = StatsBase.sample(
451,453✔
212
        rng, 1:num_params_switch, num_params_switch; replace=false
213
    )
214
    for param_idx in idx_to_switch
386,737✔
215
        ex2_params = ex2.metadata.parameters[param_idx, :]
1,945,163✔
216
        ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :]
1,417,397✔
217
        ex1.metadata.parameters[param_idx, :] .= ex2_params
1,356,066✔
218
    end
713,628✔
219

220
    return ex1, ex2
386,740✔
221
end
222

223
count_constants_for_optimization(ex::ParametricExpression) = count_constants(get_tree(ex))
11,279✔
224

225
function mutate_constant(
29,762✔
226
    ex::ParametricExpression{T},
227
    temperature,
228
    options::Options,
229
    rng::AbstractRNG=default_rng(),
230
) where {T<:DATA_TYPE}
231
    if rand(rng, Bool)
33,941✔
232
        # Normal mutation of inner constant
233
        tree = get_contents(ex)
12,752✔
234
        return with_contents(ex, mutate_constant(tree, temperature, options, rng))
12,752✔
235
    else
236
        # Mutate parameters
237
        parameter_index = rand(rng, 1:(options.expression_options.max_parameters))
12,781✔
238
        # We mutate all the parameters at once
239
        factor = mutate_factor(T, temperature, options, rng)
12,780✔
240
        ex.metadata.parameters[parameter_index, :] .*= factor
17,024✔
241
        return ex
12,781✔
242
    end
243
end
244

245
@unstable get_operators(::ParametricExpression, options::Options) = options.operators
61,985✔
246
function string_tree(tree::ParametricExpression, options::Options; kws...)
105,707✔
247
    return string_tree(tree, get_operators(tree, options); kws...)
71,112✔
248
end
249

250
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc