• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.12
/src/Dataset.jl
1
module DatasetModule
2

3
using DynamicQuantities: Quantity
4

5
using ..UtilsModule: subscriptify, get_base_type, @constfield
6
using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
7
using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units
8

9
import ...deprecate_varmap
10

11
"""
12
    Dataset{T<:DATA_TYPE,L<:LOSS_TYPE}
13

14
# Fields
15

16
- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`.
17
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
18
- `index::Int`: The index of the output feature corresponding to this
19
    dataset, if any.
20
- `n::Int`: The number of samples.
21
- `nfeatures::Int`: The number of features.
22
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
23
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
24
    these specify the per-sample weight (with shape `(n,)`).
25
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
26
    function. Since this is an arbitrary named tuple, you could pass
27
    any sort of dataset you wish to here.
28
- `avg_y`: The average value of `y` (weighted, if `weights` are passed).
29
- `use_baseline`: Whether to use a baseline loss. This will be set to `false`
30
    if the baseline loss is calculated to be `Inf`.
31
- `baseline_loss`: The loss of a constant function which predicts the average
32
    value of `y`. This is loss-dependent and should be updated with
33
    `update_baseline_loss!`.
34
- `variable_names::Array{String,1}`: The names of the features,
35
    with shape `(nfeatures,)`.
36
- `display_variable_names::Array{String,1}`: A version of `variable_names`
37
    but for printing to the terminal (e.g., with unicode versions).
38
- `y_variable_name::String`: The name of the output variable.
39
- `X_units`: Unit information of `X`. When used, this is a vector
40
    of `DynamicQuantities.Quantity{<:Any,<:Dimensions}` with shape `(nfeatures,)`.
41
- `y_units`: Unit information of `y`. When used, this is a single
42
    `DynamicQuantities.Quantity{<:Any,<:Dimensions}`.
43
- `X_sym_units`: Unit information of `X`. When used, this is a vector
44
    of `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}` with shape `(nfeatures,)`.
45
- `y_sym_units`: Unit information of `y`. When used, this is a single
46
    `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}`.
47
"""
48
mutable struct Dataset{
49
    T<:DATA_TYPE,
50
    L<:LOSS_TYPE,
51
    AX<:AbstractMatrix{T},
52
    AY<:Union{AbstractVector{T},Nothing},
53
    AW<:Union{AbstractVector{T},Nothing},
54
    NT<:NamedTuple,
55
    XU<:Union{AbstractVector{<:Quantity},Nothing},
56
    YU<:Union{Quantity,Nothing},
57
    XUS<:Union{AbstractVector{<:Quantity},Nothing},
58
    YUS<:Union{Quantity,Nothing},
59
}
60
    @constfield X::AX
936✔
61
    @constfield y::AY
62
    @constfield index::Int
63
    @constfield n::Int
64
    @constfield nfeatures::Int
65
    @constfield weighted::Bool
66
    @constfield weights::AW
67
    @constfield extra::NT
68
    @constfield avg_y::Union{T,Nothing}
69
    use_baseline::Bool
70
    baseline_loss::L
71
    @constfield variable_names::Array{String,1}
72
    @constfield display_variable_names::Array{String,1}
73
    @constfield y_variable_name::String
74
    @constfield X_units::XU
75
    @constfield y_units::YU
76
    @constfield X_sym_units::XUS
77
    @constfield y_sym_units::YUS
78
end
79

80
"""
81
    Dataset(X::AbstractMatrix{T},
82
            y::Union{AbstractVector{T},Nothing}=nothing,
83
            loss_type::Type=Nothing;
84
            weights::Union{AbstractVector{T}, Nothing}=nothing,
85
            variable_names::Union{Array{String, 1}, Nothing}=nothing,
86
            y_variable_name::Union{String,Nothing}=nothing,
87
            extra::NamedTuple=NamedTuple(),
88
            X_units::Union{AbstractVector, Nothing}=nothing,
89
            y_units=nothing,
90
    ) where {T<:DATA_TYPE}
91

92
Construct a dataset to pass between internal functions.
93
"""
94
function Dataset(
1,833✔
95
    X::AbstractMatrix{T},
96
    y::Union{AbstractVector{T},Nothing}=nothing,
97
    loss_type::Type{L}=Nothing;
98
    index::Int=1,
99
    weights::Union{AbstractVector{T},Nothing}=nothing,
100
    variable_names::Union{Array{String,1},Nothing}=nothing,
101
    display_variable_names=variable_names,
102
    y_variable_name::Union{String,Nothing}=nothing,
103
    extra::NamedTuple=NamedTuple(),
104
    X_units::Union{AbstractVector,Nothing}=nothing,
105
    y_units=nothing,
106
    # Deprecated:
107
    varMap=nothing,
108
    kws...,
109
) where {T<:DATA_TYPE,L}
110
    Base.require_one_based_indexing(X)
1,090✔
111
    y !== nothing && Base.require_one_based_indexing(y)
947✔
112
    # Deprecation warning:
113
    variable_names = deprecate_varmap(variable_names, varMap, :Dataset)
851✔
114
    if haskey(kws, :loss_type)
887✔
115
        Base.depwarn(
6✔
116
            "The `loss_type` keyword argument is deprecated. Pass as an argument instead.",
117
            :Dataset,
118
        )
119
        return Dataset(
6✔
120
            X,
121
            y,
122
            kws[:loss_type];
123
            index,
124
            weights,
125
            variable_names,
126
            display_variable_names,
127
            y_variable_name,
128
            extra,
129
            X_units,
130
            y_units,
131
        )
132
    end
133

134
    n = size(X, BATCH_DIM)
941✔
135
    nfeatures = size(X, FEATURE_DIM)
941✔
136
    weighted = weights !== nothing
845✔
137
    variable_names = if variable_names === nothing
845✔
138
        ["x$(i)" for i in 1:nfeatures]
713✔
139
    else
140
        variable_names
369✔
141
    end
142
    display_variable_names = if display_variable_names === nothing
845✔
143
        ["x$(subscriptify(i))" for i in 1:nfeatures]
713✔
144
    else
145
        display_variable_names
369✔
146
    end
147

148
    y_variable_name = if y_variable_name === nothing
845✔
149
        ("y" ∉ variable_names) ? "y" : "target"
1,572✔
150
    else
151
        y_variable_name
641✔
152
    end
153
    avg_y = if y === nothing
845✔
154
        nothing
×
155
    else
156
        if weighted
845✔
157
            sum(y .* weights) / sum(weights)
180✔
158
        else
159
            sum(y) / n
1,678✔
160
        end
161
    end
162
    out_loss_type = if L === Nothing
845✔
163
        T <: Complex ? get_base_type(T) : T
821✔
164
    else
165
        L
233✔
166
    end
167

168
    use_baseline = true
845✔
169
    baseline = one(out_loss_type)
845✔
170
    y_si_units = get_si_units(T, y_units)
1,287✔
171
    y_sym_units = get_sym_units(T, y_units)
1,287✔
172

173
    # TODO: Refactor
174
    # This basically just ensures that if the `y` units are set,
175
    # then the `X` units are set as well.
176
    X_si_units = let (_X = get_si_units(T, X_units))
1,286✔
177
        if _X === nothing && y_si_units !== nothing
941✔
178
            get_si_units(T, [one(T) for _ in 1:nfeatures])
40✔
179
        else
180
            _X
829✔
181
        end
182
    end
183
    X_sym_units = let _X = get_sym_units(T, X_units)
1,286✔
184
        if _X === nothing && y_sym_units !== nothing
941✔
185
            get_sym_units(T, [one(T) for _ in 1:nfeatures])
40✔
186
        else
187
            _X
829✔
188
        end
189
    end
190

191
    error_on_mismatched_size(nfeatures, X_si_units)
1,316✔
192

193
    return Dataset{
936✔
194
        T,
195
        out_loss_type,
196
        typeof(X),
197
        typeof(y),
198
        typeof(weights),
199
        typeof(extra),
200
        typeof(X_si_units),
201
        typeof(y_si_units),
202
        typeof(X_sym_units),
203
        typeof(y_sym_units),
204
    }(
205
        X,
206
        y,
207
        index,
208
        n,
209
        nfeatures,
210
        weighted,
211
        weights,
212
        extra,
213
        avg_y,
214
        use_baseline,
215
        baseline,
216
        variable_names,
217
        display_variable_names,
218
        y_variable_name,
219
        X_si_units,
220
        y_si_units,
221
        X_sym_units,
222
        y_sym_units,
223
    )
224
end
225
function Dataset(
12✔
226
    X::AbstractMatrix,
227
    y::Union{<:AbstractVector,Nothing}=nothing;
228
    weights::Union{<:AbstractVector,Nothing}=nothing,
229
    kws...,
230
)
231
    T = promote_type(
7✔
232
        eltype(X),
233
        (y === nothing) ? eltype(X) : eltype(y),
234
        (weights === nothing) ? eltype(X) : eltype(weights),
235
    )
236
    X = Base.Fix1(convert, T).(X)
12✔
237
    if y !== nothing
6✔
238
        y = Base.Fix1(convert, T).(y)
8✔
239
    end
240
    if weights !== nothing
6✔
241
        weights = Base.Fix1(convert, T).(weights)
8✔
242
    end
243
    return Dataset(X, y; weights=weights, kws...)
6✔
244
end
245

246
function error_on_mismatched_size(_, ::Nothing)
572✔
247
    return nothing
679✔
248
end
249
function error_on_mismatched_size(nfeatures, X_units::AbstractVector)
70✔
250
    if nfeatures != length(X_units)
83✔
251
        error(
5✔
252
            "Number of features ($(nfeatures)) does not match number of units ($(length(X_units)))",
253
        )
254
    end
255
    return nothing
78✔
256
end
257

258
function has_units(dataset::Dataset)
×
259
    return dataset.X_units !== nothing || dataset.y_units !== nothing
×
260
end
261

262
# Used for Enzyme
263
function Base.fill!(d::Dataset, val)
3✔
264
    _fill!(d.X, val)
192✔
265
    _fill!(d.y, val)
96✔
266
    _fill!(d.weights, val)
6✔
267
    _fill!(d.extra, val)
3✔
268
    return d
3✔
269
end
270
_fill!(x::AbstractArray, val) = fill!(x, val)
384✔
271
_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x))
99✔
272
_fill!(::Nothing, val) = nothing
3✔
NEW
273
_fill!(x, val) = x
×
274

275
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc