• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 13969502227

20 Mar 2025 12:38PM UTC coverage: 92.434% (-0.2%) from 92.614%
13969502227

Pull #470

github

web-flow
Merge 50e7c8845 into 41627eef8
Pull Request #470: add weas for graphs

2712 of 2934 relevant lines covered (92.43%)

2.77 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.98
/janus_core/processing/post_process.py
1
"""Module for post-processing trajectories."""
2

3
from __future__ import annotations
3✔
4

5
from collections.abc import Sequence
3✔
6
from itertools import combinations_with_replacement
3✔
7

8
from ase import Atoms
3✔
9
from ase.geometry.analysis import Analysis
3✔
10
import numpy as np
3✔
11
from numpy import float64
3✔
12
from numpy.typing import NDArray
3✔
13

14
from janus_core.helpers.janus_types import (
3✔
15
    MaybeSequence,
16
    PathLike,
17
    SliceLike,
18
)
19
from janus_core.helpers.utils import build_file_dir, slicelike_to_startstopstep
3✔
20

21

22
def compute_rdf(
3✔
23
    data: MaybeSequence[Atoms],
24
    ana: Analysis | None = None,
25
    /,
26
    *,
27
    filenames: MaybeSequence[PathLike] | None = None,
28
    by_elements: bool = False,
29
    rmax: float = 2.5,
30
    nbins: int = 50,
31
    elements: MaybeSequence[int | str] | None = None,
32
    index: SliceLike = (0, None, 1),
33
    volume: float | None = None,
34
) -> NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]:
35
    """
36
    Compute the rdf of data.
37

38
    Parameters
39
    ----------
40
    data
41
        Dataset to compute RDF of.
42
    ana
43
        ASE Analysis object for data reuse.
44
    filenames
45
        Filenames to output data to. Must match number of RDFs computed.
46
    by_elements
47
        Split RDF into pairwise by elements group. Default is False.
48
        N.B. mixed RDFs (e.g. C-H) include all self-RDFs (e.g. C-C),
49
        to get the pure (C-H) RDF subtract the self-RDFs.
50
    rmax
51
        Maximum distance of RDF.
52
    nbins
53
        Number of bins to divide RDF.
54
    elements
55
        Make partial RDFs. If `by_elements` is true will filter to
56
        only display pairs in list.
57
    index
58
        Images to analyze as:
59
        `index` if `int`,
60
        `start`, `stop`, `step` if `tuple`,
61
        `slice` if `slice` or `range`.
62
    volume
63
        Volume of cell for normalisation. Only needs to be provided
64
        if aperiodic cell. Default is (2*rmax)**3.
65

66
    Returns
67
    -------
68
    NDArray[float64] | dict[tuple[str, str] | NDArray[float64]]
69
        If `by_elements` is true returns a `dict` of RDF by element pairs.
70
        Otherwise returns RDF of total system filtered by elements.
71
    """
72
    index = slicelike_to_startstopstep(index)
3✔
73

74
    if not isinstance(data, Sequence):
3✔
75
        data = [data]
3✔
76

77
    if elements is not None and not isinstance(elements, Sequence):
3✔
78
        elements = (elements,)
×
79

80
    if (  # If aperiodic, assume volume of a cube encompassing rmax sphere.
3✔
81
        not all(data[0].get_pbc()) and volume is None
82
    ):
83
        volume = (2 * rmax) ** 3
3✔
84

85
    if ana is None:
3✔
86
        ana = Analysis(data)
3✔
87

88
    if by_elements:
3✔
89
        elements = (
3✔
90
            tuple(sorted(set(data[0].get_chemical_symbols())))
91
            if elements is None
92
            else elements
93
        )
94

95
        rdf = {
3✔
96
            element: ana.get_rdf(
97
                rmax=rmax,
98
                nbins=nbins,
99
                elements=element,
100
                imageIdx=slice(*index),
101
                return_dists=True,
102
                volume=volume,
103
            )
104
            for element in combinations_with_replacement(elements, 2)
105
        }
106

107
        # Compute RDF average
108
        rdf = {
3✔
109
            element: (rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0))
110
            for element, rdf in rdf.items()
111
        }
112

113
        if filenames is not None:
3✔
114
            if isinstance(filenames, str) or not isinstance(filenames, Sequence):
×
115
                filenames = (filenames,)
×
116

117
            assert isinstance(filenames, Sequence)
×
118

119
            if len(filenames) != len(rdf):
×
120
                raise ValueError(
×
121
                    f"Different number of file names ({len(filenames)}) "
122
                    f"to number of samples ({len(rdf)})"
123
                )
124

125
            for (dists, rdfs), out_path in zip(rdf.values(), filenames, strict=True):
×
126
                build_file_dir(out_path)
×
127
                with open(out_path, "w", encoding="utf-8") as out_file:
×
128
                    for dist, rdf_i in zip(dists, rdfs, strict=True):
×
129
                        print(dist, rdf_i, file=out_file)
×
130

131
    else:
132
        rdf = ana.get_rdf(
3✔
133
            rmax=rmax,
134
            nbins=nbins,
135
            elements=elements,
136
            imageIdx=slice(*index),
137
            return_dists=True,
138
            volume=volume,
139
        )
140

141
        assert isinstance(rdf, list)
3✔
142

143
        # Compute RDF average
144
        rdf = rdf[0][1], np.average([rdf_i[0] for rdf_i in rdf], axis=0)
3✔
145

146
        if filenames is not None:
3✔
147
            if isinstance(filenames, Sequence):
3✔
148
                if len(filenames) != 1:
3✔
149
                    raise ValueError(
×
150
                        f"Different number of file names ({len(filenames)}) "
151
                        "to number of samples (1)"
152
                    )
153
                filenames = filenames[0]
3✔
154

155
            build_file_dir(filenames)
3✔
156
            with open(filenames, "w", encoding="utf-8") as out_file:
3✔
157
                for dist, rdf_i in zip(*rdf, strict=True):
3✔
158
                    print(dist, rdf_i, file=out_file)
3✔
159

160
    return rdf
3✔
161

162

163
def compute_vaf(
3✔
164
    data: Sequence[Atoms],
165
    filenames: MaybeSequence[PathLike] | None = None,
166
    *,
167
    use_velocities: bool = False,
168
    fft: bool = False,
169
    index: SliceLike = (0, None, 1),
170
    atoms_filter: MaybeSequence[MaybeSequence[int | str | None]] = ((None,),),
171
    time_step: float = 1.0,
172
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
173
    """
174
    Compute the velocity autocorrelation function (VAF) of `data`.
175

176
    Parameters
177
    ----------
178
    data
179
        Dataset to compute VAF of.
180
    filenames
181
        If present, dump resultant VAF to file.
182
    use_velocities
183
        Compute VAF using velocities rather than momenta.
184
        Default is False.
185
    fft
186
        Compute the fourier transformed VAF.
187
        Default is False.
188
    index
189
        Images to analyze as `start`, `stop`, `step`.
190
        Default is all images.
191
    atoms_filter
192
        Compute the VAF averaged over subsets of the system.
193
        Default is all atoms.
194
    time_step
195
        Time step for scaling lags to align with input data.
196
        Default is 1 (i.e. no scaling).
197

198
    Returns
199
    -------
200
    lags : numpy.ndarray
201
        Lags at which the VAFs have been computed.
202
    vafs : list[numpy.ndarray]
203
        Computed VAF(s).
204

205
    Notes
206
    -----
207
    `atoms_filter` is given as a series of sequences of atoms or elements,
208
    where each value in the series denotes a VAF subset to calculate and
209
    each sequence determines the atoms (by index or element)
210
    to be included in that VAF.
211

212
    E.g.
213

214
    .. code-block: Python
215

216
        # Species indices in cell
217
        na = (1, 3, 5, 7)
218
        # Species by name
219
        cl = ('Cl')
220

221
        compute_vaf(..., atoms_filter=(na, cl))
222

223
    Would compute separate VAFs for each species.
224

225
    By default, one VAF will be computed for all atoms in the structure.
226
    """
227
    # Ensure if passed scalars they are turned into correct dimensionality
228
    if isinstance(atoms_filter, str) or not isinstance(atoms_filter, Sequence):
3✔
229
        atoms_filter = (atoms_filter,)
3✔
230
    if isinstance(atoms_filter[0], str) or not isinstance(atoms_filter[0], Sequence):
3✔
231
        atoms_filter = (atoms_filter,)
3✔
232
    if filenames and not isinstance(filenames, Sequence):
3✔
233
        filenames = (filenames,)
×
234

235
        if len(filenames) != len(atoms_filter):
×
236
            raise ValueError(
×
237
                f"Different number of file names ({len(filenames)}) "
238
                f"to number of samples ({len(atoms_filter)})"
239
            )
240

241
    # Extract requested data
242
    index = slicelike_to_startstopstep(index)
3✔
243
    data = data[slice(*index)]
3✔
244

245
    if use_velocities:
3✔
246
        momenta = np.asarray([datum.get_velocities() for datum in data])
3✔
247
    else:
248
        momenta = np.asarray([datum.get_momenta() for datum in data])
3✔
249

250
    n_steps = len(momenta)
3✔
251
    n_atoms = len(momenta[0])
3✔
252

253
    filtered_atoms = []
3✔
254
    atom_symbols = data[0].get_chemical_symbols()
3✔
255
    symbols = set(atom_symbols)
3✔
256
    for atoms in atoms_filter:
3✔
257
        if any(atom is None for atom in atoms):
3✔
258
            # If atoms_filter not specified use all atoms.
259
            filtered_atoms.append(range(n_atoms))
3✔
260
        elif all(isinstance(a, str) for a in atoms):
3✔
261
            # If all symbols, get the matching indices.
262
            atoms = set(atoms)
3✔
263
            if atoms.difference(symbols):
3✔
264
                raise ValueError(
3✔
265
                    f"{atoms.difference(symbols)} not allowed in VAF"
266
                    f", allowed symbols are {symbols}"
267
                )
268
            filtered_atoms.append(
3✔
269
                [i for i in range(len(atom_symbols)) if atom_symbols[i] in list(atoms)]
270
            )
271
        elif all(isinstance(a, int) for a in atoms):
3✔
272
            filtered_atoms.append(atoms)
3✔
273
        else:
274
            raise ValueError(
×
275
                "Cannot mix element symbols and indices in vaf atoms_filter"
276
            )
277

278
    used_atoms = {atom for atoms in filtered_atoms for atom in atoms}
3✔
279
    used_atoms = {j: i for i, j in enumerate(used_atoms)}
3✔
280

281
    vafs = np.sum(
3✔
282
        np.asarray(
283
            [
284
                [
285
                    np.correlate(momenta[:, j, i], momenta[:, j, i], "full")[
286
                        n_steps - 1 :
287
                    ]
288
                    for i in range(3)
289
                ]
290
                for j in used_atoms
291
            ]
292
        ),
293
        axis=1,
294
    )
295

296
    vafs /= n_steps - np.arange(n_steps)
3✔
297

298
    lags = np.arange(n_steps) * time_step
3✔
299

300
    if fft:
3✔
301
        vafs = np.fft.fft(vafs, axis=0)
3✔
302
        lags = np.fft.fftfreq(n_steps, time_step)
3✔
303

304
    vafs = (
3✔
305
        lags,
306
        [
307
            np.average([vafs[used_atoms[i]] for i in atoms], axis=0)
308
            for atoms in filtered_atoms
309
        ],
310
    )
311
    if filenames:
3✔
312
        for vaf, filename in zip(vafs[1], filenames, strict=True):
3✔
313
            build_file_dir(filename)
3✔
314
            with open(filename, "w", encoding="utf-8") as out_file:
3✔
315
                for lag, dat in zip(lags, vaf, strict=True):
3✔
316
                    print(lag, dat, file=out_file)
3✔
317

318
    return vafs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc