• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 9639805727

24 Jun 2024 05:00AM UTC coverage: 94.475% (-0.1%) from 94.617%
9639805727

Pull #326

github

web-flow
Merge 3ba1556f8 into ceddaa424
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

239 of 250 new or added lines in 15 files covered. (95.6%)

4 existing lines in 3 files now uncovered.

2548 of 2697 relevant lines covered (94.48%)

46539295.05 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.04
/src/InterfaceDynamicExpressions.jl
1
module InterfaceDynamicExpressionsModule
2

3
using Printf: @sprintf
4
using DynamicExpressions: DynamicExpressions
5
using DynamicExpressions:
6
    OperatorEnum, GenericOperatorEnum, AbstractExpressionNode, Node, GraphNode
7
using DynamicExpressions.StringsModule: needs_brackets
8
using DynamicQuantities: dimension, ustrip
9
using ..CoreModule: Options
10
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
11
using ..UtilsModule: subscriptify
12

13
import DynamicExpressions:
14
    eval_tree_array,
15
    eval_diff_tree_array,
16
    eval_grad_tree_array,
17
    print_tree,
18
    string_tree,
19
    differentiable_eval_tree_array
20

21
import ..deprecate_varmap
22

23
"""
24
    eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...)
25

26
Evaluate a binary tree (equation) over a given input data matrix. The
27
operators contain all of the operators used. This function fuses doublets
28
and triplets of operations for lower memory usage.
29

30
This function can be represented by the following pseudocode:
31

32
```
33
function eval(current_node)
34
    if current_node is leaf
35
        return current_node.value
36
    elif current_node is degree 1
37
        return current_node.operator(eval(current_node.left_child))
38
    else
39
        return current_node.operator(eval(current_node.left_child), eval(current_node.right_child))
40
```
41
The bulk of the code is for optimizations and pre-emptive NaN/Inf checks,
42
which speed up evaluation significantly.
43

44
# Arguments
45
- `tree::AbstractExpressionNode`: The root node of the tree to evaluate.
46
- `X::AbstractArray`: The input data to evaluate the tree on.
47
- `options::Options`: Options used to define the operators used in the tree.
48

49
# Returns
50
- `(output, complete)::Tuple{AbstractVector, Bool}`: the result,
51
    which is a 1D array, as well as if the evaluation completed
52
    successfully (true/false). A `false` complete means an infinity
53
    or nan was encountered, and a large loss should be assigned
54
    to the equation.
55
"""
56
function eval_tree_array(
278,764,974✔
57
    tree::AbstractExpressionNode, X::AbstractMatrix, options::Options; kws...
58
)
59
    A = expected_array_type(X)
227,246,642✔
60
    return eval_tree_array(
170,777,110✔
61
        tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
62
    )::Tuple{A,Bool}
63
end
64

65
# Improve type inference by telling Julia the expected array returned
66
function expected_array_type(X::AbstractArray)
284,782,979✔
67
    return typeof(similar(X, axes(X, 2)))
400,036,357✔
68
end
69

70
"""
71
    eval_diff_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int)
72

73
Compute the forward derivative of an expression, using a similar
74
structure and optimization to eval_tree_array. `direction` is the index of a particular
75
variable in the expression. e.g., `direction=1` would indicate derivative with
76
respect to `x1`.
77

78
# Arguments
79

80
- `tree::AbstractExpressionNode`: The expression tree to evaluate.
81
- `X::AbstractArray`: The data matrix, with each column being a data point.
82
- `options::Options`: The options containing the operators used to create the `tree`.
83
- `direction::Int`: The index of the variable to take the derivative with respect to.
84

85
# Returns
86

87
- `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation,
88
    the derivative, and whether the evaluation completed as normal (or encountered a nan or inf).
89
"""
90
function eval_diff_tree_array(
144✔
91
    tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int
92
)
93
    A = expected_array_type(X)
168✔
94
    return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool}
144✔
95
end
96

97
"""
98
    eval_grad_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; variable::Bool=false)
99

100
Compute the forward-mode derivative of an expression, using a similar
101
structure and optimization to eval_tree_array. `variable` specifies whether
102
we should take derivatives with respect to features (i.e., `X`), or with respect
103
to every constant in the expression.
104

105
# Arguments
106

107
- `tree::AbstractExpressionNode`: The expression tree to evaluate.
108
- `X::AbstractArray`: The data matrix, with each column being a data point.
109
- `options::Options`: The options containing the operators used to create the `tree`.
110
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
111
    or with respect to every constant in the expression (`variable=false`).
112

113
# Returns
114

115
- `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation,
116
    the gradient, and whether the evaluation completed as normal (or encountered a nan or inf).
117
"""
118
function eval_grad_tree_array(
154✔
119
    tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...
120
)
121
    A = expected_array_type(X)
112✔
122
    M = typeof(X)  # TODO: This won't work with StaticArrays!
84✔
123
    return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool}
84✔
124
end
125

126
"""
127
    differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options)
128

129
Evaluate an expression tree in a way that can be auto-differentiated.
130
"""
131
function differentiable_eval_tree_array(
7,224✔
132
    tree::AbstractExpressionNode, X::AbstractArray, options::Options
133
)
134
    A = expected_array_type(X)
8,428✔
135
    return differentiable_eval_tree_array(tree, X, options.operators)::Tuple{A,Bool}
7,224✔
136
end
137

138
const WILDCARD_UNIT_STRING = "[?]"
139

140
"""
141
    string_tree(tree::AbstractExpressionNode, options::Options; kws...)
142

143
Convert an equation to a string.
144

145
# Arguments
146

147
- `tree::AbstractExpressionNode`: The equation to convert to a string.
148
- `options::Options`: The options holding the definition of operators.
149
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
150
    to print for each feature.
151
"""
152
@inline function string_tree(
940,085✔
153
    tree::AbstractExpressionNode,
154
    options::Options;
155
    raw::Bool=true,
156
    X_sym_units=nothing,
157
    y_sym_units=nothing,
158
    variable_names=nothing,
159
    display_variable_names=variable_names,
160
    varMap=nothing,
161
    kws...,
162
)
163
    variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
600,191✔
164

165
    if raw
511,529✔
166
        tree = tree isa GraphNode ? convert(Node, tree) : tree
403,108✔
167
        return string_tree(
403,110✔
168
            tree, options.operators; f_variable=string_variable_raw, variable_names
169
        )
170
    end
171

172
    vprecision = vals[options.print_precision]
108,433✔
173
    if X_sym_units !== nothing || y_sym_units !== nothing
108,433✔
174
        return string_tree(
11,850✔
175
            tree,
176
            options.operators;
177
            f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units),
24,007✔
178
            f_constant=let
179
                unit_placeholder =
15,757✔
180
                    options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING
181
                (val,) -> string_constant(val, vprecision, unit_placeholder)
33,913✔
182
            end,
183
            variable_names=display_variable_names,
184
            kws...,
185
        )
186
    else
187
        return string_tree(
96,583✔
188
            tree,
189
            options.operators;
190
            f_variable=string_variable,
191
            f_constant=(val,) -> string_constant(val, vprecision, ""),
218,158✔
192
            variable_names=display_variable_names,
193
            kws...,
194
        )
195
    end
196
end
197
const vals = ntuple(Val, 8192)
198
function string_variable_raw(feature, variable_names)
737,208✔
199
    if variable_names === nothing || feature > length(variable_names)
815,715✔
200
        return "x" * string(feature)
579,149✔
201
    else
202
        return variable_names[feature]
452,310✔
203
    end
204
end
205
function string_variable(feature, variable_names, variable_units=nothing)
188,560✔
206
    base = if variable_names === nothing || feature > length(variable_names)
281,559✔
207
        "x" * subscriptify(feature)
24✔
208
    else
209
        variable_names[feature]
394,698✔
210
    end
211
    if variable_units !== nothing
197,369✔
212
        base *= format_dimensions(variable_units[feature])
14,451✔
213
    end
214
    return base
197,369✔
215
end
216
function string_constant(val, ::Val{precision}, unit_placeholder) where {precision}
128,522✔
217
    if typeof(val) <: Real
156,376✔
218
        return sprint_precision(val, Val(precision)) * unit_placeholder
134,091✔
219
    else
220
        return "(" * string(val) * ")" * unit_placeholder
24,173✔
221
    end
222
end
223
function format_dimensions(::Nothing)
108,338✔
224
    return ""
110,553✔
225
end
226
function format_dimensions(u)
26,277✔
227
    if isone(ustrip(u))
26,277✔
228
        dim = dimension(u)
26,265✔
229
        if iszero(dim)
36,643✔
230
            return ""
5,859✔
231
        else
232
            return "[" * string(dim) * "]"
20,406✔
233
        end
234
    else
235
        return "[" * string(u) * "]"
12✔
236
    end
237
end
238
@generated function sprint_precision(x, ::Val{precision}) where {precision}
134,091✔
239
    fmt_string = "%.$(precision)g"
45✔
240
    return :(@sprintf($fmt_string, x))
45✔
241
end
242

243
"""
244
    print_tree(tree::AbstractExpressionNode, options::Options; kws...)
245

246
Print an equation
247

248
# Arguments
249

250
- `tree::AbstractExpressionNode`: The equation to convert to a string.
251
- `options::Options`: The options holding the definition of operators.
252
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
253
    to print for each feature.
254
"""
UNCOV
255
function print_tree(tree::AbstractExpressionNode, options::Options; kws...)
×
UNCOV
256
    return print_tree(tree, options.operators; kws...)
×
257
end
258
function print_tree(io::IO, tree::AbstractExpressionNode, options::Options; kws...)
×
259
    return print_tree(io, tree, options.operators; kws...)
×
260
end
261

262
"""
263
    convert(::Type{<:AbstractExpressionNode{T}}, tree::AbstractExpressionNode, options::Options; kws...) where {T}
264

265
Convert an equation to a different base type `T`.
266
"""
267
function Base.convert(
×
268
    ::Type{N}, tree::AbstractExpressionNode, options::Options
269
) where {T,N<:AbstractExpressionNode{T}}
270
    return convert(N, tree, options.operators)
×
271
end
272

273
"""
274
    @extend_operators options
275

276
Extends all operators defined in this options object to work on the
277
`AbstractExpressionNode` type. While by default this is already done for operators defined
278
in `Base` when you create an options and pass `define_helper_functions=true`,
279
this does not apply to the user-defined operators. Thus, to do so, you must
280
apply this macro to the operator enum in the same module you have the operators
281
defined.
282
"""
283
macro extend_operators(options)
84✔
284
    operators = :($(options).operators)
84✔
285
    type_requirements = Options
84✔
286
    @gensym alias_operators
84✔
287
    return quote
84✔
288
        if !isa($(options), $type_requirements)
289
            error("You must pass an options type to `@extend_operators`.")
290
        end
291
        $alias_operators = $define_alias_operators($operators)
292
        $(DynamicExpressions).@extend_operators $alias_operators
293
    end |> esc
294
end
295
function define_alias_operators(operators)
168✔
296
    # We undo some of the aliases so that the user doesn't need to use, e.g.,
297
    # `safe_pow(x1, 1.5)`. They can use `x1 ^ 1.5` instead.
298
    constructor = isa(operators, OperatorEnum) ? OperatorEnum : GenericOperatorEnum
168✔
299
    return constructor(;
168✔
300
        binary_operators=inverse_binopmap.(operators.binops),
301
        unary_operators=inverse_unaopmap.(operators.unaops),
302
        define_helper_functions=false,
303
        empty_old_operators=false,
304
    )
305
end
306

307
function (tree::AbstractExpressionNode)(X, options::Options; kws...)
22✔
308
    return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...)
14✔
309
end
310
function DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
×
311
    tree::AbstractExpressionNode, X, options::Options; kws...
312
)
313
    return DynamicExpressions.EvaluationHelpersModule._grad_evaluator(
×
314
        tree, X, options.operators; turbo=options.turbo, kws...
315
    )
316
end
317

318
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc