• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

MilesCranmer / SymbolicRegression.jl / 11204590927

06 Oct 2024 07:29PM UTC coverage: 95.808% (+1.2%) from 94.617%
11204590927

Pull #326

github

web-flow
Merge e2b369ea7 into 8f67533b9
Pull Request #326: BREAKING: Change expression types to `DynamicExpressions.Expression` (from `DynamicExpressions.Node`)

466 of 482 new or added lines in 24 files covered. (96.68%)

1 existing line in 1 file now uncovered.

2651 of 2767 relevant lines covered (95.81%)

73863189.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/ext/SymbolicRegressionEnzymeExt.jl
1
module SymbolicRegressionEnzymeExt
2

3
using SymbolicRegression.LossFunctionsModule: eval_loss
4
using DynamicExpressions:
5
    AbstractExpression,
6
    AbstractExpressionNode,
7
    get_scalar_constants,
8
    set_scalar_constants!,
9
    extract_gradient,
10
    with_contents,
11
    get_contents
12
using ADTypes: AutoEnzyme
13
using Enzyme: autodiff, Reverse, Active, Const, Duplicated
14

15
import SymbolicRegression.ConstantOptimizationModule: GradEvaluator
16

17
# We prepare a copy of the tree and all arrays
18
function GradEvaluator(f::F, backend::AE) where {F,AE<:AutoEnzyme}
3✔
19
    storage_tree = copy(f.tree)
3✔
20
    _, storage_refs = get_scalar_constants(storage_tree)
3✔
21
    storage_dataset = deepcopy(f.dataset)
3✔
22
    # TODO: It is super inefficient to deepcopy; how can we skip this
23
    return GradEvaluator(f, backend, (; storage_tree, storage_refs, storage_dataset))
3✔
24
end
25

26
function evaluator(tree, dataset, options, idx, output)
27
    output[] = eval_loss(tree, dataset, options; regularization=false, idx=idx)
6✔
28
    return nothing
29
end
30

31
with_stacksize(f::F, n) where {F} = fetch(schedule(Task(f, n)))
3✔
32

33
function (g::GradEvaluator{<:Any,<:AutoEnzyme})(_, G, x::AbstractVector{T}) where {T}
3✔
34
    set_scalar_constants!(g.f.tree, x, g.f.refs)
6✔
35
    set_scalar_constants!(g.extra.storage_tree, zero(x), g.extra.storage_refs)
6✔
36
    fill!(g.extra.storage_dataset, 0)
6✔
37

38
    output = [zero(T)]
3✔
39
    doutput = [one(T)]
3✔
40

41
    with_stacksize(32 * 1024 * 1024) do
3✔
42
        autodiff(
3✔
43
            Reverse,
44
            evaluator,
45
            Duplicated(g.f.tree, g.extra.storage_tree),
46
            Duplicated(g.f.dataset, g.extra.storage_dataset),
47
            Const(g.f.options),
48
            Const(g.f.idx),
49
            Duplicated(output, doutput),
50
        )
51
    end
52

NEW
53
    if G !== nothing
×
54
        # TODO: This is redundant since we already have the references.
55
        # Should just be able to extract from the references directly.
NEW
56
        G .= first(get_scalar_constants(g.extra.storage_tree))
×
57
    end
NEW
58
    return output[]
×
59
end
60

61
end
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc