• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.4
/src/Utils.jl
1
"""Useful functions to be used throughout the library."""
2
module UtilsModule
3

4
using Printf: @printf
5
using MacroTools: splitdef
6

7
macro ignore(args...) end
144✔
8

9
const pseudo_time = Ref(0)
10

11
function get_birth_order(; deterministic::Bool=false)::Int
222,599,309✔
12
    """deterministic gives a birth time with perfect resolution, but is not thread safe."""
84,865,728✔
13
    if deterministic
93,391,953✔
14
        global pseudo_time
288,846✔
15
        pseudo_time[] += 1
433,238✔
16
        return pseudo_time[]
433,238✔
17
    else
18
        resolution = 1e7
60,275,518✔
19
        return round(Int, resolution * time())
92,962,648✔
20
    end
21
end
22

23
function is_anonymous_function(op)
3,276✔
24
    op_string = string(nameof(op))
6,552✔
25
    return length(op_string) > 1 &&
3,276✔
26
           op_string[1] == '#' &&
27
           op_string[2] in ('1', '2', '3', '4', '5', '6', '7', '8', '9')
28
end
29

30
recursive_merge(x::AbstractVector...) = cat(x...; dims=1)
4,744✔
31
recursive_merge(x::AbstractDict...) = merge(recursive_merge, x...)
4,918✔
32
recursive_merge(x...) = x[end]
18,976✔
33
recursive_merge() = error("Unexpected input.")
6✔
34

35
get_base_type(::Type{Complex{BT}}) where {BT} = BT
14✔
36

37
const subscripts = ('₀', '₁', '₂', '₃', '₄', '₅', '₆', '₇', '₈', '₉')
38
function subscriptify(number::Integer)
2,766✔
39
    return join([subscripts[i + 1] for i in reverse(digits(number))])
3,229✔
40
end
41

42
"""
43
    split_string(s::String, n::Integer)
44

45
```jldoctest
46
split_string("abcdefgh", 3)
47

48
# output
49

50
["abc", "def", "gh"]
51
```
52
"""
53
function split_string(s::String, n::Integer)
316,375✔
54
    length(s) <= n && return [s]
316,375✔
55
    # Due to unicode characters, need to split only at valid indices:
56
    I = eachindex(s) |> collect
72,999✔
57
    return [s[I[i]:I[min(i + n - 1, end)]] for i in 1:n:length(s)]
67,707✔
58
end
59

60
"""
61
Tiny equivalent to StaticArrays.MVector
62

63
This is so we don't have to load StaticArrays, which takes a long time.
64
"""
65
mutable struct MutableTuple{S,T,N} <: AbstractVector{T}
66
    data::N
67

68
    MutableTuple(::Val{_S}, ::Type{_T}, data::_N) where {_S,_T,_N} = new{_S,_T,_N}(data)
22,756,263✔
69
end
70
@inline Base.eltype(::MutableTuple{S,T}) where {S,T} = T
602,202,930✔
71
Base.@propagate_inbounds function Base.getindex(v::MutableTuple, i::Integer)
371,743,157✔
72
    T = eltype(v)
383,609,021✔
73
    # Trick from MArray.jl
74
    return GC.@preserve v unsafe_load(
554,145,370✔
75
        Base.unsafe_convert(Ptr{T}, pointer_from_objref(v)), i
76
    )
77
end
78
Base.@propagate_inbounds function Base.setindex!(v::MutableTuple, x, i::Integer)
213,057,661✔
79
    T = eltype(v)
218,603,462✔
80
    GC.@preserve v unsafe_store!(Base.unsafe_convert(Ptr{T}, pointer_from_objref(v)), x, i)
317,953,576✔
81
    return x
218,603,043✔
82
end
83
@inline Base.lastindex(::MutableTuple{S}) where {S} = S
247,040,820✔
84
@inline Base.firstindex(v::MutableTuple) = 1
38,812,230✔
85
Base.dataids(v::MutableTuple) = (UInt(pointer(v)),)
×
86
function _to_vec(v::MutableTuple{S,T}) where {S,T}
17,699,175✔
87
    x = Vector{T}(undef, S)
22,756,198✔
88
    @inbounds for i in 1:S
22,757,272✔
89
        x[i] = v[i]
78,292,184✔
90
    end
66,543,777✔
91
    return x
22,756,004✔
92
end
93

94
const max_ops = 8192
95
const vals = ntuple(Val, max_ops)
96

97
"""Return the bottom k elements of x, and their indices."""
98
bottomk_fast(x::AbstractVector{T}, k) where {T} =
11,379,508✔
99
    _bottomk_dispatch(x, vals[k])::Tuple{Vector{T},Vector{Int}}
100

101
function _bottomk_dispatch(x::AbstractVector{T}, ::Val{k}) where {T,k}
11,379,708✔
102
    if k == 1
7,525,520✔
103
        return (p -> [p]).(findmin_fast(x))
3,240✔
104
    end
105
    indmin = MutableTuple(Val(k), Int, ntuple(_ -> 1, Val(k)))
30,263,248✔
106
    minval = MutableTuple(Val(k), T, ntuple(_ -> typemax(T), Val(k)))
33,498,567✔
107
    _bottomk!(x, minval, indmin)
17,260,683✔
108
    return _to_vec(minval), _to_vec(indmin)
11,378,743✔
109
end
110
function _bottomk!(x, minval, indmin)
11,378,316✔
111
    @inbounds for i in eachindex(x)
15,374,312✔
112
        new_min = x[i] < minval[end]
208,659,216✔
113
        if new_min
137,545,495✔
114
            minval[end] = x[i]
91,190,692✔
115
            indmin[end] = i
91,190,475✔
116
            for ki in lastindex(minval):-1:(firstindex(minval) + 1)
60,092,057✔
117
                need_swap = minval[ki] < minval[ki - 1]
128,429,871✔
118
                if need_swap
84,759,370✔
119
                    minval[ki], minval[ki - 1] = minval[ki - 1], minval[ki]
74,929,239✔
120
                    indmin[ki], indmin[ki - 1] = indmin[ki - 1], indmin[ki]
90,709,253✔
121
                end
122
            end
83,935,403✔
123
        end
124
    end
230,727,451✔
125
    return nothing
11,378,795✔
126
end
127

128
# Thanks Chris Elrod
129
# https://discourse.julialang.org/t/why-is-minimum-so-much-faster-than-argmin/66814/9
130
function findmin_fast(x::AbstractVector{T}) where {T}
101,519,962✔
131
    indmin = 1
102,245,134✔
132
    minval = typemax(T)
102,245,108✔
133
    @inbounds @simd for i in eachindex(x)
156,286,111✔
134
        newmin = x[i] < minval
13,636,845,032✔
135
        minval = newmin ? x[i] : minval
13,742,951,593✔
136
        indmin = newmin ? i : indmin
13,704,652,554✔
137
    end
138
    return minval, indmin
156,302,029✔
139
end
140

141
function argmin_fast(x::AbstractVector{T}) where {T}
136,122,335✔
142
    return findmin_fast(x)[2]
156,291,389✔
143
end
144

145
function poisson_sample(λ::T) where {T}
79,366✔
146
    k, p, L = 0, one(T), exp(-λ)
79,366✔
147
    while p > L
235,998✔
148
        k += 1
156,632✔
149
        p *= rand(T)
156,632✔
150
    end
135,648✔
151
    return k - 1
79,366✔
152
end
153

154
macro threads_if(flag, ex)
21✔
155
    return quote
21✔
156
        if $flag
45,349✔
157
            Threads.@threads $ex
46,037✔
158
        else
159
            $ex
252✔
160
        end
161
    end |> esc
162
end
163

164
"""
165
    @save_kwargs variable function ... end
166

167
Save the kwargs and their default values to a variable as a constant.
168
This is to be used to create these same kwargs in other locations.
169
"""
170
macro save_kwargs(log_variable, fdef)
21✔
171
    return esc(_save_kwargs(log_variable, fdef))
21✔
172
end
173
function _save_kwargs(log_variable::Symbol, fdef::Expr)
21✔
174
    def = splitdef(fdef)
21✔
175
    # Get kwargs:
176
    kwargs = copy(def[:kwargs])
21✔
177
    filter!(kwargs) do k
21✔
178
        # Filter ...:
179
        k.head == :... && return false
1,512✔
180
        # Filter other deprecated kwargs:
181
        startswith(string(first(k.args)), "deprecated") && return false
1,491✔
182
        return true
1,470✔
183
    end
184
    return quote
21✔
185
        $(Base).@__doc__ $fdef
186
        const $log_variable = $kwargs
187
    end
188
end
189

190
# Allows using `const` fields in older versions of Julia.
191
macro constfield(ex)
336✔
192
    return esc(VERSION < v"1.8.0" ? ex : Expr(:const, ex))
336✔
193
end
194

NEW
195
json3_write(args...) = error("Please load the JSON3.jl package.")
×
196

197
"""
198
    PerThreadCache{T}
199

200
A cache that is efficient for multithreaded code, and works
201
by having a separate cache for each thread. This allows
202
us to avoid repeated locking. We only need to lock the cache
203
when resizing to the number of threads.
204
"""
205
struct PerThreadCache{T}
206
    x::Vector{T}
207
    num_threads::Ref{Int}
208
    lock::Threads.SpinLock
209

210
    PerThreadCache{T}() where {T} = new(Vector{T}(undef, 1), Ref(1), Threads.SpinLock())
63✔
211
end
212

213
function _get_thread_cache(cache::PerThreadCache{T}) where {T}
100,255,985✔
214
    if cache.num_threads[] < Threads.nthreads()
100,256,041✔
215
        Base.@lock cache.lock begin
29✔
216
            # The reason we have this extra `.len[]` parameter is to avoid
217
            # a race condition between a thread resizing the array concurrent
218
            # to the check above. Basically we want to make sure the array is
219
            # always big enough by the time we get to using it. Since `.len[]`
220
            # is set last, we can safely use the array.
221
            if cache.num_threads[] < Threads.nthreads()
25✔
222
                resize!(cache.x, Threads.nthreads())
24✔
223
                cache.num_threads[] = Threads.nthreads()
24✔
224
            end
225
        end
226
    end
227
    threadid = Threads.threadid()
100,275,663✔
228
    if !isassigned(cache.x, threadid)
100,264,359✔
229
        cache.x[threadid] = eltype(cache.x)()
92✔
230
    end
231
    return cache.x[threadid]
100,267,115✔
232
end
233
function Base.get!(f::F, cache::PerThreadCache, key) where {F<:Function}
70,790,466✔
234
    thread_cache = _get_thread_cache(cache)
80,860,429✔
235
    return get!(f, thread_cache, key)
80,871,948✔
236
end
237

238
# https://discourse.julialang.org/t/performance-of-hasmethod-vs-try-catch-on-methoderror/99827/14
239
# Faster way to catch method errors:
240
@enum IsGood::Int8 begin
241
    Good
242
    Bad
243
    Undefined
244
end
245
const SafeFunctions = PerThreadCache{Dict{Type,IsGood}}()
246

247
function safe_call(f::F, x::T, default::D) where {F,T<:Tuple,D}
19,395,351✔
248
    thread_cache = _get_thread_cache(SafeFunctions)
19,395,327✔
249
    status = get(thread_cache, Tuple{F,T}, Undefined)
38,775,951✔
250
    status == Good && return (f(x...)::D, true)
19,396,986✔
251
    status == Bad && return (default, false)
249✔
252

253
    output = try
249✔
254
        (f(x...)::D, true)
454✔
255
    catch e
256
        !isa(e, MethodError) && rethrow(e)
24✔
257
        (default, false)
212✔
258
    end
259
    if output[2]
243✔
260
        thread_cache[Tuple{F,T}] = Good
225✔
261
    else
262
        thread_cache[Tuple{F,T}] = Bad
18✔
263
    end
264
    return output
243✔
265
end
266

267
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc