• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9763114573

02 Jul 2024 02:43PM UTC coverage: 96.083% (+1.4%) from 94.697%
9763114573

Pull #326

github

web-flow
Merge 7a70dfb88 into c5ed5d0b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

352 of 357 new or added lines in 19 files covered. (98.6%)

60 existing lines in 12 files now uncovered.

2625 of 2732 relevant lines covered (96.08%)

66999448.2 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.12
/src/Dataset.jl
1
module DatasetModule
2

3
using DynamicQuantities:
4
    AbstractDimensions,
5
    Dimensions,
6
    SymbolicDimensions,
7
    Quantity,
8
    uparse,
9
    sym_uparse,
10
    DEFAULT_DIM_BASE_TYPE
11

12
using ..UtilsModule: subscriptify, get_base_type, @constfield
13
using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
14
using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units
15

16
import ...deprecate_varmap
17

18
"""
19
    Dataset{T<:DATA_TYPE,L<:LOSS_TYPE}
20

21
# Fields
22

23
- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`.
24
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
25
- `index::Int`: The index of the output feature corresponding to this
26
    dataset, if any.
27
- `n::Int`: The number of samples.
28
- `nfeatures::Int`: The number of features.
29
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
30
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
31
    these specify the per-sample weight (with shape `(n,)`).
32
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
33
    function. Since this is an arbitrary named tuple, you could pass
34
    any sort of dataset you wish to here.
35
- `avg_y`: The average value of `y` (weighted, if `weights` are passed).
36
- `use_baseline`: Whether to use a baseline loss. This will be set to `false`
37
    if the baseline loss is calculated to be `Inf`.
38
- `baseline_loss`: The loss of a constant function which predicts the average
39
    value of `y`. This is loss-dependent and should be updated with
40
    `update_baseline_loss!`.
41
- `variable_names::Array{String,1}`: The names of the features,
42
    with shape `(nfeatures,)`.
43
- `display_variable_names::Array{String,1}`: A version of `variable_names`
44
    but for printing to the terminal (e.g., with unicode versions).
45
- `y_variable_name::String`: The name of the output variable.
46
- `X_units`: Unit information of `X`. When used, this is a vector
47
    of `DynamicQuantities.Quantity{<:Any,<:Dimensions}` with shape `(nfeatures,)`.
48
- `y_units`: Unit information of `y`. When used, this is a single
49
    `DynamicQuantities.Quantity{<:Any,<:Dimensions}`.
50
- `X_sym_units`: Unit information of `X`. When used, this is a vector
51
    of `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}` with shape `(nfeatures,)`.
52
- `y_sym_units`: Unit information of `y`. When used, this is a single
53
    `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}`.
54
"""
55
mutable struct Dataset{
56
    T<:DATA_TYPE,
57
    L<:LOSS_TYPE,
58
    AX<:AbstractMatrix{T},
59
    AY<:Union{AbstractVector{T},Nothing},
60
    AW<:Union{AbstractVector{T},Nothing},
61
    NT<:NamedTuple,
62
    XU<:Union{AbstractVector{<:Quantity},Nothing},
63
    YU<:Union{Quantity,Nothing},
64
    XUS<:Union{AbstractVector{<:Quantity},Nothing},
65
    YUS<:Union{Quantity,Nothing},
66
}
67
    @constfield X::AX
930✔
68
    @constfield y::AY
69
    @constfield index::Int
70
    @constfield n::Int
71
    @constfield nfeatures::Int
72
    @constfield weighted::Bool
73
    @constfield weights::AW
74
    @constfield extra::NT
75
    @constfield avg_y::Union{T,Nothing}
76
    use_baseline::Bool
77
    baseline_loss::L
78
    @constfield variable_names::Array{String,1}
79
    @constfield display_variable_names::Array{String,1}
80
    @constfield y_variable_name::String
81
    @constfield X_units::XU
82
    @constfield y_units::YU
83
    @constfield X_sym_units::XUS
84
    @constfield y_sym_units::YUS
85
end
86

87
"""
88
    Dataset(X::AbstractMatrix{T},
89
            y::Union{AbstractVector{T},Nothing}=nothing,
90
            loss_type::Type=Nothing;
91
            weights::Union{AbstractVector{T}, Nothing}=nothing,
92
            variable_names::Union{Array{String, 1}, Nothing}=nothing,
93
            y_variable_name::Union{String,Nothing}=nothing,
94
            extra::NamedTuple=NamedTuple(),
95
            X_units::Union{AbstractVector, Nothing}=nothing,
96
            y_units=nothing,
97
    ) where {T<:DATA_TYPE}
98

99
Construct a dataset to pass between internal functions.
100
"""
101
function Dataset(
1,820✔
102
    X::AbstractMatrix{T},
103
    y::Union{AbstractVector{T},Nothing}=nothing,
104
    loss_type::Type{L}=Nothing;
105
    index::Int=1,
106
    weights::Union{AbstractVector{T},Nothing}=nothing,
107
    variable_names::Union{Array{String,1},Nothing}=nothing,
108
    display_variable_names=variable_names,
109
    y_variable_name::Union{String,Nothing}=nothing,
110
    extra::NamedTuple=NamedTuple(),
111
    X_units::Union{AbstractVector,Nothing}=nothing,
112
    y_units=nothing,
113
    # Deprecated:
114
    varMap=nothing,
115
    kws...,
116
) where {T<:DATA_TYPE,L}
117
    Base.require_one_based_indexing(X)
1,082✔
118
    y !== nothing && Base.require_one_based_indexing(y)
941✔
119
    # Deprecation warning:
120
    variable_names = deprecate_varmap(variable_names, varMap, :Dataset)
845✔
121
    if haskey(kws, :loss_type)
881✔
122
        Base.depwarn(
6✔
123
            "The `loss_type` keyword argument is deprecated. Pass as an argument instead.",
124
            :Dataset,
125
        )
126
        return Dataset(
6✔
127
            X,
128
            y,
129
            kws[:loss_type];
130
            index,
131
            weights,
132
            variable_names,
133
            display_variable_names,
134
            y_variable_name,
135
            extra,
136
            X_units,
137
            y_units,
138
        )
139
    end
140

141
    n = size(X, BATCH_DIM)
935✔
142
    nfeatures = size(X, FEATURE_DIM)
935✔
143
    weighted = weights !== nothing
839✔
144
    variable_names = if variable_names === nothing
839✔
145
        ["x$(i)" for i in 1:nfeatures]
707✔
146
    else
147
        variable_names
367✔
148
    end
149
    display_variable_names = if display_variable_names === nothing
839✔
150
        ["x$(subscriptify(i))" for i in 1:nfeatures]
707✔
151
    else
152
        display_variable_names
367✔
153
    end
154

155
    y_variable_name = if y_variable_name === nothing
839✔
156
        ("y" ∉ variable_names) ? "y" : "target"
1,560✔
157
    else
158
        y_variable_name
639✔
159
    end
160
    avg_y = if y === nothing
839✔
UNCOV
161
        nothing
×
162
    else
163
        if weighted
839✔
164
            sum(y .* weights) / sum(weights)
180✔
165
        else
166
            sum(y) / n
1,666✔
167
        end
168
    end
169
    out_loss_type = if L === Nothing
839✔
170
        T <: Complex ? get_base_type(T) : T
815✔
171
    else
172
        L
231✔
173
    end
174

175
    use_baseline = true
839✔
176
    baseline = one(out_loss_type)
839✔
177
    y_si_units = get_si_units(T, y_units)
1,278✔
178
    y_sym_units = get_sym_units(T, y_units)
1,278✔
179

180
    # TODO: Refactor
181
    # This basically just ensures that if the `y` units are set,
182
    # then the `X` units are set as well.
183
    X_si_units = let (_X = get_si_units(T, X_units))
1,277✔
184
        if _X === nothing && y_si_units !== nothing
935✔
185
            get_si_units(T, [one(T) for _ in 1:nfeatures])
40✔
186
        else
187
            _X
823✔
188
        end
189
    end
190
    X_sym_units = let _X = get_sym_units(T, X_units)
1,277✔
191
        if _X === nothing && y_sym_units !== nothing
935✔
192
            get_sym_units(T, [one(T) for _ in 1:nfeatures])
40✔
193
        else
194
            _X
823✔
195
        end
196
    end
197

198
    error_on_mismatched_size(nfeatures, X_si_units)
1,307✔
199

200
    return Dataset{
930✔
201
        T,
202
        out_loss_type,
203
        typeof(X),
204
        typeof(y),
205
        typeof(weights),
206
        typeof(extra),
207
        typeof(X_si_units),
208
        typeof(y_si_units),
209
        typeof(X_sym_units),
210
        typeof(y_sym_units),
211
    }(
212
        X,
213
        y,
214
        index,
215
        n,
216
        nfeatures,
217
        weighted,
218
        weights,
219
        extra,
220
        avg_y,
221
        use_baseline,
222
        baseline,
223
        variable_names,
224
        display_variable_names,
225
        y_variable_name,
226
        X_si_units,
227
        y_si_units,
228
        X_sym_units,
229
        y_sym_units,
230
    )
231
end
232
function Dataset(
12✔
233
    X::AbstractMatrix,
234
    y::Union{<:AbstractVector,Nothing}=nothing;
235
    weights::Union{<:AbstractVector,Nothing}=nothing,
236
    kws...,
237
)
238
    T = promote_type(
7✔
239
        eltype(X),
240
        (y === nothing) ? eltype(X) : eltype(y),
241
        (weights === nothing) ? eltype(X) : eltype(weights),
242
    )
243
    X = Base.Fix1(convert, T).(X)
12✔
244
    if y !== nothing
6✔
245
        y = Base.Fix1(convert, T).(y)
8✔
246
    end
247
    if weights !== nothing
6✔
248
        weights = Base.Fix1(convert, T).(weights)
8✔
249
    end
250
    return Dataset(X, y; weights=weights, kws...)
6✔
251
end
252

253
function error_on_mismatched_size(_, ::Nothing)
568✔
254
    return nothing
674✔
255
end
256
function error_on_mismatched_size(nfeatures, X_units::AbstractVector)
70✔
257
    if nfeatures != length(X_units)
83✔
258
        error(
5✔
259
            "Number of features ($(nfeatures)) does not match number of units ($(length(X_units)))",
260
        )
261
    end
262
    return nothing
78✔
263
end
264

265
function has_units(dataset::Dataset)
×
266
    return dataset.X_units !== nothing || dataset.y_units !== nothing
×
267
end
268

269
# Used for Enzyme
270
function Base.fill!(d::Dataset, val)
3✔
271
    _fill!(d.X, val)
192✔
272
    _fill!(d.y, val)
96✔
273
    _fill!(d.weights, val)
6✔
274
    _fill!(d.extra, val)
3✔
275
    return d
3✔
276
end
277
_fill!(x::AbstractArray, val) = fill!(x, val)
384✔
278
_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x))
99✔
279
_fill!(::Nothing, val) = nothing
3✔
NEW
UNCOV
280
_fill!(x, val) = x
×
281

282
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc