• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9686354911

26 Jun 2024 08:31PM UTC coverage: 93.22% (-1.4%) from 94.617%
9686354911

Pull #326

github

web-flow
Merge 6f8229c9f into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

275 of 296 new or added lines in 17 files covered. (92.91%)

34 existing lines in 5 files now uncovered.

2530 of 2714 relevant lines covered (93.22%)

32081968.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.65
/src/ExpressionBuilder.jl
1
module ExpressionBuilderModule
2

3
using DispatchDoctor: @unstable
4
using DynamicExpressions:
5
    AbstractExpressionNode,
6
    AbstractExpression,
7
    Expression,
8
    ParametricExpression,
9
    ParametricNode,
10
    constructorof,
11
    parse_expression,
12
    get_tree,
13
    get_contents,
14
    with_contents,
15
    with_metadata,
16
    count_constants,
17
    eval_tree_array
18
using Random: default_rng, AbstractRNG
19
using StatsBase: StatsBase
20
using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
21
using ..HallOfFameModule: HallOfFame
22
using ..LossFunctionsModule: maybe_getindex
23
using ..InterfaceDynamicExpressionsModule: expected_array_type
24
using ..PopulationModule: Population
25
using ..PopMemberModule: PopMember
26

27
import DynamicExpressions: get_operators, string_tree
28
import ..CoreModule: create_expression
29
import ..MutationFunctionsModule:
30
    make_random_leaf, crossover_trees, mutate_constant, mutate_factor
31
import ..LossFunctionsModule: eval_tree_dispatch
32
import ..ConstantOptimizationModule: count_constants_for_optimization
33

34
@unstable function create_expression(
57,909✔
35
    t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false)
36
) where {T,L,embed}
37
    return create_expression(
72,194✔
38
        constructorof(options.node_type)(; val=t), options, dataset, Val(embed)
39
    )
40
end
41
@unstable function create_expression(
871,892✔
42
    t::AbstractExpressionNode{T},
43
    options::Options,
44
    dataset::Dataset{T,L},
45
    ::Val{embed}=Val(false),
46
) where {T,L,embed}
47
    return constructorof(options.expression_type)(
1,575,366✔
48
        t; init_params(options, dataset, Val(embed))...
49
    )
50
end
51
function create_expression(
12✔
52
    ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false)
53
) where {T,L,embed}
54
    return ex
16✔
55
end
56
@unstable function init_params(
1,094,385✔
57
    options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false); kws...
58
) where {T,L,embed}
59
    return (;
764,775✔
60
        operators=embed ? options.operators : nothing,
61
        variable_names=embed ? dataset.variable_names : nothing,
62
        extra_init_params(options.expression_type, options, dataset, Val(embed); kws...)...,
63
    )
64
end
65
function extra_init_params(args...; kws...)
1,082,268✔
66
    return (;)
750,040✔
67
end
68
function extra_init_params(
10,260✔
69
    ::Type{<:ParametricExpression},
70
    options,
71
    dataset::Dataset{T,L},
72
    ::Val{embed};
73
    parameters=nothing,
74
) where {T,L,embed}
75
    num_params = options.expression_options.max_parameters
6,528✔
76
    num_classes = length(unique(dataset.extra.classes))
5,595✔
77
    parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing
5,598✔
78
    let parameters =
16,783✔
79
            parameters === nothing ? randn(T, (num_params, num_classes)) : parameters
80
        return (; parameters, parameter_names)
5,596✔
81
    end
82
end
83

84
@unstable begin
85
    function embed_metadata(
42,645✔
86
        ex::AbstractExpression, options::Options, dataset::Dataset{T,L}
87
    ) where {T,L}
88
        return with_metadata(ex; init_params(options, dataset, Val(true))...)
54,595✔
89
    end
90
    function embed_metadata(
42,645✔
91
        member::PopMember, options::Options, dataset::Dataset{T,L}
92
    ) where {T,L}
93
        return PopMember(
54,595✔
94
            embed_metadata(member.tree, options, dataset),
95
            member.score,
96
            member.loss,
97
            nothing;
98
            member.ref,
99
            member.parent,
100
            deterministic=options.deterministic,
101
        )
102
    end
103
    function embed_metadata(
848✔
104
        pop::Population, options::Options, dataset::Dataset{T,L}
105
    ) where {T,L}
106
        return Population(
1,104✔
107
            map(member -> embed_metadata(member, options, dataset), pop.members)
43,909✔
108
        )
109
    end
110
    function embed_metadata(
365✔
111
        hof::HallOfFame, options::Options, dataset::Dataset{T,L}
112
    ) where {T,L}
113
        return HallOfFame(
492✔
114
            map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists
10,686✔
115
        )
116
    end
117
    function embed_metadata(
152✔
118
        vec::Vector{H}, options::Options, dataset::Dataset{T,L}
119
    ) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
120
        return map(elem -> embed_metadata(elem, options, dataset), vec)
1,308✔
121
    end
122
end
123

124
"""Strips all metadata except for top-level information"""
125
function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L}
53,046✔
126
    return with_metadata(ex; init_params(options, dataset, Val(false))...)
68,052✔
127
end
NEW
128
function strip_metadata(
×
129
    ex::ParametricExpression, options::Options, dataset::Dataset{T,L}
130
) where {T,L}
NEW
131
    return with_metadata(
×
132
        ex; init_params(options, dataset, Val(false); ex.metadata.parameters)...
133
    )
134
end
135
function strip_metadata(
53,040✔
136
    member::PopMember, options::Options, dataset::Dataset{T,L}
137
) where {T,L}
138
    return PopMember(
66,426✔
139
        strip_metadata(member.tree, options, dataset),
140
        member.score,
141
        member.loss,
142
        nothing;
143
        member.ref,
144
        member.parent,
145
        deterministic=options.deterministic,
146
    )
147
end
148
function strip_metadata(
415✔
149
    pop::Population, options::Options, dataset::Dataset{T,L}
150
) where {T,L}
151
    return Population(map(member -> strip_metadata(member, options, dataset), pop.members))
63,287✔
152
end
153
function strip_metadata(
122✔
154
    hof::HallOfFame, options::Options, dataset::Dataset{T,L}
155
) where {T,L}
156
    return HallOfFame(
168✔
157
        map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists
3,696✔
158
    )
159
end
160

161
function eval_tree_dispatch(
2,337,173✔
162
    tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx
163
) where {T<:DATA_TYPE}
164
    A = expected_array_type(dataset.X)
2,741,915✔
165
    return eval_tree_array(
2,357,351✔
166
        tree,
167
        maybe_getindex(dataset.X, :, idx),
168
        maybe_getindex(dataset.extra.classes, idx),
169
        options.operators,
170
    )::Tuple{A,Bool}
171
end
172

173
function make_random_leaf(
682,159✔
174
    nfeatures::Int,
175
    ::Type{T},
176
    ::Type{N},
177
    rng::AbstractRNG=default_rng(),
178
    options::Union{Options,Nothing}=nothing,
179
) where {T<:DATA_TYPE,N<:ParametricNode}
180
    choice = rand(rng, 1:3)
682,119✔
181
    if choice == 1
682,122✔
182
        return ParametricNode(; val=randn(rng, T))
226,754✔
183
    elseif choice == 2
455,434✔
184
        return ParametricNode(T; feature=rand(rng, 1:nfeatures))
227,531✔
185
    else
186
        tree = ParametricNode{T}()
227,924✔
187
        tree.val = zero(T)
227,928✔
188
        tree.degree = 0
227,928✔
189
        tree.feature = 0
227,927✔
190
        tree.constant = false
227,925✔
191
        tree.is_parameter = true
227,924✔
192
        tree.parameter = rand(
227,920✔
193
            rng, UInt16(1):UInt16(options.expression_options.max_parameters)
194
        )
195
        return tree
227,915✔
196
    end
197
end
198

199
function crossover_trees(
121,225✔
200
    ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng()
201
) where {T}
202
    tree1 = get_contents(ex1)
139,491✔
203
    tree2 = get_contents(ex2)
103,611✔
204
    out1, out2 = crossover_trees(tree1, tree2, rng)
153,847✔
205
    ex1 = with_contents(ex1, out1)
153,842✔
206
    ex2 = with_contents(ex2, out2)
153,847✔
207

208
    # We also randomly share parameters
209
    nparams1 = size(ex1.metadata.parameters, 1)
103,609✔
210
    nparams2 = size(ex2.metadata.parameters, 1)
103,610✔
211
    num_params_switch = min(nparams1, nparams2)
103,611✔
212
    idx_to_switch = StatsBase.sample(
121,103✔
213
        rng, 1:num_params_switch, num_params_switch; replace=false
214
    )
215
    for param_idx in idx_to_switch
103,608✔
216
        ex2_params = ex2.metadata.parameters[param_idx, :]
513,576✔
217
        ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :]
379,414✔
218
        ex1.metadata.parameters[param_idx, :] .= ex2_params
361,171✔
219
    end
188,300✔
220

221
    return ex1, ex2
103,609✔
222
end
223

224
count_constants_for_optimization(ex::ParametricExpression) = count_constants(get_tree(ex))
3,112✔
225

226
function mutate_constant(
7,902✔
227
    ex::ParametricExpression{T},
228
    temperature,
229
    options::Options,
230
    rng::AbstractRNG=default_rng(),
231
) where {T<:DATA_TYPE}
232
    if rand(rng, Bool)
9,118✔
233
        # Normal mutation of inner constant
234
        tree = get_contents(ex)
3,471✔
235
        return with_contents(ex, mutate_constant(tree, temperature, options, rng))
3,471✔
236
    else
237
        # Mutate parameters
238
        parameter_index = rand(rng, 1:(options.expression_options.max_parameters))
3,385✔
239
        # We mutate all the parameters at once
240
        factor = mutate_factor(T, temperature, options, rng)
3,385✔
241
        ex.metadata.parameters[parameter_index, :] .*= factor
4,460✔
242
        return ex
3,385✔
243
    end
244
end
245

246
@unstable function get_operators(ex::AbstractExpression, options::Options)
106,438,532✔
247
    return get_operators(ex, options.operators)
130,552,549✔
248
end
249
@unstable function get_operators(ex::AbstractExpressionNode, options::Options)
819,254✔
250
    return get_operators(ex, options.operators)
1,019,382✔
251
end
252

253
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc