• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11394658450

17 Oct 2024 11:32PM UTC coverage: 95.332% (+0.6%) from 94.757%
11394658450

Pull #355

github

web-flow
Merge a9e5332c7 into 3892a6659
Pull Request #355: Create `TemplateExpression` for providing a pre-defined functional structure and constraints

253 of 257 new or added lines in 15 files covered. (98.44%)

3 existing lines in 2 files now uncovered.

2818 of 2956 relevant lines covered (95.33%)

38419396.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.39
/src/Configure.jl
1
const TEST_TYPE = Float32
2

3
function test_operator(op::F, x::T, y=nothing) where {F,T}
11,744,584✔
4
    local output
11,744,584✔
5
    try
11,739,832✔
6
        output = y === nothing ? op(x) : op(x, y)
12,413,136✔
7
    catch e
8
        error(
8✔
9
            "The operator `$(op)` is not well-defined over the " *
10
            ((T <: Complex) ? "complex plane, " : "real line, ") *
11
            "as it threw the error `$(typeof(e))` when evaluating the " *
12
            (y === nothing ? "input $(x). " : "inputs $(x) and $(y). ") *
13
            "You can work around this by returning " *
14
            "NaN for invalid inputs. For example, " *
15
            "`safe_log(x::T) where {T} = x > 0 ? log(x) : T(NaN)`.",
16
        )
17
    end
18
    if !isa(output, T)
11,739,824✔
19
        error(
8✔
20
            "The operator `$(op)` returned an output of type `$(typeof(output))`, " *
21
            "when it was given " *
22
            (y === nothing ? "an input $(x) " : "inputs $(x) and $(y) ") *
23
            "of type `$(T)`. " *
24
            "Please ensure that your operators return the same type as their inputs.",
25
        )
26
    end
27
    return nothing
11,739,816✔
28
end
29

30
const TEST_INPUTS = collect(range(-100, 100; length=99))
31

32
function assert_operators_well_defined(T, options::AbstractOptions)
304✔
33
    test_input = if T <: Complex
304✔
34
        (x -> convert(T, x)).(TEST_INPUTS .+ TEST_INPUTS .* im)
3,200✔
35
    else
36
        (x -> convert(T, x)).(TEST_INPUTS)
27,504✔
37
    end
38
    for x in test_input, y in test_input, op in options.operators.binops
304✔
39
        test_operator(op, x, y)
11,643,596✔
40
    end
11,682,792✔
41
    for x in test_input, op in options.operators.unaops
296✔
42
        test_operator(op, x)
96,236✔
43
    end
96,228✔
44
end
45

46
# Check for errors before they happen
47
function test_option_configuration(
256✔
48
    parallelism, datasets::Vector{D}, options::AbstractOptions, verbosity
49
) where {T,D<:Dataset{T}}
50
    if options.deterministic && parallelism != :serial
256✔
51
        error("Determinism is only guaranteed for serial mode.")
×
52
    end
53
    if parallelism == :multithreading && Threads.nthreads() == 1
256✔
54
        verbosity > 0 &&
48✔
55
            @warn "You are using multithreading mode, but only one thread is available. Try starting julia with `--threads=auto`."
56
    end
57
    if any(d -> d.X_units !== nothing || d.y_units !== nothing, datasets) &&
608✔
58
        options.dimensional_constraint_penalty === nothing
59
        verbosity > 0 &&
28✔
60
            @warn "You are using dimensional constraints, but `dimensional_constraint_penalty` was not set. The default penalty of `1000.0` will be used."
61
    end
62

63
    for op in (options.operators.binops..., options.operators.unaops...)
256✔
64
        if is_anonymous_function(op)
1,676✔
65
            throw(
×
66
                AssertionError(
67
                    "Anonymous functions can't be used as operators for SymbolicRegression.jl",
68
                ),
69
            )
70
        end
71
    end
1,924✔
72

73
    assert_operators_well_defined(T, options)
256✔
74

75
    operator_intersection = intersect(options.operators.binops, options.operators.unaops)
256✔
76
    if length(operator_intersection) > 0
256✔
77
        throw(
×
78
            AssertionError(
79
                "Your configuration is invalid - $(operator_intersection) appear in both the binary operators and unary operators.",
80
            ),
81
        )
82
    end
83
end
84

85
# Check for errors before they happen
86
function test_dataset_configuration(
256✔
87
    dataset::Dataset{T}, options::AbstractOptions, verbosity
88
) where {T<:DATA_TYPE}
89
    n = dataset.n
256✔
90
    if n != size(dataset.X, 2) ||
512✔
91
        (dataset.y !== nothing && n != size(dataset.y::AbstractArray, 1))
92
        throw(
×
93
            AssertionError(
94
                "Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].",
95
            ),
96
        )
97
    end
98

99
    if size(dataset.X, 2) > 10000 && !options.batching && verbosity > 0
256✔
100
        @info "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form."
×
101
    end
102

103
    if !(typeof(options.elementwise_loss) <: SupervisedLoss) &&
256✔
104
        is_weighted(dataset) &&
105
        !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
106
        throw(
×
107
            AssertionError(
108
                "When you create a custom loss function, and are using weights, you need to define your loss function with three scalar arguments: f(prediction, target, weight).",
109
            ),
110
        )
111
    end
112
end
113

114
""" Move custom operators and loss functions to workers, if undefined """
115
function move_functions_to_workers(
16✔
116
    procs, options::AbstractOptions, dataset::Dataset{T}, verbosity
117
) where {T}
118
    # All the types of functions we need to move to workers:
119
    function_sets = (
16✔
120
        :unaops, :binops, :elementwise_loss, :early_stop_condition, :loss_function
121
    )
122

123
    for function_set in function_sets
16✔
124
        if function_set == :unaops
80✔
125
            ops = options.operators.unaops
16✔
126
            example_inputs = (zero(T),)
16✔
127
        elseif function_set == :binops
64✔
128
            ops = options.operators.binops
16✔
129
            example_inputs = (zero(T), zero(T))
16✔
130
        elseif function_set == :elementwise_loss
48✔
131
            if typeof(options.elementwise_loss) <: SupervisedLoss
16✔
132
                continue
12✔
133
            end
134
            ops = (options.elementwise_loss,)
4✔
135
            example_inputs = if is_weighted(dataset)
4✔
136
                (zero(T), zero(T), zero(T))
4✔
137
            else
UNCOV
138
                (zero(T), zero(T))
×
139
            end
140
        elseif function_set == :early_stop_condition
32✔
141
            if !(typeof(options.early_stop_condition) <: Function)
16✔
142
                continue
×
143
            end
144
            ops = (options.early_stop_condition,)
16✔
145
            example_inputs = (zero(T), 0)
16✔
146
        elseif function_set == :loss_function
16✔
147
            if options.loss_function === nothing
16✔
148
                continue
12✔
149
            end
150
            ops = (options.loss_function,)
4✔
151
            example_inputs = (Node(T; val=zero(T)), dataset, options)
4✔
152
        else
153
            error("Invalid function set: $function_set")
×
154
        end
155
        for op in ops
56✔
156
            try
104✔
157
                test_function_on_workers(example_inputs, op, procs)
140✔
158
            catch e
159
                undefined_on_workers = isa(e.captured.ex, UndefVarError)
36✔
160
                if undefined_on_workers
36✔
161
                    copy_definition_to_workers(op, procs, options, verbosity)
36✔
162
                else
163
                    throw(e)
36✔
164
                end
165
            end
166
            test_function_on_workers(example_inputs, op, procs)
104✔
167
        end
104✔
168
    end
80✔
169
end
170

171
function copy_definition_to_workers(op, procs, options::AbstractOptions, verbosity)
36✔
172
    name = nameof(op)
72✔
173
    verbosity > 0 && @info "Copying definition of $op to workers..."
36✔
174
    src_ms = methods(op).ms
36✔
175
    # Thanks https://discourse.julialang.org/t/easy-way-to-send-custom-function-to-distributed-workers/22118/2
176
    @everywhere procs @eval function $name end
36✔
177
    for m in src_ms
36✔
178
        @everywhere procs @eval $m
36✔
179
    end
36✔
180
    verbosity > 0 && @info "Finished!"
36✔
181
    return nothing
36✔
182
end
183

184
function test_function_on_workers(example_inputs, op, procs)
208✔
185
    futures = []
208✔
186
    for proc in procs
208✔
187
        push!(futures, @spawnat proc op(example_inputs...))
352✔
188
    end
352✔
189
    for future in futures
208✔
190
        fetch(future)
320✔
191
    end
284✔
192
end
193

194
function activate_env_on_workers(
×
195
    procs, project_path::String, options::AbstractOptions, verbosity
196
)
197
    verbosity > 0 && @info "Activating environment on workers."
×
198
    @everywhere procs begin
×
199
        Base.MainInclude.eval(
×
200
            quote
201
                using Pkg
×
202
                Pkg.activate($$project_path)
×
203
            end,
204
        )
205
    end
206
end
207

208
function import_module_on_workers(
16✔
209
    procs, filename::String, options::AbstractOptions, verbosity
210
)
211
    loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]
16✔
212

213
    included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker
1,814✔
214
    expr = if included_as_local
16✔
215
        quote
×
216
            include($filename)
217
            using .SymbolicRegression
218
        end
219
    else
220
        quote
32✔
221
            using SymbolicRegression
222
        end
223
    end
224

225
    # Need to import any extension code, if loaded on head node
226
    relevant_extensions = [
112✔
227
        :Bumper,
228
        :CUDA,
229
        :ClusterManagers,
230
        :Enzyme,
231
        :LoopVectorization,
232
        :SymbolicUtils,
233
        :Zygote,
234
    ]
235
    filter!(m -> String(m) ∈ loaded_modules_head_worker, relevant_extensions)
128✔
236
    # HACK TODO – this workaround is very fragile. Likely need to submit a bug report
237
    #             to JuliaLang.
238

239
    for ext in relevant_extensions
16✔
240
        push!(
43✔
241
            expr.args,
242
            quote
243
                using $ext: $ext
244
            end,
245
        )
246
    end
43✔
247

248
    verbosity > 0 && if isempty(relevant_extensions)
16✔
249
        @info "Importing SymbolicRegression on workers."
×
250
    else
251
        @info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))."
16✔
252
    end
253
    @everywhere procs Core.eval(Core.Main, $expr)
16✔
254
    verbosity > 0 && @info "Finished!"
16✔
255
    return nothing
16✔
256
end
257

258
function test_module_on_workers(procs, options::AbstractOptions, verbosity)
16✔
259
    verbosity > 0 && @info "Testing module on workers..."
16✔
260
    futures = []
16✔
261
    for proc in procs
16✔
262
        push!(
28✔
263
            futures,
264
            @spawnat proc SymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE)
265
        )
266
    end
28✔
267
    for future in futures
16✔
268
        fetch(future)
28✔
269
    end
28✔
270
    verbosity > 0 && @info "Finished!"
16✔
271
    return nothing
16✔
272
end
273

274
function test_entire_pipeline(
16✔
275
    procs, dataset::Dataset{T}, options::AbstractOptions, verbosity
276
) where {T<:DATA_TYPE}
277
    futures = []
16✔
278
    verbosity > 0 && @info "Testing entire pipeline on workers..."
16✔
279
    for proc in procs
16✔
280
        push!(
28✔
281
            futures,
282
            @spawnat proc begin
283
                tmp_pop = Population(
35✔
284
                    dataset;
285
                    population_size=20,
286
                    nlength=3,
287
                    options=options,
288
                    nfeatures=dataset.nfeatures,
289
                )
290
                tmp_pop = s_r_cycle(
35✔
291
                    dataset,
292
                    tmp_pop,
293
                    5,
294
                    5,
295
                    RunningSearchStatistics(; options=options);
296
                    verbosity=verbosity,
297
                    options=options,
298
                    record=RecordType(),
299
                )[1]
300
                tmp_pop = optimize_and_simplify_population(
28✔
301
                    dataset, tmp_pop, options, options.maxsize, RecordType()
302
                )
303
            end
304
        )
305
    end
28✔
306
    for future in futures
16✔
307
        fetch(future)
28✔
308
    end
28✔
309
    verbosity > 0 && @info "Finished!"
16✔
310
    return nothing
16✔
311
end
312

313
function configure_workers(;
32✔
314
    procs::Union{Vector{Int},Nothing},
315
    numprocs::Int,
316
    addprocs_function::Function,
317
    options::AbstractOptions,
318
    project_path,
319
    file,
320
    exeflags::Cmd,
321
    verbosity,
322
    example_dataset::Dataset,
323
    runtests::Bool,
324
)
325
    (procs, we_created_procs) = if procs === nothing
16✔
326
        (addprocs_function(numprocs; lazy=false, exeflags), true)
16✔
327
    else
328
        (procs, false)
×
329
    end
330

331
    if we_created_procs
16✔
332
        import_module_on_workers(procs, file, options, verbosity)
20✔
333
    end
334

335
    move_functions_to_workers(procs, options, example_dataset, verbosity)
20✔
336

337
    if runtests
16✔
338
        test_module_on_workers(procs, options, verbosity)
20✔
339
        test_entire_pipeline(procs, example_dataset, options, verbosity)
20✔
340
    end
341

342
    return (procs, we_created_procs)
16✔
343
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc