• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13183562836

06 Feb 2025 04:38PM UTC coverage: 95.078% (-0.1%) from 95.207%
13183562836

Pull #400

github

web-flow
Merge 7a0241e8f into 47569ebf4
Pull Request #400: Adds post-process cli guide, specifying vafs by element name

20 of 22 new or added lines in 3 files covered. (90.91%)

3 existing lines in 2 files now uncovered.

2318 of 2438 relevant lines covered (95.08%)

2.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.76
/janus_core/processing/post_process.py
1
"""Module for post-processing trajectories."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from itertools import combinations_with_replacement
3✔
7

8
from ase import Atoms
3✔
9
from ase.geometry.analysis import Analysis
3✔
10
import numpy as np
3✔
11
from numpy import float64
3✔
12
from numpy.typing import NDArray
3✔
13

14
from janus_core.helpers.janus_types import (
3✔
15
    MaybeSequence,
16
    PathLike,
17
    SliceLike,
18
)
19
from janus_core.helpers.utils import slicelike_to_startstopstep
3✔
20

21

22
def compute_rdf(
3✔
23
    data: MaybeSequence[Atoms],
24
    ana: Analysis | None = None,
25
    /,
26
    *,
27
    filenames: MaybeSequence[PathLike] | None = None,
28
    by_elements: bool = False,
29
    rmax: float = 2.5,
30
    nbins: int = 50,
31
    elements: MaybeSequence[int | str] | None = None,
32
    index: SliceLike = (0, None, 1),
33
    volume: float | None = None,
34
) -> NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]:
35
    """
36
    Compute the rdf of data.
37

38
    Parameters
39
    ----------
40
    data
41
        Dataset to compute RDF of.
42
    ana
43
        ASE Analysis object for data reuse.
44
    filenames
45
        Filenames to output data to. Must match number of RDFs computed.
46
    by_elements
47
        Split RDF into pairwise by elements group. Default is False.
48
        N.B. mixed RDFs (e.g. C-H) include all self-RDFs (e.g. C-C),
49
        to get the pure (C-H) RDF subtract the self-RDFs.
50
    rmax
51
        Maximum distance of RDF.
52
    nbins
53
        Number of bins to divide RDF.
54
    elements
55
        Make partial RDFs. If `by_elements` is true will filter to
56
        only display pairs in list.
57
    index
58
        Images to analyze as:
59
        `index` if `int`,
60
        `start`, `stop`, `step` if `tuple`,
61
        `slice` if `slice` or `range`.
62
    volume
63
        Volume of cell for normalisation. Only needs to be provided
64
        if aperiodic cell. Default is (2*rmax)**3.
65

66
    Returns
67
    -------
68
    NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]
69
        If `by_elements` is true returns a `dict` of RDF by element pairs.
70
        Otherwise returns RDF of total system filtered by elements.
71
    """
72
    index = slicelike_to_startstopstep(index)
3✔
73

74
    if not isinstance(data, Sequence):
3✔
75
        data = [data]
3✔
76

77
    if elements is not None and not isinstance(elements, Sequence):
3✔
78
        elements = (elements,)
×
79

80
    if (  # If aperiodic, assume volume of a cube encompassing rmax sphere.
3✔
81
        not all(data[0].get_pbc()) and volume is None
82
    ):
83
        volume = (2 * rmax) ** 3
3✔
84

85
    if ana is None:
3✔
86
        ana = Analysis(data)
3✔
87

88
    if by_elements:
3✔
89
        elements = (
3✔
90
            tuple(sorted(set(data[0].get_chemical_symbols())))
91
            if elements is None
92
            else elements
93
        )
94

95
        rdf = {
3✔
96
            element: ana.get_rdf(
97
                rmax=rmax,
98
                nbins=nbins,
99
                elements=element,
100
                imageIdx=slice(*index),
101
                return_dists=True,
102
                volume=volume,
103
            )
104
            for element in combinations_with_replacement(elements, 2)
105
        }
106

107
        # Compute RDF average
108
        rdf = {
3✔
109
            element: (rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0))
110
            for element, rdf in rdf.items()
111
        }
112

113
        if filenames is not None:
3✔
114
            if isinstance(filenames, str) or not isinstance(filenames, Sequence):
×
115
                filenames = (filenames,)
×
116

117
            assert isinstance(filenames, Sequence)
×
118

119
            if len(filenames) != len(rdf):
×
120
                raise ValueError(
×
121
                    f"Different number of file names ({len(filenames)}) "
122
                    f"to number of samples ({len(rdf)})"
123
                )
124

125
            for (dists, rdfs), out_path in zip(rdf.values(), filenames, strict=True):
×
126
                with open(out_path, "w", encoding="utf-8") as out_file:
×
127
                    for dist, rdf_i in zip(dists, rdfs, strict=True):
×
128
                        print(dist, rdf_i, file=out_file)
×
129

130
    else:
131
        rdf = ana.get_rdf(
3✔
132
            rmax=rmax,
133
            nbins=nbins,
134
            elements=elements,
135
            imageIdx=slice(*index),
136
            return_dists=True,
137
            volume=volume,
138
        )
139

140
        assert isinstance(rdf, list)
3✔
141

142
        # Compute RDF average
143
        rdf = rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0)
3✔
144

145
        if filenames is not None:
3✔
146
            if isinstance(filenames, Sequence):
3✔
147
                if len(filenames) != 1:
3✔
148
                    raise ValueError(
×
149
                        f"Different number of file names ({len(filenames)}) "
150
                        "to number of samples (1)"
151
                    )
152
                filenames = filenames[0]
3✔
153

154
            with open(filenames, "w", encoding="utf-8") as out_file:
3✔
155
                for dist, rdf_i in zip(*rdf, strict=True):
3✔
156
                    print(dist, rdf_i, file=out_file)
3✔
157

158
    return rdf
3✔
159

160

161
def compute_vaf(
3✔
162
    data: Sequence[Atoms],
163
    filenames: MaybeSequence[PathLike] | None = None,
164
    *,
165
    use_velocities: bool = False,
166
    fft: bool = False,
167
    index: SliceLike = (0, None, 1),
168
    filter_atoms: MaybeSequence[MaybeSequence[int | str | None]] = ((None,),),
169
    time_step: float = 1.0,
170
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
171
    """
172
    Compute the velocity autocorrelation function (VAF) of `data`.
173

174
    Parameters
175
    ----------
176
    data
177
        Dataset to compute VAF of.
178
    filenames
179
        If present, dump resultant VAF to file.
180
    use_velocities
181
        Compute VAF using velocities rather than momenta.
182
        Default is False.
183
    fft
184
        Compute the fourier transformed VAF.
185
        Default is False.
186
    index
187
        Images to analyze as `start`, `stop`, `step`.
188
        Default is all images.
189
    filter_atoms
190
        Compute the VAF averaged over subsets of the system.
191
        Default is all atoms.
192
    time_step
193
        Time step for scaling lags to align with input data.
194
        Default is 1 (i.e. no scaling).
195

196
    Returns
197
    -------
198
    lags : numpy.ndarray
199
        Lags at which the VAFs have been computed.
200
    vafs : list[numpy.ndarray]
201
        Computed VAF(s).
202

203
    Notes
204
    -----
205
    `filter_atoms` is given as a series of sequences of atoms or elements,
206
    where each value in the series denotes a VAF subset to calculate and
207
    each sequence determines the atoms (by index or element)
208
    to be included in that VAF.
209

210
    E.g.
211

212
    .. code-block: Python
213

214
        # Species indices in cell
215
        na = (1, 3, 5, 7)
216
        # Species by name
217
        cl = ('Cl')
218

219
        compute_vaf(..., filter_atoms=(na, cl))
220

221
    Would compute separate VAFs for each species.
222

223
    By default, one VAF will be computed for all atoms in the structure.
224
    """
225
    # Ensure if passed scalars they are turned into correct dimensionality
226
    if isinstance(filter_atoms, str) or not isinstance(filter_atoms, Sequence):
3✔
227
        filter_atoms = (filter_atoms,)
3✔
228
    if isinstance(filter_atoms[0], str) or not isinstance(filter_atoms[0], Sequence):
3✔
229
        filter_atoms = (filter_atoms,)
3✔
230
    if filenames and not isinstance(filenames, Sequence):
3✔
UNCOV
231
        filenames = (filenames,)
×
232

NEW
233
        if len(filenames) != len(filter_atoms):
×
234
            raise ValueError(
×
235
                f"Different number of file names ({len(filenames)}) "
236
                f"to number of samples ({len(filter_atoms)})"
237
            )
238

239
    # Extract requested data
240
    index = slicelike_to_startstopstep(index)
3✔
241
    data = data[slice(*index)]
3✔
242

243
    if use_velocities:
3✔
244
        momenta = np.asarray([datum.get_velocities() for datum in data])
3✔
245
    else:
246
        momenta = np.asarray([datum.get_momenta() for datum in data])
3✔
247

248
    n_steps = len(momenta)
3✔
249
    n_atoms = len(momenta[0])
3✔
250

251
    filtered_atoms = []
3✔
252
    symbols = data[0].get_chemical_symbols()
3✔
253
    for atoms in filter_atoms:
3✔
254
        if any(atom is None for atom in atoms):
3✔
255
            # If filter_atoms not specified use all atoms.
256
            filtered_atoms.append(range(n_atoms))
3✔
257
        elif all(isinstance(a, str) for a in atoms):
3✔
258
            # If all symbols, get the matching indices.
259
            filtered_atoms.append(
3✔
260
                [i for i in range(len(symbols)) if symbols[i] in atoms]
261
            )
262
        elif all(isinstance(a, int) for a in atoms):
3✔
263
            filtered_atoms.append(atoms)
3✔
264
        else:
NEW
265
            raise ValueError(
×
266
                "Cannot mix element symbols and indices in vaf filter_atoms"
267
            )
268

269
    used_atoms = {atom for atoms in filtered_atoms for atom in atoms}
3✔
270
    used_atoms = {j: i for i, j in enumerate(used_atoms)}
3✔
271

272
    vafs = np.sum(
3✔
273
        np.asarray(
274
            [
275
                [
276
                    np.correlate(momenta[:, j, i], momenta[:, j, i], "full")[
277
                        n_steps - 1 :
278
                    ]
279
                    for i in range(3)
280
                ]
281
                for j in used_atoms
282
            ]
283
        ),
284
        axis=1,
285
    )
286

287
    vafs /= n_steps - np.arange(n_steps)
3✔
288

289
    lags = np.arange(n_steps) * time_step
3✔
290

291
    if fft:
3✔
292
        vafs = np.fft.fft(vafs, axis=0)
3✔
293
        lags = np.fft.fftfreq(n_steps, time_step)
3✔
294

295
    vafs = (
3✔
296
        lags,
297
        [
298
            np.average([vafs[used_atoms[i]] for i in atoms], axis=0)
299
            for atoms in filtered_atoms
300
        ],
301
    )
302
    if filenames:
3✔
303
        for vaf, filename in zip(vafs[1], filenames, strict=True):
3✔
304
            with open(filename, "w", encoding="utf-8") as out_file:
3✔
305
                for lag, dat in zip(lags, vaf, strict=True):
3✔
306
                    print(lag, dat, file=out_file)
3✔
307

308
    return vafs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc