• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.0
/cascade/auditor.py
1
"""Classes used to assess the quality of dynamics performed by surrogate models
2
and suggest how to improve them"""
3

4
import numpy as np
1✔
5
from scipy.stats import multivariate_normal
1✔
6
from ase import Atoms
1✔
7

8

9
class BaseAuditor:
1✔
10
    """Auditors to assess trustworthiness a trajectory generated by an ML surrogate
11

12
    Auditors are designed to take a segment of a trajectory generated by a
13
    machine learned surrogate and return two values:
14
        1) a score (float) that tells us how much we should trust the trajectory
15
        2) a set of frames in the trajectory that we may wish to investigate further
16
           to make this decision. This can be accomplished by, for example,
17
           running higher fidelity calculations on these frames if the trust for
18
           this trajectory segment is low.
19
    """
20

21
    def audit(self,
1✔
22
              atoms: list[Atoms],
23
              n_audits: int,
24
              sort_audits: bool = False) -> tuple[float, list[int]]:
25
        """Estimate the probability that any atom is off more than threshold and
26
           the frames with the highest UQ
27

28
        Args:
29
            atoms: list of ase atoms. Should have atoms.info['forces_ens']
30
                   set by the cascade EnsembleCalculator
31
            n_audits: number of frames to return
32
            sort_audits: whether to return frames in decreasing UQ order.
33
                         If false, uses argpartition which is linear time
34
        Returns:
35
            p_any: an estimate of the probability that the forces on any atom in
36
                   any frame is above the threshold
37
            audit_frames: indices of the frames with the highest ensemble std,
38
                          aggregated by max
39
        """
40

41
        raise NotImplementedError()
×
42

43

44
class ForceThresholdAuditor(BaseAuditor):
1✔
45
    """Determines the likelihood all calculations have error below the
46
    threshold, based on ensemble variance
47

48
    Args:
49
        threshold: in units of force in the simulation
50
        n_sample: number of samples to take when estimating the probability
51
                  of error being less than the threshold
52
    """
53

54
    def __init__(self,
1✔
55
                 threshold: float = 1,
56
                 n_sample: int = 100):
57
        self.threshold = threshold
×
58
        self.n_sample = n_sample
×
59

60
    def audit(self,
1✔
61
              atoms: list[Atoms],
62
              n_audits: int,
63
              sort_audits: bool = False) -> tuple[float, list[int]]:
64

65
        force_preds = np.asarray([a.calc.results['forces_ens'] for a in atoms])
×
66

67
        # flatten the predictions we have one dim for the ensemble and one for the rest
68
        # last dim is spatial (3)
69
        n_frames, n_models, n_atoms, _ = force_preds.shape
×
70
        force_preds_flat = force_preds.reshape((n_models, n_frames * n_atoms * 3))
×
71

72
        # build the error distribution
73
        force_cov = np.cov(force_preds_flat.T)
×
74
        force_err_dist = multivariate_normal(cov=force_cov, allow_singular=True)
×
75

76
        # take a sample from the error distribution
77
        force_var_samples_flat = force_err_dist.rvs(self.n_sample)
×
78
        force_var_samples = force_var_samples_flat.reshape((self.n_sample, n_atoms * n_frames, 3))
×
79

80
        # find the magnitude
81
        force_samples_mag = np.linalg.norm(force_var_samples, axis=-1)
×
82

83
        # the probability that any 1 magnitude exceeds the threshold
84
        p_any = (force_samples_mag > self.threshold).any(axis=1).mean()
×
85

86
        # get the frames with the highest UQ
87
        # take std over ens dimension and then max over remaining atom, spatial dims
88
        max_uq_by_frame = force_preds.std(1).max((1, 2))
×
89

90
        # get the worst frames
91
        if sort_audits:
×
92
            top_uq_ix = np.argsort(max_uq_by_frame)[::-1][:n_audits]
×
93
        else:
94
            top_uq_ix = np.argpartition(max_uq_by_frame, -n_audits)[-n_audits:]
×
95
        return p_any, top_uq_ix
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc