• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.68
/pymatgen/phonon/plotter.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module implements plotter for DOS and band structure.
6
"""
7

8
from __future__ import annotations
1✔
9

10
import logging
1✔
11
from collections import namedtuple
1✔
12

13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
import scipy.constants as const
1✔
16
from monty.json import jsanitize
1✔
17

18
from pymatgen.electronic_structure.plotter import plot_brillouin_zone
1✔
19
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
1✔
20
from pymatgen.phonon.gruneisen import GruneisenPhononBandStructureSymmLine
1✔
21
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig_plt, pretty_plot
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25
FreqUnits = namedtuple("FreqUnits", ["factor", "label"])
1✔
26

27

28
def freq_units(units):
1✔
29
    """
30

31
    Args:
32
        units: str, accepted values: thz, ev, mev, ha, cm-1, cm^-1
33

34
    Returns:
35
        Returns conversion factor from THz to the required units and the label in the form of a namedtuple
36

37
    """
38

39
    d = {
1✔
40
        "thz": FreqUnits(1, "THz"),
41
        "ev": FreqUnits(const.value("hertz-electron volt relationship") * const.tera, "eV"),
42
        "mev": FreqUnits(
43
            const.value("hertz-electron volt relationship") * const.tera / const.milli,
44
            "meV",
45
        ),
46
        "ha": FreqUnits(const.value("hertz-hartree relationship") * const.tera, "Ha"),
47
        "cm-1": FreqUnits(
48
            const.value("hertz-inverse meter relationship") * const.tera * const.centi,
49
            "cm^{-1}",
50
        ),
51
        "cm^-1": FreqUnits(
52
            const.value("hertz-inverse meter relationship") * const.tera * const.centi,
53
            "cm^{-1}",
54
        ),
55
    }
56
    try:
1✔
57
        return d[units.lower().strip()]
1✔
58
    except KeyError:
×
59
        raise KeyError(f"Value for units `{units}` unknown\nPossible values are:\n {list(d)}")
×
60

61

62
class PhononDosPlotter:
1✔
63
    """
64
    Class for plotting phonon DOSs. Note that the interface is extremely flexible
65
    given that there are many different ways in which people want to view
66
    DOS. The typical usage is::
67

68
        # Initializes plotter with some optional args. Defaults are usually
69
        # fine,
70
        plotter = PhononDosPlotter()
71

72
        # Adds a DOS with a label.
73
        plotter.add_dos("Total DOS", dos)
74

75
        # Alternatively, you can add a dict of DOSs. This is the typical
76
        # form returned by CompletePhononDos.get_element_dos().
77

78
    """
79

80
    def __init__(self, stack=False, sigma=None):
1✔
81
        """
82
        Args:
83
            stack: Whether to plot the DOS as a stacked area graph
84
            sigma: A float specifying a standard deviation for Gaussian smearing
85
            the DOS for nicer looking plots. Defaults to None for no
86
            smearing.
87
        """
88
        self.stack = stack
1✔
89
        self.sigma = sigma
1✔
90
        self._doses = {}
1✔
91

92
    def add_dos(self, label, dos):
1✔
93
        """
94
        Adds a dos for plotting.
95

96
        Args:
97
            label:
98
                label for the DOS. Must be unique.
99
            dos:
100
                PhononDos object
101
        """
102
        densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
1✔
103
        self._doses[label] = {"frequencies": dos.frequencies, "densities": densities}
1✔
104

105
    def add_dos_dict(self, dos_dict, key_sort_func=None):
1✔
106
        """
107
        Add a dictionary of doses, with an optional sorting function for the
108
        keys.
109

110
        Args:
111
            dos_dict: dict of {label: Dos}
112
            key_sort_func: function used to sort the dos_dict keys.
113
        """
114
        if key_sort_func:
1✔
115
            keys = sorted(dos_dict, key=key_sort_func)
1✔
116
        else:
117
            keys = list(dos_dict)
×
118
        for label in keys:
1✔
119
            self.add_dos(label, dos_dict[label])
1✔
120

121
    def get_dos_dict(self):
1✔
122
        """
123
        Returns the added doses as a json-serializable dict. Note that if you
124
        have specified smearing for the DOS plot, the densities returned will
125
        be the smeared densities, not the original densities.
126

127
        Returns:
128
            Dict of dos data. Generally of the form, {label: {'frequencies':..,
129
            'densities': ...}}
130
        """
131
        return jsanitize(self._doses)
1✔
132

133
    def get_plot(self, xlim=None, ylim=None, units="thz"):
1✔
134
        """
135
        Get a matplotlib plot showing the DOS.
136

137
        Args:
138
            xlim: Specifies the x-axis limits. Set to None for automatic
139
                determination.
140
            ylim: Specifies the y-axis limits.
141
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
142
        """
143
        u = freq_units(units)
1✔
144

145
        ncolors = max(3, len(self._doses))
1✔
146
        ncolors = min(9, ncolors)
1✔
147

148
        import palettable
1✔
149

150
        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors  # pylint: disable=E1101
1✔
151

152
        y = None
1✔
153
        alldensities = []
1✔
154
        allfrequencies = []
1✔
155
        plt = pretty_plot(12, 8)
1✔
156

157
        # Note that this complicated processing of frequencies is to allow for
158
        # stacked plots in matplotlib.
159
        for dos in self._doses.values():
1✔
160
            frequencies = dos["frequencies"] * u.factor
1✔
161
            densities = dos["densities"]
1✔
162
            if y is None:
1✔
163
                y = np.zeros(frequencies.shape)
1✔
164
            if self.stack:
1✔
165
                y += densities
1✔
166
                newdens = y.copy()
1✔
167
            else:
168
                newdens = densities
1✔
169
            allfrequencies.append(frequencies)
1✔
170
            alldensities.append(newdens)
1✔
171

172
        keys = list(self._doses)
1✔
173
        keys.reverse()
1✔
174
        alldensities.reverse()
1✔
175
        allfrequencies.reverse()
1✔
176
        allpts = []
1✔
177
        for i, (key, frequencies, densities) in enumerate(zip(keys, allfrequencies, alldensities)):
1✔
178
            allpts.extend(list(zip(frequencies, densities)))
1✔
179
            if self.stack:
1✔
180
                plt.fill(frequencies, densities, color=colors[i % ncolors], label=str(key))
1✔
181
            else:
182
                plt.plot(
1✔
183
                    frequencies,
184
                    densities,
185
                    color=colors[i % ncolors],
186
                    label=str(key),
187
                    linewidth=3,
188
                )
189

190
        if xlim:
1✔
191
            plt.xlim(xlim)
×
192
        if ylim:
1✔
193
            plt.ylim(ylim)
×
194
        else:
195
            xlim = plt.xlim()
1✔
196
            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
1✔
197
            plt.ylim((min(relevanty), max(relevanty)))
1✔
198

199
        ylim = plt.ylim()
1✔
200
        plt.plot([0, 0], ylim, "k--", linewidth=2)
1✔
201

202
        plt.xlabel(rf"$\mathrm{{Frequencies\ ({u.label})}}$")
1✔
203
        plt.ylabel(r"$\mathrm{Density\ of\ states}$")
1✔
204

205
        plt.legend()
1✔
206
        leg = plt.gca().get_legend()
1✔
207
        ltext = leg.get_texts()  # all the text.Text instance in the legend
1✔
208
        plt.setp(ltext, fontsize=30)
1✔
209
        plt.tight_layout()
1✔
210
        return plt
1✔
211

212
    def save_plot(self, filename, img_format="eps", xlim=None, ylim=None, units="thz"):
1✔
213
        """
214
        Save matplotlib plot to a file.
215

216
        Args:
217
            filename: Filename to write to.
218
            img_format: Image format to use. Defaults to EPS.
219
            xlim: Specifies the x-axis limits. Set to None for automatic
220
                determination.
221
            ylim: Specifies the y-axis limits.
222
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1
223
        """
224
        plt = self.get_plot(xlim, ylim, units=units)
×
225
        plt.savefig(filename, format=img_format)
×
226
        plt.close()
×
227

228
    def show(self, xlim=None, ylim=None, units="thz"):
1✔
229
        """
230
        Show the plot using matplotlib.
231

232
        Args:
233
            xlim: Specifies the x-axis limits. Set to None for automatic
234
                determination.
235
            ylim: Specifies the y-axis limits.
236
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
237
        """
238
        plt = self.get_plot(xlim, ylim, units=units)
×
239
        plt.show()
×
240

241

242
class PhononBSPlotter:
1✔
243
    """
244
    Class to plot or get data to facilitate the plot of band structure objects.
245
    """
246

247
    def __init__(self, bs):
1✔
248
        """
249
        Args:
250
            bs: A PhononBandStructureSymmLine object.
251
        """
252
        if not isinstance(bs, PhononBandStructureSymmLine):
1✔
253
            raise ValueError(
×
254
                "PhononBSPlotter only works with PhononBandStructureSymmLine objects. "
255
                "A PhononBandStructure object (on a uniform grid for instance and "
256
                "not along symmetry lines won't work)"
257
            )
258
        self._bs = bs
1✔
259
        self._nb_bands = self._bs.nb_bands
1✔
260

261
    def _maketicks(self, plt):
1✔
262
        """
263
        utility private method to add ticks to a band structure
264
        """
265
        ticks = self.get_ticks()
1✔
266
        # Sanitize only plot the uniq values
267
        uniq_d = []
1✔
268
        uniq_l = []
1✔
269
        temp_ticks = list(zip(ticks["distance"], ticks["label"]))
1✔
270
        for i, tt in enumerate(temp_ticks):
1✔
271
            if i == 0:
1✔
272
                uniq_d.append(tt[0])
1✔
273
                uniq_l.append(tt[1])
1✔
274
                logger.debug(f"Adding label {tt[0]} at {tt[1]}")
1✔
275
            else:
276
                if tt[1] == temp_ticks[i - 1][1]:
1✔
277
                    logger.debug(f"Skipping label {tt[1]}")
1✔
278
                else:
279
                    logger.debug(f"Adding label {tt[0]} at {tt[1]}")
1✔
280
                    uniq_d.append(tt[0])
1✔
281
                    uniq_l.append(tt[1])
1✔
282

283
        logger.debug(f"Unique labels are {list(zip(uniq_d, uniq_l))}")
1✔
284
        plt.gca().set_xticks(uniq_d)
1✔
285
        plt.gca().set_xticklabels(uniq_l)
1✔
286

287
        for i in range(len(ticks["label"])):
1✔
288
            if ticks["label"][i] is not None:
1✔
289
                # don't print the same label twice
290
                if i != 0:
1✔
291
                    if ticks["label"][i] == ticks["label"][i - 1]:
1✔
292
                        logger.debug(f"already print label... skipping label {ticks['label'][i]}")
1✔
293
                    else:
294
                        logger.debug(f"Adding a line at {ticks['distance'][i]} for label {ticks['label'][i]}")
1✔
295
                        plt.axvline(ticks["distance"][i], color="k")
1✔
296
                else:
297
                    logger.debug(f"Adding a line at {ticks['distance'][i]} for label {ticks['label'][i]}")
1✔
298
                    plt.axvline(ticks["distance"][i], color="k")
1✔
299
        return plt
1✔
300

301
    def bs_plot_data(self):
1✔
302
        """
303
        Get the data nicely formatted for a plot
304

305
        Returns:
306
            A dict of the following format:
307
            ticks: A dict with the 'distances' at which there is a qpoint (the
308
            x axis) and the labels (None if no label)
309
            frequencies: A list (one element for each branch) of frequencies for
310
            each qpoint: [branch][qpoint][mode]. The data is
311
            stored by branch to facilitate the plotting
312
            lattice: The reciprocal lattice.
313
        """
314
        distance = []
1✔
315
        frequency = []
1✔
316

317
        ticks = self.get_ticks()
1✔
318

319
        for b in self._bs.branches:
1✔
320
            frequency.append([])
1✔
321
            distance.append([self._bs.distance[j] for j in range(b["start_index"], b["end_index"] + 1)])
1✔
322

323
            for i in range(self._nb_bands):
1✔
324
                frequency[-1].append([self._bs.bands[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
1✔
325

326
        return {
1✔
327
            "ticks": ticks,
328
            "distances": distance,
329
            "frequency": frequency,
330
            "lattice": self._bs.lattice_rec.as_dict(),
331
        }
332

333
    def get_plot(self, ylim=None, units="thz"):
1✔
334
        """
335
        Get a matplotlib object for the bandstructure plot.
336

337
        Args:
338
            ylim: Specify the y-axis (frequency) limits; by default None let
339
                the code choose.
340
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
341
        """
342
        u = freq_units(units)
1✔
343

344
        plt = pretty_plot(12, 8)
1✔
345

346
        band_linewidth = 1
1✔
347

348
        data = self.bs_plot_data()
1✔
349
        for d in range(len(data["distances"])):
1✔
350
            for i in range(self._nb_bands):
1✔
351
                plt.plot(
1✔
352
                    data["distances"][d],
353
                    [data["frequency"][d][i][j] * u.factor for j in range(len(data["distances"][d]))],
354
                    "b-",
355
                    linewidth=band_linewidth,
356
                )
357

358
        self._maketicks(plt)
1✔
359

360
        # plot y=0 line
361
        plt.axhline(0, linewidth=1, color="k")
1✔
362

363
        # Main X and Y Labels
364
        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1✔
365
        ylabel = rf"$\mathrm{{Frequencies\ ({u.label})}}$"
1✔
366
        plt.ylabel(ylabel, fontsize=30)
1✔
367

368
        # X range (K)
369
        # last distance point
370
        x_max = data["distances"][-1][-1]
1✔
371
        plt.xlim(0, x_max)
1✔
372

373
        if ylim is not None:
1✔
374
            plt.ylim(ylim)
×
375

376
        plt.tight_layout()
1✔
377

378
        return plt
1✔
379

380
    def _get_weight(self, vec: np.ndarray, indices: list[list[int]]) -> np.ndarray:
1✔
381
        """
382
        compute the weight for each combination of sites according to the
383
        eigenvector
384
        """
385
        num_atom = int(self._nb_bands / 3)
1✔
386
        new_vec = np.zeros(num_atom)
1✔
387
        for i in range(num_atom):
1✔
388
            new_vec[i] = np.linalg.norm(vec[i * 3 : i * 3 + 3])
1✔
389
        # get the projectors for each group
390
        gw = []
1✔
391
        norm_f = 0
1✔
392
        for comb in indices:
1✔
393
            projector = np.zeros(len(new_vec))
1✔
394
            l = len(projector)
1✔
395
            for j in range(l):
1✔
396
                if j in comb:
1✔
397
                    projector[j] = 1
1✔
398
            group_weight = np.dot(projector, new_vec)
1✔
399
            gw.append(group_weight)
1✔
400
            norm_f += group_weight
1✔
401
        return np.array(gw, dtype=float) / norm_f
1✔
402

403
    @staticmethod
1✔
404
    def _make_color(colors: list[int]) -> list[int]:
1✔
405
        """
406
        convert the eigendisplacements to rgb colors
407

408
        """
409
        # if there are two groups, use red and blue
410
        if len(colors) == 2:
1✔
411
            return [colors[0], 0, colors[1]]
1✔
412
        elif len(colors) == 3:
1✔
413
            return colors
1✔
414
        # if there are four groups, use cyan, magenta, yellow and black
415
        elif len(colors) == 4:
1✔
416
            r = (1 - colors[0]) * (1 - colors[3])
1✔
417
            g = (1 - colors[1]) * (1 - colors[3])
1✔
418
            b = (1 - colors[2]) * (1 - colors[3])
1✔
419
            return [r, g, b]
1✔
420
        raise ValueError(f"Expected 2, 3 or 4 colors, got {len(colors)}")
×
421

422
    def get_proj_plot(
1✔
423
        self,
424
        site_comb: str | list[list[int]] = "element",
425
        ylim: tuple[None | float, None | float] | None = None,
426
        units: str = "thz",
427
        rgb_labels: tuple[None | str] | None = None,
428
    ) -> plt.Axes:
429
        """
430
        Get a matplotlib object for the bandstructure plot projected along atomic
431
        sites.
432

433
        Args:
434
            site_comb: a list of list, for example, [[0],[1],[2,3,4]];
435
                the numbers in each sublist represents the indices of atoms;
436
                the atoms in a same sublist will be plotted in a same color;
437
                if not specified, unique elements are automatically grouped.
438
            ylim: Specify the y-axis (frequency) limits; by default None let
439
                the code choose.
440
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
441
        """
442
        from matplotlib.collections import LineCollection
1✔
443

444
        from pymatgen.electronic_structure.plotter import BSDOSPlotter
1✔
445

446
        elements = [e.symbol for e in self._bs.structure.composition.elements]
1✔
447
        if site_comb == "element":
1✔
448
            assert 2 <= len(elements) <= 4, "the compound must have 2, 3 or 4 unique elements"
1✔
449
            indices: list[list[int]] = [[] for _ in range(len(elements))]
1✔
450
            for i, ele in enumerate(self._bs.structure.species):
1✔
451
                for j, unique_species in enumerate(self._bs.structure.composition.elements):
1✔
452
                    if ele == unique_species:
1✔
453
                        indices[j].append(i)
1✔
454
        else:
455
            assert isinstance(site_comb, list)
1✔
456
            assert 2 <= len(site_comb) <= 4, "the length of site_comb must be 2, 3 or 4"
1✔
457
            all_sites = self._bs.structure.sites
1✔
458
            all_indices = {*range(len(all_sites))}
1✔
459
            for comb in site_comb:
1✔
460
                for idx in comb:
1✔
461
                    assert 0 <= idx < len(all_sites), "one or more indices in site_comb does not exist"
1✔
462
                    all_indices.remove(idx)
1✔
463
            if len(all_indices) != 0:
1✔
464
                raise Exception(f"not all {len(all_sites)} indices are included in site_comb")
×
465
            indices = site_comb  # type: ignore[assignment]
1✔
466
        assert rgb_labels is None or len(rgb_labels) == len(indices), "wrong number of rgb_labels"
1✔
467

468
        u = freq_units(units)
1✔
469
        fig, ax = plt.subplots(figsize=(12, 8), dpi=300)
1✔
470
        self._maketicks(plt)
1✔
471

472
        data = self.bs_plot_data()
1✔
473
        k_dist = np.array(data["distances"]).flatten()
1✔
474
        for d in range(1, len(k_dist)):
1✔
475
            # consider 2 k points each time so they connect
476
            colors = []
1✔
477
            for idx in range(self._nb_bands):
1✔
478
                eigenvec_1 = self._bs.eigendisplacements[idx][d - 1].flatten()
1✔
479
                eigenvec_2 = self._bs.eigendisplacements[idx][d].flatten()
1✔
480
                colors1 = self._get_weight(eigenvec_1, indices)
1✔
481
                colors2 = self._get_weight(eigenvec_2, indices)
1✔
482
                colors.append(self._make_color((colors1 + colors2) / 2))
1✔
483
            seg = np.zeros((self._nb_bands, 2, 2))
1✔
484
            seg[:, :, 1] = self._bs.bands[:, d - 1 : d + 1] * u.factor
1✔
485
            seg[:, 0, 0] = k_dist[d - 1]
1✔
486
            seg[:, 1, 0] = k_dist[d]
1✔
487
            ls = LineCollection(seg, colors=colors, linestyles="-", linewidths=2.5)
1✔
488
            ax.add_collection(ls)
1✔
489
        if ylim is None:
1✔
490
            y_max: float = max(max(b) for b in self._bs.bands) * u.factor
1✔
491
            y_min: float = min(min(b) for b in self._bs.bands) * u.factor
1✔
492
            y_margin = (y_max - y_min) * 0.05
1✔
493
            ylim = (y_min - y_margin, y_max + y_margin)
1✔
494
        ax.set_ylim(ylim)
1✔
495
        xlim = [min(k_dist), max(k_dist)]
1✔
496
        ax.set_xlim(xlim)
1✔
497
        ax.set_xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=28)
1✔
498
        ylabel = rf"$\mathrm{{Frequencies\ ({u.label})}}$"
1✔
499
        ax.set_ylabel(ylabel, fontsize=28)
1✔
500
        ax.tick_params(labelsize=28)
1✔
501
        # make color legend
502
        labels: list[str]
503
        if rgb_labels is not None:
1✔
504
            labels = rgb_labels  # type: ignore[assignment]
1✔
505
        else:
506
            if site_comb == "element":
1✔
507
                labels = [e.symbol for e in self._bs.structure.composition.elements]
1✔
508
            else:
509
                labels = [f"{i}" for i in range(len(site_comb))]
1✔
510
        if len(indices) == 2:
1✔
511
            BSDOSPlotter._rb_line(ax, labels[0], labels[1], "best")
1✔
512
        elif len(indices) == 3:
1✔
513
            BSDOSPlotter._rgb_triangle(ax, labels[0], labels[1], labels[2], "best")
1✔
514
        else:
515
            # for 4 combinations, build a color square?
516
            pass
517
        return ax
1✔
518

519
    def show(self, ylim=None, units="thz"):
1✔
520
        """
521
        Show the plot using matplotlib.
522

523
        Args:
524
            ylim: Specify the y-axis (frequency) limits; by default None let
525
                the code choose.
526
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
527
        """
528
        plt = self.get_plot(ylim, units=units)
×
529
        plt.show()
×
530

531
    def save_plot(self, filename, img_format="eps", ylim=None, units="thz"):
1✔
532
        """
533
        Save matplotlib plot to a file.
534

535
        Args:
536
            filename: Filename to write to.
537
            img_format: Image format to use. Defaults to EPS.
538
            ylim: Specifies the y-axis limits.
539
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
540
        """
541
        plt = self.get_plot(ylim=ylim, units=units)
×
542
        plt.savefig(filename, format=img_format)
×
543
        plt.close()
×
544

545
    def show_proj(
1✔
546
        self,
547
        site_comb: str | list[list[int]] = "element",
548
        ylim: tuple[None | float, None | float] | None = None,
549
        units: str = "thz",
550
        rgb_labels: tuple[str] | None = None,
551
    ):
552
        """
553
        Show the projected plot using matplotlib.
554

555
        Args:
556
            ylim: Specify the y-axis (frequency) limits; by default None let
557
                the code choose.
558
            units: units for the frequencies. Accepted values thz, ev, mev, ha, cm-1, cm^-1.
559
        """
560
        self.get_proj_plot(site_comb=site_comb, ylim=ylim, units=units, rgb_labels=rgb_labels)
×
561
        plt.show()
×
562

563
    def get_ticks(self):
1✔
564
        """
565
        Get all ticks and labels for a band structure plot.
566

567
        Returns:
568
            A dict with 'distance': a list of distance at which ticks should
569
            be set and 'label': a list of label for each of those ticks.
570
        """
571
        tick_distance = []
1✔
572
        tick_labels = []
1✔
573
        previous_label = self._bs.qpoints[0].label
1✔
574
        previous_branch = self._bs.branches[0]["name"]
1✔
575
        for i, c in enumerate(self._bs.qpoints):
1✔
576
            if c.label is not None:
1✔
577
                tick_distance.append(self._bs.distance[i])
1✔
578
                this_branch = None
1✔
579
                for b in self._bs.branches:
1✔
580
                    if b["start_index"] <= i <= b["end_index"]:
1✔
581
                        this_branch = b["name"]
1✔
582
                        break
1✔
583
                if c.label != previous_label and previous_branch != this_branch:
1✔
584
                    label1 = c.label
×
585
                    if label1.startswith("\\") or label1.find("_") != -1:
×
586
                        label1 = "$" + label1 + "$"
×
587
                    label0 = previous_label
×
588
                    if label0.startswith("\\") or label0.find("_") != -1:
×
589
                        label0 = "$" + label0 + "$"
×
590
                    tick_labels.pop()
×
591
                    tick_distance.pop()
×
592
                    tick_labels.append(label0 + "$\\mid$" + label1)
×
593
                else:
594
                    if c.label.startswith("\\") or c.label.find("_") != -1:
1✔
595
                        tick_labels.append("$" + c.label + "$")
×
596
                    else:
597
                        tick_labels.append(c.label)
1✔
598
                previous_label = c.label
1✔
599
                previous_branch = this_branch
1✔
600
        return {"distance": tick_distance, "label": tick_labels}
1✔
601

602
    def plot_compare(self, other_plotter, units="thz"):
1✔
603
        """
604
        plot two band structure for comparison. One is in red the other in blue.
605
        The two band structures need to be defined on the same symmetry lines!
606
        and the distance between symmetry lines is the one of the band structure
607
        used to build the PhononBSPlotter
608

609
        Args:
610
            other_plotter: another PhononBSPlotter object defined along the same symmetry lines
611
            units:
612
        Returns:
613
            a matplotlib object with both band structures
614
        """
615
        u = freq_units(units)
1✔
616

617
        data_orig = self.bs_plot_data()
1✔
618
        data = other_plotter.bs_plot_data()
1✔
619

620
        if len(data_orig["distances"]) != len(data["distances"]):
1✔
621
            raise ValueError("The two objects are not compatible.")
×
622

623
        plt = self.get_plot(units=units)
1✔
624
        band_linewidth = 1
1✔
625
        for i in range(other_plotter._nb_bands):
1✔
626
            for d in range(len(data_orig["distances"])):
1✔
627
                plt.plot(
1✔
628
                    data_orig["distances"][d],
629
                    [data["frequency"][d][i][j] * u.factor for j in range(len(data_orig["distances"][d]))],
630
                    "r-",
631
                    linewidth=band_linewidth,
632
                )
633

634
        return plt
1✔
635

636
    def plot_brillouin(self):
1✔
637
        """
638
        plot the Brillouin zone
639
        """
640
        # get labels and lines
641
        labels = {}
×
642
        for q in self._bs.qpoints:
×
643
            if q.label:
×
644
                labels[q.label] = q.frac_coords
×
645

646
        lines = []
×
647
        for b in self._bs.branches:
×
648
            lines.append(
×
649
                [
650
                    self._bs.qpoints[b["start_index"]].frac_coords,
651
                    self._bs.qpoints[b["end_index"]].frac_coords,
652
                ]
653
            )
654

655
        plot_brillouin_zone(self._bs.lattice_rec, lines=lines, labels=labels)
×
656

657

658
class ThermoPlotter:
1✔
659
    """
660
    Plotter for thermodynamic properties obtained from phonon DOS.
661
    If the structure corresponding to the DOS, it will be used to extract the formula unit and provide
662
    the plots in units of mol instead of mole-cell
663
    """
664

665
    def __init__(self, dos, structure=None):
1✔
666
        """
667
        Args:
668
            dos: A PhononDos object.
669
            structure: A Structure object corresponding to the structure used for the calculation.
670
        """
671
        self.dos = dos
1✔
672
        self.structure = structure
1✔
673

674
    def _plot_thermo(self, func, temperatures, factor=1, ax=None, ylabel=None, label=None, ylim=None, **kwargs):
1✔
675
        """
676
        Plots a thermodynamic property for a generic function from a PhononDos instance.
677

678
        Args:
679
            func: the thermodynamic function to be used to calculate the property
680
            temperatures: a list of temperatures
681
            factor: a multiplicative factor applied to the thermodynamic property calculated. Used to change
682
                the units.
683
            ax: matplotlib :class:`Axes` or None if a new figure should be created.
684
            ylabel: label for the y axis
685
            label: label of the plot
686
            ylim: tuple specifying the y-axis limits.
687
            kwargs: kwargs passed to the matplotlib function 'plot'.
688
        Returns:
689
            matplotlib figure
690
        """
691
        ax, fig, plt = get_ax_fig_plt(ax)
1✔
692

693
        values = []
1✔
694

695
        for t in temperatures:
1✔
696
            values.append(func(t, structure=self.structure) * factor)
1✔
697

698
        ax.plot(temperatures, values, label=label, **kwargs)
1✔
699

700
        if ylim:
1✔
701
            ax.set_ylim(ylim)
×
702

703
        ax.set_xlim((np.min(temperatures), np.max(temperatures)))
1✔
704
        ylim = plt.ylim()
1✔
705
        if ylim[0] < 0 < ylim[1]:
1✔
706
            plt.plot(plt.xlim(), [0, 0], "k-", linewidth=1)
1✔
707

708
        ax.set_xlabel(r"$T$ (K)")
1✔
709
        if ylabel:
1✔
710
            ax.set_ylabel(ylabel)
1✔
711

712
        return fig
1✔
713

714
    @add_fig_kwargs
1✔
715
    def plot_cv(self, tmin, tmax, ntemp, ylim=None, **kwargs):
1✔
716
        """
717
        Plots the constant volume specific heat C_v in a temperature range.
718

719
        Args:
720
            tmin: minimum temperature
721
            tmax: maximum temperature
722
            ntemp: number of steps
723
            ylim: tuple specifying the y-axis limits.
724
            kwargs: kwargs passed to the matplotlib function 'plot'.
725
        Returns:
726
            matplotlib figure
727
        """
728
        temperatures = np.linspace(tmin, tmax, ntemp)
1✔
729

730
        if self.structure:
1✔
731
            ylabel = r"$C_v$ (J/K/mol)"
1✔
732
        else:
733
            ylabel = r"$C_v$ (J/K/mol-c)"
×
734

735
        fig = self._plot_thermo(self.dos.cv, temperatures, ylabel=ylabel, ylim=ylim, **kwargs)
1✔
736

737
        return fig
1✔
738

739
    @add_fig_kwargs
1✔
740
    def plot_entropy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
1✔
741
        """
742
        Plots the vibrational entrpy in a temperature range.
743

744
        Args:
745
            tmin: minimum temperature
746
            tmax: maximum temperature
747
            ntemp: number of steps
748
            ylim: tuple specifying the y-axis limits.
749
            kwargs: kwargs passed to the matplotlib function 'plot'.
750
        Returns:
751
            matplotlib figure
752
        """
753
        temperatures = np.linspace(tmin, tmax, ntemp)
1✔
754

755
        if self.structure:
1✔
756
            ylabel = r"$S$ (J/K/mol)"
1✔
757
        else:
758
            ylabel = r"$S$ (J/K/mol-c)"
×
759

760
        fig = self._plot_thermo(self.dos.entropy, temperatures, ylabel=ylabel, ylim=ylim, **kwargs)
1✔
761

762
        return fig
1✔
763

764
    @add_fig_kwargs
1✔
765
    def plot_internal_energy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
1✔
766
        """
767
        Plots the vibrational internal energy in a temperature range.
768

769
        Args:
770
            tmin: minimum temperature
771
            tmax: maximum temperature
772
            ntemp: number of steps
773
            ylim: tuple specifying the y-axis limits.
774
            kwargs: kwargs passed to the matplotlib function 'plot'.
775
        Returns:
776
            matplotlib figure
777
        """
778
        temperatures = np.linspace(tmin, tmax, ntemp)
1✔
779

780
        if self.structure:
1✔
781
            ylabel = r"$\Delta E$ (kJ/mol)"
1✔
782
        else:
783
            ylabel = r"$\Delta E$ (kJ/mol-c)"
×
784

785
        fig = self._plot_thermo(self.dos.internal_energy, temperatures, ylabel=ylabel, ylim=ylim, factor=1e-3, **kwargs)
1✔
786

787
        return fig
1✔
788

789
    @add_fig_kwargs
1✔
790
    def plot_helmholtz_free_energy(self, tmin, tmax, ntemp, ylim=None, **kwargs):
1✔
791
        """
792
        Plots the vibrational contribution to the Helmoltz free energy in a temperature range.
793

794
        Args:
795
            tmin: minimum temperature
796
            tmax: maximum temperature
797
            ntemp: number of steps
798
            ylim: tuple specifying the y-axis limits.
799
            kwargs: kwargs passed to the matplotlib function 'plot'.
800
        Returns:
801
            matplotlib figure
802
        """
803
        temperatures = np.linspace(tmin, tmax, ntemp)
1✔
804

805
        if self.structure:
1✔
806
            ylabel = r"$\Delta F$ (kJ/mol)"
1✔
807
        else:
808
            ylabel = r"$\Delta F$ (kJ/mol-c)"
×
809

810
        fig = self._plot_thermo(
1✔
811
            self.dos.helmholtz_free_energy, temperatures, ylabel=ylabel, ylim=ylim, factor=1e-3, **kwargs
812
        )
813

814
        return fig
1✔
815

816
    @add_fig_kwargs
1✔
817
    def plot_thermodynamic_properties(self, tmin, tmax, ntemp, ylim=None, **kwargs):
1✔
818
        """
819
        Plots all the thermodynamic properties in a temperature range.
820

821
        Args:
822
            tmin: minimum temperature
823
            tmax: maximum temperature
824
            ntemp: number of steps
825
            ylim: tuple specifying the y-axis limits.
826
            kwargs: kwargs passed to the matplotlib function 'plot'.
827
        Returns:
828
            matplotlib figure
829
        """
830
        temperatures = np.linspace(tmin, tmax, ntemp)
1✔
831

832
        mol = "" if self.structure else "-c"
1✔
833

834
        fig = self._plot_thermo(
1✔
835
            self.dos.cv,
836
            temperatures,
837
            ylabel="Thermodynamic properties",
838
            ylim=ylim,
839
            label=rf"$C_v$ (J/K/mol{mol})",
840
            **kwargs,
841
        )
842
        self._plot_thermo(
1✔
843
            self.dos.entropy, temperatures, ylim=ylim, ax=fig.axes[0], label=rf"$S$ (J/K/mol{mol})", **kwargs
844
        )
845
        self._plot_thermo(
1✔
846
            self.dos.internal_energy,
847
            temperatures,
848
            ylim=ylim,
849
            ax=fig.axes[0],
850
            factor=1e-3,
851
            label=rf"$\Delta E$ (kJ/mol{mol})",
852
            **kwargs,
853
        )
854
        self._plot_thermo(
1✔
855
            self.dos.helmholtz_free_energy,
856
            temperatures,
857
            ylim=ylim,
858
            ax=fig.axes[0],
859
            factor=1e-3,
860
            label=rf"$\Delta F$ (kJ/mol{mol})",
861
            **kwargs,
862
        )
863

864
        fig.axes[0].legend(loc="best")
1✔
865

866
        return fig
1✔
867

868

869
class GruneisenPlotter:
1✔
870
    """
871
    Class to plot Gruneisenparameter Object
872
    """
873

874
    def __init__(self, gruneisen):
1✔
875
        """
876
        Class to plot information from Gruneisenparameter Object
877
        Args:
878
            gruneisen: GruneisenParameter Object
879
        """
880
        self._gruneisen = gruneisen
1✔
881

882
    def get_plot(self, marker="o", markersize=None, units="thz"):
1✔
883
        """
884
        will produce a plot
885
        Args:
886
            marker: marker for the depiction
887
            markersize: size of the marker
888
            units: unit for the plots, accepted units: thz, ev, mev, ha, cm-1, cm^-1
889

890
        Returns: plot
891
        """
892
        u = freq_units(units)
1✔
893

894
        xs = self._gruneisen.frequencies.flatten() * u.factor
1✔
895
        ys = self._gruneisen.gruneisen.flatten()
1✔
896

897
        plt = pretty_plot(12, 8)
1✔
898

899
        plt.xlabel(rf"$\mathrm{{Frequency\ ({u.label})}}$")
1✔
900
        plt.ylabel(r"$\mathrm{Grüneisen\ parameter}$")
1✔
901

902
        n = len(ys) - 1
1✔
903
        for i, (x, y) in enumerate(zip(xs, ys)):
1✔
904
            color = (1.0 / n * i, 0, 1.0 / n * (n - i))
1✔
905

906
            if markersize:
1✔
907
                plt.plot(x, y, marker, color=color, markersize=markersize)
×
908
            else:
909
                plt.plot(x, y, marker, color=color)
1✔
910

911
        plt.tight_layout()
1✔
912

913
        return plt
1✔
914

915
    def show(self, units="thz"):
1✔
916
        """
917
        will show the plot
918
        Args:
919
            units: units for the plot, accepted units: thz, ev, mev, ha, cm-1, cm^-1
920

921
        Returns: plot
922
        """
923
        plt = self.get_plot(units=units)
×
924
        plt.show()
×
925

926
    def save_plot(self, filename, img_format="pdf", units="thz"):
1✔
927
        """
928
        Will save the plot to a file
929
        Args:
930
            filename: name of the filename
931
            img_format: format of the saved plot
932
            units: accepted units: thz, ev, mev, ha, cm-1, cm^-1
933

934
        Returns:
935
        """
936
        plt = self.get_plot(units=units)
×
937
        plt.savefig(filename, format=img_format)
×
938
        plt.close()
×
939

940

941
class GruneisenPhononBSPlotter(PhononBSPlotter):
1✔
942
    """
943
    Class to plot or get data to facilitate the plot of band structure objects.
944
    """
945

946
    def __init__(self, bs):
1✔
947
        """
948
        Args:
949
            bs: A GruneisenPhononBandStructureSymmLine object.
950
        """
951
        if not isinstance(bs, GruneisenPhononBandStructureSymmLine):
1✔
952
            raise ValueError(
×
953
                "GruneisenPhononBSPlotter only works with GruneisenPhononBandStructureSymmLine objects. "
954
                "A GruneisenPhononBandStructure object (on a uniform grid for instance and "
955
                "not along symmetry lines won't work)"
956
            )
957
        super().__init__(bs)
1✔
958

959
    def bs_plot_data(self):
1✔
960
        """
961
        Get the data nicely formatted for a plot
962

963
        Returns:
964
            A dict of the following format:
965
            ticks: A dict with the 'distances' at which there is a qpoint (the
966
            x axis) and the labels (None if no label)
967
            frequencies: A list (one element for each branch) of frequencies for
968
            each qpoint: [branch][qpoint][mode]. The data is
969
            stored by branch to facilitate the plotting
970
            gruneisen: GruneisenPhononBandStructureSymmLine
971
            lattice: The reciprocal lattice.
972
        """
973
        distance, frequency, gruneisen = ([] for _ in range(3))
1✔
974

975
        ticks = self.get_ticks()
1✔
976

977
        for b in self._bs.branches:
1✔
978
            frequency.append([])
1✔
979
            gruneisen.append([])
1✔
980
            distance.append([self._bs.distance[j] for j in range(b["start_index"], b["end_index"] + 1)])
1✔
981

982
            for i in range(self._nb_bands):
1✔
983
                frequency[-1].append([self._bs.bands[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
1✔
984
                gruneisen[-1].append([self._bs.gruneisen[i][j] for j in range(b["start_index"], b["end_index"] + 1)])
1✔
985

986
        return {
1✔
987
            "ticks": ticks,
988
            "distances": distance,
989
            "frequency": frequency,
990
            "gruneisen": gruneisen,
991
            "lattice": self._bs.lattice_rec.as_dict(),
992
        }
993

994
    def get_plot_gs(self, ylim=None):
1✔
995
        """
996
        Get a matplotlib object for the gruneisen bandstructure plot.
997

998
        Args:
999
            ylim: Specify the y-axis (gruneisen) limits; by default None let
1000
                the code choose.
1001
        """
1002
        plt = pretty_plot(12, 8)
1✔
1003

1004
        # band_linewidth = 1
1005

1006
        data = self.bs_plot_data()
1✔
1007
        for d in range(len(data["distances"])):
1✔
1008
            for i in range(self._nb_bands):
1✔
1009
                plt.plot(
1✔
1010
                    data["distances"][d],
1011
                    [data["gruneisen"][d][i][j] for j in range(len(data["distances"][d]))],
1012
                    "b-",
1013
                    # linewidth=band_linewidth)
1014
                    marker="o",
1015
                    markersize=2,
1016
                    linewidth=2,
1017
                )
1018

1019
        self._maketicks(plt)
1✔
1020

1021
        # plot y=0 line
1022
        plt.axhline(0, linewidth=1, color="k")
1✔
1023

1024
        # Main X and Y Labels
1025
        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1✔
1026
        plt.ylabel(r"$\mathrm{Grüneisen\ Parameter}$", fontsize=30)
1✔
1027

1028
        # X range (K)
1029
        # last distance point
1030
        x_max = data["distances"][-1][-1]
1✔
1031
        plt.xlim(0, x_max)
1✔
1032

1033
        if ylim is not None:
1✔
1034
            plt.ylim(ylim)
×
1035

1036
        plt.tight_layout()
1✔
1037

1038
        return plt
1✔
1039

1040
    def show_gs(self, ylim=None):
1✔
1041
        """
1042
        Show the plot using matplotlib.
1043

1044
        Args:
1045
            ylim: Specifies the y-axis limits.
1046
        """
1047
        plt = self.get_plot_gs(ylim)
×
1048
        plt.show()
×
1049

1050
    def save_plot_gs(self, filename, img_format="eps", ylim=None):
1✔
1051
        """
1052
        Save matplotlib plot to a file.
1053

1054
        Args:
1055
            filename: Filename to write to.
1056
            img_format: Image format to use. Defaults to EPS.
1057
            ylim: Specifies the y-axis limits.
1058
        """
1059
        plt = self.get_plot_gs(ylim=ylim)
×
1060
        plt.savefig(filename, format=img_format)
×
1061
        plt.close()
×
1062

1063
    def plot_compare_gs(self, other_plotter):
1✔
1064
        """
1065
        plot two band structure for comparison. One is in red the other in blue.
1066
        The two band structures need to be defined on the same symmetry lines!
1067
        and the distance between symmetry lines is
1068
        the one of the band structure used to build the PhononBSPlotter
1069

1070
        Args:
1071
            another GruneisenPhononBSPlotter object defined along the same symmetry lines
1072

1073
        Returns:
1074
            a matplotlib object with both band structures
1075
        """
1076
        data_orig = self.bs_plot_data()
×
1077
        data = other_plotter.bs_plot_data()
×
1078

1079
        if len(data_orig["distances"]) != len(data["distances"]):
×
1080
            raise ValueError("The two objects are not compatible.")
×
1081

1082
        plt = self.get_plot()
×
1083
        band_linewidth = 1
×
1084
        for i in range(other_plotter._nb_bands):
×
1085
            for d in range(len(data_orig["distances"])):
×
1086
                plt.plot(
×
1087
                    data_orig["distances"][d],
1088
                    [data["gruneisen"][d][i][j] for j in range(len(data_orig["distances"][d]))],
1089
                    "r-",
1090
                    linewidth=band_linewidth,
1091
                )
1092

1093
        return plt
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc