• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9704727222

27 Jun 2024 11:01PM UTC coverage: 95.922% (+1.3%) from 94.617%
9704727222

Pull #326

github

web-flow
Merge 1f104aaf8 into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

301 of 307 new or added lines in 17 files covered. (98.05%)

1 existing line in 1 file now uncovered.

2611 of 2722 relevant lines covered (95.92%)

35611300.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.67
/src/InterfaceDynamicExpressions.jl
1
module InterfaceDynamicExpressionsModule
2

3
using Printf: @sprintf
4
using DynamicExpressions:
5
    DynamicExpressions as DE,
6
    AbstractOperatorEnum,
7
    OperatorEnum,
8
    GenericOperatorEnum,
9
    AbstractExpression,
10
    AbstractExpressionNode,
11
    ParametricExpression,
12
    Node,
13
    GraphNode
14
using DynamicQuantities: dimension, ustrip
15
using ..CoreModule: Options
16
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
17
using ..UtilsModule: subscriptify
18

19
import ..deprecate_varmap
20

21
"""
22
    eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...)
23

24
Evaluate a binary tree (equation) over a given input data matrix. The
25
operators contain all of the operators used. This function fuses doublets
26
and triplets of operations for lower memory usage.
27

28
This function can be represented by the following pseudocode:
29

30
```
31
function eval(current_node)
32
    if current_node is leaf
33
        return current_node.value
34
    elif current_node is degree 1
35
        return current_node.operator(eval(current_node.left_child))
36
    else
37
        return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
38
```
39
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
40
which speed up evaluation significantly.
41

42
# Arguments
43
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate.
44
- `X::AbstractArray`: The input data to evaluate the tree on.
45
- `options::Options`: Options used to define the operators used in the tree.
46

47
# Returns
48
- `(output, complete)::Tuple{AbstractVector, Bool}`: the result,
49
    which is a 1D array, as well as if the evaluation completed
50
    successfully (true/false). A `false` complete means an infinity
51
    or nan was encountered, and a large loss should be assigned
52
    to the equation.
53
"""
54
function DE.eval_tree_array(
174,786,170✔
55
    tree::Union{AbstractExpressionNode,AbstractExpression},
56
    X::AbstractMatrix,
57
    options::Options;
58
    kws...,
59
)
60
    A = expected_array_type(X)
139,373,639✔
61
    return DE.eval_tree_array(
105,213,776✔
62
        tree,
63
        X,
64
        DE.get_operators(tree, options);
65
        turbo=options.turbo,
66
        bumper=options.bumper,
67
        kws...,
68
    )::Tuple{A,Bool}
69
end
70
function DE.eval_tree_array(
20✔
71
    tree::ParametricExpression,
72
    X::AbstractMatrix,
73
    classes::AbstractVector{<:Integer},
74
    options::Options;
75
    kws...,
76
)
77
    A = expected_array_type(X)
16✔
78
    return DE.eval_tree_array(
12✔
79
        tree,
80
        X,
81
        classes,
82
        DE.get_operators(tree, options);
83
        turbo=options.turbo,
84
        bumper=options.bumper,
85
        kws...,
86
    )::Tuple{A,Bool}
87
end
88

89
# Improve type inference by telling Julia the expected array returned
90
function expected_array_type(X::AbstractArray)
177,101,369✔
91
    return typeof(similar(X, axes(X, 2)))
246,812,610✔
92
end
93

94
"""
95
    eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int)
96

97
Compute the forward derivative of an expression, using a similar
98
structure and optimization to eval_tree_array. `direction` is the index of a particular
99
variable in the expression. e.g., `direction=1` would indicate derivative with
100
respect to `x1`.
101

102
# Arguments
103

104
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
105
- `X::AbstractArray`: The data matrix, with each column being a data point.
106
- `options::Options`: The options containing the operators used to create the `tree`.
107
- `direction::Int`: The index of the variable to take the derivative with respect to.
108

109
# Returns
110

111
- `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation,
112
    the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
113
"""
114
function DE.eval_diff_tree_array(
144✔
115
    tree::Union{AbstractExpression,AbstractExpressionNode},
116
    X::AbstractArray,
117
    options::Options,
118
    direction::Int,
119
)
120
    A = expected_array_type(X)
168✔
121
    # TODO: Add `AbstractExpression` implementation in `Expression.jl`
122
    return DE.eval_diff_tree_array(
144✔
123
        DE.get_tree(tree), X, DE.get_operators(tree, options), direction
124
    )::Tuple{A,A,Bool}
125
end
126

127
"""
128
    eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false)
129

130
Compute the forward-mode derivative of an expression, using a similar
131
structure and optimization to eval_tree_array. `variable` specifies whether
132
we should take derivatives with respect to features (i.e., `X`), or with respect
133
to every constant in the expression.
134

135
# Arguments
136

137
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
138
- `X::AbstractArray`: The data matrix, with each column being a data point.
139
- `options::Options`: The options containing the operators used to create the `tree`.
140
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
141
    or with respect to every constant in the expression (`variable=false`).
142

143
# Returns
144

145
- `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation,
146
    the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
147
"""
148
function DE.eval_grad_tree_array(
154✔
149
    tree::Union{AbstractExpression,AbstractExpressionNode},
150
    X::AbstractArray,
151
    options::Options;
152
    kws...,
153
)
154
    A = expected_array_type(X)
112✔
155
    M = typeof(X)  # TODO: This won't work with StaticArrays!
84✔
156
    return DE.eval_grad_tree_array(
84✔
157
        tree, X, DE.get_operators(tree, options); kws...
158
    )::Tuple{A,M,Bool}
159
end
160

161
"""
162
    differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options)
163

164
Evaluate an expression tree in a way that can be auto-differentiated.
165
"""
166
function DE.differentiable_eval_tree_array(
7,224✔
167
    tree::Union{AbstractExpression,AbstractExpressionNode},
168
    X::AbstractArray,
169
    options::Options,
170
)
171
    A = expected_array_type(X)
8,428✔
172
    # TODO: Add `AbstractExpression` implementation in `Expression.jl`
173
    return DE.differentiable_eval_tree_array(
7,224✔
174
        DE.get_tree(tree), X, DE.get_operators(tree, options)
175
    )::Tuple{A,Bool}
176
end
177

178
const WILDCARD_UNIT_STRING = "[?]"
179

180
"""
181
    string_tree(tree::AbstractExpressionNode, options::Options; kws...)
182

183
Convert an equation to a string.
184

185
# Arguments
186

187
- `tree::AbstractExpressionNode`: The equation to convert to a string.
188
- `options::Options`: The options holding the definition of operators.
189
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
190
    to print for each feature.
191
"""
192
@inline function DE.string_tree(
959,232✔
193
    tree::Union{AbstractExpression,AbstractExpressionNode},
194
    options::Options;
195
    raw::Bool=true,
196
    X_sym_units=nothing,
197
    y_sym_units=nothing,
198
    variable_names=nothing,
199
    display_variable_names=variable_names,
200
    varMap=nothing,
201
    kws...,
202
)
203
    variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
670,653✔
204

205
    if raw
569,333✔
206
        tree = tree isa GraphNode ? convert(Node, tree) : tree
449,358✔
207
        return DE.string_tree(
449,366✔
208
            tree,
209
            DE.get_operators(tree, options);
210
            f_variable=string_variable_raw,
211
            variable_names,
212
        )
213
    end
214

215
    vprecision = vals[options.print_precision]
119,974✔
216
    if X_sym_units !== nothing || y_sym_units !== nothing
119,974✔
217
        return DE.string_tree(
12,233✔
218
            tree,
219
            DE.get_operators(tree, options);
220
            f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
24,706✔
221
            f_constant=let
222
                unit_placeholder =
16,533✔
223
                    options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING
224
                (val,) -> string_constant(val, vprecision, unit_placeholder)
35,450✔
225
            end,
226
            variable_names=display_variable_names,
227
            kws...,
228
        )
229
    else
230
        return DE.string_tree(
107,741✔
231
            tree,
232
            DE.get_operators(tree, options);
233
            f_variable=string_variable,
234
            f_constant=(val,) -> string_constant(val, vprecision, ""),
272,776✔
235
            variable_names=display_variable_names,
236
            kws...,
237
        )
238
    end
239
end
240
const vals = ntuple(Val, 8192)
241
function string_variable_raw(feature, variable_names)
859,496✔
242
    if variable_names === nothing || feature > length(variable_names)
957,185✔
243
        return "x" * string(feature)
660,093✔
244
    else
245
        return variable_names[feature]
522,965✔
246
    end
247
end
248
function string_variable(feature, variable_names, variable_units=nothing)
239,484✔
249
    base = if variable_names === nothing || feature > length(variable_names)
375,488✔
250
        "x" * subscriptify(feature)
24✔
251
    else
252
        variable_names[feature]
504,710✔
253
    end
254
    if variable_units !== nothing
252,375✔
255
        base *= format_dimensions(variable_units[feature])
14,886✔
256
    end
257
    return base
252,375✔
258
end
259
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
153,199✔
260
    if typeof(val) <: Real
196,754✔
261
        return sprint_precision(val, Val(precision)) * unit_placeholder
152,873✔
262
    else
263
        return "(" * string(val) * ")" * unit_placeholder
45,534✔
264
    end
265
end
266
function format_dimensions(::Nothing)
87,970✔
267
    return ""
91,676✔
268
end
269
function format_dimensions(u)
27,095✔
270
    if isone(ustrip(u))
27,095✔
271
        dim = dimension(u)
27,083✔
272
        if iszero(dim)
37,390✔
273
            return ""
5,858✔
274
        else
275
            return "[" * string(dim) * "]"
21,225✔
276
        end
277
    else
278
        return "[" * string(u) * "]"
12✔
279
    end
280
end
281
@generated function sprint_precision(x, ::Val{precision}) where {precision}
152,873✔
282
    fmt_string = "%.$(precision)g"
39✔
283
    return :(@sprintf($fmt_string, x))
39✔
284
end
285

286
"""
287
    print_tree(tree::AbstractExpressionNode, options::Options; kws...)
288

289
Print an equation
290

291
# Arguments
292

293
- `tree::AbstractExpressionNode`: The equation to convert to a string.
294
- `options::Options`: The options holding the definition of operators.
295
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
296
    to print for each feature.
297
"""
298
function DE.print_tree(
33✔
299
    tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
300
)
301
    return DE.print_tree(tree, DE.get_operators(tree, options); kws...)
21✔
302
end
303
function DE.print_tree(
33✔
304
    io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
305
)
306
    return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...)
21✔
307
end
308

309
"""
310
    @extend_operators options
311

312
Extends all operators defined in this options object to work on the
313
`AbstractExpressionNode` type. While by default this is already done for operators defined
314
in `Base` when you create an options and pass `define_helper_functions=true`,
315
this does not apply to the user-defined operators. Thus, to do so, you must
316
apply this macro to the operator enum in the same module you have the operators
317
defined.
318
"""
319
macro extend_operators(options)
84✔
320
    operators = :($(options).operators)
84✔
321
    type_requirements = Options
84✔
322
    @gensym alias_operators
84✔
323
    return quote
84✔
324
        if !isa($(options), $type_requirements)
325
            error("You must pass an options type to `@extend_operators`.")
326
        end
327
        $alias_operators = $define_alias_operators($operators)
328
        $(DE).@extend_operators $alias_operators
329
    end |> esc
330
end
331
function define_alias_operators(operators)
168✔
332
    # We undo some of the aliases so that the user doesn't need to use, e.g.,
333
    # `safe_pow(x1, 1.5)`. They can use `x1 ^ 1.5` instead.
334
    constructor = isa(operators, OperatorEnum) ? OperatorEnum : GenericOperatorEnum
168✔
335
    return constructor(;
168✔
336
        binary_operators=inverse_binopmap.(operators.binops),
337
        unary_operators=inverse_unaopmap.(operators.unaops),
338
        define_helper_functions=false,
339
        empty_old_operators=false,
340
    )
341
end
342

343
function (tree::Union{AbstractExpression,AbstractExpressionNode})(
33✔
344
    X, options::Options; kws...
345
)
346
    return tree(
21✔
347
        X,
348
        DE.get_operators(tree, options);
349
        turbo=options.turbo,
350
        bumper=options.bumper,
351
        kws...,
352
    )
353
end
354
function DE.EvaluationHelpersModule._grad_evaluator(
11✔
355
    tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws...
356
)
357
    return DE.EvaluationHelpersModule._grad_evaluator(
7✔
358
        tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws...
359
    )
360
end
361

NEW
362
combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
×
363
# TODO: Move this definition to DynamicExpressions.jl
364

365
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc