• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.41
/cascade/dynamics.py
1
"""Interface to run a dynamics protocol using a learned forcefield"""
2
from pathlib import Path
1✔
3
from typing import Any, Callable, Optional
1✔
4
from dataclasses import dataclass, field
1✔
5
from tempfile import TemporaryDirectory
1✔
6
from uuid import uuid4
1✔
7

8
from ase import Atoms
1✔
9
from ase.io import Trajectory, read
1✔
10
from ase.optimize.optimize import Dynamics, Optimizer
1✔
11

12
from cascade.learning.base import BaseLearnableForcefield, State
1✔
13

14

15
# TODO (wardlt): Consider having the state of the `Dynamics` class stored here. Some dynamics classes (e.g., NPT) have state
16
@dataclass
1✔
17
class Progress:
1✔
18
    """The progress of an atomic state through a dynamics protocol"""
19

20
    atoms: Atoms
1✔
21
    """Current atomic structure"""
1✔
22
    name: str = field(default_factory=lambda: str(uuid4()))
1✔
23
    """Name assigned this trajectory. Defaults to a UUID"""
1✔
24
    stage: int = 0
1✔
25
    """Current stage within the overall :class:`DynamicsProtocol`."""
1✔
26
    timestep: int = 0
1✔
27
    """Timestep within the current stage"""
1✔
28

29
    def update(self, new_atoms: Atoms, steps_completed: int, finished_stage: bool):
1✔
30
        """Update the state of the current progress
31

32
        Args:
33
            new_atoms: Current structure
34
            steps_completed: Number of steps completed since the last progress update
35
            finished_stage: Whether the structure has finished a step within the overall protocol
36
        """
37

38
        self.atoms = new_atoms.copy()
×
39
        if finished_stage:
×
40
            self.stage += 1
×
41
            self.timestep = 0
×
42
        else:
43
            self.timestep += steps_completed
×
44

45

46
@dataclass
1✔
47
class DynamicsStage:
1✔
48
    driver: type[Dynamics]
1✔
49
    """Which dynamics to run as an ASE Dynamics class"""
1✔
50
    timesteps: int | None = None
1✔
51
    """Maximum number of timesteps to run.
1✔
52

53
    Use ``None`` to run until :attr:`driver` reports the dynamics as converged"""
54
    driver_kwargs: dict[str, Any] = field(default_factory=dict)
1✔
55
    """Keyword arguments used to create the driver"""
1✔
56
    run_kwargs: dict[str, Any] = field(default_factory=dict)
1✔
57
    """Keyword arguments passed to the driver's run method"""
1✔
58
    post_fun: Optional[Callable[[Atoms], None]] = None
1✔
59
    """Post-processing function applied after run is complete. Modifies the input arguments"""
1✔
60

61

62
class DynamicsProtocol:
1✔
63
    """A protocol for running several stages of dynamics calls together
64

65
    Args:
66
        stages: List of dynamics to be run in sequential order
67
        scratch_dir: Directory in which to write temporary files
68
    """
69

70
    stages: list[DynamicsStage]
1✔
71
    """List of dynamics processes to run sequentially"""
1✔
72

73
    def __init__(self, stages: list[DynamicsStage], scratch_dir: Path | None = None):
1✔
74
        self.stages = stages.copy()
×
75
        self.scratch_dir = scratch_dir
×
76

77
    # TODO (wardlt): We might need to run dynamics with a physics code, which will require changing the interface
78
    def run_dynamics(self,
1✔
79
                     start: Progress,
80
                     model_msg: bytes | State,
81
                     learner: BaseLearnableForcefield,
82
                     max_timesteps: int,
83
                     max_frames: int | None = None,
84
                     device: str = None) -> tuple[bool, list[Atoms]]:
85
        """Run dynamics for a maximum number of timesteps using a particular forcefield
86

87
        Runs dynamics until the end of a process or until the maximum number of timesteps is reached.
88

89
        Args:
90
            start: Starting point of the dynamic trajectory
91
            model_msg: Serialized form of the forcefield used for training
92
            learner: Class used to generate the ASE calculator object
93
            max_timesteps: Maximum number of timesteps to run dynamics
94
            max_frames: Maximum number of frames from the atomistic simulation to return for auditing
95
            device: Device to use for evaluating forcefield
96
        Returns:
97
            List of frames selected at used for auditing the dynamics
98
        """
99

100
        # Create a temporary directory in which to run the data
101
        stage = self.stages[start.stage]  # Pick the current process
×
102
        with TemporaryDirectory(dir=self.scratch_dir, prefix='cascade-dyn_', suffix=f'_{start.name}') as tmp:
×
103
            tmp = Path(tmp)
×
104
            dyn = stage.driver(start.atoms, logfile=str(tmp / 'dyn.log'), **stage.driver_kwargs)
×
105

106
            # Attach the calculator
107
            calc = learner.make_calculator(model_msg, device)
×
108
            atoms = start.atoms
×
109
            atoms.calc = calc
×
110

111
            # Attach the trajectory writer
112
            traj_freq = 1 if max_frames is None else max_timesteps // max_frames
×
113
            traj_path = str(tmp / 'run.traj')
×
114
            with Trajectory(traj_path, mode='w', atoms=start.atoms) as traj:
×
115
                dyn.attach(traj, traj_freq)
×
116

117
                # Run dynamics, then check if we have finished
118
                converged = dyn.run(steps=max_timesteps, **stage.run_kwargs)
×
119
                total_timesteps = max_timesteps + start.timestep  # Total progress along this stage
×
120

121
                if converged and isinstance(dyn, Optimizer):  # Optimization is done if convergence is met
×
122
                    done = True
×
123
                elif isinstance(dyn, Optimizer) and stage.timesteps is not None:  # Optimization is also done if we've run out of timesteps
×
124
                    done = total_timesteps >= stage.timesteps
×
125
                else:
126
                    done = total_timesteps >= stage.timesteps
×
127

128
                if done and stage.post_fun is not None:
×
129
                    stage.post_fun(atoms)
×
130

131
            # Read in the trajectory then append the current frame to it
132
            traj_atoms = read(traj_path, ':')
×
133
            if traj_atoms[-1] != atoms:
×
134
                traj_atoms.append(atoms)
×
135

136
            return done, traj_atoms
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc