• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

globus-labs / cascade / 18732618520

22 Oct 2025 11:21PM UTC coverage: 25.34% (-70.0%) from 95.34%
18732618520

Pull #70

github

miketynes
fix init, logging init, waiting
Pull Request #70: Academy proto

261 of 1030 relevant lines covered (25.34%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.07
/cascade/learning/torchani/ase.py
1
# -*- coding: utf-8 -*-
2
"""Tools for interfacing with `ASE`_.
3

4
.. _ASE:
5
    https://wiki.fysik.dtu.dk/ase
6
"""
7

8
# TODO (wardlt): TorchANI is in a code freeze as they refactor, so I'm copying my alterations over to here
9
#  They are originally from https://github.com/aiqm/torchani/blob/17204c6dccf6210753bc8c0ca4c92278b60719c9/torchani/ase.py
10

11
import torch
1✔
12
from ase.calculators.calculator import all_properties
1✔
13

14
from torchani import utils
1✔
15
import ase.calculators.calculator
1✔
16
import ase.units
1✔
17

18

19
class Calculator(ase.calculators.calculator.Calculator):
1✔
20
    """TorchANI calculator for ASE
21

22
    Arguments:
23
        species (:class:`collections.abc.Sequence` of :class:`str`):
24
            sequence of all supported species, in order.
25
        model (:class:`torch.nn.Module`): neural network potential model
26
            that convert coordinates into energies.
27
        overwrite (bool): After wrapping atoms into central box, whether
28
            to replace the original positions stored in :class:`ase.Atoms`
29
            object with the wrapped positions.
30
    """
31

32
    implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
1✔
33

34
    def __init__(self, species, model, overwrite=False):
1✔
35
        super().__init__()
×
36
        self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
×
37
        self.model = model
×
38
        # Since ANI is used in inference mode, no gradients on model parameters are required here
39
        for p in self.model.parameters():
×
40
            p.requires_grad_(False)
×
41
        self.overwrite = overwrite
×
42

43
        a_parameter = next(self.model.parameters())
×
44
        self.device = a_parameter.device
×
45
        self.dtype = a_parameter.dtype
×
46
        try:
×
47
            # We assume that the model has a "periodic_table_index" attribute
48
            # if it doesn't we set the calculator's attribute to false and we
49
            # assume that species will be correctly transformed by
50
            # species_to_tensor
51
            self.periodic_table_index = model.periodic_table_index
×
52
        except AttributeError:
×
53
            self.periodic_table_index = False
×
54

55
    def calculate(self, atoms=None, properties=all_properties,
1✔
56
                  system_changes=ase.calculators.calculator.all_changes):
57
        super().calculate(atoms, properties, system_changes)
×
58

59
        self.results.clear()  # Removes any previous calculation results <- LW Change
×
60

61
        cell = torch.tensor(self.atoms.get_cell(complete=True).array,
×
62
                            dtype=self.dtype, device=self.device)
63
        pbc = torch.tensor(self.atoms.get_pbc(), dtype=torch.bool,
×
64
                           device=self.device)
65
        pbc_enabled = pbc.any().item()
×
66

67
        if self.periodic_table_index:
×
68
            species = torch.tensor(self.atoms.get_atomic_numbers(), dtype=torch.long, device=self.device)
×
69
        else:
70
            species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
×
71

72
        species = species.unsqueeze(0)
×
73
        coordinates = torch.tensor(self.atoms.get_positions())
×
74
        coordinates = coordinates.to(self.device).to(self.dtype) \
×
75
                                 .requires_grad_('forces' in properties)
76

77
        if pbc_enabled:
×
78
            coordinates = utils.map2central(cell, coordinates, pbc)
×
79
            if self.overwrite and atoms is not None:
×
80
                atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
×
81

82
        if 'stress' in properties:
×
83
            scaling = torch.eye(3, requires_grad=True, dtype=self.dtype, device=self.device)
×
84
            coordinates = coordinates @ scaling
×
85
        coordinates = coordinates.unsqueeze(0)
×
86

87
        if pbc_enabled:
×
88
            if 'stress' in properties:
×
89
                cell = cell @ scaling
×
90
            energy = self.model((species, coordinates), cell=cell, pbc=pbc).energies
×
91
        else:
92
            energy = self.model((species, coordinates)).energies
×
93

94
        energy *= ase.units.Hartree
×
95
        self.results['energy'] = energy.item()
×
96
        self.results['free_energy'] = energy.item()
×
97

98
        if 'forces' in properties:
×
99
            forces = -torch.autograd.grad(energy.squeeze(), coordinates, retain_graph='stress' in properties)[0]
×
100
            self.results['forces'] = forces.squeeze(0).to('cpu').numpy()
×
101

102
        if 'stress' in properties:
×
103
            volume = self.atoms.get_volume()
×
104
            stress = torch.autograd.grad(energy.squeeze(), scaling)[0] / volume
×
105
            self.results['stress'] = stress.cpu().numpy()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc