• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.41
/pymatgen/core/trajectory.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module provides classes to define a simulation trajectory, which could come from
6
either relaxation or molecular dynamics.
7
"""
8

9
from __future__ import annotations
1✔
10

11
import itertools
1✔
12
import warnings
1✔
13
from fnmatch import fnmatch
1✔
14
from pathlib import Path
1✔
15
from typing import Any, Dict, List, Sequence, Tuple, Union
1✔
16

17
import numpy as np
1✔
18
from monty.io import zopen
1✔
19
from monty.json import MSONable
1✔
20

21
from pymatgen.core.structure import (
1✔
22
    Composition,
23
    DummySpecies,
24
    Element,
25
    Lattice,
26
    Species,
27
    Structure,
28
)
29
from pymatgen.io.vasp.outputs import Vasprun, Xdatcar
1✔
30

31
__author__ = "Eric Sivonxay, Shyam Dwaraknath, Mingjian Wen"
1✔
32
__version__ = "0.1"
1✔
33
__date__ = "Jun 29, 2022"
1✔
34

35
Vector3D = Tuple[float, float, float]
1✔
36
Matrix3D = Tuple[Vector3D, Vector3D, Vector3D]
1✔
37
SitePropsType = Union[List[Dict[Any, Sequence[Any]]], Dict[Any, Sequence[Any]]]
1✔
38

39

40
class Trajectory(MSONable):
1✔
41
    """
42
    Trajectory of a relaxation or molecular dynamics simulation.
43

44
    Provides basic functions such as slicing trajectory, combining trajectories, and
45
    obtaining displacements.
46
    """
47

48
    def __init__(
1✔
49
        self,
50
        lattice: Lattice | Matrix3D | list[Lattice] | list[Matrix3D] | np.ndarray,
51
        species: list[str | Element | Species | DummySpecies | Composition],
52
        frac_coords: list[list[Vector3D]] | np.ndarray,
53
        *,
54
        site_properties: SitePropsType | None = None,
55
        frame_properties: list[dict] | None = None,
56
        constant_lattice: bool = True,
57
        time_step: int | float | None = None,
58
        coords_are_displacement: bool = False,
59
        base_positions: list[list[Vector3D]] | np.ndarray | None = None,
60
    ):
61
        """
62
        In below, `N` denotes the number of sites in the structure, and `M` denotes the
63
        number of frames in the trajectory.
64

65
        Args:
66
            lattice: shape (3, 3) or (M, 3, 3). Lattice of the structures in the
67
                trajectory; should be used together with `constant_lattice`.
68
                If `constant_lattice=True`, this should be a single lattice that is
69
                common for all structures in the trajectory (e.g. in an NVT run).
70
                If `constant_lattice=False`, this should be a list of lattices,
71
                each for one structure in the trajectory (e.g. in an NPT run or a
72
                relaxation that allows changing the cell size).
73
            species: shape (N,). List of species on each site. Can take in flexible
74
                input, including:
75
                i.  A sequence of element / species specified either as string
76
                    symbols, e.g. ["Li", "Fe2+", "P", ...] or atomic numbers,
77
                    e.g., (3, 56, ...) or actual Element or Species objects.
78
                ii. List of dict of elements/species and occupancies, e.g.,
79
                    [{"Fe" : 0.5, "Mn":0.5}, ...]. This allows the setup of
80
                    disordered structures.
81
            frac_coords: shape (M, N, 3). fractional coordinates of the sites.
82
            site_properties: Properties associated with the sites. This should be a
83
                list of `M` dicts for a single dict. If a list of dicts, each provides
84
                the site properties for a frame. Each value in a dict should be a
85
                sequence of length `N`, giving the properties of the `N` sites.
86
                For example, for a trajectory with `M=2` and `N=4`, the
87
                `site_properties` can be: [{"magmom":[5,5,5,5]}, {"magmom":[5,5,5,5]}].
88
                If a single dict, the site properties in the dict apply to all frames
89
                in the trajectory. For example, for a trajectory with `M=2` and `N=4`,
90
                {"magmom":[2,2,2,2]} means that, through the entire trajectory,
91
                the magmom are kept constant at 2 for all four atoms.
92
            frame_properties: Properties associated with the structure (e.g. total
93
                energy). This should be a sequence of `M` dicts, with each dict
94
                providing the properties for a frame. For example, for a trajectory with
95
                `M=2`, the `frame_properties` can be [{'energy':1.0}, {'energy':2.0}].
96
            constant_lattice: Whether the lattice changes during the simulation.
97
                Should be used together with `lattice`. See usage there.
98
            time_step: Time step of MD simulation in femto-seconds. Should be `None`
99
                for relaxation trajectory.
100
            coords_are_displacement: Whether `frac_coords` are given in displacements
101
                (True) or positions (False). Note, if this is `True`, `frac_coords`
102
                of a frame (say i) should be relative to the previous frame (i.e.
103
                i-1), but not relative to the `base_position`.
104
            base_positions: shape (N, 3). The starting positions of all atoms in the
105
                trajectory. Used to reconstruct positions when converting from
106
                displacements to positions. Only needs to be specified if
107
                `coords_are_displacement=True`. Defaults to the first index of
108
                `frac_coords` when `coords_are_displacement=False`.
109
        """
110
        if isinstance(lattice, Lattice):
1✔
111
            lattice = lattice.matrix
×
112
        elif isinstance(lattice, list) and isinstance(lattice[0], Lattice):
1✔
113
            lattice = [x.matrix for x in lattice]  # type: ignore
×
114
        lattice = np.asarray(lattice)
1✔
115

116
        if not constant_lattice and lattice.shape == (3, 3):
1✔
117
            self.lattice = np.tile(lattice, (len(frac_coords), 1, 1))
×
118
            warnings.warn(
×
119
                "Get `constant_lattice=False`, but only get a single `lattice`. "
120
                "Use this single `lattice` as the lattice for all frames."
121
            )
122
        else:
123
            self.lattice = lattice
1✔
124

125
        self.constant_lattice = constant_lattice
1✔
126

127
        if coords_are_displacement:
1✔
128
            if base_positions is None:
×
129
                warnings.warn(
×
130
                    "Without providing an array of starting positions, the positions "
131
                    "for each time step will not be available."
132
                )
133
            self.base_positions = base_positions
×
134
        else:
135
            self.base_positions = frac_coords[0]  # type: ignore[assignment]
1✔
136
        self.coords_are_displacement = coords_are_displacement
1✔
137

138
        self.species = species
1✔
139
        self.frac_coords = np.asarray(frac_coords)
1✔
140
        self.time_step = time_step
1✔
141

142
        self._check_site_props(site_properties)
1✔
143
        self.site_properties = site_properties
1✔
144

145
        self._check_frame_props(frame_properties)
1✔
146
        self.frame_properties = frame_properties
1✔
147

148
    def get_structure(self, i: int) -> Structure:
1✔
149
        """
150
        Get structure at specified index.
151

152
        Args:
153
            i: Index of structure.
154

155
        Returns:
156
            A pymatgen Structure object.
157
        """
158
        return self[i]
×
159

160
    def to_positions(self):
1✔
161
        """
162
        Convert displacements between consecutive frames into positions.
163

164
        `base_positions` and `frac_coords` should both be in fractional coords or
165
        absolute coords.
166

167
        This is the opposite operation of `to_displacements()`.
168
        """
169
        if self.coords_are_displacement:
1✔
170
            cumulative_displacements = np.cumsum(self.frac_coords, axis=0)
1✔
171
            positions = self.base_positions + cumulative_displacements
1✔
172
            self.frac_coords = positions
1✔
173
            self.coords_are_displacement = False
1✔
174

175
    def to_displacements(self):
1✔
176
        """
177
        Converts positions of trajectory into displacements between consecutive frames.
178

179
        `base_positions` and `frac_coords` should both be in fractional coords. Does
180
        not work for absolute coords because the atoms are to be wrapped into the
181
        simulation box.
182

183
        This is the opposite operation of `to_positions()`.
184
        """
185
        if not self.coords_are_displacement:
1✔
186
            displacements = np.subtract(
1✔
187
                self.frac_coords,
188
                np.roll(self.frac_coords, 1, axis=0),
189
            )
190
            displacements[0] = np.zeros(np.shape(self.frac_coords[0]))
1✔
191

192
            # Deal with PBC.
193
            # For example - If in one frame an atom has fractional coordinates of
194
            # [0, 0, 0.98] and in the next its coordinates are [0, 0, 0.01], this atom
195
            # will have moved 0.03*c, but if we only subtract the positions, we would
196
            # get a displacement vector of [0, 0, -0.97]. Therefore, we can correct for
197
            # this by adding or subtracting 1 from the value.
198
            displacements = [np.subtract(d, np.around(d)) for d in displacements]
1✔
199

200
            self.frac_coords = displacements
1✔
201
            self.coords_are_displacement = True
1✔
202

203
    def extend(self, trajectory: Trajectory):
1✔
204
        """
205
        Append a trajectory to the current one.
206

207
        The lattice, coords, and all other properties are combined.
208

209
        Args:
210
            trajectory: Trajectory to append.
211
        """
212
        if self.time_step != trajectory.time_step:
1✔
213
            raise ValueError(
×
214
                "Cannot extend trajectory. Time steps of the trajectories are "
215
                f"incompatible: {self.time_step} and {trajectory.time_step}."
216
            )
217

218
        if self.species != trajectory.species:
1✔
219
            raise ValueError(
1✔
220
                "Cannot extend trajectory. Species in the trajectories are "
221
                f"incompatible: {self.species} and {trajectory.species}."
222
            )
223

224
        # Ensure both trajectories are in positions before combining
225
        self.to_positions()
1✔
226
        trajectory.to_positions()
1✔
227

228
        self.site_properties = self._combine_site_props(
1✔
229
            self.site_properties,
230
            trajectory.site_properties,
231
            len(self),
232
            len(trajectory),
233
        )
234

235
        self.frame_properties = self._combine_frame_props(
1✔
236
            self.frame_properties,
237
            trajectory.frame_properties,
238
            len(self),
239
            len(trajectory),
240
        )
241

242
        self.lattice, self.constant_lattice = self._combine_lattice(
1✔
243
            self.lattice,
244
            trajectory.lattice,
245
            len(self),
246
            len(trajectory),
247
        )
248

249
        # Note, this should be after the other self._combine... method calls, since
250
        # len(self) is used there.
251
        self.frac_coords = np.concatenate((self.frac_coords, trajectory.frac_coords))
1✔
252

253
    def __iter__(self):
1✔
254
        """
255
        Iterator of the trajectory, yielding a pymatgen structure for each frame.
256
        """
257
        for i in range(len(self)):
1✔
258
            yield self[i]
1✔
259

260
    def __len__(self):
1✔
261
        """
262
        Number of frames in the trajectory.
263
        """
264
        return len(self.frac_coords)
1✔
265

266
    def __getitem__(self, frames: int | slice | list[int]) -> Structure | Trajectory:
1✔
267
        """
268
        Get a subset of the trajectory.
269

270
        The output depends on the type of the input `frames`. If an int is given, return
271
        a pymatgen Structure at the specified frame. If a list or a slice, return a new
272
        trajectory with a subset of frames.
273

274
        Args:
275
            frames: Indices of the trajectory to return.
276

277
        Return:
278
            Subset of trajectory
279
        """
280
        # Convert to position mode if not ready
281
        self.to_positions()
1✔
282

283
        # For integer input, return the structure at that frame
284
        if isinstance(frames, int):
1✔
285
            if frames >= len(self):
1✔
286
                raise IndexError(f"Frame index {frames} out of range.")
×
287

288
            lattice = self.lattice if self.constant_lattice else self.lattice[frames]
1✔
289

290
            return Structure(
1✔
291
                Lattice(lattice),
292
                self.species,
293
                self.frac_coords[frames],
294
                site_properties=self._get_site_props(frames),  # type: ignore
295
                to_unit_cell=True,
296
            )
297

298
        # For slice input, return a trajectory
299
        if isinstance(frames, (slice, list, np.ndarray)):
1✔
300
            if isinstance(frames, slice):
1✔
301
                start, stop, step = frames.indices(len(self))
1✔
302
                selected = list(range(start, stop, step))
1✔
303
            else:
304
                # Get rid of frames that exceed trajectory length
305
                selected = [i for i in frames if i < len(self)]
1✔
306

307
                if len(selected) < len(frames):
1✔
308
                    bad_frames = [i for i in frames if i > len(self)]
×
309
                    raise IndexError(f"Frame index {bad_frames} out of range.")
×
310

311
            lattice = self.lattice if self.constant_lattice else self.lattice[selected]
1✔
312
            frac_coords = self.frac_coords[selected]
1✔
313

314
            if self.frame_properties is not None:
1✔
315
                frame_properties = [self.frame_properties[i] for i in selected]
1✔
316
            else:
317
                frame_properties = None
1✔
318

319
            return Trajectory(
1✔
320
                lattice,
321
                self.species,
322
                frac_coords,
323
                site_properties=self._get_site_props(selected),
324
                frame_properties=frame_properties,
325
                constant_lattice=self.constant_lattice,
326
                time_step=self.time_step,
327
                coords_are_displacement=False,
328
                base_positions=self.base_positions,
329
            )
330

331
        supported = [int, slice, list or np.ndarray]
×
332
        raise ValueError(f"Expect the type of frames be one of {supported}; {type(frames)}.")
×
333

334
    def write_Xdatcar(
1✔
335
        self,
336
        filename: str | Path = "XDATCAR",
337
        system: str | None = None,
338
        significant_figures: int = 6,
339
    ):
340
        """
341
        Writes to Xdatcar file.
342

343
        The supported kwargs are the same as those for the
344
        Xdatcar_from_structs.get_string method and are passed through directly.
345

346
        Args:
347
            filename: Name of file to write.  It's prudent to end the filename with
348
                'XDATCAR', as most visualization and analysis software require this
349
                for autodetection.
350
            system: Description of system (e.g. 2D MoS2).
351
            significant_figures: Significant figures in the output file.
352
        """
353
        # Ensure trajectory is in position form
354
        self.to_positions()
1✔
355

356
        if system is None:
1✔
357
            system = f"{self[0].composition.reduced_formula}"
1✔
358

359
        lines = []
1✔
360
        format_str = f"{{:.{significant_figures}f}}"
1✔
361
        syms = [site.specie.symbol for site in self[0]]
1✔
362
        site_symbols = [a[0] for a in itertools.groupby(syms)]
1✔
363
        syms = [site.specie.symbol for site in self[0]]
1✔
364
        n_atoms = [len(tuple(a[1])) for a in itertools.groupby(syms)]
1✔
365

366
        for si, frac_coords in enumerate(self.frac_coords):
1✔
367
            # Only print out the info block if
368
            if si == 0 or not self.constant_lattice:
1✔
369
                lines.extend([system, "1.0"])
1✔
370

371
                if self.constant_lattice:
1✔
372
                    _lattice = self.lattice
1✔
373
                else:
374
                    _lattice = self.lattice[si]
1✔
375

376
                for latt_vec in _lattice:
1✔
377
                    lines.append(f'{" ".join(map(str, latt_vec))}')
1✔
378

379
                lines.append(" ".join(site_symbols))
1✔
380
                lines.append(" ".join(map(str, n_atoms)))
1✔
381

382
            lines.append(f"Direct configuration=     {si + 1}")
1✔
383

384
            for frac_coord, specie in zip(frac_coords, self.species):
1✔
385
                coords = frac_coord
1✔
386
                line = f'{" ".join(format_str.format(c) for c in coords)} {specie}'
1✔
387
                lines.append(line)
1✔
388

389
        xdatcar_string = "\n".join(lines) + "\n"
1✔
390

391
        with zopen(filename, "wt") as f:
1✔
392
            f.write(xdatcar_string)
1✔
393

394
    def as_dict(self) -> dict:
1✔
395
        """
396
        Return the trajectory as a MSONAble dict.
397
        """
398
        return {
1✔
399
            "@module": type(self).__module__,
400
            "@class": type(self).__name__,
401
            "lattice": self.lattice.tolist(),
402
            "species": self.species,
403
            "frac_coords": self.frac_coords.tolist(),
404
            "site_properties": self.site_properties,
405
            "frame_properties": self.frame_properties,
406
            "constant_lattice": self.constant_lattice,
407
            "time_step": self.time_step,
408
            "coords_are_displacement": self.coords_are_displacement,
409
            "base_positions": self.base_positions,
410
        }
411

412
    @classmethod
1✔
413
    def from_structures(
1✔
414
        cls,
415
        structures: list[Structure],
416
        constant_lattice: bool = True,
417
        **kwargs,
418
    ) -> Trajectory:
419
        """
420
        Create trajectory from a list of structures.
421

422
        Note: Assumes no atoms removed during simulation.
423

424
        Args:
425
            structures: pymatgen Structure objects.
426
            constant_lattice: Whether the lattice changes during the simulation,
427
                such as in an NPT MD simulation.
428

429
        Returns:
430
            A trajectory from the structures.
431
        """
432
        if constant_lattice:
1✔
433
            lattice = structures[0].lattice.matrix
1✔
434
        else:
435
            lattice = np.array([structure.lattice.matrix for structure in structures])
1✔
436

437
        species = structures[0].species
1✔
438
        frac_coords = [structure.frac_coords for structure in structures]
1✔
439
        site_properties = [structure.site_properties for structure in structures]
1✔
440

441
        return cls(
1✔
442
            lattice,
443
            species,  # type: ignore
444
            frac_coords,
445
            site_properties=site_properties,  # type: ignore
446
            constant_lattice=constant_lattice,
447
            **kwargs,
448
        )
449

450
    @classmethod
1✔
451
    def from_file(
1✔
452
        cls,
453
        filename: str | Path,
454
        constant_lattice: bool = True,
455
        **kwargs,
456
    ) -> Trajectory:
457
        """
458
        Create trajectory from XDATCAR or vasprun.xml file.
459

460
        Args:
461
            filename: Path to the file to read from.
462
            constant_lattice: Whether the lattice changes during the simulation,
463
                such as in an NPT MD simulation.
464

465
        Returns:
466
            A trajectory from the file.
467
        """
468
        fname = Path(filename).expanduser().resolve().name
1✔
469

470
        if fnmatch(fname, "*XDATCAR*"):
1✔
471
            structures = Xdatcar(filename).structures
1✔
472
        elif fnmatch(fname, "vasprun*.xml*"):
×
473
            structures = Vasprun(filename).structures
×
474
        else:
475
            supported = ("XDATCAR", "vasprun.xml")
×
476
            raise ValueError(f"Expect file to be one of {supported}; got {filename}.")
×
477

478
        return cls.from_structures(
1✔
479
            structures,
480
            constant_lattice=constant_lattice,
481
            **kwargs,
482
        )
483

484
    @staticmethod
1✔
485
    def _combine_lattice(lat1: np.ndarray, lat2: np.ndarray, len1: int, len2: int) -> tuple[np.ndarray, bool]:
1✔
486
        """
487
        Helper function to combine trajectory lattice.
488
        """
489
        if lat1.ndim == lat2.ndim == 2:
1✔
490
            constant_lat = True
1✔
491
            lat = lat1
1✔
492
        else:
493
            constant_lat = False
×
494
            if lat1.ndim == 2:
×
495
                lat1 = np.tile(lat1, (len1, 1, 1))
×
496
            if lat2.ndim == 2:
×
497
                lat2 = np.tile(lat2, (len2, 1, 1))
×
498
            lat = np.concatenate((lat1, lat2))
×
499

500
        return lat, constant_lat
1✔
501

502
    @staticmethod
1✔
503
    def _combine_site_props(
1✔
504
        prop1: SitePropsType | None, prop2: SitePropsType | None, len1: int, len2: int
505
    ) -> SitePropsType | None:
506
        """
507
        Combine site properties.
508

509
        Either one of prop1 or prop2 can be None, dict, or a list of dict. All
510
        possibilities of combining them are considered.
511
        """
512
        # special cases
513

514
        if prop1 is None and prop2 is None:
1✔
515
            return None
1✔
516

517
        if isinstance(prop1, dict) and prop1 == prop2:
1✔
518
            return prop1
1✔
519

520
        # general case
521

522
        assert prop1 is None or isinstance(prop1, (list, dict))
1✔
523
        assert prop2 is None or isinstance(prop2, (list, dict))
1✔
524

525
        p1_candidates = {
1✔
526
            "NoneType": [None] * len1,
527
            "dict": [prop1] * len1,
528
            "list": prop1,
529
        }
530
        p2_candidates = {
1✔
531
            "NoneType": [None] * len2,
532
            "dict": [prop2] * len2,
533
            "list": prop2,
534
        }
535
        p1_selected: list = p1_candidates[type(prop1).__name__]  # type: ignore
1✔
536
        p2_selected: list = p2_candidates[type(prop2).__name__]  # type: ignore
1✔
537

538
        return p1_selected + p2_selected
1✔
539

540
    @staticmethod
1✔
541
    def _combine_frame_props(prop1: list[dict] | None, prop2: list[dict] | None, len1: int, len2: int) -> list | None:
1✔
542
        """
543
        Combine frame properties.
544
        """
545
        if prop1 is None and prop2 is None:
1✔
546
            return None
1✔
547
        if prop1 is None:
1✔
548
            return [None] * len1 + list(prop2)  # type: ignore
×
549
        if prop2 is None:
1✔
550
            return list(prop1) + [None] * len2  # type: ignore
1✔
551
        return list(prop1) + list(prop2)  # type:ignore
1✔
552

553
    def _check_site_props(self, site_props: SitePropsType | None):
1✔
554
        """
555
        Check data shape of site properties.
556
        """
557
        if site_props is None:
1✔
558
            return
1✔
559

560
        if isinstance(site_props, dict):
1✔
561
            site_props = [site_props]
1✔
562
        else:
563
            assert len(site_props) == len(
1✔
564
                self
565
            ), f"Size of the site properties {len(site_props)} does not equal to the number of frames {len(self)}."
566

567
        num_sites = len(self.frac_coords[0])
1✔
568
        for d in site_props:
1✔
569
            for k, v in d.items():
1✔
570
                assert len(v) == num_sites, (
1✔
571
                    f"Size of site property {k} {len(v)}) does not equal to the "
572
                    f"number of sites in the structure {num_sites}."
573
                )
574

575
    def _check_frame_props(self, frame_props: list[dict] | None):
1✔
576
        """
577
        Check data shape of site properties.
578
        """
579
        if frame_props is None:
1✔
580
            return
1✔
581

582
        assert len(frame_props) == len(
1✔
583
            self
584
        ), f"Size of the frame properties {len(frame_props)} does not equal to the number of frames {len(self)}."
585

586
    def _get_site_props(self, frames: int | list[int]) -> SitePropsType | None:
1✔
587
        """
588
        Slice site properties.
589
        """
590
        if self.site_properties is None:
1✔
591
            return None
1✔
592
        if isinstance(self.site_properties, dict):
1✔
593
            return self.site_properties
×
594
        if isinstance(self.site_properties, list):
1✔
595
            if isinstance(frames, int):
1✔
596
                return self.site_properties[frames]
1✔
597
            if isinstance(frames, list):
1✔
598
                return [self.site_properties[i] for i in frames]
1✔
599
            raise ValueError("Unexpected frames type.")
×
600
        raise ValueError("Unexpected site_properties type.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc