• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9686354911

26 Jun 2024 08:31PM UTC coverage: 93.22% (-1.4%) from 94.617%
9686354911

Pull #326

github

web-flow
Merge 6f8229c9f into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

275 of 296 new or added lines in 17 files covered. (92.91%)

34 existing lines in 5 files now uncovered.

2530 of 2714 relevant lines covered (93.22%)

32081968.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.91
/src/InterfaceDynamicExpressions.jl
1
module InterfaceDynamicExpressionsModule
2

3
using Printf: @sprintf
4
using DynamicExpressions:
5
    DynamicExpressions as DE,
6
    AbstractOperatorEnum,
7
    OperatorEnum,
8
    GenericOperatorEnum,
9
    AbstractExpression,
10
    AbstractExpressionNode,
11
    ParametricExpression,
12
    Node,
13
    GraphNode
14
using DynamicQuantities: dimension, ustrip
15
using ..CoreModule: Options
16
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
17
using ..UtilsModule: subscriptify
18

19
import ..deprecate_varmap
20

21
"""
22
    eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...)
23

24
Evaluate a binary tree (equation) over a given input data matrix. The
25
operators contain all of the operators used. This function fuses doublets
26
and triplets of operations for lower memory usage.
27

28
This function can be represented by the following pseudocode:
29

30
```
31
function eval(current_node)
32
    if current_node is leaf
33
        return current_node.value
34
    elif current_node is degree 1
35
        return current_node.operator(eval(current_node.left_child))
36
    else
37
        return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
38
```
39
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
40
which speed up evaluation significantly.
41

42
# Arguments
43
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate.
44
- `X::AbstractArray`: The input data to evaluate the tree on.
45
- `options::Options`: Options used to define the operators used in the tree.
46

47
# Returns
48
- `(output, complete)::Tuple{AbstractVector, Bool}`: the result,
49
    which is a 1D array, as well as if the evaluation completed
50
    successfully (true/false). A `false` complete means an infinity
51
    or nan was encountered, and a large loss should be assigned
52
    to the equation.
53
"""
54
function DE.eval_tree_array(
213,824,296✔
55
    tree::Union{AbstractExpressionNode,AbstractExpression},
56
    X::AbstractMatrix,
57
    options::Options;
58
    kws...,
59
)
60
    A = expected_array_type(X)
177,920,048✔
61
    return DE.eval_tree_array(
131,190,581✔
62
        tree,
63
        X,
64
        DE.get_operators(tree, options);
65
        turbo=options.turbo,
66
        bumper=options.bumper,
67
        kws...,
68
    )::Tuple{A,Bool}
69
end
70
function DE.eval_tree_array(
10✔
71
    tree::ParametricExpression,
72
    X::AbstractMatrix,
73
    classes::AbstractVector{<:Integer},
74
    options::Options;
75
    kws...,
76
)
77
    A = expected_array_type(X)
8✔
78
    return DE.eval_tree_array(
6✔
79
        tree,
80
        X,
81
        classes,
82
        DE.get_operators(tree, options);
83
        turbo=options.turbo,
84
        bumper=options.bumper,
85
        kws...,
86
    )::Tuple{A,Bool}
87
end
88

89
# Improve type inference by telling Julia the expected array returned
90
function expected_array_type(X::AbstractArray)
214,909,387✔
91
    return typeof(similar(X, axes(X, 2)))
308,807,095✔
92
end
93

94
"""
95
    eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int)
96

97
Compute the forward derivative of an expression, using a similar
98
structure and optimization to eval_tree_array. `direction` is the index of a particular
99
variable in the expression. e.g., `direction=1` would indicate derivative with
100
respect to `x1`.
101

102
# Arguments
103

104
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
105
- `X::AbstractArray`: The data matrix, with each column being a data point.
106
- `options::Options`: The options containing the operators used to create the `tree`.
107
- `direction::Int`: The index of the variable to take the derivative with respect to.
108

109
# Returns
110

111
- `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation,
112
    the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
113
"""
114
function DE.eval_diff_tree_array(
144✔
115
    tree::Union{AbstractExpression,AbstractExpressionNode},
116
    X::AbstractArray,
117
    options::Options,
118
    direction::Int,
119
)
120
    A = expected_array_type(X)
168✔
121
    # TODO: Add `AbstractExpression` implementation in `Expression.jl`
122
    return DE.eval_diff_tree_array(
144✔
123
        DE.get_tree(tree), X, DE.get_operators(tree, options), direction
124
    )::Tuple{A,A,Bool}
125
end
126

127
"""
128
    eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false)
129

130
Compute the forward-mode derivative of an expression, using a similar
131
structure and optimization to eval_tree_array. `variable` specifies whether
132
we should take derivatives with respect to features (i.e., `X`), or with respect
133
to every constant in the expression.
134

135
# Arguments
136

137
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
138
- `X::AbstractArray`: The data matrix, with each column being a data point.
139
- `options::Options`: The options containing the operators used to create the `tree`.
140
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
141
    or with respect to every constant in the expression (`variable=false`).
142

143
# Returns
144

145
- `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation,
146
    the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
147
"""
148
function DE.eval_grad_tree_array(
154✔
149
    tree::Union{AbstractExpression,AbstractExpressionNode},
150
    X::AbstractArray,
151
    options::Options;
152
    kws...,
153
)
154
    A = expected_array_type(X)
112✔
155
    M = typeof(X)  # TODO: This won't work with StaticArrays!
84✔
156
    return DE.eval_grad_tree_array(
84✔
157
        tree, X, DE.get_operators(tree, options); kws...
158
    )::Tuple{A,M,Bool}
159
end
160

161
"""
162
    differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options)
163

164
Evaluate an expression tree in a way that can be auto-differentiated.
165
"""
166
function DE.differentiable_eval_tree_array(
7,224✔
167
    tree::Union{AbstractExpression,AbstractExpressionNode},
168
    X::AbstractArray,
169
    options::Options,
170
)
171
    A = expected_array_type(X)
8,428✔
172
    # TODO: Add `AbstractExpression` implementation in `Expression.jl`
173
    return DE.differentiable_eval_tree_array(
7,224✔
174
        DE.get_tree(tree), X, DE.get_operators(tree, options)
175
    )::Tuple{A,Bool}
176
end
177

178
const WILDCARD_UNIT_STRING = "[?]"
179

180
"""
181
    string_tree(tree::AbstractExpressionNode, options::Options; kws...)
182

183
Convert an equation to a string.
184

185
# Arguments
186

187
- `tree::AbstractExpressionNode`: The equation to convert to a string.
188
- `options::Options`: The options holding the definition of operators.
189
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
190
    to print for each feature.
191
"""
192
@inline function DE.string_tree(
687,113✔
193
    tree::Union{AbstractExpression,AbstractExpressionNode},
194
    options::Options;
195
    raw::Bool=true,
196
    X_sym_units=nothing,
197
    y_sym_units=nothing,
198
    variable_names=nothing,
199
    display_variable_names=variable_names,
200
    varMap=nothing,
201
    kws...,
202
)
203
    variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
486,515✔
204

205
    if raw
410,943✔
206
        tree = tree isa GraphNode ? convert(Node, tree) : tree
313,548✔
207
        return DE.string_tree(
313,552✔
208
            tree,
209
            DE.get_operators(tree, options);
210
            f_variable=string_variable_raw,
211
            variable_names,
212
        )
213
    end
214

215
    vprecision = vals[options.print_precision]
97,390✔
216
    if X_sym_units !== nothing || y_sym_units !== nothing
97,390✔
217
        return DE.string_tree(
12,692✔
218
            tree,
219
            DE.get_operators(tree, options);
220
            f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
27,173✔
221
            f_constant=let
222
                unit_placeholder =
16,699✔
223
                    options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING
224
                (val,) -> string_constant(val, vprecision, unit_placeholder)
39,452✔
225
            end,
226
            variable_names=display_variable_names,
227
            kws...,
228
        )
229
    else
230
        return DE.string_tree(
84,698✔
231
            tree,
232
            DE.get_operators(tree, options);
233
            f_variable=string_variable,
234
            f_constant=(val,) -> string_constant(val, vprecision, ""),
174,453✔
235
            variable_names=display_variable_names,
236
            kws...,
237
        )
238
    end
239
end
240
const vals = ntuple(Val, 8192)
241
function string_variable_raw(feature, variable_names)
571,379✔
242
    if variable_names === nothing || feature > length(variable_names)
600,931✔
243
        return "x" * string(feature)
657,239✔
244
    else
245
        return variable_names[feature]
173,330✔
246
    end
247
end
248
function string_variable(feature, variable_names, variable_units=nothing)
168,487✔
249
    base = if variable_names === nothing || feature > length(variable_names)
248,583✔
250
        "x" * subscriptify(feature)
24✔
251
    else
252
        variable_names[feature]
354,556✔
253
    end
254
    if variable_units !== nothing
177,298✔
255
        base *= format_dimensions(variable_units[feature])
16,391✔
256
    end
257
    return base
177,298✔
258
end
259
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
107,121✔
260
    if typeof(val) <: Real
128,701✔
261
        return sprint_precision(val, Val(precision)) * unit_placeholder
99,698✔
262
    else
263
        return "(" * string(val) * ")" * unit_placeholder
33,238✔
264
    end
265
end
266
function format_dimensions(::Nothing)
71,617✔
267
    return ""
74,191✔
268
end
269
function format_dimensions(u)
29,059✔
270
    if isone(ustrip(u))
29,059✔
271
        dim = dimension(u)
29,047✔
272
        if iszero(dim)
40,266✔
273
            return ""
6,431✔
274
        else
275
            return "[" * string(dim) * "]"
22,616✔
276
        end
277
    else
278
        return "[" * string(u) * "]"
12✔
279
    end
280
end
281
@generated function sprint_precision(x, ::Val{precision}) where {precision}
99,698✔
282
    fmt_string = "%.$(precision)g"
38✔
283
    return :(@sprintf($fmt_string, x))
38✔
284
end
285

286
"""
287
    print_tree(tree::AbstractExpressionNode, options::Options; kws...)
288

289
Print an equation
290

291
# Arguments
292

293
- `tree::AbstractExpressionNode`: The equation to convert to a string.
294
- `options::Options`: The options holding the definition of operators.
295
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
296
    to print for each feature.
297
"""
298
function DE.print_tree(
11✔
299
    tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
300
)
301
    return DE.print_tree(tree, DE.get_operators(tree, options); kws...)
7✔
302
end
NEW
303
function DE.print_tree(
×
304
    io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
305
)
NEW
306
    return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...)
×
307
end
308

309
"""
310
    convert(::Type{<:AbstractExpressionNode{T}}, tree::AbstractExpressionNode, options::Options; kws...) where {T}
311

312
Convert an equation to a different base type `T`.
313
"""
314
function Base.convert(
×
315
    ::Type{N}, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options
316
) where {T,N<:AbstractExpressionNode{T}}
NEW
317
    return convert(N, tree, DE.get_operators(tree, options))
×
318
end
319

320
"""
321
    @extend_operators options
322

323
Extends all operators defined in this options object to work on the
324
`AbstractExpressionNode` type. While by default this is already done for operators defined
325
in `Base` when you create an options and pass `define_helper_functions=true`,
326
this does not apply to the user-defined operators. Thus, to do so, you must
327
apply this macro to the operator enum in the same module you have the operators
328
defined.
329
"""
330
macro extend_operators(options)
84✔
331
    operators = :($(options).operators)
84✔
332
    type_requirements = Options
84✔
333
    @gensym alias_operators
84✔
334
    return quote
84✔
335
        if !isa($(options), $type_requirements)
336
            error("You must pass an options type to `@extend_operators`.")
337
        end
338
        $alias_operators = $define_alias_operators($operators)
339
        $(DE).@extend_operators $alias_operators
340
    end |> esc
341
end
342
function define_alias_operators(operators)
168✔
343
    # We undo some of the aliases so that the user doesn't need to use, e.g.,
344
    # `safe_pow(x1, 1.5)`. They can use `x1 ^ 1.5` instead.
345
    constructor = isa(operators, OperatorEnum) ? OperatorEnum : GenericOperatorEnum
168✔
346
    return constructor(;
168✔
347
        binary_operators=inverse_binopmap.(operators.binops),
348
        unary_operators=inverse_unaopmap.(operators.unaops),
349
        define_helper_functions=false,
350
        empty_old_operators=false,
351
    )
352
end
353

354
function (tree::Union{AbstractExpression,AbstractExpressionNode})(
22✔
355
    X, options::Options; kws...
356
)
357
    return tree(
14✔
358
        X,
359
        DE.get_operators(tree, options);
360
        turbo=options.turbo,
361
        bumper=options.bumper,
362
        kws...,
363
    )
364
end
NEW
365
function DE.EvaluationHelpersModule._grad_evaluator(
×
366
    tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws...
367
)
NEW
368
    return DE.EvaluationHelpersModule._grad_evaluator(
×
369
        tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws...
370
    )
371
end
372

NEW
373
combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
×
374
# TODO: Move this definition to DynamicExpressions.jl
375

376
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc