• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.16
/src/Configure.jl
1
const TEST_TYPE = Float32
2

3
function test_operator(op::F, x::T, y=nothing) where {F,T}
21,139,099✔
4
    local output
18,600,838✔
5
    try
21,062,668✔
6
        output = y === nothing ? op(x) : op(x, y)
22,386,529✔
7
    catch e
8
        error(
11✔
9
            "The operator `$(op)` is not well-defined over the " *
10
            ((T <: Complex) ? "complex plane, " : "real line, ") *
11
            "as it threw the error `$(typeof(e))` when evaluating the " *
12
            (y === nothing ? "input $(x). " : "inputs $(x) and $(y). ") *
13
            "You can work around this by returning " *
14
            "NaN for invalid inputs. For example, " *
15
            "`safe_log(x::T) where {T} = x > 0 ? log(x) : T(NaN)`.",
16
        )
17
    end
18
    if !isa(output, T)
18,524,396✔
19
        error(
11✔
20
            "The operator `$(op)` returned an output of type `$(typeof(output))`, " *
21
            "when it was given " *
22
            (y === nothing ? "an input $(x) " : "inputs $(x) and $(y) ") *
23
            "of type `$(T)`. " *
24
            "Please ensure that your operators return the same type as their inputs.",
25
        )
26
    end
27
    return nothing
21,062,646✔
28
end
29

30
const TEST_INPUTS = collect(range(-100, 100; length=99))
31

32
function assert_operators_well_defined(T, options::Options)
520✔
33
    test_input = if T <: Complex
520✔
34
        (x -> convert(T, x)).(TEST_INPUTS .+ TEST_INPUTS .* im)
4,700✔
35
    else
36
        (x -> convert(T, x)).(TEST_INPUTS)
47,820✔
37
    end
38
    for x in test_input, y in test_input, op in options.operators.binops
520✔
39
        test_operator(op, x, y)
21,699,425✔
40
    end
18,229,860✔
41
    for x in test_input, op in options.operators.unaops
509✔
42
        test_operator(op, x)
187,715✔
43
    end
156,618✔
44
end
45

46
# Check for errors before they happen
47
function test_option_configuration(
450✔
48
    parallelism, datasets::Vector{D}, options::Options, verbosity
49
) where {T,D<:Dataset{T}}
50
    if options.deterministic && parallelism != :serial
450✔
51
        error("Determinism is only guaranteed for serial mode.")
×
52
    end
53
    if parallelism == :multithreading && Threads.nthreads() == 1
450✔
54
        verbosity > 0 &&
144✔
55
            @warn "You are using multithreading mode, but only one thread is available. Try starting julia with `--threads=auto`."
56
    end
57
    if any(d -> d.X_units !== nothing || d.y_units !== nothing, datasets) &&
1,060✔
58
        options.dimensional_constraint_penalty === nothing
59
        verbosity > 0 &&
42✔
60
            @warn "You are using dimensional constraints, but `dimensional_constraint_penalty` was not set. The default penalty of `1000.0` will be used."
61
    end
62

63
    for op in (options.operators.binops..., options.operators.unaops...)
450✔
64
        if is_anonymous_function(op)
3,308✔
65
            throw(
75✔
66
                AssertionError(
67
                    "Anonymous functions can't be used as operators for SymbolicRegression.jl",
68
                ),
69
            )
70
        end
71
    end
3,105✔
72

73
    assert_operators_well_defined(T, options)
450✔
74

75
    operator_intersection = intersect(options.operators.binops, options.operators.unaops)
450✔
76
    if length(operator_intersection) > 0
450✔
77
        throw(
×
78
            AssertionError(
79
                "Your configuration is invalid - $(operator_intersection) appear in both the binary operators and unary operators.",
80
            ),
81
        )
82
    end
83
end
84

85
# Check for errors before they happen
86
function test_dataset_configuration(
450✔
87
    dataset::Dataset{T}, options::Options, verbosity
88
) where {T<:DATA_TYPE}
89
    n = dataset.n
450✔
90
    if n != size(dataset.X, 2) ||
900✔
91
        (dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1))
92
        throw(
×
93
            AssertionError(
94
                "Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].",
95
            ),
96
        )
97
    end
98

99
    if size(dataset.X, 2) > 10000 && !options.batching && verbosity > 0
450✔
100
        @info "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form."
×
101
    end
102

103
    if !(typeof(options.elementwise_loss) <: SupervisedLoss) &&
450✔
104
        dataset.weighted &&
105
        !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)])
106
        throw(
×
107
            AssertionError(
108
                "When you create a custom loss function, and are using weights, you need to define your loss function with three scalar arguments: f(prediction, target, weight).",
109
            ),
110
        )
111
    end
112
end
113

114
""" Move custom operators and loss functions to workers, if undefined """
115
function move_functions_to_workers(
24✔
116
    procs, options::Options, dataset::Dataset{T}, verbosity
117
) where {T}
118
    # All the types of functions we need to move to workers:
119
    function_sets = (
24✔
120
        :unaops, :binops, :elementwise_loss, :early_stop_condition, :loss_function
121
    )
122

123
    for function_set in function_sets
24✔
124
        if function_set == :unaops
120✔
125
            ops = options.operators.unaops
24✔
126
            example_inputs = (zero(T),)
24✔
127
        elseif function_set == :binops
96✔
128
            ops = options.operators.binops
24✔
129
            example_inputs = (zero(T), zero(T))
24✔
130
        elseif function_set == :elementwise_loss
72✔
131
            if typeof(options.elementwise_loss) <: SupervisedLoss
24✔
132
                continue
18✔
133
            end
134
            ops = (options.elementwise_loss,)
6✔
135
            example_inputs = if dataset.weighted
6✔
136
                (zero(T), zero(T), zero(T))
6✔
137
            else
138
                (zero(T), zero(T))
6✔
139
            end
140
        elseif function_set == :early_stop_condition
48✔
141
            if !(typeof(options.early_stop_condition) <: Function)
24✔
UNCOV
142
                continue
×
143
            end
144
            ops = (options.early_stop_condition,)
24✔
145
            example_inputs = (zero(T), 0)
24✔
146
        elseif function_set == :loss_function
24✔
147
            if options.loss_function === nothing
24✔
148
                continue
18✔
149
            end
150
            ops = (options.loss_function,)
6✔
151
            example_inputs = (Node(T; val=zero(T)), dataset, options)
6✔
152
        else
153
            error("Invalid function set: $function_set")
×
154
        end
155
        for op in ops
84✔
156
            try
156✔
157
                test_function_on_workers(example_inputs, op, procs)
210✔
158
            catch e
159
                undefined_on_workers = isa(e.captured.ex, UndefVarError)
54✔
160
                if undefined_on_workers
54✔
161
                    copy_definition_to_workers(op, procs, options, verbosity)
54✔
162
                else
163
                    throw(e)
54✔
164
                end
165
            end
166
            test_function_on_workers(example_inputs, op, procs)
156✔
167
        end
144✔
168
    end
100✔
169
end
170

171
function copy_definition_to_workers(op, procs, options::Options, verbosity)
54✔
172
    name = nameof(op)
108✔
173
    verbosity > 0 && @info "Copying definition of $op to workers..."
54✔
174
    src_ms = methods(op).ms
54✔
175
    # Thanks https://discourse.julialang.org/t/easy-way-to-send-custom-function-to-distributed-workers/22118/2
176
    @everywhere procs @eval function $name end
72✔
177
    for m in src_ms
54✔
178
        @everywhere procs @eval $m
72✔
179
    end
54✔
180
    verbosity > 0 && @info "Finished!"
54✔
181
    return nothing
54✔
182
end
183

184
function test_function_on_workers(example_inputs, op, procs)
312✔
185
    futures = []
312✔
186
    for proc in procs
312✔
187
        push!(futures, @spawnat proc op(example_inputs...))
580✔
188
    end
492✔
189
    for future in futures
312✔
190
        fetch(future)
480✔
191
    end
355✔
192
end
193

194
function activate_env_on_workers(procs, project_path::String, options::Options, verbosity)
8✔
195
    verbosity > 0 && @info "Activating environment on workers."
8✔
196
    @everywhere procs begin
8✔
197
        Base.MainInclude.eval(
198
            quote
199
                using Pkg
200
                Pkg.activate($$project_path)
201
            end,
202
        )
203
    end
204
end
205

206
function import_module_on_workers(procs, filename::String, options::Options, verbosity)
24✔
207
    loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]
24✔
208

209
    included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker
2,093✔
210
    expr = if included_as_local
24✔
211
        quote
×
212
            include($filename)
213
            using .SymbolicRegression
214
        end
215
    else
216
        quote
48✔
217
            using SymbolicRegression
218
        end
219
    end
220

221
    # Need to import any extension code, if loaded on head node
222
    relevant_extensions = [
168✔
223
        :Bumper,
224
        :CUDA,
225
        :ClusterManagers,
226
        :Enzyme,
227
        :LoopVectorization,
228
        :SymbolicUtils,
229
        :Zygote,
230
    ]
231
    filter!(m -> String(m) ∈ loaded_modules_head_worker, relevant_extensions)
192✔
232
    # HACK TODO – this workaround is very fragile. Likely need to submit a bug report
233
    #             to JuliaLang.
234

235
    for ext in relevant_extensions
24✔
236
        push!(
69✔
237
            expr.args,
238
            quote
239
                using $ext: $ext
240
            end,
241
        )
242
    end
58✔
243

244
    verbosity > 0 && if isempty(relevant_extensions)
24✔
245
        @info "Importing SymbolicRegression on workers."
×
246
    else
247
        @info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))."
32✔
248
    end
249
    @everywhere procs Core.eval(Core.Main, $expr)
32✔
250
    verbosity > 0 && @info "Finished!"
24✔
251
    return nothing
24✔
252
end
253

254
function test_module_on_workers(procs, options::Options, verbosity)
24✔
255
    verbosity > 0 && @info "Testing module on workers..."
24✔
256
    futures = []
24✔
257
    for proc in procs
24✔
258
        push!(
46✔
259
            futures,
260
            @spawnat proc SymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE)
261
        )
262
    end
39✔
263
    for future in futures
24✔
264
        fetch(future)
46✔
265
    end
39✔
266
    verbosity > 0 && @info "Finished!"
24✔
267
    return nothing
24✔
268
end
269

270
function test_entire_pipeline(
24✔
271
    procs, dataset::Dataset{T}, options::Options, verbosity
272
) where {T<:DATA_TYPE}
273
    futures = []
24✔
274
    verbosity > 0 && @info "Testing entire pipeline on workers..."
24✔
275
    for proc in procs
24✔
276
        push!(
46✔
277
            futures,
278
            @spawnat proc begin
279
                tmp_pop = Population(
53✔
280
                    dataset;
281
                    population_size=20,
282
                    nlength=3,
283
                    options=options,
284
                    nfeatures=dataset.nfeatures,
285
                )
286
                tmp_pop = s_r_cycle(
59✔
287
                    dataset,
288
                    tmp_pop,
289
                    5,
290
                    5,
291
                    RunningSearchStatistics(; options=options);
292
                    verbosity=verbosity,
293
                    options=options,
294
                    record=RecordType(),
295
                )[1]
296
                tmp_pop = optimize_and_simplify_population(
39✔
297
                    dataset, tmp_pop, options, options.maxsize, RecordType()
298
                )
299
            end
300
        )
301
    end
39✔
302
    for future in futures
24✔
303
        fetch(future)
46✔
304
    end
39✔
305
    verbosity > 0 && @info "Finished!"
24✔
306
    return nothing
24✔
307
end
308

309
function configure_workers(;
48✔
310
    procs::Union{Vector{Int},Nothing},
311
    numprocs::Int,
312
    addprocs_function::Function,
313
    options::Options,
314
    project_path,
315
    file,
316
    exeflags::Cmd,
317
    verbosity,
318
    example_dataset::Dataset,
319
    runtests::Bool,
320
)
321
    (procs, we_created_procs) = if procs === nothing
28✔
322
        (addprocs_function(numprocs; lazy=false, exeflags), true)
24✔
323
    else
324
        (procs, false)
8✔
325
    end
326

327
    if we_created_procs
24✔
328
        if VERSION < v"1.9.0"
24✔
329
            # On newer Julia; environment is activated automatically
330
            activate_env_on_workers(procs, project_path, options, verbosity)
8✔
331
        end
332
        import_module_on_workers(procs, file, options, verbosity)
36✔
333
    end
334

335
    move_functions_to_workers(procs, options, example_dataset, verbosity)
36✔
336

337
    if runtests
24✔
338
        test_module_on_workers(procs, options, verbosity)
36✔
339
        test_entire_pipeline(procs, example_dataset, options, verbosity)
36✔
340
    end
341

342
    return (procs, we_created_procs)
24✔
343
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc