• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13289975070

12 Feb 2025 04:20PM UTC coverage: 94.236%. First build
13289975070

push

github

web-flow
Adds post-process cli guide, specifying vafs by element name (#400)

* Adds post-process cli guide, specifying vafs by element name

* filter_atoms -> atoms_filter

* Check allowed symbols

* Comment on traj in user guide

* Add RDF user guide example

* Seperate invalid symbol test, check passing just str throws error

---------

Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
Co-authored-by: Alin Marin Elena <alin@elena.re>

24 of 26 new or added lines in 3 files covered. (92.31%)

2338 of 2481 relevant lines covered (94.24%)

2.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.52
/janus_core/processing/post_process.py
1
"""Module for post-processing trajectories."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from itertools import combinations_with_replacement
3✔
7

8
from ase import Atoms
3✔
9
from ase.geometry.analysis import Analysis
3✔
10
import numpy as np
3✔
11
from numpy import float64
3✔
12
from numpy.typing import NDArray
3✔
13

14
from janus_core.helpers.janus_types import (
3✔
15
    MaybeSequence,
16
    PathLike,
17
    SliceLike,
18
)
19
from janus_core.helpers.utils import slicelike_to_startstopstep
3✔
20

21

22
def compute_rdf(
3✔
23
    data: MaybeSequence[Atoms],
24
    ana: Analysis | None = None,
25
    /,
26
    *,
27
    filenames: MaybeSequence[PathLike] | None = None,
28
    by_elements: bool = False,
29
    rmax: float = 2.5,
30
    nbins: int = 50,
31
    elements: MaybeSequence[int | str] | None = None,
32
    index: SliceLike = (0, None, 1),
33
    volume: float | None = None,
34
) -> NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]:
35
    """
36
    Compute the rdf of data.
37

38
    Parameters
39
    ----------
40
    data
41
        Dataset to compute RDF of.
42
    ana
43
        ASE Analysis object for data reuse.
44
    filenames
45
        Filenames to output data to. Must match number of RDFs computed.
46
    by_elements
47
        Split RDF into pairwise by elements group. Default is False.
48
        N.B. mixed RDFs (e.g. C-H) include all self-RDFs (e.g. C-C),
49
        to get the pure (C-H) RDF subtract the self-RDFs.
50
    rmax
51
        Maximum distance of RDF.
52
    nbins
53
        Number of bins to divide RDF.
54
    elements
55
        Make partial RDFs. If `by_elements` is true will filter to
56
        only display pairs in list.
57
    index
58
        Images to analyze as:
59
        `index` if `int`,
60
        `start`, `stop`, `step` if `tuple`,
61
        `slice` if `slice` or `range`.
62
    volume
63
        Volume of cell for normalisation. Only needs to be provided
64
        if aperiodic cell. Default is (2*rmax)**3.
65

66
    Returns
67
    -------
68
    NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]
69
        If `by_elements` is true returns a `dict` of RDF by element pairs.
70
        Otherwise returns RDF of total system filtered by elements.
71
    """
72
    index = slicelike_to_startstopstep(index)
3✔
73

74
    if not isinstance(data, Sequence):
3✔
75
        data = [data]
3✔
76

77
    if elements is not None and not isinstance(elements, Sequence):
3✔
78
        elements = (elements,)
×
79

80
    if (  # If aperiodic, assume volume of a cube encompassing rmax sphere.
3✔
81
        not all(data[0].get_pbc()) and volume is None
82
    ):
83
        volume = (2 * rmax) ** 3
3✔
84

85
    if ana is None:
3✔
86
        ana = Analysis(data)
3✔
87

88
    if by_elements:
3✔
89
        elements = (
3✔
90
            tuple(sorted(set(data[0].get_chemical_symbols())))
91
            if elements is None
92
            else elements
93
        )
94

95
        rdf = {
3✔
96
            element: ana.get_rdf(
97
                rmax=rmax,
98
                nbins=nbins,
99
                elements=element,
100
                imageIdx=slice(*index),
101
                return_dists=True,
102
                volume=volume,
103
            )
104
            for element in combinations_with_replacement(elements, 2)
105
        }
106

107
        # Compute RDF average
108
        rdf = {
3✔
109
            element: (rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0))
110
            for element, rdf in rdf.items()
111
        }
112

113
        if filenames is not None:
3✔
114
            if isinstance(filenames, str) or not isinstance(filenames, Sequence):
×
115
                filenames = (filenames,)
×
116

117
            assert isinstance(filenames, Sequence)
×
118

119
            if len(filenames) != len(rdf):
×
120
                raise ValueError(
×
121
                    f"Different number of file names ({len(filenames)}) "
122
                    f"to number of samples ({len(rdf)})"
123
                )
124

125
            for (dists, rdfs), out_path in zip(rdf.values(), filenames, strict=True):
×
126
                with open(out_path, "w", encoding="utf-8") as out_file:
×
127
                    for dist, rdf_i in zip(dists, rdfs, strict=True):
×
128
                        print(dist, rdf_i, file=out_file)
×
129

130
    else:
131
        rdf = ana.get_rdf(
3✔
132
            rmax=rmax,
133
            nbins=nbins,
134
            elements=elements,
135
            imageIdx=slice(*index),
136
            return_dists=True,
137
            volume=volume,
138
        )
139

140
        assert isinstance(rdf, list)
3✔
141

142
        # Compute RDF average
143
        rdf = rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0)
3✔
144

145
        if filenames is not None:
3✔
146
            if isinstance(filenames, Sequence):
3✔
147
                if len(filenames) != 1:
3✔
148
                    raise ValueError(
×
149
                        f"Different number of file names ({len(filenames)}) "
150
                        "to number of samples (1)"
151
                    )
152
                filenames = filenames[0]
3✔
153

154
            with open(filenames, "w", encoding="utf-8") as out_file:
3✔
155
                for dist, rdf_i in zip(*rdf, strict=True):
3✔
156
                    print(dist, rdf_i, file=out_file)
3✔
157

158
    return rdf
3✔
159

160

161
def compute_vaf(
3✔
162
    data: Sequence[Atoms],
163
    filenames: MaybeSequence[PathLike] | None = None,
164
    *,
165
    use_velocities: bool = False,
166
    fft: bool = False,
167
    index: SliceLike = (0, None, 1),
168
    atoms_filter: MaybeSequence[MaybeSequence[int | str | None]] = ((None,),),
169
    time_step: float = 1.0,
170
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
171
    """
172
    Compute the velocity autocorrelation function (VAF) of `data`.
173

174
    Parameters
175
    ----------
176
    data
177
        Dataset to compute VAF of.
178
    filenames
179
        If present, dump resultant VAF to file.
180
    use_velocities
181
        Compute VAF using velocities rather than momenta.
182
        Default is False.
183
    fft
184
        Compute the fourier transformed VAF.
185
        Default is False.
186
    index
187
        Images to analyze as `start`, `stop`, `step`.
188
        Default is all images.
189
    atoms_filter
190
        Compute the VAF averaged over subsets of the system.
191
        Default is all atoms.
192
    time_step
193
        Time step for scaling lags to align with input data.
194
        Default is 1 (i.e. no scaling).
195

196
    Returns
197
    -------
198
    lags : numpy.ndarray
199
        Lags at which the VAFs have been computed.
200
    vafs : list[numpy.ndarray]
201
        Computed VAF(s).
202

203
    Notes
204
    -----
205
    `atoms_filter` is given as a series of sequences of atoms or elements,
206
    where each value in the series denotes a VAF subset to calculate and
207
    each sequence determines the atoms (by index or element)
208
    to be included in that VAF.
209

210
    E.g.
211

212
    .. code-block: Python
213

214
        # Species indices in cell
215
        na = (1, 3, 5, 7)
216
        # Species by name
217
        cl = ('Cl')
218

219
        compute_vaf(..., atoms_filter=(na, cl))
220

221
    Would compute separate VAFs for each species.
222

223
    By default, one VAF will be computed for all atoms in the structure.
224
    """
225
    # Ensure if passed scalars they are turned into correct dimensionality
226
    if isinstance(atoms_filter, str) or not isinstance(atoms_filter, Sequence):
3✔
227
        atoms_filter = (atoms_filter,)
3✔
228
    if isinstance(atoms_filter[0], str) or not isinstance(atoms_filter[0], Sequence):
3✔
229
        atoms_filter = (atoms_filter,)
3✔
230
    if filenames and not isinstance(filenames, Sequence):
3✔
231
        filenames = (filenames,)
×
232

NEW
233
        if len(filenames) != len(atoms_filter):
×
234
            raise ValueError(
×
235
                f"Different number of file names ({len(filenames)}) "
236
                f"to number of samples ({len(atoms_filter)})"
237
            )
238

239
    # Extract requested data
240
    index = slicelike_to_startstopstep(index)
3✔
241
    data = data[slice(*index)]
3✔
242

243
    if use_velocities:
3✔
244
        momenta = np.asarray([datum.get_velocities() for datum in data])
3✔
245
    else:
246
        momenta = np.asarray([datum.get_momenta() for datum in data])
3✔
247

248
    n_steps = len(momenta)
3✔
249
    n_atoms = len(momenta[0])
3✔
250

251
    filtered_atoms = []
3✔
252
    atom_symbols = data[0].get_chemical_symbols()
3✔
253
    symbols = set(atom_symbols)
3✔
254
    for atoms in atoms_filter:
3✔
255
        if any(atom is None for atom in atoms):
3✔
256
            # If atoms_filter not specified use all atoms.
257
            filtered_atoms.append(range(n_atoms))
3✔
258
        elif all(isinstance(a, str) for a in atoms):
3✔
259
            # If all symbols, get the matching indices.
260
            atoms = set(atoms)
3✔
261
            if atoms.difference(symbols):
3✔
262
                raise ValueError(
3✔
263
                    f"{atoms.difference(symbols)} not allowed in VAF"
264
                    f", allowed symbols are {symbols}"
265
                )
266
            filtered_atoms.append(
3✔
267
                [i for i in range(len(atom_symbols)) if atom_symbols[i] in list(atoms)]
268
            )
269
        elif all(isinstance(a, int) for a in atoms):
3✔
270
            filtered_atoms.append(atoms)
3✔
271
        else:
NEW
272
            raise ValueError(
×
273
                "Cannot mix element symbols and indices in vaf atoms_filter"
274
            )
275

276
    used_atoms = {atom for atoms in filtered_atoms for atom in atoms}
3✔
277
    used_atoms = {j: i for i, j in enumerate(used_atoms)}
3✔
278

279
    vafs = np.sum(
3✔
280
        np.asarray(
281
            [
282
                [
283
                    np.correlate(momenta[:, j, i], momenta[:, j, i], "full")[
284
                        n_steps - 1 :
285
                    ]
286
                    for i in range(3)
287
                ]
288
                for j in used_atoms
289
            ]
290
        ),
291
        axis=1,
292
    )
293

294
    vafs /= n_steps - np.arange(n_steps)
3✔
295

296
    lags = np.arange(n_steps) * time_step
3✔
297

298
    if fft:
3✔
299
        vafs = np.fft.fft(vafs, axis=0)
3✔
300
        lags = np.fft.fftfreq(n_steps, time_step)
3✔
301

302
    vafs = (
3✔
303
        lags,
304
        [
305
            np.average([vafs[used_atoms[i]] for i in atoms], axis=0)
306
            for atoms in filtered_atoms
307
        ],
308
    )
309
    if filenames:
3✔
310
        for vaf, filename in zip(vafs[1], filenames, strict=True):
3✔
311
            with open(filename, "w", encoding="utf-8") as out_file:
3✔
312
                for lag, dat in zip(lags, vaf, strict=True):
3✔
313
                    print(lag, dat, file=out_file)
3✔
314

315
    return vafs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc