• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11394658450

17 Oct 2024 11:32PM UTC coverage: 95.332% (+0.6%) from 94.757%
11394658450

Pull #355

github

web-flow
Merge a9e5332c7 into 3892a6659
Pull Request #355: Create `TemplateExpression` for providing a pre-defined functional structure and constraints

253 of 257 new or added lines in 15 files covered. (98.44%)

3 existing lines in 2 files now uncovered.

2818 of 2956 relevant lines covered (95.33%)

38419396.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.0
/src/Dataset.jl
1
module DatasetModule
2

3
using DynamicQuantities: Quantity
4

5
using ..UtilsModule: subscriptify, get_base_type
6
using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
7
using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units
8

9
import ...deprecate_varmap
10

11
"""
12
    Dataset{T<:DATA_TYPE,L<:LOSS_TYPE}
13

14
# Fields
15

16
- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`.
17
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
18
- `index::Int`: The index of the output feature corresponding to this
19
    dataset, if any.
20
- `n::Int`: The number of samples.
21
- `nfeatures::Int`: The number of features.
22
- `weights::Union{AbstractVector,Nothing}`: If the dataset is weighted,
23
    these specify the per-sample weight (with shape `(n,)`).
24
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
25
    function. Since this is an arbitrary named tuple, you could pass
26
    any sort of dataset you wish to here.
27
- `avg_y`: The average value of `y` (weighted, if `weights` are passed).
28
- `use_baseline`: Whether to use a baseline loss. This will be set to `false`
29
    if the baseline loss is calculated to be `Inf`.
30
- `baseline_loss`: The loss of a constant function which predicts the average
31
    value of `y`. This is loss-dependent and should be updated with
32
    `update_baseline_loss!`.
33
- `variable_names::Array{String,1}`: The names of the features,
34
    with shape `(nfeatures,)`.
35
- `display_variable_names::Array{String,1}`: A version of `variable_names`
36
    but for printing to the terminal (e.g., with unicode versions).
37
- `y_variable_name::String`: The name of the output variable.
38
- `X_units`: Unit information of `X`. When used, this is a vector
39
    of `DynamicQuantities.Quantity{<:Any,<:Dimensions}` with shape `(nfeatures,)`.
40
- `y_units`: Unit information of `y`. When used, this is a single
41
    `DynamicQuantities.Quantity{<:Any,<:Dimensions}`.
42
- `X_sym_units`: Unit information of `X`. When used, this is a vector
43
    of `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}` with shape `(nfeatures,)`.
44
- `y_sym_units`: Unit information of `y`. When used, this is a single
45
    `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}`.
46
"""
47
mutable struct Dataset{
48
    T<:DATA_TYPE,
49
    L<:LOSS_TYPE,
50
    AX<:AbstractMatrix{T},
51
    AY<:Union{AbstractVector,Nothing},
52
    AW<:Union{AbstractVector,Nothing},
53
    NT<:NamedTuple,
54
    XU<:Union{AbstractVector{<:Quantity},Nothing},
55
    YU<:Union{Quantity,Nothing},
56
    XUS<:Union{AbstractVector{<:Quantity},Nothing},
57
    YUS<:Union{Quantity,Nothing},
58
}
59
    const X::AX
560✔
60
    const y::AY
61
    const index::Int
62
    const n::Int
63
    const nfeatures::Int
64
    const weights::AW
65
    const extra::NT
66
    const avg_y::Union{T,Nothing}
67
    use_baseline::Bool
68
    baseline_loss::L
69
    const variable_names::Array{String,1}
70
    const display_variable_names::Array{String,1}
71
    const y_variable_name::String
72
    const X_units::XU
73
    const y_units::YU
74
    const X_sym_units::XUS
75
    const y_sym_units::YUS
76
end
77

78
"""
79
    Dataset(X::AbstractMatrix{T},
80
            y::Union{AbstractVector{T},Nothing}=nothing,
81
            loss_type::Type=Nothing;
82
            weights::Union{AbstractVector, Nothing}=nothing,
83
            variable_names::Union{Array{String, 1}, Nothing}=nothing,
84
            y_variable_name::Union{String,Nothing}=nothing,
85
            extra::NamedTuple=NamedTuple(),
86
            X_units::Union{AbstractVector, Nothing}=nothing,
87
            y_units=nothing,
88
    ) where {T<:DATA_TYPE}
89

90
Construct a dataset to pass between internal functions.
91
"""
92
function Dataset(
1,136✔
93
    X::AbstractMatrix{T},
94
    y::Union{AbstractVector,Nothing}=nothing,
95
    loss_type::Type{L}=Nothing;
96
    index::Int=1,
97
    weights::Union{AbstractVector,Nothing}=nothing,
98
    variable_names::Union{Array{String,1},Nothing}=nothing,
99
    display_variable_names=variable_names,
100
    y_variable_name::Union{String,Nothing}=nothing,
101
    extra::NamedTuple=NamedTuple(),
102
    X_units::Union{AbstractVector,Nothing}=nothing,
103
    y_units=nothing,
104
    # Deprecated:
105
    varMap=nothing,
106
    kws...,
107
) where {T<:DATA_TYPE,L}
108
    Base.require_one_based_indexing(X)
568✔
109
    y !== nothing && Base.require_one_based_indexing(y)
568✔
110
    # Deprecation warning:
111
    variable_names = deprecate_varmap(variable_names, varMap, :Dataset)
568✔
112
    if haskey(kws, :loss_type)
568✔
113
        Base.depwarn(
4✔
114
            "The `loss_type` keyword argument is deprecated. Pass as an argument instead.",
115
            :Dataset,
116
        )
117
        return Dataset(
4✔
118
            X,
119
            y,
120
            kws[:loss_type];
121
            index,
122
            weights,
123
            variable_names,
124
            display_variable_names,
125
            y_variable_name,
126
            extra,
127
            X_units,
128
            y_units,
129
        )
130
    end
131

132
    n = size(X, BATCH_DIM)
564✔
133
    nfeatures = size(X, FEATURE_DIM)
564✔
134
    variable_names = if variable_names === nothing
564✔
135
        ["x$(i)" for i in 1:nfeatures]
408✔
136
    else
137
        variable_names
156✔
138
    end
139
    display_variable_names = if display_variable_names === nothing
564✔
140
        ["x$(subscriptify(i))" for i in 1:nfeatures]
408✔
141
    else
142
        display_variable_names
156✔
143
    end
144

145
    y_variable_name = if y_variable_name === nothing
564✔
146
        ("y" ∉ variable_names) ? "y" : "target"
1,072✔
147
    else
148
        y_variable_name
356✔
149
    end
150
    avg_y = if y === nothing || !(eltype(y) isa Number)
564✔
151
        nothing
564✔
152
    else
NEW
153
        if weights !== nothing
×
UNCOV
154
            sum(y .* weights) / sum(weights)
×
155
        else
UNCOV
156
            sum(y) / n
×
157
        end
158
    end
159
    out_loss_type = if L === Nothing
564✔
160
        T <: Complex ? get_base_type(T) : T
548✔
161
    else
162
        L
16✔
163
    end
164

165
    use_baseline = true
564✔
166
    baseline = one(out_loss_type)
564✔
167
    y_si_units = get_si_units(T, y_units)
733✔
168
    y_sym_units = get_sym_units(T, y_units)
733✔
169

170
    # TODO: Refactor
171
    # This basically just ensures that if the `y` units are set,
172
    # then the `X` units are set as well.
173
    X_si_units = let (_X = get_si_units(T, X_units))
737✔
174
        if _X === nothing && y_si_units !== nothing
564✔
175
            get_si_units(T, [one(T) for _ in 1:nfeatures])
20✔
176
        else
177
            _X
548✔
178
        end
179
    end
180
    X_sym_units = let _X = get_sym_units(T, X_units)
737✔
181
        if _X === nothing && y_sym_units !== nothing
564✔
182
            get_sym_units(T, [one(T) for _ in 1:nfeatures])
20✔
183
        else
184
            _X
548✔
185
        end
186
    end
187

188
    error_on_mismatched_size(nfeatures, X_si_units)
747✔
189

190
    return Dataset{
560✔
191
        T,
192
        out_loss_type,
193
        typeof(X),
194
        typeof(y),
195
        typeof(weights),
196
        typeof(extra),
197
        typeof(X_si_units),
198
        typeof(y_si_units),
199
        typeof(X_sym_units),
200
        typeof(y_sym_units),
201
    }(
202
        X,
203
        y,
204
        index,
205
        n,
206
        nfeatures,
207
        weights,
208
        extra,
209
        avg_y,
210
        use_baseline,
211
        baseline,
212
        variable_names,
213
        display_variable_names,
214
        y_variable_name,
215
        X_si_units,
216
        y_si_units,
217
        X_sym_units,
218
        y_sym_units,
219
    )
220
end
221

222
is_weighted(dataset::Dataset) = dataset.weights !== nothing
95,558,178✔
223

224
function error_on_mismatched_size(_, ::Nothing)
508✔
225
    return nothing
508✔
226
end
227
function error_on_mismatched_size(nfeatures, X_units::AbstractVector)
56✔
228
    if nfeatures != length(X_units)
56✔
229
        error(
4✔
230
            "Number of features ($(nfeatures)) does not match number of units ($(length(X_units)))",
231
        )
232
    end
233
    return nothing
52✔
234
end
235

236
function has_units(dataset::Dataset)
×
237
    return dataset.X_units !== nothing || dataset.y_units !== nothing
×
238
end
239

240
# Used for Enzyme
241
function Base.fill!(d::Dataset, val)
1✔
242
    _fill!(d.X, val)
64✔
243
    _fill!(d.y, val)
32✔
244
    _fill!(d.weights, val)
2✔
245
    _fill!(d.extra, val)
1✔
246
    return d
1✔
247
end
248
_fill!(x::AbstractArray, val) = fill!(x, val)
128✔
249
_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x))
33✔
250
_fill!(::Nothing, val) = nothing
1✔
251
_fill!(x, val) = x
×
252

253
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc