• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.14
/pymatgen/electronic_structure/plotter.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3
"""
1✔
4
This module implements plotter for DOS and band structure.
5
"""
6

7
from __future__ import annotations
1✔
8

9
import copy
1✔
10
import itertools
1✔
11
import logging
1✔
12
import math
1✔
13
import warnings
1✔
14
from collections import Counter
1✔
15
from typing import List, Literal, cast
1✔
16

17
import matplotlib.lines as mlines
1✔
18
import numpy as np
1✔
19
import scipy.interpolate as scint
1✔
20
from monty.dev import requires
1✔
21
from monty.json import jsanitize
1✔
22

23
from pymatgen.core.periodic_table import Element
1✔
24
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
1✔
25
from pymatgen.electronic_structure.boltztrap import BoltztrapError
1✔
26
from pymatgen.electronic_structure.core import OrbitalType, Spin
1✔
27
from pymatgen.electronic_structure.dos import CompleteDos, Dos
1✔
28
from pymatgen.util.plotting import add_fig_kwargs, get_ax3d_fig_plt, pretty_plot
1✔
29
from pymatgen.util.typing import ArrayLike
1✔
30

31
try:
1✔
32
    from mayavi import mlab
1✔
33
except ImportError:
1✔
34
    mlab = None
1✔
35

36
__author__ = "Shyue Ping Ong, Geoffroy Hautier, Anubhav Jain"
1✔
37
__copyright__ = "Copyright 2012, The Materials Project"
1✔
38
__version__ = "0.1"
1✔
39
__maintainer__ = "Shyue Ping Ong"
1✔
40
__email__ = "shyuep@gmail.com"
1✔
41
__date__ = "May 1, 2012"
1✔
42

43
logger = logging.getLogger(__name__)
1✔
44

45

46
class DosPlotter:
1✔
47
    """
48
    Class for plotting DOSs. Note that the interface is extremely flexible
49
    given that there are many different ways in which people want to view
50
    DOS. The typical usage is::
51

52
        # Initializes plotter with some optional args. Defaults are usually
53
        # fine,
54
        plotter = DosPlotter()
55

56
        # Adds a DOS with a label.
57
        plotter.add_dos("Total DOS", dos)
58

59
        # Alternatively, you can add a dict of DOSs. This is the typical
60
        # form returned by CompleteDos.get_spd/element/others_dos().
61
        plotter.add_dos_dict({"dos1": dos1, "dos2": dos2})
62
        plotter.add_dos_dict(complete_dos.get_spd_dos())
63
    """
64

65
    def __init__(self, zero_at_efermi: bool = True, stack: bool = False, sigma: float | None = None) -> None:
1✔
66
        """
67
        Args:
68
            zero_at_efermi (bool): Whether to shift all Dos to have zero energy at the
69
                fermi energy. Defaults to True.
70
            stack (bool): Whether to plot the DOS as a stacked area graph
71
            sigma (float): Specify a standard deviation for Gaussian smearing
72
                the DOS for nicer looking plots. Defaults to None for no
73
                smearing.
74
        """
75
        self.zero_at_efermi = zero_at_efermi
1✔
76
        self.stack = stack
1✔
77
        self.sigma = sigma
1✔
78
        self._norm_val = True
1✔
79
        self._doses: dict[
1✔
80
            str, dict[Literal["energies", "densities", "efermi"], float | ArrayLike | dict[Spin, ArrayLike]]
81
        ] = {}
82

83
    def add_dos(self, label: str, dos: Dos) -> None:
1✔
84
        """
85
        Adds a dos for plotting.
86

87
        Args:
88
            label: label for the DOS. Must be unique.
89
            dos: Dos object
90
        """
91
        if dos.norm_vol is None:
1✔
92
            self._norm_val = False
1✔
93
        energies = dos.energies - dos.efermi if self.zero_at_efermi else dos.energies
1✔
94
        densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
1✔
95
        efermi = dos.efermi
1✔
96
        self._doses[label] = {
1✔
97
            "energies": energies,
98
            "densities": densities,
99
            "efermi": efermi,
100
        }
101

102
    def add_dos_dict(self, dos_dict, key_sort_func=None):
1✔
103
        """
104
        Add a dictionary of doses, with an optional sorting function for the
105
        keys.
106

107
        Args:
108
            dos_dict: dict of {label: Dos}
109
            key_sort_func: function used to sort the dos_dict keys.
110
        """
111
        if key_sort_func:
1✔
112
            keys = sorted(dos_dict, key=key_sort_func)
1✔
113
        else:
114
            keys = list(dos_dict)
×
115
        for label in keys:
1✔
116
            self.add_dos(label, dos_dict[label])
1✔
117

118
    def get_dos_dict(self):
1✔
119
        """
120
        Returns the added doses as a json-serializable dict. Note that if you
121
        have specified smearing for the DOS plot, the densities returned will
122
        be the smeared densities, not the original densities.
123

124
        Returns:
125
            dict: Dict of dos data. Generally of the form
126
            {label: {'energies':..., 'densities': {'up':...}, 'efermi':efermi}}
127
        """
128
        return jsanitize(self._doses)
1✔
129

130
    def get_plot(self, xlim=None, ylim=None):
1✔
131
        """
132
        Get a matplotlib plot showing the DOS.
133

134
        Args:
135
            xlim: Specifies the x-axis limits. Set to None for automatic
136
                determination.
137
            ylim: Specifies the y-axis limits.
138
        """
139
        ncolors = max(3, len(self._doses))
1✔
140
        ncolors = min(9, ncolors)
1✔
141

142
        import palettable
1✔
143

144
        # pylint: disable=E1101
145
        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
1✔
146

147
        ys = None
1✔
148
        alldensities = []
1✔
149
        allenergies = []
1✔
150
        plt = pretty_plot(12, 8)
1✔
151

152
        # Note that this complicated processing of energies is to allow for
153
        # stacked plots in matplotlib.
154
        for dos in self._doses.values():
1✔
155
            energies = dos["energies"]
1✔
156
            densities = dos["densities"]
1✔
157
            if not ys:
1✔
158
                ys = {
1✔
159
                    Spin.up: np.zeros(energies.shape),
160
                    Spin.down: np.zeros(energies.shape),
161
                }
162
            newdens = {}
1✔
163
            for spin in [Spin.up, Spin.down]:
1✔
164
                if spin in densities:
1✔
165
                    if self.stack:
1✔
166
                        ys[spin] += densities[spin]
1✔
167
                        newdens[spin] = ys[spin].copy()
1✔
168
                    else:
169
                        newdens[spin] = densities[spin]
×
170
            allenergies.append(energies)
1✔
171
            alldensities.append(newdens)
1✔
172

173
        keys = list(self._doses)
1✔
174
        keys.reverse()
1✔
175
        alldensities.reverse()
1✔
176
        allenergies.reverse()
1✔
177
        allpts = []
1✔
178
        for idx, key in enumerate(keys):
1✔
179
            xs = []
1✔
180
            ys = []
1✔
181
            for spin in [Spin.up, Spin.down]:
1✔
182
                if spin in alldensities[idx]:
1✔
183
                    densities = list(int(spin) * alldensities[idx][spin])
1✔
184
                    energies = list(allenergies[idx])
1✔
185
                    if spin == Spin.down:
1✔
186
                        energies.reverse()
1✔
187
                        densities.reverse()
1✔
188
                    xs.extend(energies)
1✔
189
                    ys.extend(densities)
1✔
190
            allpts.extend(list(zip(xs, ys)))
1✔
191
            if self.stack:
1✔
192
                plt.fill(xs, ys, color=colors[idx % ncolors], label=str(key))
1✔
193
            else:
194
                plt.plot(xs, ys, color=colors[idx % ncolors], label=str(key), linewidth=3)
×
195
            if not self.zero_at_efermi:
1✔
196
                ylim = plt.ylim()
×
197
                plt.plot(
×
198
                    [self._doses[key]["efermi"], self._doses[key]["efermi"]],
199
                    ylim,
200
                    color=colors[idx % ncolors],
201
                    linestyle="--",
202
                    linewidth=2,
203
                )
204

205
        if xlim:
1✔
206
            plt.xlim(xlim)
×
207
        if ylim:
1✔
208
            plt.ylim(ylim)
×
209
        else:
210
            xlim = plt.xlim()
1✔
211
            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
1✔
212
            plt.ylim((min(relevanty), max(relevanty)))
1✔
213

214
        if self.zero_at_efermi:
1✔
215
            ylim = plt.ylim()
1✔
216
            plt.plot([0, 0], ylim, "k--", linewidth=2)
1✔
217

218
        plt.xlabel("Energies (eV)")
1✔
219

220
        if self._norm_val:
1✔
221
            plt.ylabel("Density of states (states/eV/ų)")
×
222
        else:
223
            plt.ylabel("Density of states (states/eV)")
1✔
224

225
        plt.axhline(y=0, color="k", linestyle="--", linewidth=2)
1✔
226
        plt.legend()
1✔
227
        leg = plt.gca().get_legend()
1✔
228
        ltext = leg.get_texts()  # all the text.Text instance in the legend
1✔
229
        plt.setp(ltext, fontsize=30)
1✔
230
        plt.tight_layout()
1✔
231
        return plt
1✔
232

233
    def save_plot(self, filename, img_format="eps", xlim=None, ylim=None):
1✔
234
        """
235
        Save matplotlib plot to a file.
236

237
        Args:
238
            filename: Filename to write to.
239
            img_format: Image format to use. Defaults to EPS.
240
            xlim: Specifies the x-axis limits. Set to None for automatic
241
                determination.
242
            ylim: Specifies the y-axis limits.
243
        """
244
        plt = self.get_plot(xlim, ylim)
1✔
245
        plt.savefig(filename, format=img_format)
1✔
246

247
    def show(self, xlim=None, ylim=None):
1✔
248
        """
249
        Show the plot using matplotlib.
250

251
        Args:
252
            xlim: Specifies the x-axis limits. Set to None for automatic
253
                determination.
254
            ylim: Specifies the y-axis limits.
255
        """
256
        plt = self.get_plot(xlim, ylim)
×
257
        plt.show()
×
258

259

260
class BSPlotter:
1✔
261
    """
262
    Class to plot or get data to facilitate the plot of band structure objects.
263
    """
264

265
    def __init__(self, bs: BandStructureSymmLine) -> None:
1✔
266
        """
267
        Args:
268
            bs: A BandStructureSymmLine object.
269
        """
270
        self._bs: list[BandStructureSymmLine] = []
1✔
271
        self._nb_bands: list[int] = []
1✔
272

273
        self.add_bs(bs)
1✔
274

275
    def _check_bs_kpath(self, bs_list: list[BandStructureSymmLine]) -> Literal[True]:
1✔
276
        """
277
        Helper method that check all the band objs in bs_list are
278
        BandStructureSymmLine objs and they all have the same kpath.
279
        """
280
        # check obj type
281
        for bs in bs_list:
1✔
282
            if not isinstance(bs, BandStructureSymmLine):
1✔
283
                raise ValueError(
×
284
                    "BSPlotter only works with BandStructureSymmLine objects. "
285
                    "A BandStructure object (on a uniform grid for instance and "
286
                    "not along symmetry lines won't work)"
287
                )
288

289
        # check the kpath
290
        if len(bs_list) == 1 and not self._bs:
1✔
291
            return True
1✔
292

293
        if not self._bs:
1✔
294
            kpath_ref = [br["name"] for br in bs_list[0].branches]
1✔
295
        else:
296
            kpath_ref = [br["name"] for br in self._bs[0].branches]
1✔
297

298
        for bs in bs_list:
1✔
299
            if kpath_ref != [br["name"] for br in bs.branches]:
1✔
300
                msg = (
×
301
                    f"BSPlotter only works with BandStructureSymmLine "
302
                    f"which have the same kpath. \n{bs} has a different kpath!"
303
                )
304
                raise ValueError(msg)
×
305

306
        return True
1✔
307

308
    def add_bs(self, bs: BandStructureSymmLine | list[BandStructureSymmLine]) -> None:
1✔
309
        """
310
        Method to add bands objects to the BSPlotter
311
        """
312
        if not isinstance(bs, list):
1✔
313
            bs = [bs]
1✔
314

315
        if self._check_bs_kpath(bs):
1✔
316
            self._bs.extend(bs)
1✔
317
            # TODO: come with an intelligent way to cut the highest unconverged
318
            # bands
319
            self._nb_bands.extend([b.nb_bands for b in bs])
1✔
320

321
    def _maketicks(self, plt):
1✔
322
        """
323
        Utility private method to add ticks to a band structure
324
        """
325
        ticks = self.get_ticks()
1✔
326
        # Sanitize only plot the uniq values
327
        uniq_d = []
1✔
328
        uniq_l = []
1✔
329
        temp_ticks = list(zip(ticks["distance"], ticks["label"]))
1✔
330
        for i, t in enumerate(temp_ticks):
1✔
331
            if i == 0:
1✔
332
                uniq_d.append(t[0])
1✔
333
                uniq_l.append(t[1])
1✔
334
                logger.debug(f"Adding label {t[0]} at {t[1]}")
1✔
335
            else:
336
                if t[1] == temp_ticks[i - 1][1]:
1✔
337
                    logger.debug(f"Skipping label {t[1]}")
1✔
338
                else:
339
                    logger.debug(f"Adding label {t[0]} at {t[1]}")
1✔
340
                    uniq_d.append(t[0])
1✔
341
                    uniq_l.append(t[1])
1✔
342

343
        logger.debug(f"Unique labels are {list(zip(uniq_d, uniq_l))}")
1✔
344
        plt.gca().set_xticks(uniq_d)
1✔
345
        plt.gca().set_xticklabels(uniq_l)
1✔
346

347
        for i in range(len(ticks["label"])):
1✔
348
            if ticks["label"][i] is not None:
1✔
349
                # don't print the same label twice
350
                if i != 0:
1✔
351
                    if ticks["label"][i] == ticks["label"][i - 1]:
1✔
352
                        logger.debug(f"already print label... skipping label {ticks['label'][i]}")
1✔
353
                    else:
354
                        logger.debug(f"Adding a line at {ticks['distance'][i]} for label {ticks['label'][i]}")
1✔
355
                        plt.axvline(ticks["distance"][i], color="k")
1✔
356
                else:
357
                    logger.debug(f"Adding a line at {ticks['distance'][i]} for label {ticks['label'][i]}")
1✔
358
                    plt.axvline(ticks["distance"][i], color="k")
1✔
359
        return plt
1✔
360

361
    @staticmethod
1✔
362
    def _get_branch_steps(branches):
1✔
363
        """
364
        Method to find discontinuous branches
365
        """
366
        steps = [0]
1✔
367
        for b1, b2 in zip(branches[:-1], branches[1:]):
1✔
368
            if b2["name"].split("-")[0] != b1["name"].split("-")[-1]:
1✔
369
                steps.append(b2["start_index"])
1✔
370
        steps.append(branches[-1]["end_index"] + 1)
1✔
371
        return steps
1✔
372

373
    @staticmethod
1✔
374
    def _rescale_distances(bs_ref, bs):
1✔
375
        """
376
        Method to rescale distances of bs to distances in bs_ref.
377
        This is used for plotting two bandstructures (same k-path)
378
        of different materials.
379
        """
380
        scaled_distances = []
1✔
381

382
        for br, br2 in zip(bs_ref.branches, bs.branches):
1✔
383
            s = br["start_index"]
1✔
384
            e = br["end_index"]
1✔
385
            max_d = bs_ref.distance[e]
1✔
386
            min_d = bs_ref.distance[s]
1✔
387
            s2 = br2["start_index"]
1✔
388
            e2 = br2["end_index"]
1✔
389
            np = e2 - s2
1✔
390
            if np == 0:
1✔
391
                # it deals with single point branches
392
                scaled_distances.extend([min_d])
×
393
            else:
394
                scaled_distances.extend([(max_d - min_d) / np * i + min_d for i in range(np + 1)])
1✔
395

396
        return scaled_distances
1✔
397

398
    def bs_plot_data(self, zero_to_efermi=True, bs=None, bs_ref=None, split_branches=True):
1✔
399
        """
400
        Get the data nicely formatted for a plot
401

402
        Args:
403
            zero_to_efermi: Automatically subtract off the Fermi energy from the
404
                eigenvalues and plot.
405
            bs: the bandstructure to get the data from. If not provided, the first
406
                one in the self._bs list will be used.
407
            bs_ref: is the bandstructure of reference when a rescale of the distances
408
                is need to plot multiple bands
409
            split_branches: if True distances and energies are split according to the
410
                branches. If False distances and energies are split only where branches
411
                are discontinuous (reducing the number of lines to plot).
412

413
        Returns:
414
            dict: A dictionary of the following format:
415
            ticks: A dict with the 'distances' at which there is a kpoint (the
416
            x axis) and the labels (None if no label).
417
            energy: A dict storing bands for spin up and spin down data
418
            {Spin:[np.array(nb_bands,kpoints),...]} as a list of discontinuous kpath
419
            of energies. The energy of multiple continuous branches are stored together.
420
            vbm: A list of tuples (distance,energy) marking the vbms. The
421
            energies are shifted with respect to the fermi level is the
422
            option has been selected.
423
            cbm: A list of tuples (distance,energy) marking the cbms. The
424
            energies are shifted with respect to the fermi level is the
425
            option has been selected.
426
            lattice: The reciprocal lattice.
427
            zero_energy: This is the energy used as zero for the plot.
428
            band_gap:A string indicating the band gap and its nature (empty if
429
            it's a metal).
430
            is_metal: True if the band structure is metallic (i.e., there is at
431
            least one band crossing the fermi level).
432
        """
433
        if bs is None:
1✔
434
            if isinstance(self._bs, list):
1✔
435
                # if BSPlotter
436
                bs = self._bs[0]
1✔
437
            else:
438
                # if BSPlotterProjected
439
                bs = self._bs
1✔
440

441
        energies = {str(sp): [] for sp in bs.bands}
1✔
442

443
        bs_is_metal = bs.is_metal()
1✔
444

445
        if not bs_is_metal:
1✔
446
            vbm = bs.get_vbm()
1✔
447
            cbm = bs.get_cbm()
1✔
448

449
        zero_energy = 0.0
1✔
450
        if zero_to_efermi:
1✔
451
            if bs_is_metal:
1✔
452
                zero_energy = bs.efermi
1✔
453
            else:
454
                zero_energy = vbm["energy"]
1✔
455

456
        # rescale distances when a bs_ref is given as reference,
457
        # and when bs and bs_ref have different points in branches.
458
        # Usually bs_ref is the first one in self._bs list is bs_ref
459
        distances = bs.distance
1✔
460
        if bs_ref is not None:
1✔
461
            if bs_ref.branches != bs.branches:
1✔
462
                distances = self._rescale_distances(bs_ref, bs)
1✔
463

464
        if split_branches:
1✔
465
            steps = [br["end_index"] + 1 for br in bs.branches][:-1]
1✔
466
        else:
467
            # join all the continuous branches
468
            # to reduce the total number of branches to plot
469
            steps = self._get_branch_steps(bs.branches)[1:-1]
1✔
470

471
        distances = np.split(distances, steps)
1✔
472
        for sp in bs.bands:
1✔
473
            energies[str(sp)] = np.hsplit(bs.bands[sp] - zero_energy, steps)
1✔
474

475
        ticks = self.get_ticks()
1✔
476

477
        vbm_plot = []
1✔
478
        cbm_plot = []
1✔
479
        bg_str = ""
1✔
480

481
        if not bs_is_metal:
1✔
482
            for index in cbm["kpoint_index"]:
1✔
483
                cbm_plot.append(
1✔
484
                    (
485
                        bs.distance[index],
486
                        cbm["energy"] - zero_energy if zero_to_efermi else cbm["energy"],
487
                    )
488
                )
489

490
            for index in vbm["kpoint_index"]:
1✔
491
                vbm_plot.append(
1✔
492
                    (
493
                        bs.distance[index],
494
                        vbm["energy"] - zero_energy if zero_to_efermi else vbm["energy"],
495
                    )
496
                )
497

498
            bg = bs.get_band_gap()
1✔
499
            direct = "Indirect"
1✔
500
            if bg["direct"]:
1✔
501
                direct = "Direct"
1✔
502

503
            bg_str = f"{direct} {bg['transition']} bandgap = {bg['energy']}"
1✔
504

505
        return {
1✔
506
            "ticks": ticks,
507
            "distances": distances,
508
            "energy": energies,
509
            "vbm": vbm_plot,
510
            "cbm": cbm_plot,
511
            "lattice": bs.lattice_rec.as_dict(),
512
            "zero_energy": zero_energy,
513
            "is_metal": bs_is_metal,
514
            "band_gap": bg_str,
515
        }
516

517
    @staticmethod
1✔
518
    def _interpolate_bands(distances, energies, smooth_tol=0, smooth_k=3, smooth_np=100):
1✔
519
        """
520
        Method that interpolates the provided energies using B-splines as
521
        implemented in scipy.interpolate. Distances and energies has to provided
522
        already split into pieces (branches work good, for longer segments
523
        the interpolation may fail).
524

525
        Interpolation failure can be caused by trying to fit an entire
526
        band with one spline rather than fitting with piecewise splines
527
        (splines are ill-suited to fit discontinuities).
528

529
        The number of splines used to fit a band is determined by the
530
        number of branches (high symmetry lines) defined in the
531
        BandStructureSymmLine object (see BandStructureSymmLine._branches).
532
        """
533
        int_energies, int_distances = [], []
1✔
534
        smooth_k_orig = smooth_k
1✔
535

536
        for dist, ene in zip(distances, energies):
1✔
537
            br_en = []
1✔
538
            warning_nan = (
1✔
539
                f"WARNING! Distance / branch, band cannot be "
540
                f"interpolated. See full warning in source. "
541
                f"If this is not a mistake, try increasing "
542
                f"smooth_tol. Current smooth_tol is {smooth_tol}."
543
            )
544

545
            warning_m_fewer_k = (
1✔
546
                f"The number of points (m) has to be higher then "
547
                f"the order (k) of the splines. In this branch {len(dist)} "
548
                f"points are found, while k is set to {smooth_k}. "
549
                f"Smooth_k will be reduced to {smooth_k - 1} for this branch."
550
            )
551

552
            # skip single point branches
553
            if len(dist) in (2, 3):
1✔
554
                # reducing smooth_k when the number
555
                # of points are fewer then k
556
                smooth_k = len(dist) - 1
×
557
                warnings.warn(warning_m_fewer_k)
×
558
            elif len(dist) == 1:
1✔
559
                warnings.warn("Skipping single point branch")
×
560
                continue
×
561

562
            int_distances.append(np.linspace(dist[0], dist[-1], smooth_np))
1✔
563

564
            for ien in ene:
1✔
565
                tck = scint.splrep(dist, ien, s=smooth_tol, k=smooth_k)
1✔
566

567
                br_en.append(scint.splev(int_distances[-1], tck))
1✔
568

569
            smooth_k = smooth_k_orig
1✔
570

571
            int_energies.append(np.vstack(br_en))
1✔
572

573
            if np.any(np.isnan(int_energies[-1])):
1✔
574
                warnings.warn(warning_nan)
×
575

576
        return int_distances, int_energies
1✔
577

578
    def get_plot(
1✔
579
        self,
580
        zero_to_efermi=True,
581
        ylim=None,
582
        smooth=False,
583
        vbm_cbm_marker=False,
584
        smooth_tol=0,
585
        smooth_k=3,
586
        smooth_np=100,
587
        bs_labels=None,
588
    ):
589
        """
590
        Get a matplotlib object for the bandstructures plot.
591
        Multiple bandstructure objs are plotted together if they have the
592
        same high symm path.
593

594
        Args:
595
            zero_to_efermi: Automatically subtract off the Fermi energy from
596
                the eigenvalues and plot (E-Ef).
597
            ylim: Specify the y-axis (energy) limits; by default None let
598
                the code choose. It is vbm-4 and cbm+4 if insulator
599
                efermi-10 and efermi+10 if metal
600
            smooth (bool or list(bools)): interpolates the bands by a spline cubic.
601
                A single bool values means to interpolate all the bandstructure objs.
602
                A list of bools allows to select the bandstructure obs to interpolate.
603
            vbm_cbm_marker (bool): if True, a marker is added to the vbm and cbm.
604
            smooth_tol (float) : tolerance for fitting spline to band data.
605
                Default is None such that no tolerance will be used.
606
            smooth_k (int): degree of splines 1<k<5
607
            smooth_np (int): number of interpolated points per each branch.
608
            bs_labels: labels for each band for the plot legend.
609
        """
610
        plt = pretty_plot(12, 8)
1✔
611

612
        if isinstance(smooth, bool):
1✔
613
            smooth = [smooth] * len(self._bs)
1✔
614

615
        handles = []
1✔
616
        vbm_min, cbm_max = [], []
1✔
617

618
        colors = list(plt.rcParams["axes.prop_cycle"].by_key().values())[0]
1✔
619
        for ibs, bs in enumerate(self._bs):
1✔
620
            # set first bs in the list as ref for rescaling the distances of the other bands
621
            bs_ref = self._bs[0] if len(self._bs) > 1 and ibs > 0 else None
1✔
622

623
            if smooth[ibs]:
1✔
624
                # interpolation works good on short segments like branches
625
                data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=True)
1✔
626
            else:
627
                data = self.bs_plot_data(zero_to_efermi, bs, bs_ref, split_branches=False)
1✔
628

629
            # remember if one bs is a metal for setting the ylim later
630
            one_is_metal = False
1✔
631
            if not one_is_metal and data["is_metal"]:
1✔
632
                one_is_metal = data["is_metal"]
1✔
633

634
            # remember all the cbm and vbm for setting the ylim later
635
            if not data["is_metal"]:
1✔
636
                cbm_max.append(data["cbm"][0][1])
1✔
637
                vbm_min.append(data["vbm"][0][1])
1✔
638
            else:
639
                cbm_max.append(bs.efermi)
1✔
640
                vbm_min.append(bs.efermi)
1✔
641

642
            for sp in bs.bands:
1✔
643
                ls = "-" if str(sp) == "1" else "--"
1✔
644

645
                if bs_labels is None:
1✔
646
                    bs_label = f"Band {ibs} {sp.name}"
1✔
647
                else:
648
                    # assume bs_labels is Sequence[str]
649
                    bs_label = f"{bs_labels[ibs]} {sp.name}"
×
650

651
                handles.append(mlines.Line2D([], [], lw=2, ls=ls, color=colors[ibs], label=bs_label))
1✔
652

653
                distances, energies = data["distances"], data["energy"][str(sp)]
1✔
654

655
                if smooth[ibs]:
1✔
656
                    distances, energies = self._interpolate_bands(
1✔
657
                        distances,
658
                        energies,
659
                        smooth_tol=smooth_tol,
660
                        smooth_k=smooth_k,
661
                        smooth_np=smooth_np,
662
                    )
663
                    # join all branches together
664
                    distances = np.hstack(distances)
1✔
665
                    energies = np.hstack(energies)
1✔
666
                    # split only discontinuous branches
667
                    steps = self._get_branch_steps(bs.branches)[1:-1]
1✔
668
                    distances = np.split(distances, steps)
1✔
669
                    energies = np.hsplit(energies, steps)
1✔
670

671
                for dist, ene in zip(distances, energies):
1✔
672
                    plt.plot(dist, ene.T, c=colors[ibs], ls=ls)
1✔
673

674
            # plot markers for vbm and cbm
675
            if vbm_cbm_marker:
1✔
676
                for cbm in data["cbm"]:
1✔
677
                    plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
1✔
678
                for vbm in data["vbm"]:
1✔
679
                    plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
1✔
680

681
            # Draw Fermi energy, only if not the zero
682
            if not zero_to_efermi:
1✔
683
                ef = bs.efermi
1✔
684
                plt.axhline(ef, lw=2, ls="-.", color=colors[ibs])
1✔
685

686
        # defaults for ylim
687
        e_min = -4
1✔
688
        e_max = 4
1✔
689
        if one_is_metal:
1✔
690
            e_min = -10
1✔
691
            e_max = 10
1✔
692

693
        if ylim is None:
1✔
694
            if zero_to_efermi:
1✔
695
                if one_is_metal:
1✔
696
                    # Plot A Metal
697
                    plt.ylim(e_min, e_max)
1✔
698
                else:
699
                    plt.ylim(e_min, max(cbm_max) + e_max)
1✔
700
            else:
701
                all_efermi = [b.efermi for b in self._bs]
1✔
702
                ll = min([min(vbm_min), min(all_efermi)])
1✔
703
                hh = max([max(cbm_max), max(all_efermi)])
1✔
704
                plt.ylim(ll + e_min, hh + e_max)
1✔
705
        else:
706
            plt.ylim(ylim)
×
707

708
        self._maketicks(plt)
1✔
709

710
        # Main X and Y Labels
711
        plt.xlabel(r"$\mathrm{Wave\ Vector}$", fontsize=30)
1✔
712
        ylabel = r"$\mathrm{E\ -\ E_f\ (eV)}$" if zero_to_efermi else r"$\mathrm{Energy\ (eV)}$"
1✔
713
        plt.ylabel(ylabel, fontsize=30)
1✔
714

715
        # X range (K)
716
        # last distance point
717
        x_max = data["distances"][-1][-1]
1✔
718
        plt.xlim(0, x_max)
1✔
719

720
        plt.legend(handles=handles)
1✔
721

722
        plt.tight_layout()
1✔
723

724
        # auto tight_layout when resizing or pressing t
725
        def fix_layout(event):
1✔
726
            if (event.name == "key_press_event" and event.key == "t") or event.name == "resize_event":
×
727
                plt.gcf().tight_layout()
×
728
                plt.gcf().canvas.draw()
×
729

730
        plt.gcf().canvas.mpl_connect("key_press_event", fix_layout)
1✔
731
        plt.gcf().canvas.mpl_connect("resize_event", fix_layout)
1✔
732

733
        return plt
1✔
734

735
    def show(self, zero_to_efermi=True, ylim=None, smooth=False, smooth_tol=None):
1✔
736
        """
737
        Show the plot using matplotlib.
738

739
        Args:
740
            zero_to_efermi: Automatically subtract off the Fermi energy from
741
                the eigenvalues and plot (E-Ef).
742
            ylim: Specify the y-axis (energy) limits; by default None let
743
                the code choose. It is vbm-4 and cbm+4 if insulator
744
                efermi-10 and efermi+10 if metal
745
            smooth: interpolates the bands by a spline cubic
746
            smooth_tol (float) : tolerance for fitting spline to band data.
747
                Default is None such that no tolerance will be used.
748
        """
749
        plt = self.get_plot(zero_to_efermi, ylim, smooth)
×
750
        plt.show()
×
751

752
    def save_plot(self, filename, img_format="eps", ylim=None, zero_to_efermi=True, smooth=False):
1✔
753
        """
754
        Save matplotlib plot to a file.
755

756
        Args:
757
            filename: Filename to write to.
758
            img_format: Image format to use. Defaults to EPS.
759
            ylim: Specifies the y-axis limits.
760
            zero_to_efermi: Automatically the Fermi level as the origin.
761
            smooth: Cubic spline interpolation of the bands.
762
        """
763
        plt = self.get_plot(ylim=ylim, zero_to_efermi=zero_to_efermi, smooth=smooth)
1✔
764
        plt.savefig(filename, format=img_format)
1✔
765
        plt.close()
1✔
766

767
    def get_ticks(self):
1✔
768
        """
769
        Get all ticks and labels for a band structure plot.
770

771
        Returns:
772
            dict: A dictionary with 'distance': a list of distance at which
773
            ticks should be set and 'label': a list of label for each of those
774
            ticks.
775
        """
776
        bs = self._bs[0] if isinstance(self._bs, list) else self._bs
1✔
777
        ticks, distance = [], []
1✔
778
        for br in bs.branches:
1✔
779
            s, e = br["start_index"], br["end_index"]
1✔
780

781
            labels = br["name"].split("-")
1✔
782

783
            # skip those branches with only one point
784
            if labels[0] == labels[1]:
1✔
785
                continue
×
786

787
            # add latex $$
788
            for i, l in enumerate(labels):
1✔
789
                if l.startswith("\\") or "_" in l:
1✔
790
                    labels[i] = "$" + l + "$"
1✔
791

792
            # If next branch is not continuous,
793
            # join the first lbl to the previous tick label
794
            # and add the second lbl to ticks list
795
            # otherwise add to ticks list both new labels.
796
            # Similar for distances.
797
            if ticks and labels[0] != ticks[-1]:
1✔
798
                ticks[-1] += "$\\mid$" + labels[0]
1✔
799
                ticks.append(labels[1])
1✔
800
                distance.append(bs.distance[e])
1✔
801
            else:
802
                ticks.extend(labels)
1✔
803
                distance.extend([bs.distance[s], bs.distance[e]])
1✔
804

805
        return {"distance": distance, "label": ticks}
1✔
806

807
    def get_ticks_old(self):
1✔
808
        """
809
        Get all ticks and labels for a band structure plot.
810

811
        Returns:
812
            dict: A dictionary with 'distance': a list of distance at which
813
            ticks should be set and 'label': a list of label for each of those
814
            ticks.
815
        """
816
        bs = self._bs[0]
×
817
        tick_distance = []
×
818
        tick_labels = []
×
819
        previous_label = bs.kpoints[0].label
×
820
        previous_branch = bs.branches[0]["name"]
×
821
        for i, c in enumerate(bs.kpoints):
×
822
            if c.label is not None:
×
823
                tick_distance.append(bs.distance[i])
×
824
                this_branch = None
×
825
                for b in bs.branches:
×
826
                    if b["start_index"] <= i <= b["end_index"]:
×
827
                        this_branch = b["name"]
×
828
                        break
×
829
                if c.label != previous_label and previous_branch != this_branch:
×
830
                    label1 = c.label
×
831
                    if label1.startswith("\\") or label1.find("_") != -1:
×
832
                        label1 = "$" + label1 + "$"
×
833
                    label0 = previous_label
×
834
                    if label0.startswith("\\") or label0.find("_") != -1:
×
835
                        label0 = "$" + label0 + "$"
×
836
                    tick_labels.pop()
×
837
                    tick_distance.pop()
×
838
                    tick_labels.append(label0 + "$\\mid$" + label1)
×
839
                else:
840
                    if c.label.startswith("\\") or c.label.find("_") != -1:
×
841
                        tick_labels.append("$" + c.label + "$")
×
842
                    else:
843
                        tick_labels.append(c.label)
×
844
                previous_label = c.label
×
845
                previous_branch = this_branch
×
846
        return {"distance": tick_distance, "label": tick_labels}
×
847

848
    def plot_compare(self, other_plotter, legend=True):
1✔
849
        """
850
        Plot two band structure for comparison. One is in red the other in blue
851
        (no difference in spins). The two band structures need to be defined
852
        on the same symmetry lines! and the distance between symmetry lines is
853
        the one of the band structure used to build the BSPlotter
854

855
        Args:
856
            other_plotter: Another band structure object defined along the same symmetry lines
857
            legend: True to add a legend to the plot
858

859
        Returns:
860
            a matplotlib object with both band structures
861
        """
862
        warnings.warn("Deprecated method. Use BSPlotter([sbs1,sbs2,...]).get_plot() instead.")
×
863

864
        # TODO: add exception if the band structures are not compatible
865
        import matplotlib.lines as mlines
×
866

867
        plt = self.get_plot()
×
868
        data_orig = self.bs_plot_data()
×
869
        data = other_plotter.bs_plot_data()
×
870
        band_linewidth = 1
×
871
        for i in range(other_plotter._nb_bands):
×
872
            for d in range(len(data_orig["distances"])):
×
873
                plt.plot(
×
874
                    data_orig["distances"][d],
875
                    [e[str(Spin.up)][i] for e in data["energy"]][d],
876
                    "c-",
877
                    linewidth=band_linewidth,
878
                )
879
                if other_plotter._bs.is_spin_polarized:
×
880
                    plt.plot(
×
881
                        data_orig["distances"][d],
882
                        [e[str(Spin.down)][i] for e in data["energy"]][d],
883
                        "m--",
884
                        linewidth=band_linewidth,
885
                    )
886
        if legend:
×
887
            handles = [
×
888
                mlines.Line2D([], [], linewidth=2, color="b", label="bs 1 up"),
889
                mlines.Line2D([], [], linewidth=2, color="r", label="bs 1 down", linestyle="--"),
890
                mlines.Line2D([], [], linewidth=2, color="c", label="bs 2 up"),
891
                mlines.Line2D([], [], linewidth=2, color="m", linestyle="--", label="bs 2 down"),
892
            ]
893

894
            plt.legend(handles=handles)
×
895
        return plt
×
896

897
    def plot_brillouin(self):
1✔
898
        """Plot the Brillouin zone"""
899
        # get labels and lines
900
        labels = {}
×
901
        for k in self._bs[0].kpoints:
×
902
            if k.label:
×
903
                labels[k.label] = k.frac_coords
×
904

905
        lines = []
×
906
        for b in self._bs[0].branches:
×
907
            lines.append(
×
908
                [
909
                    self._bs[0].kpoints[b["start_index"]].frac_coords,
910
                    self._bs[0].kpoints[b["end_index"]].frac_coords,
911
                ]
912
            )
913

914
        plot_brillouin_zone(self._bs[0].lattice_rec, lines=lines, labels=labels)
×
915

916

917
class BSPlotterProjected(BSPlotter):
1✔
918
    """
919
    Class to plot or get data to facilitate the plot of band structure objects
920
    projected along orbitals, elements or sites.
921
    """
922

923
    def __init__(self, bs):
1✔
924
        """
925
        Args:
926
            bs: A BandStructureSymmLine object with projections.
927
        """
928
        if isinstance(bs, list):
1✔
929
            warnings.warn(
×
930
                "Multiple bands are not handled by BSPlotterProjected. The first band in the list will be considered"
931
            )
932
            bs = bs[0]
×
933

934
        if len(bs.projections) == 0:
1✔
935
            raise ValueError("try to plot projections on a band structure without any")
×
936

937
        self._bs = bs
1✔
938
        self._nb_bands = bs.nb_bands
1✔
939

940
    def _get_projections_by_branches(self, dictio):
1✔
941
        proj = self._bs.get_projections_on_elements_and_orbitals(dictio)
1✔
942
        proj_br = []
1✔
943
        for b in self._bs.branches:
1✔
944
            if self._bs.is_spin_polarized:
1✔
945
                proj_br.append(
1✔
946
                    {
947
                        str(Spin.up): [[] for l in range(self._nb_bands)],
948
                        str(Spin.down): [[] for l in range(self._nb_bands)],
949
                    }
950
                )
951
            else:
952
                proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
1✔
953

954
            for i in range(self._nb_bands):
1✔
955
                for j in range(b["start_index"], b["end_index"] + 1):
1✔
956
                    proj_br[-1][str(Spin.up)][i].append(
1✔
957
                        {e: {o: proj[Spin.up][i][j][e][o] for o in proj[Spin.up][i][j][e]} for e in proj[Spin.up][i][j]}
958
                    )
959
            if self._bs.is_spin_polarized:
1✔
960
                for b in self._bs.branches:
1✔
961
                    for i in range(self._nb_bands):
1✔
962
                        for j in range(b["start_index"], b["end_index"] + 1):
1✔
963
                            proj_br[-1][str(Spin.down)][i].append(
1✔
964
                                {
965
                                    e: {o: proj[Spin.down][i][j][e][o] for o in proj[Spin.down][i][j][e]}
966
                                    for e in proj[Spin.down][i][j]
967
                                }
968
                            )
969
        return proj_br
1✔
970

971
    def get_projected_plots_dots(self, dictio, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False):
1✔
972
        """
973
        Method returning a plot composed of subplots along different elements
974
        and orbitals.
975

976
        Args:
977
            dictio: The element and orbitals you want a projection on. The
978
                format is {Element:[Orbitals]} for instance
979
                {'Cu':['d','s'],'O':['p']} will give projections for Cu on
980
                d and s orbitals and on oxygen p.
981
                If you use this class to plot LobsterBandStructureSymmLine,
982
                the orbitals are named as in the FATBAND filename, e.g.
983
                "2p" or "2p_x"
984

985
        Returns:
986
            a pylab object with different subfigures for each projection
987
            The blue and red colors are for spin up and spin down.
988
            The bigger the red or blue dot in the band structure the higher
989
            character for the corresponding element and orbital.
990
        """
991
        band_linewidth = 1.0
1✔
992
        fig_cols = len(dictio) * 100
1✔
993
        fig_rows = max(len(v) for v in dictio.values()) * 10
1✔
994
        proj = self._get_projections_by_branches(dictio)
1✔
995
        data = self.bs_plot_data(zero_to_efermi)
1✔
996
        plt = pretty_plot(12, 8)
1✔
997
        e_min = -4
1✔
998
        e_max = 4
1✔
999
        if self._bs.is_metal():
1✔
1000
            e_min = -10
×
1001
            e_max = 10
×
1002
        count = 1
1✔
1003

1004
        for el in dictio:
1✔
1005
            for o in dictio[el]:
1✔
1006
                plt.subplot(fig_rows + fig_cols + count)
1✔
1007
                self._maketicks(plt)
1✔
1008
                for b in range(len(data["distances"])):
1✔
1009
                    for i in range(self._nb_bands):
1✔
1010
                        plt.plot(
1✔
1011
                            data["distances"][b],
1012
                            data["energy"][str(Spin.up)][b][i],
1013
                            "b-",
1014
                            linewidth=band_linewidth,
1015
                        )
1016
                        if self._bs.is_spin_polarized:
1✔
1017
                            plt.plot(
1✔
1018
                                data["distances"][b],
1019
                                data["energy"][str(Spin.down)][b][i],
1020
                                "r--",
1021
                                linewidth=band_linewidth,
1022
                            )
1023
                            for j in range(len(data["energy"][str(Spin.up)][b][i])):
1✔
1024
                                plt.plot(
1✔
1025
                                    data["distances"][b][j],
1026
                                    data["energy"][str(Spin.down)][b][i][j],
1027
                                    "ro",
1028
                                    markersize=proj[b][str(Spin.down)][i][j][str(el)][o] * 15.0,
1029
                                )
1030
                        for j in range(len(data["energy"][str(Spin.up)][b][i])):
1✔
1031
                            plt.plot(
1✔
1032
                                data["distances"][b][j],
1033
                                data["energy"][str(Spin.up)][b][i][j],
1034
                                "bo",
1035
                                markersize=proj[b][str(Spin.up)][i][j][str(el)][o] * 15.0,
1036
                            )
1037
                if ylim is None:
1✔
1038
                    if self._bs.is_metal():
1✔
1039
                        if zero_to_efermi:
×
1040
                            plt.ylim(e_min, e_max)
×
1041
                        else:
1042
                            plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
×
1043
                    else:
1044
                        if vbm_cbm_marker:
1✔
1045
                            for cbm in data["cbm"]:
×
1046
                                plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
×
1047

1048
                            for vbm in data["vbm"]:
×
1049
                                plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
×
1050

1051
                        plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1✔
1052
                else:
1053
                    plt.ylim(ylim)
×
1054
                plt.title(str(el) + " " + str(o))
1✔
1055
                count += 1
1✔
1056
        return plt
1✔
1057

1058
    def get_elt_projected_plots(self, zero_to_efermi=True, ylim=None, vbm_cbm_marker=False):
1✔
1059
        """
1060
        Method returning a plot composed of subplots along different elements
1061

1062
        Returns:
1063
            a pylab object with different subfigures for each projection
1064
            The blue and red colors are for spin up and spin down
1065
            The bigger the red or blue dot in the band structure the higher
1066
            character for the corresponding element and orbital
1067
        """
1068
        band_linewidth = 1.0
1✔
1069
        proj = self._get_projections_by_branches(
1✔
1070
            {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements}
1071
        )
1072
        data = self.bs_plot_data(zero_to_efermi)
1✔
1073
        plt = pretty_plot(12, 8)
1✔
1074
        e_min = -4
1✔
1075
        e_max = 4
1✔
1076
        if self._bs.is_metal():
1✔
1077
            e_min = -10
×
1078
            e_max = 10
×
1079
        count = 1
1✔
1080
        for el in self._bs.structure.composition.elements:
1✔
1081
            plt.subplot(220 + count)
1✔
1082
            self._maketicks(plt)
1✔
1083
            for b in range(len(data["distances"])):
1✔
1084
                for i in range(self._nb_bands):
1✔
1085
                    plt.plot(
1✔
1086
                        data["distances"][b],
1087
                        data["energy"][str(Spin.up)][b][i],
1088
                        "-",
1089
                        color=[192 / 255, 192 / 255, 192 / 255],
1090
                        linewidth=band_linewidth,
1091
                    )
1092
                    if self._bs.is_spin_polarized:
1✔
1093
                        plt.plot(
1✔
1094
                            data["distances"][b],
1095
                            data["energy"][str(Spin.down)][b][i],
1096
                            "--",
1097
                            color=[128 / 255, 128 / 255, 128 / 255],
1098
                            linewidth=band_linewidth,
1099
                        )
1100
                        for j in range(len(data["energy"][str(Spin.up)][b][i])):
1✔
1101
                            markerscale = sum(
1✔
1102
                                proj[b][str(Spin.down)][i][j][str(el)][o]
1103
                                for o in proj[b][str(Spin.down)][i][j][str(el)]
1104
                            )
1105
                            plt.plot(
1✔
1106
                                data["distances"][b][j],
1107
                                data["energy"][str(Spin.down)][b][i][j],
1108
                                "bo",
1109
                                markersize=markerscale * 15.0,
1110
                                color=[
1111
                                    markerscale,
1112
                                    0.3 * markerscale,
1113
                                    0.4 * markerscale,
1114
                                ],
1115
                            )
1116
                    for j in range(len(data["energy"][str(Spin.up)][b][i])):
1✔
1117
                        markerscale = sum(
1✔
1118
                            proj[b][str(Spin.up)][i][j][str(el)][o] for o in proj[b][str(Spin.up)][i][j][str(el)]
1119
                        )
1120
                        plt.plot(
1✔
1121
                            data["distances"][b][j],
1122
                            data["energy"][str(Spin.up)][b][i][j],
1123
                            "o",
1124
                            markersize=markerscale * 15.0,
1125
                            color=[markerscale, 0.3 * markerscale, 0.4 * markerscale],
1126
                        )
1127
            if ylim is None:
1✔
1128
                if self._bs.is_metal():
1✔
1129
                    if zero_to_efermi:
×
1130
                        plt.ylim(e_min, e_max)
×
1131
                    else:
1132
                        plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
×
1133
                else:
1134
                    if vbm_cbm_marker:
1✔
1135
                        for cbm in data["cbm"]:
×
1136
                            plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
×
1137

1138
                        for vbm in data["vbm"]:
×
1139
                            plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
×
1140

1141
                    plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1✔
1142
            else:
1143
                plt.ylim(ylim)
×
1144
            plt.title(str(el))
1✔
1145
            count += 1
1✔
1146

1147
        return plt
1✔
1148

1149
    def get_elt_projected_plots_color(self, zero_to_efermi=True, elt_ordered=None):
1✔
1150
        """
1151
        Returns a pylab plot object with one plot where the band structure
1152
        line color depends on the character of the band (along different
1153
        elements). Each element is associated with red, green or blue
1154
        and the corresponding rgb color depending on the character of the band
1155
        is used. The method can only deal with binary and ternary compounds
1156

1157
        spin up and spin down are differientiated by a '-' and a '--' line
1158

1159
        Args:
1160
            elt_ordered: A list of Element ordered. The first one is red,
1161
                second green, last blue
1162

1163
        Returns:
1164
            a pylab object
1165
        """
1166
        band_linewidth = 3.0
1✔
1167
        if len(self._bs.structure.composition.elements) > 3:
1✔
1168
            raise ValueError
×
1169
        if elt_ordered is None:
1✔
1170
            elt_ordered = self._bs.structure.composition.elements
1✔
1171
        proj = self._get_projections_by_branches(
1✔
1172
            {e.symbol: ["s", "p", "d"] for e in self._bs.structure.composition.elements}
1173
        )
1174
        data = self.bs_plot_data(zero_to_efermi)
1✔
1175
        plt = pretty_plot(12, 8)
1✔
1176

1177
        spins = [Spin.up]
1✔
1178
        if self._bs.is_spin_polarized:
1✔
1179
            spins = [Spin.up, Spin.down]
×
1180
        self._maketicks(plt)
1✔
1181
        for s in spins:
1✔
1182
            for b in range(len(data["distances"])):
1✔
1183
                for i in range(self._nb_bands):
1✔
1184
                    for j in range(len(data["energy"][str(s)][b][i]) - 1):
1✔
1185
                        sum_e = 0.0
1✔
1186
                        for el in elt_ordered:
1✔
1187
                            sum_e = sum_e + sum(
1✔
1188
                                proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]
1189
                            )
1190
                        if sum_e == 0.0:
1✔
1191
                            color = [0.0] * len(elt_ordered)
×
1192
                        else:
1193
                            color = [
1✔
1194
                                sum(proj[b][str(s)][i][j][str(el)][o] for o in proj[b][str(s)][i][j][str(el)]) / sum_e
1195
                                for el in elt_ordered
1196
                            ]
1197
                        if len(color) == 2:
1✔
1198
                            color.append(0.0)
1✔
1199
                            color[2] = color[1]
1✔
1200
                            color[1] = 0.0
1✔
1201
                        sign = "-"
1✔
1202
                        if s == Spin.down:
1✔
1203
                            sign = "--"
×
1204
                        plt.plot(
1✔
1205
                            [data["distances"][b][j], data["distances"][b][j + 1]],
1206
                            [
1207
                                data["energy"][str(s)][b][i][j],
1208
                                data["energy"][str(s)][b][i][j + 1],
1209
                            ],
1210
                            sign,
1211
                            color=color,
1212
                            linewidth=band_linewidth,
1213
                        )
1214

1215
        if self._bs.is_metal():
1✔
1216
            if zero_to_efermi:
×
1217
                e_min = -10
×
1218
                e_max = 10
×
1219
                plt.ylim(e_min, e_max)
×
1220
                plt.ylim(self._bs.efermi + e_min, self._bs.efermi + e_max)
×
1221
        else:
1222
            plt.ylim(data["vbm"][0][1] - 4.0, data["cbm"][0][1] + 2.0)
1✔
1223
        return plt
1✔
1224

1225
    def _get_projections_by_branches_patom_pmorb(self, dictio, dictpa, sum_atoms, sum_morbs, selected_branches):
1✔
1226
        import copy
1✔
1227

1228
        setos = {
1✔
1229
            "s": 0,
1230
            "py": 1,
1231
            "pz": 2,
1232
            "px": 3,
1233
            "dxy": 4,
1234
            "dyz": 5,
1235
            "dz2": 6,
1236
            "dxz": 7,
1237
            "dx2": 8,
1238
            "f_3": 9,
1239
            "f_2": 10,
1240
            "f_1": 11,
1241
            "f0": 12,
1242
            "f1": 13,
1243
            "f2": 14,
1244
            "f3": 15,
1245
        }
1246

1247
        num_branches = len(self._bs.branches)
1✔
1248
        if selected_branches is not None:
1✔
1249
            indices = []
×
1250
            if not isinstance(selected_branches, list):
×
1251
                raise TypeError("You do not give a correct type of 'selected_branches'. It should be 'list' type.")
×
1252
            if len(selected_branches) == 0:
×
1253
                raise ValueError("The 'selected_branches' is empty. We cannot do anything.")
×
1254
            for index in selected_branches:
×
1255
                if not isinstance(index, int):
×
1256
                    raise ValueError(
×
1257
                        "You do not give a correct type of index of symmetry lines. It should be 'int' type"
1258
                    )
1259
                if index > num_branches or index < 1:
×
1260
                    raise ValueError(
×
1261
                        f"You give a incorrect index of symmetry lines: {index}. The index should be in range of "
1262
                        f"[1, {num_branches}]."
1263
                    )
1264
                indices.append(index - 1)
×
1265
        else:
1266
            indices = range(0, num_branches)
1✔
1267

1268
        proj = self._bs.projections
1✔
1269
        proj_br = []
1✔
1270
        for index in indices:
1✔
1271
            b = self._bs.branches[index]
1✔
1272
            print(b)
1✔
1273
            if self._bs.is_spin_polarized:
1✔
1274
                proj_br.append(
×
1275
                    {
1276
                        str(Spin.up): [[] for l in range(self._nb_bands)],
1277
                        str(Spin.down): [[] for l in range(self._nb_bands)],
1278
                    }
1279
                )
1280
            else:
1281
                proj_br.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
1✔
1282

1283
            for i in range(self._nb_bands):
1✔
1284
                for j in range(b["start_index"], b["end_index"] + 1):
1✔
1285
                    edict = {}
1✔
1286
                    for elt in dictpa:
1✔
1287
                        for anum in dictpa[elt]:
1✔
1288
                            edict[elt + str(anum)] = {}
1✔
1289
                            for morb in dictio[elt]:
1✔
1290
                                edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1]
1✔
1291
                    proj_br[-1][str(Spin.up)][i].append(edict)
1✔
1292

1293
            if self._bs.is_spin_polarized:
1✔
1294
                for i in range(self._nb_bands):
×
1295
                    for j in range(b["start_index"], b["end_index"] + 1):
×
1296
                        edict = {}
×
1297
                        for elt in dictpa:
×
1298
                            for anum in dictpa[elt]:
×
1299
                                edict[elt + str(anum)] = {}
×
1300
                                for morb in dictio[elt]:
×
1301
                                    edict[elt + str(anum)][morb] = proj[Spin.up][i][j][setos[morb]][anum - 1]
×
1302
                        proj_br[-1][str(Spin.down)][i].append(edict)
×
1303

1304
        # Adjusting  projections for plot
1305
        dictio_d, dictpa_d = self._summarize_keys_for_plot(dictio, dictpa, sum_atoms, sum_morbs)
1✔
1306
        print(f"dictio_d: {str(dictio_d)}")
1✔
1307
        print(f"dictpa_d: {str(dictpa_d)}")
1✔
1308

1309
        if (sum_atoms is None) and (sum_morbs is None):
1✔
1310
            proj_br_d = copy.deepcopy(proj_br)
1✔
1311
        else:
1312
            proj_br_d = []
×
1313
            branch = -1
×
1314
            for index in indices:
×
1315
                branch += 1
×
1316
                br = self._bs.branches[index]
×
1317
                if self._bs.is_spin_polarized:
×
1318
                    proj_br_d.append(
×
1319
                        {
1320
                            str(Spin.up): [[] for l in range(self._nb_bands)],
1321
                            str(Spin.down): [[] for l in range(self._nb_bands)],
1322
                        }
1323
                    )
1324
                else:
1325
                    proj_br_d.append({str(Spin.up): [[] for l in range(self._nb_bands)]})
×
1326

1327
                if (sum_atoms is not None) and (sum_morbs is None):
×
1328
                    for i in range(self._nb_bands):
×
1329
                        for j in range(br["end_index"] - br["start_index"] + 1):
×
1330
                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
×
1331
                            edict = {}
×
1332
                            for elt in dictpa:
×
1333
                                if elt in sum_atoms:
×
1334
                                    for anum in dictpa_d[elt][:-1]:
×
1335
                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1336
                                    edict[elt + dictpa_d[elt][-1]] = {}
×
1337
                                    for morb in dictio[elt]:
×
1338
                                        sprojection = 0.0
×
1339
                                        for anum in sum_atoms[elt]:
×
1340
                                            sprojection += atoms_morbs[elt + str(anum)][morb]
×
1341
                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1342
                                else:
1343
                                    for anum in dictpa_d[elt]:
×
1344
                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1345
                            proj_br_d[-1][str(Spin.up)][i].append(edict)
×
1346
                    if self._bs.is_spin_polarized:
×
1347
                        for i in range(self._nb_bands):
×
1348
                            for j in range(br["end_index"] - br["start_index"] + 1):
×
1349
                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
×
1350
                                edict = {}
×
1351
                                for elt in dictpa:
×
1352
                                    if elt in sum_atoms:
×
1353
                                        for anum in dictpa_d[elt][:-1]:
×
1354
                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1355
                                        edict[elt + dictpa_d[elt][-1]] = {}
×
1356
                                        for morb in dictio[elt]:
×
1357
                                            sprojection = 0.0
×
1358
                                            for anum in sum_atoms[elt]:
×
1359
                                                sprojection += atoms_morbs[elt + str(anum)][morb]
×
1360
                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1361
                                    else:
1362
                                        for anum in dictpa_d[elt]:
×
1363
                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1364
                                proj_br_d[-1][str(Spin.down)][i].append(edict)
×
1365

1366
                elif (sum_atoms is None) and (sum_morbs is not None):
×
1367
                    for i in range(self._nb_bands):
×
1368
                        for j in range(br["end_index"] - br["start_index"] + 1):
×
1369
                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
×
1370
                            edict = {}
×
1371
                            for elt in dictpa:
×
1372
                                if elt in sum_morbs:
×
1373
                                    for anum in dictpa_d[elt]:
×
1374
                                        edict[elt + anum] = {}
×
1375
                                        for morb in dictio_d[elt][:-1]:
×
1376
                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1377
                                        sprojection = 0.0
×
1378
                                        for morb in sum_morbs[elt]:
×
1379
                                            sprojection += atoms_morbs[elt + anum][morb]
×
1380
                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1381
                                else:
1382
                                    for anum in dictpa_d[elt]:
×
1383
                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1384
                            proj_br_d[-1][str(Spin.up)][i].append(edict)
×
1385
                    if self._bs.is_spin_polarized:
×
1386
                        for i in range(self._nb_bands):
×
1387
                            for j in range(br["end_index"] - br["start_index"] + 1):
×
1388
                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
×
1389
                                edict = {}
×
1390
                                for elt in dictpa:
×
1391
                                    if elt in sum_morbs:
×
1392
                                        for anum in dictpa_d[elt]:
×
1393
                                            edict[elt + anum] = {}
×
1394
                                            for morb in dictio_d[elt][:-1]:
×
1395
                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1396
                                            sprojection = 0.0
×
1397
                                            for morb in sum_morbs[elt]:
×
1398
                                                sprojection += atoms_morbs[elt + anum][morb]
×
1399
                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1400
                                    else:
1401
                                        for anum in dictpa_d[elt]:
×
1402
                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1403
                                proj_br_d[-1][str(Spin.down)][i].append(edict)
×
1404

1405
                else:
1406
                    for i in range(self._nb_bands):
×
1407
                        for j in range(br["end_index"] - br["start_index"] + 1):
×
1408
                            atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.up)][i][j])
×
1409
                            edict = {}
×
1410
                            for elt in dictpa:
×
1411
                                if (elt in sum_atoms) and (elt in sum_morbs):
×
1412
                                    for anum in dictpa_d[elt][:-1]:
×
1413
                                        edict[elt + anum] = {}
×
1414
                                        for morb in dictio_d[elt][:-1]:
×
1415
                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1416
                                        sprojection = 0.0
×
1417
                                        for morb in sum_morbs[elt]:
×
1418
                                            sprojection += atoms_morbs[elt + anum][morb]
×
1419
                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1420

1421
                                    edict[elt + dictpa_d[elt][-1]] = {}
×
1422
                                    for morb in dictio_d[elt][:-1]:
×
1423
                                        sprojection = 0.0
×
1424
                                        for anum in sum_atoms[elt]:
×
1425
                                            sprojection += atoms_morbs[elt + str(anum)][morb]
×
1426
                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1427

1428
                                    sprojection = 0.0
×
1429
                                    for anum in sum_atoms[elt]:
×
1430
                                        for morb in sum_morbs[elt]:
×
1431
                                            sprojection += atoms_morbs[elt + str(anum)][morb]
×
1432
                                    edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection
×
1433

1434
                                elif (elt in sum_atoms) and (elt not in sum_morbs):
×
1435
                                    for anum in dictpa_d[elt][:-1]:
×
1436
                                        edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1437
                                    edict[elt + dictpa_d[elt][-1]] = {}
×
1438
                                    for morb in dictio[elt]:
×
1439
                                        sprojection = 0.0
×
1440
                                        for anum in sum_atoms[elt]:
×
1441
                                            sprojection += atoms_morbs[elt + str(anum)][morb]
×
1442
                                        edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1443

1444
                                elif (elt not in sum_atoms) and (elt in sum_morbs):
×
1445
                                    for anum in dictpa_d[elt]:
×
1446
                                        edict[elt + anum] = {}
×
1447
                                        for morb in dictio_d[elt][:-1]:
×
1448
                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1449
                                        sprojection = 0.0
×
1450
                                        for morb in sum_morbs[elt]:
×
1451
                                            sprojection += atoms_morbs[elt + anum][morb]
×
1452
                                        edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1453

1454
                                else:
1455
                                    for anum in dictpa_d[elt]:
×
1456
                                        edict[elt + anum] = {}
×
1457
                                        for morb in dictio_d[elt]:
×
1458
                                            edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1459
                            proj_br_d[-1][str(Spin.up)][i].append(edict)
×
1460

1461
                    if self._bs.is_spin_polarized:
×
1462
                        for i in range(self._nb_bands):
×
1463
                            for j in range(br["end_index"] - br["start_index"] + 1):
×
1464
                                atoms_morbs = copy.deepcopy(proj_br[branch][str(Spin.down)][i][j])
×
1465
                                edict = {}
×
1466
                                for elt in dictpa:
×
1467
                                    if (elt in sum_atoms) and (elt in sum_morbs):
×
1468
                                        for anum in dictpa_d[elt][:-1]:
×
1469
                                            edict[elt + anum] = {}
×
1470
                                            for morb in dictio_d[elt][:-1]:
×
1471
                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1472
                                            sprojection = 0.0
×
1473
                                            for morb in sum_morbs[elt]:
×
1474
                                                sprojection += atoms_morbs[elt + anum][morb]
×
1475
                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1476

1477
                                        edict[elt + dictpa_d[elt][-1]] = {}
×
1478
                                        for morb in dictio_d[elt][:-1]:
×
1479
                                            sprojection = 0.0
×
1480
                                            for anum in sum_atoms[elt]:
×
1481
                                                sprojection += atoms_morbs[elt + str(anum)][morb]
×
1482
                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1483

1484
                                        sprojection = 0.0
×
1485
                                        for anum in sum_atoms[elt]:
×
1486
                                            for morb in sum_morbs[elt]:
×
1487
                                                sprojection += atoms_morbs[elt + str(anum)][morb]
×
1488
                                        edict[elt + dictpa_d[elt][-1]][dictio_d[elt][-1]] = sprojection
×
1489

1490
                                    elif (elt in sum_atoms) and (elt not in sum_morbs):
×
1491
                                        for anum in dictpa_d[elt][:-1]:
×
1492
                                            edict[elt + anum] = copy.deepcopy(atoms_morbs[elt + anum])
×
1493
                                        edict[elt + dictpa_d[elt][-1]] = {}
×
1494
                                        for morb in dictio[elt]:
×
1495
                                            sprojection = 0.0
×
1496
                                            for anum in sum_atoms[elt]:
×
1497
                                                sprojection += atoms_morbs[elt + str(anum)][morb]
×
1498
                                            edict[elt + dictpa_d[elt][-1]][morb] = sprojection
×
1499

1500
                                    elif (elt not in sum_atoms) and (elt in sum_morbs):
×
1501
                                        for anum in dictpa_d[elt]:
×
1502
                                            edict[elt + anum] = {}
×
1503
                                            for morb in dictio_d[elt][:-1]:
×
1504
                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1505
                                            sprojection = 0.0
×
1506
                                            for morb in sum_morbs[elt]:
×
1507
                                                sprojection += atoms_morbs[elt + anum][morb]
×
1508
                                            edict[elt + anum][dictio_d[elt][-1]] = sprojection
×
1509

1510
                                    else:
1511
                                        for anum in dictpa_d[elt]:
×
1512
                                            edict[elt + anum] = {}
×
1513
                                            for morb in dictio_d[elt]:
×
1514
                                                edict[elt + anum][morb] = atoms_morbs[elt + anum][morb]
×
1515
                                proj_br_d[-1][str(Spin.down)][i].append(edict)
×
1516

1517
        return proj_br_d, dictio_d, dictpa_d, indices
1✔
1518

1519
    def get_projected_plots_dots_patom_pmorb(
1✔
1520
        self,
1521
        dictio,
1522
        dictpa,
1523
        sum_atoms=None,
1524
        sum_morbs=None,
1525
        zero_to_efermi=True,
1526
        ylim=None,
1527
        vbm_cbm_marker=False,
1528
        selected_branches=None,
1529
        w_h_size=(12, 8),
1530
        num_column=None,
1531
    ):
1532
        """
1533
        Method returns a plot composed of subplots for different atoms and
1534
        orbitals (subshell orbitals such as 's', 'p', 'd' and 'f' defined by
1535
        azimuthal quantum numbers l = 0, 1, 2 and 3, respectively or
1536
        individual orbitals like 'px', 'py' and 'pz' defined by magnetic
1537
        quantum numbers m = -1, 1 and 0, respectively).
1538
        This is an extension of "get_projected_plots_dots" method.
1539

1540
        Args:
1541
            dictio: The elements and the orbitals you need to project on. The
1542
                format is {Element:[Orbitals]}, for instance:
1543
                {'Cu':['dxy','s','px'],'O':['px','py','pz']} will give
1544
                projections for Cu on orbitals dxy, s, px and
1545
                for O on orbitals px, py, pz. If you want to sum over all
1546
                individual orbitals of subshell orbitals,
1547
                for example, 'px', 'py' and 'pz' of O, just simply set
1548
                {'Cu':['dxy','s','px'],'O':['p']} and set sum_morbs (see
1549
                explanations below) as {'O':[p],...}.
1550
                Otherwise, you will get an error.
1551
            dictpa: The elements and their sites (defined by site numbers) you
1552
                need to project on. The format is
1553
                {Element: [Site numbers]}, for instance: {'Cu':[1,5],'O':[3,4]}
1554
                will give projections for Cu on site-1
1555
                and on site-5, O on site-3 and on site-4 in the cell.
1556
                Attention:
1557
                The correct site numbers of atoms are consistent with
1558
                themselves in the structure computed. Normally,
1559
                the structure should be totally similar with POSCAR file,
1560
                however, sometimes VASP can rotate or
1561
                translate the cell. Thus, it would be safe if using Vasprun
1562
                class to get the final_structure and as a
1563
                result, correct index numbers of atoms.
1564
            sum_atoms: Sum projection of the similar atoms together (e.g.: Cu
1565
                on site-1 and Cu on site-5). The format is
1566
                {Element: [Site numbers]}, for instance:
1567
                 {'Cu': [1,5], 'O': [3,4]} means summing projections over Cu on
1568
                 site-1 and Cu on site-5 and O on site-3
1569
                 and on site-4. If you do not want to use this functional, just
1570
                 turn it off by setting sum_atoms = None.
1571
            sum_morbs: Sum projections of individual orbitals of similar atoms
1572
                together (e.g.: 'dxy' and 'dxz'). The
1573
                format is {Element: [individual orbitals]}, for instance:
1574
                {'Cu': ['dxy', 'dxz'], 'O': ['px', 'py']} means summing
1575
                projections over 'dxy' and 'dxz' of Cu and 'px'
1576
                and 'py' of O. If you do not want to use this functional, just
1577
                turn it off by setting sum_morbs = None.
1578
            selected_branches: The index of symmetry lines you chose for
1579
                plotting. This can be useful when the number of
1580
                symmetry lines (in KPOINTS file) are manny while you only want
1581
                to show for certain ones. The format is
1582
                [index of line], for instance:
1583
                [1, 3, 4] means you just need to do projection along lines
1584
                number 1, 3 and 4 while neglecting lines
1585
                number 2 and so on. By default, this is None type and all
1586
                symmetry lines will be plotted.
1587
            w_h_size: This variable help you to control the width and height
1588
                of figure. By default, width = 12 and
1589
                height = 8 (inches). The width/height ratio is kept the same
1590
                for subfigures and the size of each depends
1591
                on how many number of subfigures are plotted.
1592
            num_column: This variable help you to manage how the subfigures are
1593
                arranged in the figure by setting
1594
                up the number of columns of subfigures. The value should be an
1595
                int number. For example, num_column = 3
1596
                means you want to plot subfigures in 3 columns. By default,
1597
                num_column = None and subfigures are
1598
                aligned in 2 columns.
1599

1600
        Returns:
1601
            A pylab object with different subfigures for different projections.
1602
            The blue and red colors lines are bands
1603
            for spin up and spin down. The green and cyan dots are projections
1604
            for spin up and spin down. The bigger
1605
            the green or cyan dots in the projected band structures, the higher
1606
            character for the corresponding elements
1607
            and orbitals. List of individual orbitals and their numbers (set up
1608
            by VASP and no special meaning):
1609
            s = 0; py = 1 pz = 2 px = 3; dxy = 4 dyz = 5 dz2 = 6 dxz = 7 dx2 = 8;
1610
            f_3 = 9 f_2 = 10 f_1 = 11 f0 = 12 f1 = 13 f2 = 14 f3 = 15
1611
        """
1612
        dictio, sum_morbs = self._Orbitals_SumOrbitals(dictio, sum_morbs)
1✔
1613
        dictpa, sum_atoms, number_figs = self._number_of_subfigures(dictio, dictpa, sum_atoms, sum_morbs)
1✔
1614
        print(f"Number of subfigures: {number_figs}")
1✔
1615
        if number_figs > 9:
1✔
1616
            print(
×
1617
                f"The number of sub-figures {number_figs} might be too manny and the implementation might take a long "
1618
                f"time.\n A smaller number or a plot with selected symmetry lines (selected_branches) might be better."
1619
            )
1620
        from pymatgen.util.plotting import pretty_plot
1✔
1621

1622
        band_linewidth = 0.5
1✔
1623
        plt = pretty_plot(w_h_size[0], w_h_size[1])
1✔
1624
        (
1✔
1625
            proj_br_d,
1626
            dictio_d,
1627
            dictpa_d,
1628
            branches,
1629
        ) = self._get_projections_by_branches_patom_pmorb(dictio, dictpa, sum_atoms, sum_morbs, selected_branches)
1630
        data = self.bs_plot_data(zero_to_efermi)
1✔
1631
        e_min = -4
1✔
1632
        e_max = 4
1✔
1633
        if self._bs.is_metal():
1✔
1634
            e_min = -10
×
1635
            e_max = 10
×
1636

1637
        count = 0
1✔
1638
        for elt in dictpa_d:
1✔
1639
            for numa in dictpa_d[elt]:
1✔
1640
                for o in dictio_d[elt]:
1✔
1641
                    count += 1
1✔
1642
                    if num_column is None:
1✔
1643
                        if number_figs == 1:
1✔
1644
                            plt.subplot(1, 1, 1)
×
1645
                        else:
1646
                            row = number_figs // 2
1✔
1647
                            if number_figs % 2 == 0:
1✔
1648
                                plt.subplot(row, 2, count)
×
1649
                            else:
1650
                                plt.subplot(row + 1, 2, count)
1✔
1651
                    elif isinstance(num_column, int):
×
1652
                        row = number_figs / num_column
×
1653
                        if number_figs % num_column == 0:
×
1654
                            plt.subplot(row, num_column, count)
×
1655
                        else:
1656
                            plt.subplot(row + 1, num_column, count)
×
1657
                    else:
1658
                        raise ValueError("The invalid 'num_column' is assigned. It should be an integer.")
×
1659

1660
                    plt, shift = self._maketicks_selected(plt, branches)
1✔
1661
                    br = -1
1✔
1662
                    for b in branches:
1✔
1663
                        br += 1
1✔
1664
                        for i in range(self._nb_bands):
1✔
1665
                            plt.plot(
1✔
1666
                                list(map(lambda x: x - shift[br], data["distances"][b])),
1667
                                [data["energy"][str(Spin.up)][b][i][j] for j in range(len(data["distances"][b]))],
1668
                                "b-",
1669
                                linewidth=band_linewidth,
1670
                            )
1671

1672
                            if self._bs.is_spin_polarized:
1✔
1673
                                plt.plot(
×
1674
                                    list(
1675
                                        map(
1676
                                            lambda x: x - shift[br],
1677
                                            data["distances"][b],
1678
                                        )
1679
                                    ),
1680
                                    [data["energy"][str(Spin.down)][b][i][j] for j in range(len(data["distances"][b]))],
1681
                                    "r--",
1682
                                    linewidth=band_linewidth,
1683
                                )
1684
                                for j in range(len(data["energy"][str(Spin.up)][b][i])):
×
1685
                                    plt.plot(
×
1686
                                        data["distances"][b][j] - shift[br],
1687
                                        data["energy"][str(Spin.down)][b][i][j],
1688
                                        "co",
1689
                                        markersize=proj_br_d[br][str(Spin.down)][i][j][elt + numa][o] * 15.0,
1690
                                    )
1691

1692
                            for j in range(len(data["energy"][str(Spin.up)][b][i])):
1✔
1693
                                plt.plot(
1✔
1694
                                    data["distances"][b][j] - shift[br],
1695
                                    data["energy"][str(Spin.up)][b][i][j],
1696
                                    "go",
1697
                                    markersize=proj_br_d[br][str(Spin.up)][i][j][elt + numa][o] * 15.0,
1698
                                )
1699

1700
                    if ylim is None:
1✔
1701
                        if self._bs.is_metal():
1✔
1702
                            if zero_to_efermi:
×
1703
                                plt.ylim(e_min, e_max)
×
1704
                            else:
1705
                                plt.ylim(self._bs.efermi + e_min, self._bs._efermi + e_max)
×
1706
                        else:
1707
                            if vbm_cbm_marker:
1✔
1708
                                for cbm in data["cbm"]:
×
1709
                                    plt.scatter(cbm[0], cbm[1], color="r", marker="o", s=100)
×
1710

1711
                                for vbm in data["vbm"]:
×
1712
                                    plt.scatter(vbm[0], vbm[1], color="g", marker="o", s=100)
×
1713

1714
                            plt.ylim(data["vbm"][0][1] + e_min, data["cbm"][0][1] + e_max)
1✔
1715
                    else:
1716
                        plt.ylim(ylim)
×
1717
                    plt.title(elt + " " + numa + " " + str(o))
1✔
1718

1719
        return plt
1✔
1720

1721
    @classmethod
1✔
1722
    def _Orbitals_SumOrbitals(cls, dictio, sum_morbs):
1✔
1723
        all_orbitals = [
1✔
1724
            "s",
1725
            "p",
1726
            "d",
1727
            "f",
1728
            "px",
1729
            "py",
1730
            "pz",
1731
            "dxy",
1732
            "dyz",
1733
            "dxz",
1734
            "dx2",
1735
            "dz2",
1736
            "f_3",
1737
            "f_2",
1738
            "f_1",
1739
            "f0",
1740
            "f1",
1741
            "f2",
1742
            "f3",
1743
        ]
1744
        individual_orbs = {
1✔
1745
            "p": ["px", "py", "pz"],
1746
            "d": ["dxy", "dyz", "dxz", "dx2", "dz2"],
1747
            "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"],
1748
        }
1749

1750
        if not isinstance(dictio, dict):
1✔
1751
            raise TypeError("The invalid type of 'dictio' was bound. It should be dict type.")
×
1752
        if len(dictio) == 0:
1✔
1753
            raise KeyError("The 'dictio' is empty. We cannot do anything.")
×
1754

1755
        for elt in dictio:
1✔
1756
            if Element.is_valid_symbol(elt):
1✔
1757
                if isinstance(dictio[elt], list):
1✔
1758
                    if len(dictio[elt]) == 0:
1✔
1759
                        raise ValueError(f"The dictio[{elt}] is empty. We cannot do anything")
×
1760
                    for orb in dictio[elt]:
1✔
1761
                        if not isinstance(orb, str):
1✔
1762
                            raise ValueError(
×
1763
                                f"The invalid format of orbitals is in 'dictio[{elt}]': {orb}. They should be string."
1764
                            )
1765
                        if orb not in all_orbitals:
1✔
1766
                            raise ValueError(f"The invalid name of orbital is given in 'dictio[{elt}]'.")
×
1767
                        if orb in individual_orbs:
1✔
1768
                            if len(set(dictio[elt]).intersection(individual_orbs[orb])) != 0:
×
1769
                                raise ValueError(f"The 'dictio[{elt}]' contains orbitals repeated.")
×
1770
                    nelems = Counter(dictio[elt]).values()
1✔
1771
                    if sum(nelems) > len(nelems):
1✔
1772
                        raise ValueError(f"You put in at least two similar orbitals in dictio[{elt}].")
×
1773
                else:
1774
                    raise TypeError(f"The invalid type of value was put into 'dictio[{elt}]'. It should be list type.")
×
1775
            else:
1776
                raise KeyError(f"The invalid element was put into 'dictio' as a key: {elt}")
×
1777

1778
        if sum_morbs is None:
1✔
1779
            print("You do not want to sum projection over orbitals.")
1✔
1780
        elif not isinstance(sum_morbs, dict):
×
1781
            raise TypeError("The invalid type of 'sum_orbs' was bound. It should be dict or 'None' type.")
×
1782
        elif len(sum_morbs) == 0:
×
1783
            raise KeyError("The 'sum_morbs' is empty. We cannot do anything")
×
1784
        else:
1785
            for elt in sum_morbs:
×
1786
                if Element.is_valid_symbol(elt):
×
1787
                    if isinstance(sum_morbs[elt], list):
×
1788
                        for orb in sum_morbs[elt]:
×
1789
                            if not isinstance(orb, str):
×
1790
                                raise TypeError(
×
1791
                                    f"The invalid format of orbitals is in 'sum_morbs[{elt}]': {orb}. "
1792
                                    "They should be string."
1793
                                )
1794
                            if orb not in all_orbitals:
×
1795
                                raise ValueError(f"The invalid name of orbital in 'sum_morbs[{elt}]' is given.")
×
1796
                            if orb in individual_orbs:
×
1797
                                if len(set(sum_morbs[elt]).intersection(individual_orbs[orb])) != 0:
×
1798
                                    raise ValueError(f"The 'sum_morbs[{elt}]' contains orbitals repeated.")
×
1799
                        nelems = Counter(sum_morbs[elt]).values()
×
1800
                        if sum(nelems) > len(nelems):
×
1801
                            raise ValueError(f"You put in at least two similar orbitals in sum_morbs[{elt}].")
×
1802
                    else:
1803
                        raise TypeError(
×
1804
                            f"The invalid type of value was put into 'sum_morbs[{elt}]'. It should be list type."
1805
                        )
1806
                    if elt not in dictio:
×
1807
                        raise ValueError(
×
1808
                            f"You cannot sum projection over orbitals of atoms {elt!r} because they are not "
1809
                            "mentioned in 'dictio'."
1810
                        )
1811
                else:
1812
                    raise KeyError(f"The invalid element was put into 'sum_morbs' as a key: {elt}")
×
1813

1814
        for elt in dictio:
1✔
1815
            if len(dictio[elt]) == 1:
1✔
1816
                if len(dictio[elt][0]) > 1:
×
1817
                    if elt in sum_morbs:
×
1818
                        raise ValueError(
×
1819
                            f"You cannot sum projection over one individual orbital {dictio[elt][0]!r} of {elt!r}."
1820
                        )
1821
                else:
1822
                    if sum_morbs is None:
×
1823
                        pass
×
1824
                    elif elt not in sum_morbs:
×
1825
                        print(f"You do not want to sum projection over orbitals of element: {elt}")
×
1826
                    else:
1827
                        if len(sum_morbs[elt]) == 0:
×
1828
                            raise ValueError(f"The empty list is an invalid value for sum_morbs[{elt}].")
×
1829
                        if len(sum_morbs[elt]) > 1:
×
1830
                            for orb in sum_morbs[elt]:
×
1831
                                if dictio[elt][0] not in orb:
×
1832
                                    raise ValueError(f"The invalid orbital {orb!r} was put into 'sum_morbs[{elt}]'.")
×
1833
                        else:
1834
                            if orb == "s" or len(orb) > 1:
×
1835
                                raise ValueError(f"The invalid orbital {orb!r} was put into sum_orbs[{elt!r}].")
×
1836
                            sum_morbs[elt] = individual_orbs[dictio[elt][0]]
×
1837
                            dictio[elt] = individual_orbs[dictio[elt][0]]
×
1838
            else:
1839
                duplicate = copy.deepcopy(dictio[elt])
1✔
1840
                for orb in dictio[elt]:
1✔
1841
                    if orb in individual_orbs:
1✔
1842
                        duplicate.remove(orb)
×
1843
                        for o in individual_orbs[orb]:
×
1844
                            duplicate.append(o)
×
1845
                dictio[elt] = copy.deepcopy(duplicate)
1✔
1846

1847
                if sum_morbs is None:
1✔
1848
                    pass
1✔
1849
                elif elt not in sum_morbs:
×
1850
                    print(f"You do not want to sum projection over orbitals of element: {elt}")
×
1851
                else:
1852
                    if len(sum_morbs[elt]) == 0:
×
1853
                        raise ValueError(f"The empty list is an invalid value for sum_morbs[{elt}].")
×
1854
                    if len(sum_morbs[elt]) == 1:
×
1855
                        orb = sum_morbs[elt][0]
×
1856
                        if orb == "s":
×
1857
                            raise ValueError(
×
1858
                                "We do not sum projection over only 's' orbital of the same type of element."
1859
                            )
1860
                        if orb in individual_orbs:
×
1861
                            sum_morbs[elt].pop(0)
×
1862
                            for o in individual_orbs[orb]:
×
1863
                                sum_morbs[elt].append(o)
×
1864
                        else:
1865
                            raise ValueError(f"You never sum projection over one orbital in sum_morbs[{elt}]")
×
1866
                    else:
1867
                        duplicate = copy.deepcopy(sum_morbs[elt])
×
1868
                        for orb in sum_morbs[elt]:
×
1869
                            if orb in individual_orbs:
×
1870
                                duplicate.remove(orb)
×
1871
                                for o in individual_orbs[orb]:
×
1872
                                    duplicate.append(o)
×
1873
                        sum_morbs[elt] = copy.deepcopy(duplicate)
×
1874

1875
                    for orb in sum_morbs[elt]:
×
1876
                        if orb not in dictio[elt]:
×
1877
                            raise ValueError(f"The orbitals of sum_morbs[{elt}] conflict with those of dictio[{elt}].")
×
1878

1879
        return dictio, sum_morbs
1✔
1880

1881
    def _number_of_subfigures(self, dictio, dictpa, sum_atoms, sum_morbs):
1✔
1882
        from collections import Counter
1✔
1883

1884
        from pymatgen.core.periodic_table import Element
1✔
1885

1886
        if not isinstance(dictpa, dict):
1✔
1887
            raise TypeError("The invalid type of 'dictpa' was bound. It should be dict type.")
×
1888
        if len(dictpa) == 0:
1✔
1889
            raise KeyError("The 'dictpa' is empty. We cannot do anything.")
×
1890
        for elt in dictpa:
1✔
1891
            if Element.is_valid_symbol(elt):
1✔
1892
                if isinstance(dictpa[elt], list):
1✔
1893
                    if len(dictpa[elt]) == 0:
1✔
1894
                        raise ValueError(f"The dictpa[{elt}] is empty. We cannot do anything")
×
1895
                    _sites = self._bs.structure.sites
1✔
1896
                    indices = []
1✔
1897
                    for i in range(0, len(_sites)):  # pylint: disable=C0200
1✔
1898
                        if list(_sites[i]._species)[0] == Element(elt):
1✔
1899
                            indices.append(i + 1)
1✔
1900
                    for number in dictpa[elt]:
1✔
1901
                        if isinstance(number, str):
1✔
1902
                            if number.lower() == "all":
×
1903
                                dictpa[elt] = indices
×
1904
                                print(f"You want to consider all {elt!r} atoms.")
×
1905
                                break
×
1906

1907
                            raise ValueError(f"You put wrong site numbers in 'dictpa[{elt}]': {number}.")
×
1908
                        if isinstance(number, int):
1✔
1909
                            if number not in indices:
1✔
1910
                                raise ValueError(f"You put wrong site numbers in 'dictpa[{elt}]': {number}.")
×
1911
                        else:
1912
                            raise ValueError(f"You put wrong site numbers in 'dictpa[{elt}]': {number}.")
×
1913
                    nelems = Counter(dictpa[elt]).values()
1✔
1914
                    if sum(nelems) > len(nelems):
1✔
1915
                        raise ValueError(f"You put at least two similar site numbers into 'dictpa[{elt}]'.")
×
1916
                else:
1917
                    raise TypeError(f"The invalid type of value was put into 'dictpa[{elt}]'. It should be list type.")
×
1918
            else:
1919
                raise KeyError(f"The invalid element was put into 'dictpa' as a key: {elt}")
×
1920

1921
        if len(list(dictio)) != len(list(dictpa)):
1✔
1922
            raise KeyError("The number of keys in 'dictio' and 'dictpa' are not the same.")
×
1923
        for elt in dictio:
1✔
1924
            if elt not in dictpa:
1✔
1925
                raise KeyError(f"The element {elt!r} is not in both dictpa and dictio.")
×
1926
        for elt in dictpa:
1✔
1927
            if elt not in dictio:
1✔
1928
                raise KeyError(f"The element {elt!r} in not in both dictpa and dictio.")
×
1929

1930
        if sum_atoms is None:
1✔
1931
            print("You do not want to sum projection over atoms.")
1✔
1932
        elif not isinstance(sum_atoms, dict):
×
1933
            raise TypeError("The invalid type of 'sum_atoms' was bound. It should be dict type.")
×
1934
        elif len(sum_atoms) == 0:
×
1935
            raise KeyError("The 'sum_atoms' is empty. We cannot do anything.")
×
1936
        else:
1937
            for elt in sum_atoms:
×
1938
                if Element.is_valid_symbol(elt):
×
1939
                    if isinstance(sum_atoms[elt], list):
×
1940
                        if len(sum_atoms[elt]) == 0:
×
1941
                            raise ValueError(f"The sum_atoms[{elt}] is empty. We cannot do anything")
×
1942
                        _sites = self._bs.structure.sites
×
1943
                        indices = []
×
1944
                        for i in range(0, len(_sites)):  # pylint: disable=C0200
×
1945
                            if list(_sites[i]._species)[0] == Element(elt):
×
1946
                                indices.append(i + 1)
×
1947
                        for number in sum_atoms[elt]:
×
1948
                            if isinstance(number, str):
×
1949
                                if number.lower() == "all":
×
1950
                                    sum_atoms[elt] = indices
×
1951
                                    print(f"You want to sum projection over all {elt!r} atoms.")
×
1952
                                    break
×
1953
                                raise ValueError(f"You put wrong site numbers in 'sum_atoms[{elt}]'.")
×
1954
                            if isinstance(number, int):
×
1955
                                if number not in indices:
×
1956
                                    raise ValueError(f"You put wrong site numbers in 'sum_atoms[{elt}]'.")
×
1957
                                if number not in dictpa[elt]:
×
1958
                                    raise ValueError(
×
1959
                                        f"You cannot sum projection with atom number {number!r} because it is not "
1960
                                        f"mentioned in dicpta[{elt}]"
1961
                                    )
1962
                            else:
1963
                                raise ValueError(f"You put wrong site numbers in 'sum_atoms[{elt}]'.")
×
1964
                        nelems = Counter(sum_atoms[elt]).values()
×
1965
                        if sum(nelems) > len(nelems):
×
1966
                            raise ValueError(f"You put at least two similar site numbers into 'sum_atoms[{elt}]'.")
×
1967
                    else:
1968
                        raise TypeError(
×
1969
                            f"The invalid type of value was put into 'sum_atoms[{elt}]'. It should be list type."
1970
                        )
1971
                    if elt not in dictpa:
×
1972
                        raise ValueError(
×
1973
                            f"You cannot sum projection over atoms {elt!r} because it is not mentioned in 'dictio'."
1974
                        )
1975
                else:
1976
                    raise KeyError(f"The invalid element was put into 'sum_atoms' as a key: {elt}")
×
1977
                if len(sum_atoms[elt]) == 1:
×
1978
                    raise ValueError(f"We do not sum projection over only one atom: {elt}")
×
1979

1980
        max_number_figs = 0
1✔
1981
        decrease = 0
1✔
1982
        for elt in dictio:
1✔
1983
            max_number_figs += len(dictio[elt]) * len(dictpa[elt])
1✔
1984

1985
        if (sum_atoms is None) and (sum_morbs is None):
1✔
1986
            number_figs = max_number_figs
1✔
1987
        elif (sum_atoms is not None) and (sum_morbs is None):
×
1988
            for elt in sum_atoms:
×
1989
                decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt])
×
1990
            number_figs = max_number_figs - decrease
×
1991
        elif (sum_atoms is None) and (sum_morbs is not None):
×
1992
            for elt in sum_morbs:
×
1993
                decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt])
×
1994
            number_figs = max_number_figs - decrease
×
1995
        elif (sum_atoms is not None) and (sum_morbs is not None):
×
1996
            for elt in sum_atoms:
×
1997
                decrease += (len(sum_atoms[elt]) - 1) * len(dictio[elt])
×
1998
            for elt in sum_morbs:
×
1999
                if elt in sum_atoms:
×
2000
                    decrease += (len(sum_morbs[elt]) - 1) * (len(dictpa[elt]) - len(sum_atoms[elt]) + 1)
×
2001
                else:
2002
                    decrease += (len(sum_morbs[elt]) - 1) * len(dictpa[elt])
×
2003
            number_figs = max_number_figs - decrease
×
2004
        else:
2005
            raise ValueError("Invalid format of 'sum_atoms' and 'sum_morbs'.")
×
2006

2007
        return dictpa, sum_atoms, number_figs
1✔
2008

2009
    def _summarize_keys_for_plot(self, dictio, dictpa, sum_atoms, sum_morbs):
1✔
2010
        from pymatgen.core.periodic_table import Element
1✔
2011

2012
        individual_orbs = {
1✔
2013
            "p": ["px", "py", "pz"],
2014
            "d": ["dxy", "dyz", "dxz", "dx2", "dz2"],
2015
            "f": ["f_3", "f_2", "f_1", "f0", "f1", "f2", "f3"],
2016
        }
2017

2018
        def number_label(list_numbers):
1✔
2019
            list_numbers = sorted(list_numbers)
×
2020
            divide = [[]]
×
2021
            divide[0].append(list_numbers[0])
×
2022
            group = 0
×
2023
            for i in range(1, len(list_numbers)):
×
2024
                if list_numbers[i] == list_numbers[i - 1] + 1:
×
2025
                    divide[group].append(list_numbers[i])
×
2026
                else:
2027
                    group += 1
×
2028
                    divide.append([list_numbers[i]])
×
2029
            label = ""
×
2030
            for elem in divide:
×
2031
                if len(elem) > 1:
×
2032
                    label += str(elem[0]) + "-" + str(elem[-1]) + ","
×
2033
                else:
2034
                    label += str(elem[0]) + ","
×
2035
            return label[:-1]
×
2036

2037
        def orbital_label(list_orbitals):
1✔
2038
            divide = {}
×
2039
            for orb in list_orbitals:
×
2040
                if orb[0] in divide:
×
2041
                    divide[orb[0]].append(orb)
×
2042
                else:
2043
                    divide[orb[0]] = []
×
2044
                    divide[orb[0]].append(orb)
×
2045
            label = ""
×
2046
            for elem, v in divide.items():
×
2047
                if elem == "s":
×
2048
                    label += "s,"
×
2049
                else:
2050
                    if len(v) == len(individual_orbs[elem]):
×
2051
                        label += elem + ","
×
2052
                    else:
2053
                        l = [o[1:] for o in v]
×
2054
                        label += elem + str(l).replace("['", "").replace("']", "").replace("', '", "-") + ","
×
2055
            return label[:-1]
×
2056

2057
        if (sum_atoms is None) and (sum_morbs is None):
1✔
2058
            dictio_d = dictio
1✔
2059
            dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa}
1✔
2060

2061
        elif (sum_atoms is not None) and (sum_morbs is None):
×
2062
            dictio_d = dictio
×
2063
            dictpa_d = {}
×
2064
            for elt in dictpa:
×
2065
                dictpa_d[elt] = []
×
2066
                if elt in sum_atoms:
×
2067
                    _sites = self._bs.structure.sites
×
2068
                    indices = []
×
2069
                    for i in range(0, len(_sites)):  # pylint: disable=C0200
×
2070
                        if list(_sites[i]._species)[0] == Element(elt):
×
2071
                            indices.append(i + 1)
×
2072
                    flag_1 = len(set(dictpa[elt]).intersection(indices))
×
2073
                    flag_2 = len(set(sum_atoms[elt]).intersection(indices))
×
2074
                    if flag_1 == len(indices) and flag_2 == len(indices):
×
2075
                        dictpa_d[elt].append("all")
×
2076
                    else:
2077
                        for anum in dictpa[elt]:
×
2078
                            if anum not in sum_atoms[elt]:
×
2079
                                dictpa_d[elt].append(str(anum))
×
2080
                        label = number_label(sum_atoms[elt])
×
2081
                        dictpa_d[elt].append(label)
×
2082
                else:
2083
                    for anum in dictpa[elt]:
×
2084
                        dictpa_d[elt].append(str(anum))
×
2085

2086
        elif (sum_atoms is None) and (sum_morbs is not None):
×
2087
            dictio_d = {}
×
2088
            for elt in dictio:
×
2089
                dictio_d[elt] = []
×
2090
                if elt in sum_morbs:
×
2091
                    for morb in dictio[elt]:
×
2092
                        if morb not in sum_morbs[elt]:
×
2093
                            dictio_d[elt].append(morb)
×
2094
                    label = orbital_label(sum_morbs[elt])
×
2095
                    dictio_d[elt].append(label)
×
2096
                else:
2097
                    dictio_d[elt] = dictio[elt]
×
2098
            dictpa_d = {elt: [str(anum) for anum in dictpa[elt]] for elt in dictpa}
×
2099

2100
        else:
2101
            dictio_d = {}
×
2102
            for elt in dictio:
×
2103
                dictio_d[elt] = []
×
2104
                if elt in sum_morbs:
×
2105
                    for morb in dictio[elt]:
×
2106
                        if morb not in sum_morbs[elt]:
×
2107
                            dictio_d[elt].append(morb)
×
2108
                    label = orbital_label(sum_morbs[elt])
×
2109
                    dictio_d[elt].append(label)
×
2110
                else:
2111
                    dictio_d[elt] = dictio[elt]
×
2112
            dictpa_d = {}
×
2113
            for elt in dictpa:
×
2114
                dictpa_d[elt] = []
×
2115
                if elt in sum_atoms:
×
2116
                    _sites = self._bs.structure.sites
×
2117
                    indices = []
×
2118
                    for i in range(0, len(_sites)):  # pylint: disable=C0200
×
2119
                        if list(_sites[i]._species)[0] == Element(elt):
×
2120
                            indices.append(i + 1)
×
2121
                    flag_1 = len(set(dictpa[elt]).intersection(indices))
×
2122
                    flag_2 = len(set(sum_atoms[elt]).intersection(indices))
×
2123
                    if flag_1 == len(indices) and flag_2 == len(indices):
×
2124
                        dictpa_d[elt].append("all")
×
2125
                    else:
2126
                        for anum in dictpa[elt]:
×
2127
                            if anum not in sum_atoms[elt]:
×
2128
                                dictpa_d[elt].append(str(anum))
×
2129
                        label = number_label(sum_atoms[elt])
×
2130
                        dictpa_d[elt].append(label)
×
2131
                else:
2132
                    for anum in dictpa[elt]:
×
2133
                        dictpa_d[elt].append(str(anum))
×
2134

2135
        return dictio_d, dictpa_d
1✔
2136

2137
    def _maketicks_selected(self, plt, branches):
1✔
2138
        """
2139
        Utility private method to add ticks to a band structure with selected branches
2140
        """
2141
        ticks = self.get_ticks()
1✔
2142
        distance = []
1✔
2143
        label = []
1✔
2144
        rm_elems = []
1✔
2145
        for i in range(1, len(ticks["distance"])):
1✔
2146
            if ticks["label"][i] == ticks["label"][i - 1]:
1✔
2147
                rm_elems.append(i)
1✔
2148
        for i in range(len(ticks["distance"])):
1✔
2149
            if i not in rm_elems:
1✔
2150
                distance.append(ticks["distance"][i])
1✔
2151
                label.append(ticks["label"][i])
1✔
2152
        l_branches = [distance[i] - distance[i - 1] for i in range(1, len(distance))]
1✔
2153
        n_distance = []
1✔
2154
        n_label = []
1✔
2155
        for branch in branches:
1✔
2156
            n_distance.append(l_branches[branch])
1✔
2157
            if ("$\\mid$" not in label[branch]) and ("$\\mid$" not in label[branch + 1]):
1✔
2158
                n_label.append([label[branch], label[branch + 1]])
1✔
2159
            elif ("$\\mid$" in label[branch]) and ("$\\mid$" not in label[branch + 1]):
1✔
2160
                n_label.append([label[branch].split("$")[-1], label[branch + 1]])
1✔
2161
            elif ("$\\mid$" not in label[branch]) and ("$\\mid$" in label[branch + 1]):
1✔
2162
                n_label.append([label[branch], label[branch + 1].split("$")[0]])
1✔
2163
            else:
2164
                n_label.append([label[branch].split("$")[-1], label[branch + 1].split("$")[0]])
×
2165

2166
        f_distance = []
1✔
2167
        rf_distance = []
1✔
2168
        f_label = []
1✔
2169
        f_label.append(n_label[0][0])
1✔
2170
        f_label.append(n_label[0][1])
1✔
2171
        f_distance.append(0.0)
1✔
2172
        f_distance.append(n_distance[0])
1✔
2173
        rf_distance.append(0.0)
1✔
2174
        rf_distance.append(n_distance[0])
1✔
2175
        length = n_distance[0]
1✔
2176
        for i in range(1, len(n_distance)):
1✔
2177
            if n_label[i][0] == n_label[i - 1][1]:
1✔
2178
                f_distance.append(length)
1✔
2179
                f_distance.append(length + n_distance[i])
1✔
2180
                f_label.append(n_label[i][0])
1✔
2181
                f_label.append(n_label[i][1])
1✔
2182
            else:
2183
                f_distance.append(length + n_distance[i])
1✔
2184
                f_label[-1] = n_label[i - 1][1] + "$\\mid$" + n_label[i][0]
1✔
2185
                f_label.append(n_label[i][1])
1✔
2186
            rf_distance.append(length + n_distance[i])
1✔
2187
            length += n_distance[i]
1✔
2188

2189
        n_ticks = {"distance": f_distance, "label": f_label}
1✔
2190
        uniq_d = []
1✔
2191
        uniq_l = []
1✔
2192
        temp_ticks = list(zip(n_ticks["distance"], n_ticks["label"]))
1✔
2193
        for i, t in enumerate(temp_ticks):
1✔
2194
            if i == 0:
1✔
2195
                uniq_d.append(t[0])
1✔
2196
                uniq_l.append(t[1])
1✔
2197
                logger.debug(f"Adding label {t[0]} at {t[1]}")
1✔
2198
            else:
2199
                if t[1] == temp_ticks[i - 1][1]:
1✔
2200
                    logger.debug(f"Skipping label {t[1]}")
1✔
2201
                else:
2202
                    logger.debug(f"Adding label {t[0]} at {t[1]}")
1✔
2203
                    uniq_d.append(t[0])
1✔
2204
                    uniq_l.append(t[1])
1✔
2205

2206
        logger.debug(f"Unique labels are {list(zip(uniq_d, uniq_l))}")
1✔
2207
        plt.gca().set_xticks(uniq_d)
1✔
2208
        plt.gca().set_xticklabels(uniq_l)
1✔
2209

2210
        for i in range(len(n_ticks["label"])):
1✔
2211
            if n_ticks["label"][i] is not None:
1✔
2212
                # don't print the same label twice
2213
                if i != 0:
1✔
2214
                    if n_ticks["label"][i] == n_ticks["label"][i - 1]:
1✔
2215
                        logger.debug(f"already print label... skipping label {n_ticks['label'][i]}")
1✔
2216
                    else:
2217
                        logger.debug(f"Adding a line at {n_ticks['distance'][i]} for label {n_ticks['label'][i]}")
1✔
2218
                        plt.axvline(n_ticks["distance"][i], color="k")
1✔
2219
                else:
2220
                    logger.debug(f"Adding a line at {n_ticks['distance'][i]} for label {n_ticks['label'][i]}")
1✔
2221
                    plt.axvline(n_ticks["distance"][i], color="k")
1✔
2222

2223
        shift = []
1✔
2224
        br = -1
1✔
2225
        for branch in branches:
1✔
2226
            br += 1
1✔
2227
            shift.append(distance[branch] - rf_distance[br])
1✔
2228

2229
        return plt, shift
1✔
2230

2231

2232
class BSDOSPlotter:
1✔
2233
    """
2234
    A joint, aligned band structure and density of states plot. Contributions
2235
    from Jan Pohls as well as the online example from Germain Salvato-Vallverdu:
2236
    http://gvallver.perso.univ-pau.fr/?p=587
2237
    """
2238

2239
    def __init__(
1✔
2240
        self,
2241
        bs_projection: Literal["elements"] | None = "elements",
2242
        dos_projection: str = "elements",
2243
        vb_energy_range: float = 4,
2244
        cb_energy_range: float = 4,
2245
        fixed_cb_energy: bool = False,
2246
        egrid_interval: float = 1,
2247
        font: str = "Times New Roman",
2248
        axis_fontsize: float = 20,
2249
        tick_fontsize: float = 15,
2250
        legend_fontsize: float = 14,
2251
        bs_legend: str = "best",
2252
        dos_legend: str = "best",
2253
        rgb_legend: bool = True,
2254
        fig_size: tuple[float, float] = (11, 8.5),
2255
    ) -> None:
2256
        """
2257
        Instantiate plotter settings.
2258

2259
        Args:
2260
            bs_projection ('elements' | None): Whether to project the bands onto elements.
2261
            dos_projection (str): "elements", "orbitals", or None
2262
            vb_energy_range (float): energy in eV to show of valence bands
2263
            cb_energy_range (float): energy in eV to show of conduction bands
2264
            fixed_cb_energy (bool): If true, the cb_energy_range will be interpreted
2265
                as constant (i.e., no gap correction for cb energy)
2266
            egrid_interval (float): interval for grid marks
2267
            font (str): font family
2268
            axis_fontsize (float): font size for axis
2269
            tick_fontsize (float): font size for axis tick labels
2270
            legend_fontsize (float): font size for legends
2271
            bs_legend (str): matplotlib string location for legend or None
2272
            dos_legend (str): matplotlib string location for legend or None
2273
            rgb_legend (bool): (T/F) whether to draw RGB triangle/bar for element proj.
2274
            fig_size(tuple): dimensions of figure size (width, height)
2275
        """
2276
        self.bs_projection = bs_projection
1✔
2277
        self.dos_projection = dos_projection
1✔
2278
        self.vb_energy_range = vb_energy_range
1✔
2279
        self.cb_energy_range = cb_energy_range
1✔
2280
        self.fixed_cb_energy = fixed_cb_energy
1✔
2281
        self.egrid_interval = egrid_interval
1✔
2282
        self.font = font
1✔
2283
        self.axis_fontsize = axis_fontsize
1✔
2284
        self.tick_fontsize = tick_fontsize
1✔
2285
        self.legend_fontsize = legend_fontsize
1✔
2286
        self.bs_legend = bs_legend
1✔
2287
        self.dos_legend = dos_legend
1✔
2288
        self.rgb_legend = rgb_legend
1✔
2289
        self.fig_size = fig_size
1✔
2290

2291
    def get_plot(self, bs: BandStructureSymmLine, dos: Dos | CompleteDos | None = None):
1✔
2292
        """
2293
        Get a matplotlib plot object.
2294
        Args:
2295
            bs (BandStructureSymmLine): the bandstructure to plot. Projection
2296
                data must exist for projected plots.
2297
            dos (Dos): the Dos to plot. Projection data must exist (i.e.,
2298
                CompleteDos) for projected plots.
2299

2300
        Returns:
2301
            matplotlib.pyplot object on which you can call commands like show()
2302
            and savefig()
2303
        """
2304
        import matplotlib.lines as mlines
1✔
2305
        import matplotlib.pyplot as mplt
1✔
2306
        from matplotlib.gridspec import GridSpec
1✔
2307

2308
        # make sure the user-specified band structure projection is valid
2309
        bs_projection = self.bs_projection
1✔
2310
        if dos:
1✔
2311
            elements = [e.symbol for e in dos.structure.composition.elements]
1✔
2312
        elif bs_projection and bs.structure:
1✔
2313
            elements = [e.symbol for e in bs.structure.composition.elements]
1✔
2314
        else:
2315
            elements = []
×
2316

2317
        rgb_legend = (
1✔
2318
            self.rgb_legend and bs_projection and bs_projection.lower() == "elements" and len(elements) in [2, 3, 4]
2319
        )
2320

2321
        if (
1✔
2322
            bs_projection
2323
            and bs_projection.lower() == "elements"
2324
            and (len(elements) not in [2, 3, 4] or not bs.get_projection_on_elements())
2325
        ):
2326
            warnings.warn(
1✔
2327
                "Cannot get element projected data; either the projection data "
2328
                "doesn't exist, or you don't have a compound with exactly 2 "
2329
                "or 3 or 4 unique elements."
2330
            )
2331
            bs_projection = None
1✔
2332

2333
        # specify energy range of plot
2334
        emin = -self.vb_energy_range
1✔
2335
        emax = self.cb_energy_range if self.fixed_cb_energy else self.cb_energy_range + bs.get_band_gap()["energy"]
1✔
2336

2337
        # initialize all the k-point labels and k-point x-distances for bs plot
2338
        xlabels = []  # all symmetry point labels on x-axis
1✔
2339
        xlabel_distances = []  # positions of symmetry point x-labels
1✔
2340

2341
        x_distances_list = []
1✔
2342
        prev_right_klabel = None  # used to determine which branches require a midline separator
1✔
2343

2344
        for branch in bs.branches:
1✔
2345
            x_distances = []
1✔
2346

2347
            # get left and right kpoint labels of this branch
2348
            left_k, right_k = branch["name"].split("-")
1✔
2349

2350
            # add $ notation for LaTeX kpoint labels
2351
            if left_k[0] == "\\" or "_" in left_k:
1✔
2352
                left_k = "$" + left_k + "$"
1✔
2353
            if right_k[0] == "\\" or "_" in right_k:
1✔
2354
                right_k = "$" + right_k + "$"
1✔
2355

2356
            # add left k label to list of labels
2357
            if prev_right_klabel is None:
1✔
2358
                xlabels.append(left_k)
1✔
2359
                xlabel_distances.append(0)
1✔
2360
            elif prev_right_klabel != left_k:  # used for pipe separator
1✔
2361
                xlabels[-1] = xlabels[-1] + "$\\mid$ " + left_k
1✔
2362

2363
            # add right k label to list of labels
2364
            xlabels.append(right_k)
1✔
2365
            prev_right_klabel = right_k
1✔
2366

2367
            # add x-coordinates for labels
2368
            left_kpoint = bs.kpoints[branch["start_index"]].cart_coords
1✔
2369
            right_kpoint = bs.kpoints[branch["end_index"]].cart_coords
1✔
2370
            distance = np.linalg.norm(right_kpoint - left_kpoint)
1✔
2371
            xlabel_distances.append(xlabel_distances[-1] + distance)  # type: ignore
1✔
2372

2373
            # add x-coordinates for kpoint data
2374
            npts = branch["end_index"] - branch["start_index"]
1✔
2375
            distance_interval = distance / npts
1✔
2376
            x_distances.append(xlabel_distances[-2])
1✔
2377
            for _ in range(npts):
1✔
2378
                x_distances.append(x_distances[-1] + distance_interval)
1✔
2379
            x_distances_list.append(x_distances)
1✔
2380

2381
        # set up bs and dos plot
2382
        gs = GridSpec(1, 2, width_ratios=[2, 1]) if dos else GridSpec(1, 1)
1✔
2383

2384
        fig = mplt.figure(figsize=self.fig_size)
1✔
2385
        fig.patch.set_facecolor("white")
1✔
2386
        bs_ax = mplt.subplot(gs[0])
1✔
2387
        if dos:
1✔
2388
            dos_ax = mplt.subplot(gs[1])
1✔
2389

2390
        # set basic axes limits for the plot
2391
        bs_ax.set_xlim(0, x_distances_list[-1][-1])
1✔
2392
        bs_ax.set_ylim(emin, emax)
1✔
2393
        if dos:
1✔
2394
            dos_ax.set_ylim(emin, emax)
1✔
2395

2396
        # add BS xticks, labels, etc.
2397
        bs_ax.set_xticks(xlabel_distances)
1✔
2398
        bs_ax.set_xticklabels(xlabels, size=self.tick_fontsize)
1✔
2399
        bs_ax.set_xlabel("Wavevector $k$", fontsize=self.axis_fontsize, family=self.font)
1✔
2400
        bs_ax.set_ylabel("$E-E_F$ / eV", fontsize=self.axis_fontsize, family=self.font)
1✔
2401

2402
        # add BS fermi level line at E=0 and gridlines
2403
        bs_ax.hlines(y=0, xmin=0, xmax=x_distances_list[-1][-1], color="k", lw=2)
1✔
2404
        bs_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval))
1✔
2405
        bs_ax.set_yticklabels(np.arange(emin, emax + 1e-5, self.egrid_interval), size=self.tick_fontsize)
1✔
2406
        bs_ax.set_axisbelow(True)
1✔
2407
        bs_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1)
1✔
2408
        if dos:
1✔
2409
            dos_ax.set_yticks(np.arange(emin, emax + 1e-5, self.egrid_interval))
1✔
2410
            dos_ax.set_yticklabels([])
1✔
2411
            dos_ax.grid(color=[0.5, 0.5, 0.5], linestyle="dotted", linewidth=1)
1✔
2412

2413
        # renormalize the band energy to the Fermi level
2414
        band_energies: dict[Spin, list[float]] = {}
1✔
2415
        for spin in (Spin.up, Spin.down):
1✔
2416
            if spin in bs.bands:
1✔
2417
                band_energies[spin] = []
1✔
2418
                for band in bs.bands[spin]:
1✔
2419
                    band = cast(List[float], band)
1✔
2420
                    band_energies[spin].append([e - bs.efermi for e in band])  # type: ignore
1✔
2421

2422
        # renormalize the DOS energies to Fermi level
2423
        if dos:
1✔
2424
            dos_energies = [e - dos.efermi for e in dos.energies]
1✔
2425

2426
        # get the projection data to set colors for the band structure
2427
        colordata = self._get_colordata(bs, elements, bs_projection)
1✔
2428

2429
        # plot the colored band structure lines
2430
        for spin in (Spin.up, Spin.down):
1✔
2431
            if spin in band_energies:
1✔
2432
                linestyles = "solid" if spin == Spin.up else "dotted"
1✔
2433
                for band_idx, band in enumerate(band_energies[spin]):
1✔
2434
                    current_pos = 0
1✔
2435
                    for x_distances in x_distances_list:
1✔
2436
                        sub_band = band[current_pos : current_pos + len(x_distances)]
1✔
2437

2438
                        self._rgbline(
1✔
2439
                            bs_ax,
2440
                            x_distances,
2441
                            sub_band,
2442
                            colordata[spin][band_idx, :, 0][current_pos : current_pos + len(x_distances)],
2443
                            colordata[spin][band_idx, :, 1][current_pos : current_pos + len(x_distances)],
2444
                            colordata[spin][band_idx, :, 2][current_pos : current_pos + len(x_distances)],
2445
                            linestyles=linestyles,
2446
                        )
2447

2448
                        current_pos += len(x_distances)
1✔
2449

2450
        if dos:
1✔
2451
            # Plot the DOS and projected DOS
2452
            for spin in (Spin.up, Spin.down):
1✔
2453
                if spin in dos.densities:
1✔
2454
                    # plot the total DOS
2455
                    dos_densities = dos.densities[spin] * int(spin)
1✔
2456
                    label = "total" if spin == Spin.up else None
1✔
2457
                    dos_ax.plot(dos_densities, dos_energies, color=(0.6, 0.6, 0.6), label=label)
1✔
2458
                    dos_ax.fill_betweenx(
1✔
2459
                        dos_energies,
2460
                        0,
2461
                        dos_densities,
2462
                        color=(0.7, 0.7, 0.7),
2463
                        facecolor=(0.7, 0.7, 0.7),
2464
                    )
2465

2466
                    if self.dos_projection is None:
1✔
2467
                        pass
×
2468

2469
                    elif self.dos_projection.lower() == "elements":
1✔
2470
                        # plot the atom-projected DOS
2471
                        colors = ["b", "r", "g", "m", "y", "c", "k", "w"]
1✔
2472
                        el_dos = dos.get_element_dos()
1✔
2473
                        for idx, el in enumerate(elements):
1✔
2474
                            dos_densities = el_dos[Element(el)].densities[spin] * int(spin)
1✔
2475
                            label = el if spin == Spin.up else None
1✔
2476
                            dos_ax.plot(
1✔
2477
                                dos_densities,
2478
                                dos_energies,
2479
                                color=colors[idx],
2480
                                label=label,
2481
                            )
2482

2483
                    elif self.dos_projection.lower() == "orbitals":
×
2484
                        # plot each of the atomic projected DOS
2485
                        colors = ["b", "r", "g", "m"]
×
2486
                        spd_dos = dos.get_spd_dos()
×
2487
                        for idx, orb in enumerate([OrbitalType.s, OrbitalType.p, OrbitalType.d, OrbitalType.f]):
×
2488
                            if orb in spd_dos:
×
2489
                                dos_densities = spd_dos[orb].densities[spin] * int(spin)
×
2490
                                label = orb if spin == Spin.up else None  # type: ignore
×
2491
                                dos_ax.plot(
×
2492
                                    dos_densities,
2493
                                    dos_energies,
2494
                                    color=colors[idx],
2495
                                    label=label,
2496
                                )
2497

2498
            # get index of lowest and highest energy being plotted, used to help auto-scale DOS x-axis
2499
            emin_idx = next(x[0] for x in enumerate(dos_energies) if x[1] >= emin)
1✔
2500
            emax_idx = len(dos_energies) - next(x[0] for x in enumerate(reversed(dos_energies)) if x[1] <= emax)
1✔
2501

2502
            # determine DOS x-axis range
2503
            dos_xmin = (
1✔
2504
                0 if Spin.down not in dos.densities else -max(dos.densities[Spin.down][emin_idx : emax_idx + 1] * 1.05)
2505
            )
2506
            dos_xmax = max([max(dos.densities[Spin.up][emin_idx:emax_idx]) * 1.05, abs(dos_xmin)])
1✔
2507

2508
            # set up the DOS x-axis and add Fermi level line
2509
            dos_ax.set_xlim(dos_xmin, dos_xmax)
1✔
2510
            dos_ax.set_xticklabels([])
1✔
2511
            dos_ax.hlines(y=0, xmin=dos_xmin, xmax=dos_xmax, color="k", lw=2)
1✔
2512
            dos_ax.set_xlabel("DOS", fontsize=self.axis_fontsize, family=self.font)
1✔
2513

2514
        # add legend for band structure
2515
        if self.bs_legend and not rgb_legend:
1✔
2516
            handles = []
1✔
2517

2518
            if bs_projection is None:
1✔
2519
                handles = [
1✔
2520
                    mlines.Line2D([], [], linewidth=2, color="k", label="spin up"),
2521
                    mlines.Line2D(
2522
                        [],
2523
                        [],
2524
                        linewidth=2,
2525
                        color="b",
2526
                        linestyle="dotted",
2527
                        label="spin down",
2528
                    ),
2529
                ]
2530

2531
            elif bs_projection.lower() == "elements":
×
2532
                colors = ["b", "r", "g"]
×
2533
                for idx, el in enumerate(elements):
×
2534
                    handles.append(mlines.Line2D([], [], linewidth=2, color=colors[idx], label=el))
×
2535

2536
            bs_ax.legend(
1✔
2537
                handles=handles,
2538
                fancybox=True,
2539
                prop={"size": self.legend_fontsize, "family": self.font},
2540
                loc=self.bs_legend,
2541
            )
2542

2543
        elif self.bs_legend and rgb_legend:
1✔
2544
            if len(elements) == 2:
1✔
2545
                self._rb_line(bs_ax, elements[1], elements[0], loc=self.bs_legend)
×
2546
            elif len(elements) == 3:
1✔
2547
                self._rgb_triangle(bs_ax, elements[1], elements[2], elements[0], loc=self.bs_legend)
×
2548
            elif len(elements) == 4:
1✔
2549
                self._cmyk_triangle(bs_ax, elements[1], elements[2], elements[0], elements[3], loc=self.bs_legend)
1✔
2550
        # add legend for DOS
2551
        if dos and self.dos_legend:
1✔
2552
            dos_ax.legend(
1✔
2553
                fancybox=True,
2554
                prop={"size": self.legend_fontsize, "family": self.font},
2555
                loc=self.dos_legend,
2556
            )
2557

2558
        mplt.subplots_adjust(wspace=0.1)
1✔
2559
        return mplt
1✔
2560

2561
    @staticmethod
1✔
2562
    def _rgbline(ax, k, e, red, green, blue, alpha=1, linestyles="solid"):
1✔
2563
        """
2564
        An RGB colored line for plotting.
2565
        creation of segments based on:
2566
        http://nbviewer.ipython.org/urls/raw.github.com/dpsanders/matplotlib-examples/master/colorline.ipynb
2567
        Args:
2568
            ax: matplotlib axis
2569
            k: x-axis data (k-points)
2570
            e: y-axis data (energies)
2571
            red: red data
2572
            green: green data
2573
            blue: blue data
2574
            alpha: alpha values data
2575
            linestyles: linestyle for plot (e.g., "solid" or "dotted")
2576
        """
2577
        from matplotlib.collections import LineCollection
1✔
2578

2579
        pts = np.array([k, e]).T.reshape(-1, 1, 2)  # pylint: disable=E1121
1✔
2580
        seg = np.concatenate([pts[:-1], pts[1:]], axis=1)
1✔
2581

2582
        nseg = len(k) - 1
1✔
2583
        r = [0.5 * (red[i] + red[i + 1]) for i in range(nseg)]
1✔
2584
        g = [0.5 * (green[i] + green[i + 1]) for i in range(nseg)]
1✔
2585
        b = [0.5 * (blue[i] + blue[i + 1]) for i in range(nseg)]
1✔
2586
        a = np.ones(nseg, np.float_) * alpha
1✔
2587
        lc = LineCollection(seg, colors=list(zip(r, g, b, a)), linewidth=2, linestyles=linestyles)
1✔
2588
        ax.add_collection(lc)
1✔
2589

2590
    @staticmethod
1✔
2591
    def _get_colordata(bs, elements, bs_projection):
1✔
2592
        """
2593
        Get color data, including projected band structures
2594

2595
        Args:
2596
            bs: Bandstructure object
2597
            elements: elements (in desired order) for setting to blue, red, green
2598
            bs_projection: None for no projection, "elements" for element projection
2599

2600
        Returns:
2601
            Dictionary representation of color data.
2602
        """
2603
        contribs = {}
1✔
2604
        if bs_projection and bs_projection.lower() == "elements":
1✔
2605
            projections = bs.get_projection_on_elements()
1✔
2606

2607
        for spin in (Spin.up, Spin.down):
1✔
2608
            if spin in bs.bands:
1✔
2609
                contribs[spin] = []
1✔
2610
                for band_idx in range(bs.nb_bands):
1✔
2611
                    colors = []
1✔
2612
                    for k_idx in range(len(bs.kpoints)):
1✔
2613
                        if bs_projection and bs_projection.lower() == "elements":
1✔
2614
                            c = [0, 0, 0, 0]
1✔
2615
                            projs = projections[spin][band_idx][k_idx]
1✔
2616
                            # note: squared color interpolations are smoother
2617
                            # see: https://youtu.be/LKnqECcg6Gw
2618
                            projs = {k: v**2 for k, v in projs.items()}
1✔
2619
                            total = sum(projs.values())
1✔
2620
                            if total > 0:
1✔
2621
                                for idx, e in enumerate(elements):
1✔
2622
                                    c[idx] = math.sqrt(projs[e] / total)  # min is to handle round errors
1✔
2623

2624
                            c = [
1✔
2625
                                c[1],
2626
                                c[2],
2627
                                c[0],
2628
                                c[3],
2629
                            ]  # prefer blue, then red, then green or magenta, then yellow, then cyan, then black
2630
                            if len(elements) == 4:
1✔
2631
                                # convert cmyk to rgb
2632
                                c = [(1 - c[0]) * (1 - c[3]), ((1 - c[1]) * (1 - c[3])), ((1 - c[2]) * (1 - c[3]))]
1✔
2633
                            else:
2634
                                c = [c[0], c[1], c[2]]
×
2635

2636
                        else:
2637
                            c = [0, 0, 0] if spin == Spin.up else [0, 0, 1]  # black for spin up, blue for spin down
1✔
2638

2639
                        colors.append(c)
1✔
2640

2641
                    contribs[spin].append(colors)
1✔
2642
                contribs[spin] = np.array(contribs[spin])
1✔
2643

2644
        return contribs
1✔
2645

2646
    @staticmethod
1✔
2647
    def _cmyk_triangle(ax, c_label, m_label, y_label, k_label, loc):
1✔
2648
        """
2649
        Draw an RGB triangle legend on the desired axis
2650
        """
2651
        if loc not in range(1, 11):
1✔
2652
            loc = 2
1✔
2653

2654
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
1✔
2655

2656
        inset_ax = inset_axes(ax, width=1.5, height=1.5, loc=loc)
1✔
2657
        mesh = 35
1✔
2658
        x = []
1✔
2659
        y = []
1✔
2660
        color = []
1✔
2661
        for c in range(0, mesh):
1✔
2662
            for ye in range(0, mesh):
1✔
2663
                for m in range(0, mesh):
1✔
2664
                    if not (c == mesh - 1 and ye == mesh - 1 and m == mesh - 1) and not (c == 0 and ye == 0 and m == 0):
1✔
2665
                        c1 = c / (c + ye + m)
1✔
2666
                        ye1 = ye / (c + ye + m)
1✔
2667
                        m1 = m / (c + ye + m)
1✔
2668
                        x.append(0.33 * (2.0 * ye1 + c1) / (c1 + ye1 + m1))
1✔
2669
                        y.append(0.33 * np.sqrt(3) * c1 / (c1 + ye1 + m1))
1✔
2670
                        rc = 1 - c / (mesh - 1)
1✔
2671
                        gc = 1 - m / (mesh - 1)
1✔
2672
                        bc = 1 - ye / (mesh - 1)
1✔
2673
                        color.append([rc, gc, bc])
1✔
2674

2675
        # x = [n + 0.25 for n in x]  # nudge x coordinates
2676
        # y = [n + (max_y - 1) for n in y]  # shift y coordinates to top
2677
        # plot the triangle
2678
        inset_ax.scatter(x, y, s=7, marker=".", edgecolor=color)
1✔
2679
        inset_ax.set_xlim([-0.35, 1.00])
1✔
2680
        inset_ax.set_ylim([-0.35, 1.00])
1✔
2681

2682
        # add the labels
2683
        inset_ax.text(
1✔
2684
            0.70, -0.2, m_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="left"
2685
        )
2686
        inset_ax.text(
1✔
2687
            0.325, 0.70, c_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="center"
2688
        )
2689
        inset_ax.text(
1✔
2690
            -0.05, -0.2, y_label, fontsize=13, family="Times New Roman", color=(0, 0, 0), horizontalalignment="right"
2691
        )
2692
        inset_ax.text(
1✔
2693
            0.325, 0.22, k_label, fontsize=13, family="Times New Roman", color=(1, 1, 1), horizontalalignment="center"
2694
        )
2695

2696
        inset_ax.get_xaxis().set_visible(False)
1✔
2697
        inset_ax.get_yaxis().set_visible(False)
1✔
2698

2699
    @staticmethod
1✔
2700
    def _rgb_triangle(ax, r_label, g_label, b_label, loc):
1✔
2701
        """
2702
        Draw an RGB triangle legend on the desired axis
2703
        """
2704
        if loc not in range(1, 11):
1✔
2705
            loc = 2
1✔
2706

2707
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
1✔
2708

2709
        inset_ax = inset_axes(ax, width=1, height=1, loc=loc)
1✔
2710
        mesh = 35
1✔
2711
        x = []
1✔
2712
        y = []
1✔
2713
        color = []
1✔
2714
        for r in range(0, mesh):
1✔
2715
            for g in range(0, mesh):
1✔
2716
                for b in range(0, mesh):
1✔
2717
                    if not (r == 0 and b == 0 and g == 0):
1✔
2718
                        r1 = r / (r + g + b)
1✔
2719
                        g1 = g / (r + g + b)
1✔
2720
                        b1 = b / (r + g + b)
1✔
2721
                        x.append(0.33 * (2.0 * g1 + r1) / (r1 + b1 + g1))
1✔
2722
                        y.append(0.33 * np.sqrt(3) * r1 / (r1 + b1 + g1))
1✔
2723
                        rc = math.sqrt(r**2 / (r**2 + g**2 + b**2))
1✔
2724
                        gc = math.sqrt(g**2 / (r**2 + g**2 + b**2))
1✔
2725
                        bc = math.sqrt(b**2 / (r**2 + g**2 + b**2))
1✔
2726
                        color.append([rc, gc, bc])
1✔
2727

2728
        # x = [n + 0.25 for n in x]  # nudge x coordinates
2729
        # y = [n + (max_y - 1) for n in y]  # shift y coordinates to top
2730
        # plot the triangle
2731
        inset_ax.scatter(x, y, s=7, marker=".", edgecolor=color)  # pylint: disable=E1101
1✔
2732
        inset_ax.set_xlim([-0.35, 1.00])  # pylint: disable=E1101
1✔
2733
        inset_ax.set_ylim([-0.35, 1.00])  # pylint: disable=E1101
1✔
2734

2735
        # add the labels
2736
        inset_ax.text(  # pylint: disable=E1101
1✔
2737
            0.70,
2738
            -0.2,
2739
            g_label,
2740
            fontsize=13,
2741
            family="Times New Roman",
2742
            color=(0, 0, 0),
2743
            horizontalalignment="left",
2744
        )
2745
        inset_ax.text(  # pylint: disable=E1101
1✔
2746
            0.325,
2747
            0.70,
2748
            r_label,
2749
            fontsize=13,
2750
            family="Times New Roman",
2751
            color=(0, 0, 0),
2752
            horizontalalignment="center",
2753
        )
2754
        inset_ax.text(  # pylint: disable=E1101
1✔
2755
            -0.05,
2756
            -0.2,
2757
            b_label,
2758
            fontsize=13,
2759
            family="Times New Roman",
2760
            color=(0, 0, 0),
2761
            horizontalalignment="right",
2762
        )
2763

2764
        inset_ax.get_xaxis().set_visible(False)  # pylint: disable=E1101
1✔
2765
        inset_ax.get_yaxis().set_visible(False)  # pylint: disable=E1101
1✔
2766

2767
    @staticmethod
1✔
2768
    def _rb_line(ax, r_label, b_label, loc):
1✔
2769
        # Draw an rb bar legend on the desired axis
2770

2771
        if loc not in range(1, 11):
1✔
2772
            loc = 2
1✔
2773
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes
1✔
2774

2775
        inset_ax = inset_axes(ax, width=1.2, height=0.4, loc=loc)
1✔
2776

2777
        x = []
1✔
2778
        y = []
1✔
2779
        color = []
1✔
2780
        for i in range(0, 1000):
1✔
2781
            x.append(i / 1800.0 + 0.55)
1✔
2782
            y.append(0)
1✔
2783
            color.append([math.sqrt(c) for c in [1 - (i / 1000) ** 2, 0, (i / 1000) ** 2]])
1✔
2784

2785
        # plot the bar
2786
        # pylint: disable=E1101
2787
        inset_ax.scatter(x, y, s=250.0, marker="s", c=color)
1✔
2788
        inset_ax.set_xlim([-0.1, 1.7])
1✔
2789
        inset_ax.text(
1✔
2790
            1.35,
2791
            0,
2792
            b_label,
2793
            fontsize=13,
2794
            family="Times New Roman",
2795
            color=(0, 0, 0),
2796
            horizontalalignment="left",
2797
            verticalalignment="center",
2798
        )
2799
        inset_ax.text(
1✔
2800
            0.30,
2801
            0,
2802
            r_label,
2803
            fontsize=13,
2804
            family="Times New Roman",
2805
            color=(0, 0, 0),
2806
            horizontalalignment="right",
2807
            verticalalignment="center",
2808
        )
2809

2810
        inset_ax.get_xaxis().set_visible(False)
1✔
2811
        inset_ax.get_yaxis().set_visible(False)
1✔
2812

2813

2814
class BoltztrapPlotter:
1✔
2815
    # TODO: We need a unittest for this. Come on folks.
2816
    """
2817
    class containing methods to plot the data from Boltztrap.
2818
    """
2819

2820
    def __init__(self, bz):
1✔
2821
        """
2822
        Args:
2823
            bz: a BoltztrapAnalyzer object
2824
        """
2825
        self._bz = bz
×
2826

2827
    def _plot_doping(self, plt, temp):
1✔
2828
        if len(self._bz.doping) != 0:
×
2829
            limit = 2.21e15
×
2830
            plt.axvline(self._bz.mu_doping["n"][temp][0], linewidth=3.0, linestyle="--")
×
2831
            plt.text(
×
2832
                self._bz.mu_doping["n"][temp][0] + 0.01,
2833
                limit,
2834
                "$n$=10$^{" + str(math.log10(self._bz.doping["n"][0])) + "}$",
2835
                color="b",
2836
            )
2837
            plt.axvline(self._bz.mu_doping["n"][temp][-1], linewidth=3.0, linestyle="--")
×
2838
            plt.text(
×
2839
                self._bz.mu_doping["n"][temp][-1] + 0.01,
2840
                limit,
2841
                "$n$=10$^{" + str(math.log10(self._bz.doping["n"][-1])) + "}$",
2842
                color="b",
2843
            )
2844
            plt.axvline(self._bz.mu_doping["p"][temp][0], linewidth=3.0, linestyle="--")
×
2845
            plt.text(
×
2846
                self._bz.mu_doping["p"][temp][0] + 0.01,
2847
                limit,
2848
                "$p$=10$^{" + str(math.log10(self._bz.doping["p"][0])) + "}$",
2849
                color="b",
2850
            )
2851
            plt.axvline(self._bz.mu_doping["p"][temp][-1], linewidth=3.0, linestyle="--")
×
2852
            plt.text(
×
2853
                self._bz.mu_doping["p"][temp][-1] + 0.01,
2854
                limit,
2855
                "$p$=10$^{" + str(math.log10(self._bz.doping["p"][-1])) + "}$",
2856
                color="b",
2857
            )
2858

2859
    def _plot_bg_limits(self, plt):
1✔
2860
        plt.axvline(0.0, color="k", linewidth=3.0)
×
2861
        plt.axvline(self._bz.gap, color="k", linewidth=3.0)
×
2862

2863
    def plot_seebeck_eff_mass_mu(self, temps=(300,), output="average", Lambda=0.5):
1✔
2864
        """
2865
        Plot respect to the chemical potential of the Seebeck effective mass
2866
        calculated as explained in Ref.
2867
        Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor
2868
        from ab initio band structure calculations.
2869
        npj Computational Materials 3, 8 (2017).
2870

2871
        Args:
2872
            output: 'average' returns the seebeck effective mass calculated
2873
                using the average of the three diagonal components of the
2874
                seebeck tensor. 'tensor' returns the seebeck effective mass
2875
                respect to the three diagonal components of the seebeck tensor.
2876
            temps:  list of temperatures of calculated seebeck.
2877
            Lambda: fitting parameter used to model the scattering (0.5 means
2878
                constant relaxation time).
2879

2880
        Returns:
2881
            a matplotlib object
2882
        """
2883
        plt = pretty_plot(9, 7)
×
2884
        for T in temps:
×
2885
            sbk_mass = self._bz.get_seebeck_eff_mass(output=output, temp=T, Lambda=0.5)
×
2886
            # remove noise inside the gap
2887
            start = self._bz.mu_doping["p"][T][0]
×
2888
            stop = self._bz.mu_doping["n"][T][0]
×
2889
            mu_steps_1 = []
×
2890
            mu_steps_2 = []
×
2891
            sbk_mass_1 = []
×
2892
            sbk_mass_2 = []
×
2893
            for i, mu in enumerate(self._bz.mu_steps):
×
2894
                if mu <= start:
×
2895
                    mu_steps_1.append(mu)
×
2896
                    sbk_mass_1.append(sbk_mass[i])
×
2897
                elif mu >= stop:
×
2898
                    mu_steps_2.append(mu)
×
2899
                    sbk_mass_2.append(sbk_mass[i])
×
2900

2901
            plt.plot(mu_steps_1, sbk_mass_1, label=str(T) + "K", linewidth=3.0)
×
2902
            plt.plot(mu_steps_2, sbk_mass_2, linewidth=3.0)
×
2903
            if output == "average":
×
2904
                plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c())
×
2905
            elif output == "tensor":
×
2906
                plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c())
×
2907
                plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c())
×
2908
                plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c())
×
2909

2910
        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
×
2911
        plt.ylabel("Seebeck effective mass", fontsize=30)
×
2912
        plt.xticks(fontsize=25)
×
2913
        plt.yticks(fontsize=25)
×
2914
        if output == "tensor":
×
2915
            plt.legend(
×
2916
                [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")],
2917
                fontsize=20,
2918
            )
2919
        elif output == "average":
×
2920
            plt.legend(fontsize=20)
×
2921
        plt.tight_layout()
×
2922
        return plt
×
2923

2924
    def plot_complexity_factor_mu(self, temps=(300,), output="average", Lambda=0.5):
1✔
2925
        """
2926
        Plot respect to the chemical potential of the Fermi surface complexity
2927
        factor calculated as explained in Ref.
2928
        Gibbs, Z. M. et al., Effective mass and fermi surface complexity factor
2929
        from ab initio band structure calculations.
2930
        npj Computational Materials 3, 8 (2017).
2931

2932
        Args:
2933
            output: 'average' returns the complexity factor calculated using the average
2934
                    of the three diagonal components of the seebeck and conductivity tensors.
2935
                    'tensor' returns the complexity factor respect to the three
2936
                    diagonal components of seebeck and conductivity tensors.
2937
            temps:  list of temperatures of calculated seebeck and conductivity.
2938
            Lambda: fitting parameter used to model the scattering (0.5 means constant
2939
                    relaxation time).
2940

2941
        Returns:
2942
            a matplotlib object
2943
        """
2944
        plt = pretty_plot(9, 7)
×
2945
        for T in temps:
×
2946
            cmplx_fact = self._bz.get_complexity_factor(output=output, temp=T, Lambda=Lambda)
×
2947
            start = self._bz.mu_doping["p"][T][0]
×
2948
            stop = self._bz.mu_doping["n"][T][0]
×
2949
            mu_steps_1 = []
×
2950
            mu_steps_2 = []
×
2951
            cmplx_fact_1 = []
×
2952
            cmplx_fact_2 = []
×
2953
            for i, mu in enumerate(self._bz.mu_steps):
×
2954
                if mu <= start:
×
2955
                    mu_steps_1.append(mu)
×
2956
                    cmplx_fact_1.append(cmplx_fact[i])
×
2957
                elif mu >= stop:
×
2958
                    mu_steps_2.append(mu)
×
2959
                    cmplx_fact_2.append(cmplx_fact[i])
×
2960

2961
            plt.plot(mu_steps_1, cmplx_fact_1, label=str(T) + "K", linewidth=3.0)
×
2962
            plt.plot(mu_steps_2, cmplx_fact_2, linewidth=3.0)
×
2963
            if output == "average":
×
2964
                plt.gca().get_lines()[1].set_c(plt.gca().get_lines()[0].get_c())
×
2965
            elif output == "tensor":
×
2966
                plt.gca().get_lines()[3].set_c(plt.gca().get_lines()[0].get_c())
×
2967
                plt.gca().get_lines()[4].set_c(plt.gca().get_lines()[1].get_c())
×
2968
                plt.gca().get_lines()[5].set_c(plt.gca().get_lines()[2].get_c())
×
2969

2970
        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
×
2971
        plt.ylabel("Complexity Factor", fontsize=30)
×
2972
        plt.xticks(fontsize=25)
×
2973
        plt.yticks(fontsize=25)
×
2974
        if output == "tensor":
×
2975
            plt.legend(
×
2976
                [str(i) + "_" + str(T) + "K" for T in temps for i in ("x", "y", "z")],
2977
                fontsize=20,
2978
            )
2979
        elif output == "average":
×
2980
            plt.legend(fontsize=20)
×
2981
        plt.tight_layout()
×
2982
        return plt
×
2983

2984
    def plot_seebeck_mu(self, temp=600, output="eig", xlim=None):
1✔
2985
        """
2986
        Plot the seebeck coefficient in function of Fermi level
2987

2988
        Args:
2989
            temp:
2990
                the temperature
2991
            xlim:
2992
                a list of min and max fermi energy by default (0, and band gap)
2993

2994
        Returns:
2995
            a matplotlib object
2996
        """
2997
        plt = pretty_plot(9, 7)
×
2998
        seebeck = self._bz.get_seebeck(output=output, doping_levels=False)[temp]
×
2999
        plt.plot(self._bz.mu_steps, seebeck, linewidth=3.0)
×
3000

3001
        self._plot_bg_limits(plt)
×
3002
        self._plot_doping(plt, temp)
×
3003
        if output == "eig":
×
3004
            plt.legend(["S$_1$", "S$_2$", "S$_3$"])
×
3005
        if xlim is None:
×
3006
            plt.xlim(-0.5, self._bz.gap + 0.5)
×
3007
        else:
3008
            plt.xlim(xlim[0], xlim[1])
×
3009
        plt.ylabel("Seebeck \n coefficient  ($\\mu$V/K)", fontsize=30.0)
×
3010
        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
×
3011
        plt.xticks(fontsize=25)
×
3012
        plt.yticks(fontsize=25)
×
3013
        plt.tight_layout()
×
3014
        return plt
×
3015

3016
    def plot_conductivity_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
1✔
3017
        """
3018
        Plot the conductivity in function of Fermi level. Semi-log plot
3019

3020
        Args:
3021
            temp: the temperature
3022
            xlim: a list of min and max fermi energy by default (0, and band
3023
                gap)
3024
            tau: A relaxation time in s. By default none and the plot is by
3025
               units of relaxation time
3026

3027
        Returns:
3028
            a matplotlib object
3029
        """
3030
        cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
×
3031
        plt = pretty_plot(9, 7)
×
3032
        plt.semilogy(self._bz.mu_steps, cond, linewidth=3.0)
×
3033
        self._plot_bg_limits(plt)
×
3034
        self._plot_doping(plt, temp)
×
3035
        if output == "eig":
×
3036
            plt.legend(["$\\Sigma_1$", "$\\Sigma_2$", "$\\Sigma_3$"])
×
3037
        if xlim is None:
×
3038
            plt.xlim(-0.5, self._bz.gap + 0.5)
×
3039
        else:
3040
            plt.xlim(xlim)
×
3041
        plt.ylim([1e13 * relaxation_time, 1e20 * relaxation_time])
×
3042
        plt.ylabel("conductivity,\n $\\Sigma$ (1/($\\Omega$ m))", fontsize=30.0)
×
3043
        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
×
3044
        plt.xticks(fontsize=25)
×
3045
        plt.yticks(fontsize=25)
×
3046
        plt.tight_layout()
×
3047
        return plt
×
3048

3049
    def plot_power_factor_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
1✔
3050
        """
3051
        Plot the power factor in function of Fermi level. Semi-log plot
3052

3053
        Args:
3054
            temp: the temperature
3055
            xlim: a list of min and max fermi energy by default (0, and band
3056
                gap)
3057
            tau: A relaxation time in s. By default none and the plot is by
3058
               units of relaxation time
3059

3060
        Returns:
3061
            a matplotlib object
3062
        """
3063
        plt = pretty_plot(9, 7)
×
3064
        pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
×
3065
        plt.semilogy(self._bz.mu_steps, pf, linewidth=3.0)
×
3066
        self._plot_bg_limits(plt)
×
3067
        self._plot_doping(plt, temp)
×
3068
        if output == "eig":
×
3069
            plt.legend(["PF$_1$", "PF$_2$", "PF$_3$"])
×
3070
        if xlim is None:
×
3071
            plt.xlim(-0.5, self._bz.gap + 0.5)
×
3072
        else:
3073
            plt.xlim(xlim)
×
3074
        plt.ylabel("Power factor, ($\\mu$W/(mK$^2$))", fontsize=30.0)
×
3075
        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
×
3076
        plt.xticks(fontsize=25)
×
3077
        plt.yticks(fontsize=25)
×
3078
        plt.tight_layout()
×
3079
        return plt
×
3080

3081
    def plot_zt_mu(self, temp=600, output="eig", relaxation_time=1e-14, xlim=None):
1✔
3082
        """
3083
        Plot the ZT in function of Fermi level.
3084

3085
        Args:
3086
            temp: the temperature
3087
            xlim: a list of min and max fermi energy by default (0, and band
3088
                gap)
3089
            tau: A relaxation time in s. By default none and the plot is by
3090
               units of relaxation time
3091

3092
        Returns:
3093
            a matplotlib object
3094
        """
3095
        plt = pretty_plot(9, 7)
×
3096
        zt = self._bz.get_zt(relaxation_time=relaxation_time, output=output, doping_levels=False)[temp]
×
3097
        plt.plot(self._bz.mu_steps, zt, linewidth=3.0)
×
3098
        self._plot_bg_limits(plt)
×
3099
        self._plot_doping(plt, temp)
×
3100
        if output == "eig":
×
3101
            plt.legend(["ZT$_1$", "ZT$_2$", "ZT$_3$"])
×
3102
        if xlim is None:
×
3103
            plt.xlim(-0.5, self._bz.gap + 0.5)
×
3104
        else:
3105
            plt.xlim(xlim)
×
3106
        plt.ylabel("ZT", fontsize=30.0)
×
3107
        plt.xlabel("E-E$_f$ (eV)", fontsize=30.0)
×
3108
        plt.xticks(fontsize=25)
×
3109
        plt.yticks(fontsize=25)
×
3110
        plt.tight_layout()
×
3111
        return plt
×
3112

3113
    def plot_seebeck_temp(self, doping="all", output="average"):
1✔
3114
        """
3115
        Plot the Seebeck coefficient in function of temperature for different
3116
        doping levels.
3117

3118
        Args:
3119
            dopings: the default 'all' plots all the doping levels in the analyzer.
3120
                     Specify a list of doping levels if you want to plot only some.
3121
            output: with 'average' you get an average of the three directions
3122
                    with 'eigs' you get all the three directions.
3123

3124
        Returns:
3125
            a matplotlib object
3126
        """
3127
        if output == "average":
×
3128
            sbk = self._bz.get_seebeck(output="average")
×
3129
        elif output == "eigs":
×
3130
            sbk = self._bz.get_seebeck(output="eigs")
×
3131

3132
        plt = pretty_plot(22, 14)
×
3133
        tlist = sorted(sbk["n"])
×
3134
        doping = self._bz.doping["n"] if doping == "all" else doping
×
3135
        for i, dt in enumerate(["n", "p"]):
×
3136
            plt.subplot(121 + i)
×
3137
            for dop in doping:
×
3138
                d = self._bz.doping[dt].index(dop)
×
3139
                sbk_temp = []
×
3140
                for temp in tlist:
×
3141
                    sbk_temp.append(sbk[dt][temp][d])
×
3142
                if output == "average":
×
3143
                    plt.plot(tlist, sbk_temp, marker="s", label=str(dop) + " $cm^{-3}$")
×
3144
                elif output == "eigs":
×
3145
                    for xyz in range(3):
×
3146
                        plt.plot(
×
3147
                            tlist,
3148
                            list(zip(*sbk_temp))[xyz],
3149
                            marker="s",
3150
                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3151
                        )
3152
            plt.title(dt + "-type", fontsize=20)
×
3153
            if i == 0:
×
3154
                plt.ylabel("Seebeck \n coefficient  ($\\mu$V/K)", fontsize=30.0)
×
3155
            plt.xlabel("Temperature (K)", fontsize=30.0)
×
3156

3157
            p = "lower right" if i == 0 else "best"
×
3158
            plt.legend(loc=p, fontsize=15)
×
3159
            plt.grid()
×
3160
            plt.xticks(fontsize=25)
×
3161
            plt.yticks(fontsize=25)
×
3162

3163
        plt.tight_layout()
×
3164

3165
        return plt
×
3166

3167
    def plot_conductivity_temp(self, doping="all", output="average", relaxation_time=1e-14):
1✔
3168
        """
3169
        Plot the conductivity in function of temperature for different doping levels.
3170

3171
        Args:
3172
            dopings: the default 'all' plots all the doping levels in the analyzer.
3173
                     Specify a list of doping levels if you want to plot only some.
3174
            output: with 'average' you get an average of the three directions
3175
                    with 'eigs' you get all the three directions.
3176
            relaxation_time: specify a constant relaxation time value
3177

3178
        Returns:
3179
            a matplotlib object
3180
        """
3181
        if output == "average":
×
3182
            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average")
×
3183
        elif output == "eigs":
×
3184
            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs")
×
3185

3186
        plt = pretty_plot(22, 14)
×
3187
        tlist = sorted(cond["n"])
×
3188
        doping = self._bz.doping["n"] if doping == "all" else doping
×
3189
        for i, dt in enumerate(["n", "p"]):
×
3190
            plt.subplot(121 + i)
×
3191
            for dop in doping:
×
3192
                d = self._bz.doping[dt].index(dop)
×
3193
                cond_temp = []
×
3194
                for temp in tlist:
×
3195
                    cond_temp.append(cond[dt][temp][d])
×
3196
                if output == "average":
×
3197
                    plt.plot(tlist, cond_temp, marker="s", label=str(dop) + " $cm^{-3}$")
×
3198
                elif output == "eigs":
×
3199
                    for xyz in range(3):
×
3200
                        plt.plot(
×
3201
                            tlist,
3202
                            list(zip(*cond_temp))[xyz],
3203
                            marker="s",
3204
                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3205
                        )
3206
            plt.title(dt + "-type", fontsize=20)
×
3207
            if i == 0:
×
3208
                plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0)
×
3209
            plt.xlabel("Temperature (K)", fontsize=30.0)
×
3210

3211
            p = "best"  # 'lower right' if i == 0 else ''
×
3212
            plt.legend(loc=p, fontsize=15)
×
3213
            plt.grid()
×
3214
            plt.xticks(fontsize=25)
×
3215
            plt.yticks(fontsize=25)
×
3216
            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
×
3217

3218
        plt.tight_layout()
×
3219

3220
        return plt
×
3221

3222
    def plot_power_factor_temp(self, doping="all", output="average", relaxation_time=1e-14):
1✔
3223
        """
3224
        Plot the Power Factor in function of temperature for different doping levels.
3225

3226
        Args:
3227
            dopings: the default 'all' plots all the doping levels in the analyzer.
3228
                     Specify a list of doping levels if you want to plot only some.
3229
            output: with 'average' you get an average of the three directions
3230
                    with 'eigs' you get all the three directions.
3231
            relaxation_time: specify a constant relaxation time value
3232

3233
        Returns:
3234
            a matplotlib object
3235
        """
3236
        if output == "average":
×
3237
            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
×
3238
        elif output == "eigs":
×
3239
            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
×
3240

3241
        plt = pretty_plot(22, 14)
×
3242
        tlist = sorted(pf["n"])
×
3243
        doping = self._bz.doping["n"] if doping == "all" else doping
×
3244
        for i, dt in enumerate(["n", "p"]):
×
3245
            plt.subplot(121 + i)
×
3246
            for dop in doping:
×
3247
                d = self._bz.doping[dt].index(dop)
×
3248
                pf_temp = []
×
3249
                for temp in tlist:
×
3250
                    pf_temp.append(pf[dt][temp][d])
×
3251
                if output == "average":
×
3252
                    plt.plot(tlist, pf_temp, marker="s", label=str(dop) + " $cm^{-3}$")
×
3253
                elif output == "eigs":
×
3254
                    for xyz in range(3):
×
3255
                        plt.plot(
×
3256
                            tlist,
3257
                            list(zip(*pf_temp))[xyz],
3258
                            marker="s",
3259
                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3260
                        )
3261
            plt.title(dt + "-type", fontsize=20)
×
3262
            if i == 0:
×
3263
                plt.ylabel("Power Factor ($\\mu$W/(mK$^2$))", fontsize=30.0)
×
3264
            plt.xlabel("Temperature (K)", fontsize=30.0)
×
3265

3266
            p = "best"  # 'lower right' if i == 0 else ''
×
3267
            plt.legend(loc=p, fontsize=15)
×
3268
            plt.grid()
×
3269
            plt.xticks(fontsize=25)
×
3270
            plt.yticks(fontsize=25)
×
3271
            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
×
3272

3273
        plt.tight_layout()
×
3274
        return plt
×
3275

3276
    def plot_zt_temp(self, doping="all", output="average", relaxation_time=1e-14):
1✔
3277
        """
3278
        Plot the figure of merit zT in function of temperature for different doping levels.
3279

3280
        Args:
3281
            dopings: the default 'all' plots all the doping levels in the analyzer.
3282
                     Specify a list of doping levels if you want to plot only some.
3283
            output: with 'average' you get an average of the three directions
3284
                    with 'eigs' you get all the three directions.
3285
            relaxation_time: specify a constant relaxation time value
3286

3287
        Returns:
3288
            a matplotlib object
3289
        """
3290
        if output == "average":
×
3291
            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average")
×
3292
        elif output == "eigs":
×
3293
            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs")
×
3294

3295
        plt = pretty_plot(22, 14)
×
3296
        tlist = sorted(zt["n"])
×
3297
        doping = self._bz.doping["n"] if doping == "all" else doping
×
3298
        for i, dt in enumerate(["n", "p"]):
×
3299
            plt.subplot(121 + i)
×
3300
            for dop in doping:
×
3301
                d = self._bz.doping[dt].index(dop)
×
3302
                zt_temp = []
×
3303
                for temp in tlist:
×
3304
                    zt_temp.append(zt[dt][temp][d])
×
3305
                if output == "average":
×
3306
                    plt.plot(tlist, zt_temp, marker="s", label=str(dop) + " $cm^{-3}$")
×
3307
                elif output == "eigs":
×
3308
                    for xyz in range(3):
×
3309
                        plt.plot(
×
3310
                            tlist,
3311
                            list(zip(*zt_temp))[xyz],
3312
                            marker="s",
3313
                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3314
                        )
3315
            plt.title(dt + "-type", fontsize=20)
×
3316
            if i == 0:
×
3317
                plt.ylabel("zT", fontsize=30.0)
×
3318
            plt.xlabel("Temperature (K)", fontsize=30.0)
×
3319

3320
            p = "best"  # 'lower right' if i == 0 else ''
×
3321
            plt.legend(loc=p, fontsize=15)
×
3322
            plt.grid()
×
3323
            plt.xticks(fontsize=25)
×
3324
            plt.yticks(fontsize=25)
×
3325

3326
        plt.tight_layout()
×
3327
        return plt
×
3328

3329
    def plot_eff_mass_temp(self, doping="all", output="average"):
1✔
3330
        """
3331
        Plot the average effective mass in function of temperature
3332
        for different doping levels.
3333

3334
        Args:
3335
            dopings: the default 'all' plots all the doping levels in the analyzer.
3336
                     Specify a list of doping levels if you want to plot only some.
3337
            output: with 'average' you get an average of the three directions
3338
                    with 'eigs' you get all the three directions.
3339

3340
        Returns:
3341
            a matplotlib object
3342
        """
3343
        if output == "average":
×
3344
            em = self._bz.get_average_eff_mass(output="average")
×
3345
        elif output == "eigs":
×
3346
            em = self._bz.get_average_eff_mass(output="eigs")
×
3347

3348
        plt = pretty_plot(22, 14)
×
3349
        tlist = sorted(em["n"])
×
3350
        doping = self._bz.doping["n"] if doping == "all" else doping
×
3351
        for i, dt in enumerate(["n", "p"]):
×
3352
            plt.subplot(121 + i)
×
3353
            for dop in doping:
×
3354
                d = self._bz.doping[dt].index(dop)
×
3355
                em_temp = []
×
3356
                for temp in tlist:
×
3357
                    em_temp.append(em[dt][temp][d])
×
3358
                if output == "average":
×
3359
                    plt.plot(tlist, em_temp, marker="s", label=str(dop) + " $cm^{-3}$")
×
3360
                elif output == "eigs":
×
3361
                    for xyz in range(3):
×
3362
                        plt.plot(
×
3363
                            tlist,
3364
                            list(zip(*em_temp))[xyz],
3365
                            marker="s",
3366
                            label=str(xyz) + " " + str(dop) + " $cm^{-3}$",
3367
                        )
3368
            plt.title(dt + "-type", fontsize=20)
×
3369
            if i == 0:
×
3370
                plt.ylabel("Effective mass (m$_e$)", fontsize=30.0)
×
3371
            plt.xlabel("Temperature (K)", fontsize=30.0)
×
3372

3373
            p = "best"  # 'lower right' if i == 0 else ''
×
3374
            plt.legend(loc=p, fontsize=15)
×
3375
            plt.grid()
×
3376
            plt.xticks(fontsize=25)
×
3377
            plt.yticks(fontsize=25)
×
3378

3379
        plt.tight_layout()
×
3380
        return plt
×
3381

3382
    def plot_seebeck_dop(self, temps="all", output="average"):
1✔
3383
        """
3384
        Plot the Seebeck in function of doping levels for different temperatures.
3385

3386
        Args:
3387
            temps: the default 'all' plots all the temperatures in the analyzer.
3388
                   Specify a list of temperatures if you want to plot only some.
3389
            output: with 'average' you get an average of the three directions
3390
                    with 'eigs' you get all the three directions.
3391

3392
        Returns:
3393
            a matplotlib object
3394
        """
3395
        if output == "average":
×
3396
            sbk = self._bz.get_seebeck(output="average")
×
3397
        elif output == "eigs":
×
3398
            sbk = self._bz.get_seebeck(output="eigs")
×
3399

3400
        tlist = sorted(sbk["n"]) if temps == "all" else temps
×
3401
        plt = pretty_plot(22, 14)
×
3402
        for i, dt in enumerate(["n", "p"]):
×
3403
            plt.subplot(121 + i)
×
3404
            for temp in tlist:
×
3405
                if output == "eigs":
×
3406
                    for xyz in range(3):
×
3407
                        plt.semilogx(
×
3408
                            self._bz.doping[dt],
3409
                            list(zip(*sbk[dt][temp]))[xyz],
3410
                            marker="s",
3411
                            label=str(xyz) + " " + str(temp) + " K",
3412
                        )
3413
                elif output == "average":
×
3414
                    plt.semilogx(
×
3415
                        self._bz.doping[dt],
3416
                        sbk[dt][temp],
3417
                        marker="s",
3418
                        label=str(temp) + " K",
3419
                    )
3420
            plt.title(dt + "-type", fontsize=20)
×
3421
            if i == 0:
×
3422
                plt.ylabel("Seebeck coefficient ($\\mu$V/K)", fontsize=30.0)
×
3423
            plt.xlabel("Doping concentration (cm$^{-3}$)", fontsize=30.0)
×
3424

3425
            p = "lower right" if i == 0 else "best"
×
3426
            plt.legend(loc=p, fontsize=15)
×
3427
            plt.grid()
×
3428
            plt.xticks(fontsize=25)
×
3429
            plt.yticks(fontsize=25)
×
3430

3431
        plt.tight_layout()
×
3432

3433
        return plt
×
3434

3435
    def plot_conductivity_dop(self, temps="all", output="average", relaxation_time=1e-14):
1✔
3436
        """
3437
        Plot the conductivity in function of doping levels for different
3438
        temperatures.
3439

3440
        Args:
3441
            temps: the default 'all' plots all the temperatures in the analyzer.
3442
                   Specify a list of temperatures if you want to plot only some.
3443
            output: with 'average' you get an average of the three directions
3444
                    with 'eigs' you get all the three directions.
3445
            relaxation_time: specify a constant relaxation time value
3446

3447
        Returns:
3448
            a matplotlib object
3449
        """
3450
        if output == "average":
×
3451
            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="average")
×
3452
        elif output == "eigs":
×
3453
            cond = self._bz.get_conductivity(relaxation_time=relaxation_time, output="eigs")
×
3454

3455
        tlist = sorted(cond["n"]) if temps == "all" else temps
×
3456
        plt = pretty_plot(22, 14)
×
3457
        for i, dt in enumerate(["n", "p"]):
×
3458
            plt.subplot(121 + i)
×
3459
            for temp in tlist:
×
3460
                if output == "eigs":
×
3461
                    for xyz in range(3):
×
3462
                        plt.semilogx(
×
3463
                            self._bz.doping[dt],
3464
                            list(zip(*cond[dt][temp]))[xyz],
3465
                            marker="s",
3466
                            label=str(xyz) + " " + str(temp) + " K",
3467
                        )
3468
                elif output == "average":
×
3469
                    plt.semilogx(
×
3470
                        self._bz.doping[dt],
3471
                        cond[dt][temp],
3472
                        marker="s",
3473
                        label=str(temp) + " K",
3474
                    )
3475
            plt.title(dt + "-type", fontsize=20)
×
3476
            if i == 0:
×
3477
                plt.ylabel("conductivity $\\sigma$ (1/($\\Omega$ m))", fontsize=30.0)
×
3478
            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
×
3479
            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
×
3480
            plt.legend(fontsize=15)
×
3481
            plt.grid()
×
3482
            plt.xticks(fontsize=25)
×
3483
            plt.yticks(fontsize=25)
×
3484

3485
        plt.tight_layout()
×
3486

3487
        return plt
×
3488

3489
    def plot_power_factor_dop(self, temps="all", output="average", relaxation_time=1e-14):
1✔
3490
        """
3491
        Plot the Power Factor in function of doping levels for different temperatures.
3492

3493
        Args:
3494
            temps: the default 'all' plots all the temperatures in the analyzer.
3495
                   Specify a list of temperatures if you want to plot only some.
3496
            output: with 'average' you get an average of the three directions
3497
                    with 'eigs' you get all the three directions.
3498
            relaxation_time: specify a constant relaxation time value
3499

3500
        Returns:
3501
            a matplotlib object
3502
        """
3503
        if output == "average":
×
3504
            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="average")
×
3505
        elif output == "eigs":
×
3506
            pf = self._bz.get_power_factor(relaxation_time=relaxation_time, output="eigs")
×
3507

3508
        tlist = sorted(pf["n"]) if temps == "all" else temps
×
3509
        plt = pretty_plot(22, 14)
×
3510
        for i, dt in enumerate(["n", "p"]):
×
3511
            plt.subplot(121 + i)
×
3512
            for temp in tlist:
×
3513
                if output == "eigs":
×
3514
                    for xyz in range(3):
×
3515
                        plt.semilogx(
×
3516
                            self._bz.doping[dt],
3517
                            list(zip(*pf[dt][temp]))[xyz],
3518
                            marker="s",
3519
                            label=str(xyz) + " " + str(temp) + " K",
3520
                        )
3521
                elif output == "average":
×
3522
                    plt.semilogx(
×
3523
                        self._bz.doping[dt],
3524
                        pf[dt][temp],
3525
                        marker="s",
3526
                        label=str(temp) + " K",
3527
                    )
3528
            plt.title(dt + "-type", fontsize=20)
×
3529
            if i == 0:
×
3530
                plt.ylabel("Power Factor  ($\\mu$W/(mK$^2$))", fontsize=30.0)
×
3531
            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
×
3532
            plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
×
3533
            p = "best"  # 'lower right' if i == 0 else ''
×
3534
            plt.legend(loc=p, fontsize=15)
×
3535
            plt.grid()
×
3536
            plt.xticks(fontsize=25)
×
3537
            plt.yticks(fontsize=25)
×
3538

3539
        plt.tight_layout()
×
3540

3541
        return plt
×
3542

3543
    def plot_zt_dop(self, temps="all", output="average", relaxation_time=1e-14):
1✔
3544
        """
3545
        Plot the figure of merit zT in function of doping levels for different
3546
        temperatures.
3547

3548
        Args:
3549
            temps: the default 'all' plots all the temperatures in the analyzer.
3550
                   Specify a list of temperatures if you want to plot only some.
3551
            output: with 'average' you get an average of the three directions
3552
                    with 'eigs' you get all the three directions.
3553
            relaxation_time: specify a constant relaxation time value
3554

3555
        Returns:
3556
            a matplotlib object
3557
        """
3558
        if output == "average":
×
3559
            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="average")
×
3560
        elif output == "eigs":
×
3561
            zt = self._bz.get_zt(relaxation_time=relaxation_time, output="eigs")
×
3562

3563
        tlist = sorted(zt["n"]) if temps == "all" else temps
×
3564
        plt = pretty_plot(22, 14)
×
3565
        for i, dt in enumerate(["n", "p"]):
×
3566
            plt.subplot(121 + i)
×
3567
            for temp in tlist:
×
3568
                if output == "eigs":
×
3569
                    for xyz in range(3):
×
3570
                        plt.semilogx(
×
3571
                            self._bz.doping[dt],
3572
                            list(zip(*zt[dt][temp]))[xyz],
3573
                            marker="s",
3574
                            label=str(xyz) + " " + str(temp) + " K",
3575
                        )
3576
                elif output == "average":
×
3577
                    plt.semilogx(
×
3578
                        self._bz.doping[dt],
3579
                        zt[dt][temp],
3580
                        marker="s",
3581
                        label=str(temp) + " K",
3582
                    )
3583
            plt.title(dt + "-type", fontsize=20)
×
3584
            if i == 0:
×
3585
                plt.ylabel("zT", fontsize=30.0)
×
3586
            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
×
3587

3588
            p = "lower right" if i == 0 else "best"
×
3589
            plt.legend(loc=p, fontsize=15)
×
3590
            plt.grid()
×
3591
            plt.xticks(fontsize=25)
×
3592
            plt.yticks(fontsize=25)
×
3593

3594
        plt.tight_layout()
×
3595

3596
        return plt
×
3597

3598
    def plot_eff_mass_dop(self, temps="all", output="average"):
1✔
3599
        """
3600
        Plot the average effective mass in function of doping levels
3601
        for different temperatures.
3602

3603
        Args:
3604
            temps: the default 'all' plots all the temperatures in the analyzer.
3605
                   Specify a list of temperatures if you want to plot only some.
3606
            output: with 'average' you get an average of the three directions
3607
                    with 'eigs' you get all the three directions.
3608
            relaxation_time: specify a constant relaxation time value
3609

3610
        Returns:
3611
            a matplotlib object
3612
        """
3613
        if output == "average":
×
3614
            em = self._bz.get_average_eff_mass(output="average")
×
3615
        elif output == "eigs":
×
3616
            em = self._bz.get_average_eff_mass(output="eigs")
×
3617

3618
        tlist = sorted(em["n"]) if temps == "all" else temps
×
3619
        plt = pretty_plot(22, 14)
×
3620
        for i, dt in enumerate(["n", "p"]):
×
3621
            plt.subplot(121 + i)
×
3622
            for temp in tlist:
×
3623
                if output == "eigs":
×
3624
                    for xyz in range(3):
×
3625
                        plt.semilogx(
×
3626
                            self._bz.doping[dt],
3627
                            list(zip(*em[dt][temp]))[xyz],
3628
                            marker="s",
3629
                            label=str(xyz) + " " + str(temp) + " K",
3630
                        )
3631
                elif output == "average":
×
3632
                    plt.semilogx(
×
3633
                        self._bz.doping[dt],
3634
                        em[dt][temp],
3635
                        marker="s",
3636
                        label=str(temp) + " K",
3637
                    )
3638
            plt.title(dt + "-type", fontsize=20)
×
3639
            if i == 0:
×
3640
                plt.ylabel("Effective mass (m$_e$)", fontsize=30.0)
×
3641
            plt.xlabel("Doping concentration ($cm^{-3}$)", fontsize=30.0)
×
3642

3643
            p = "lower right" if i == 0 else "best"
×
3644
            plt.legend(loc=p, fontsize=15)
×
3645
            plt.grid()
×
3646
            plt.xticks(fontsize=25)
×
3647
            plt.yticks(fontsize=25)
×
3648

3649
        plt.tight_layout()
×
3650

3651
        return plt
×
3652

3653
    def plot_dos(self, sigma=0.05):
1✔
3654
        """Plot dos
3655

3656
        Args:
3657
            sigma: a smearing
3658

3659
        Returns:
3660
            a matplotlib object
3661
        """
3662
        plotter = DosPlotter(sigma=sigma)
×
3663
        plotter.add_dos("t", self._bz.dos)
×
3664
        return plotter.get_plot()
×
3665

3666
    def plot_carriers(self, temp=300):
1✔
3667
        """
3668
        Plot the carrier concentration in function of Fermi level
3669

3670
        Args:
3671
            temp: the temperature
3672

3673
        Returns:
3674
            a matplotlib object
3675
        """
3676
        plt = pretty_plot(9, 7)
×
3677
        carriers = [abs(c / (self._bz.vol * 1e-24)) for c in self._bz._carrier_conc[temp]]
×
3678
        plt.semilogy(self._bz.mu_steps, carriers, linewidth=3.0, color="r")
×
3679
        self._plot_bg_limits(plt)
×
3680
        self._plot_doping(plt, temp)
×
3681
        plt.xlim(-0.5, self._bz.gap + 0.5)
×
3682
        plt.ylim(1e14, 1e22)
×
3683
        plt.ylabel("carrier concentration (cm-3)", fontsize=30.0)
×
3684
        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
×
3685
        plt.xticks(fontsize=25)
×
3686
        plt.yticks(fontsize=25)
×
3687
        plt.tight_layout()
×
3688
        return plt
×
3689

3690
    def plot_hall_carriers(self, temp=300):
1✔
3691
        """
3692
        Plot the Hall carrier concentration in function of Fermi level
3693

3694
        Args:
3695
            temp: the temperature
3696

3697
        Returns:
3698
            a matplotlib object
3699
        """
3700
        plt = pretty_plot(9, 7)
×
3701
        hall_carriers = [abs(i) for i in self._bz.get_hall_carrier_concentration()[temp]]
×
3702
        plt.semilogy(self._bz.mu_steps, hall_carriers, linewidth=3.0, color="r")
×
3703
        self._plot_bg_limits(plt)
×
3704
        self._plot_doping(plt, temp)
×
3705
        plt.xlim(-0.5, self._bz.gap + 0.5)
×
3706
        plt.ylim(1e14, 1e22)
×
3707
        plt.ylabel("Hall carrier concentration (cm-3)", fontsize=30.0)
×
3708
        plt.xlabel("E-E$_f$ (eV)", fontsize=30)
×
3709
        plt.xticks(fontsize=25)
×
3710
        plt.yticks(fontsize=25)
×
3711
        plt.tight_layout()
×
3712
        return plt
×
3713

3714

3715
class CohpPlotter:
1✔
3716
    """
3717
    Class for plotting crystal orbital Hamilton populations (COHPs) or
3718
    crystal orbital overlap populations (COOPs). It is modeled after the
3719
    DosPlotter object.
3720
    """
3721

3722
    def __init__(self, zero_at_efermi=True, are_coops=False, are_cobis=False):
1✔
3723
        """
3724
        Args:
3725
            zero_at_efermi: Whether to shift all populations to have zero
3726
                energy at the Fermi level. Defaults to True.
3727
            are_coops: Switch to indicate that these are COOPs, not COHPs.
3728
                Defaults to False for COHPs.
3729
            are_cobis: Switch to indicate that these are COBIs, not COHPs/COOPs.
3730
                Defaults to False for COHPs
3731
        """
3732
        self.zero_at_efermi = zero_at_efermi
1✔
3733
        self.are_coops = are_coops
1✔
3734
        self.are_cobis = are_cobis
1✔
3735
        self._cohps = {}
1✔
3736

3737
    def add_cohp(self, label, cohp):
1✔
3738
        """
3739
        Adds a COHP for plotting.
3740

3741
        Args:
3742
            label: Label for the COHP. Must be unique.
3743

3744
            cohp: COHP object.
3745
        """
3746
        energies = cohp.energies - cohp.efermi if self.zero_at_efermi else cohp.energies
1✔
3747
        populations = cohp.get_cohp()
1✔
3748
        int_populations = cohp.get_icohp()
1✔
3749
        self._cohps[label] = {
1✔
3750
            "energies": energies,
3751
            "COHP": populations,
3752
            "ICOHP": int_populations,
3753
            "efermi": cohp.efermi,
3754
        }
3755

3756
    def add_cohp_dict(self, cohp_dict, key_sort_func=None):
1✔
3757
        """
3758
        Adds a dictionary of COHPs with an optional sorting function
3759
        for the keys.
3760

3761
        Args:
3762
            cohp_dict: dict of the form {label: Cohp}
3763

3764
            key_sort_func: function used to sort the cohp_dict keys.
3765
        """
3766
        if key_sort_func:
1✔
3767
            keys = sorted(cohp_dict, key=key_sort_func)
1✔
3768
        else:
3769
            keys = list(cohp_dict)
1✔
3770
        for label in keys:
1✔
3771
            self.add_cohp(label, cohp_dict[label])
1✔
3772

3773
    def get_cohp_dict(self):
1✔
3774
        """
3775
        Returns the added COHPs as a json-serializable dict. Note that if you
3776
        have specified smearing for the COHP plot, the populations returned
3777
        will be the smeared and not the original populations.
3778

3779
        Returns:
3780
            dict: Dict of COHP data of the form {label: {"efermi": efermi,
3781
            "energies": ..., "COHP": {Spin.up: ...}, "ICOHP": ...}}.
3782
        """
3783
        return jsanitize(self._cohps)
1✔
3784

3785
    def get_plot(
1✔
3786
        self,
3787
        xlim=None,
3788
        ylim=None,
3789
        plot_negative=None,
3790
        integrated=False,
3791
        invert_axes=True,
3792
    ):
3793
        """
3794
        Get a matplotlib plot showing the COHP.
3795

3796
        Args:
3797
            xlim: Specifies the x-axis limits. Defaults to None for
3798
                automatic determination.
3799

3800
            ylim: Specifies the y-axis limits. Defaults to None for
3801
                automatic determination.
3802

3803
            plot_negative: It is common to plot -COHP(E) so that the
3804
                sign means the same for COOPs and COHPs. Defaults to None
3805
                for automatic determination: If are_coops is True, this
3806
                will be set to False, else it will be set to True.
3807

3808
            integrated: Switch to plot ICOHPs. Defaults to False.
3809

3810
            invert_axes: Put the energies onto the y-axis, which is
3811
                common in chemistry.
3812

3813
        Returns:
3814
            A matplotlib object.
3815
        """
3816
        if self.are_coops:
1✔
3817
            cohp_label = "COOP"
1✔
3818
        elif self.are_cobis:
1✔
3819
            cohp_label = "COBI"
×
3820
        else:
3821
            cohp_label = "COHP"
1✔
3822

3823
        if plot_negative is None:
1✔
3824
            plot_negative = (not self.are_coops) and (not self.are_cobis)
1✔
3825

3826
        if integrated:
1✔
3827
            cohp_label = "I" + cohp_label + " (eV)"
1✔
3828

3829
        if plot_negative:
1✔
3830
            cohp_label = "-" + cohp_label
1✔
3831

3832
        if self.zero_at_efermi:
1✔
3833
            energy_label = "$E - E_f$ (eV)"
1✔
3834
        else:
3835
            energy_label = "$E$ (eV)"
1✔
3836

3837
        ncolors = max(3, len(self._cohps))
1✔
3838
        ncolors = min(9, ncolors)
1✔
3839

3840
        import palettable
1✔
3841

3842
        # pylint: disable=E1101
3843
        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors
1✔
3844

3845
        plt = pretty_plot(12, 8)
1✔
3846

3847
        allpts = []
1✔
3848
        keys = list(self._cohps)
1✔
3849
        for i, key in enumerate(keys):
1✔
3850
            energies = self._cohps[key]["energies"]
1✔
3851
            if not integrated:
1✔
3852
                populations = self._cohps[key]["COHP"]
1✔
3853
            else:
3854
                populations = self._cohps[key]["ICOHP"]
1✔
3855
            for spin in [Spin.up, Spin.down]:
1✔
3856
                if spin in populations:
1✔
3857
                    if invert_axes:
1✔
3858
                        x = -populations[spin] if plot_negative else populations[spin]
1✔
3859
                        y = energies
1✔
3860
                    else:
3861
                        x = energies
1✔
3862
                        y = -populations[spin] if plot_negative else populations[spin]
1✔
3863
                    allpts.extend(list(zip(x, y)))
1✔
3864
                    if spin == Spin.up:
1✔
3865
                        plt.plot(
1✔
3866
                            x,
3867
                            y,
3868
                            color=colors[i % ncolors],
3869
                            linestyle="-",
3870
                            label=str(key),
3871
                            linewidth=3,
3872
                        )
3873
                    else:
3874
                        plt.plot(x, y, color=colors[i % ncolors], linestyle="--", linewidth=3)
1✔
3875

3876
        if xlim:
1✔
3877
            plt.xlim(xlim)
×
3878
        if ylim:
1✔
3879
            plt.ylim(ylim)
×
3880
        else:
3881
            xlim = plt.xlim()
1✔
3882
            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
1✔
3883
            plt.ylim((min(relevanty), max(relevanty)))
1✔
3884

3885
        xlim = plt.xlim()
1✔
3886
        ylim = plt.ylim()
1✔
3887
        if not invert_axes:
1✔
3888
            plt.plot(xlim, [0, 0], "k-", linewidth=2)
1✔
3889
            if self.zero_at_efermi:
1✔
3890
                plt.plot([0, 0], ylim, "k--", linewidth=2)
×
3891
            else:
3892
                plt.plot(
1✔
3893
                    [self._cohps[key]["efermi"], self._cohps[key]["efermi"]],
3894
                    ylim,
3895
                    color=colors[i % ncolors],
3896
                    linestyle="--",
3897
                    linewidth=2,
3898
                )
3899
        else:
3900
            plt.plot([0, 0], ylim, "k-", linewidth=2)
1✔
3901
            if self.zero_at_efermi:
1✔
3902
                plt.plot(xlim, [0, 0], "k--", linewidth=2)
1✔
3903
            else:
3904
                plt.plot(
1✔
3905
                    xlim,
3906
                    [self._cohps[key]["efermi"], self._cohps[key]["efermi"]],
3907
                    color=colors[i % ncolors],
3908
                    linestyle="--",
3909
                    linewidth=2,
3910
                )
3911

3912
        if invert_axes:
1✔
3913
            plt.xlabel(cohp_label)
1✔
3914
            plt.ylabel(energy_label)
1✔
3915
        else:
3916
            plt.xlabel(energy_label)
1✔
3917
            plt.ylabel(cohp_label)
1✔
3918

3919
        plt.legend()
1✔
3920
        leg = plt.gca().get_legend()
1✔
3921
        ltext = leg.get_texts()
1✔
3922
        plt.setp(ltext, fontsize=30)
1✔
3923
        plt.tight_layout()
1✔
3924
        return plt
1✔
3925

3926
    def save_plot(self, filename, img_format="eps", xlim=None, ylim=None):
1✔
3927
        """
3928
        Save matplotlib plot to a file.
3929

3930
        Args:
3931
            filename: File name to write to.
3932
            img_format: Image format to use. Defaults to EPS.
3933
            xlim: Specifies the x-axis limits. Defaults to None for
3934
                automatic determination.
3935
            ylim: Specifies the y-axis limits. Defaults to None for
3936
                automatic determination.
3937
        """
3938
        plt = self.get_plot(xlim, ylim)
1✔
3939
        plt.savefig(filename, format=img_format)
1✔
3940

3941
    def show(self, xlim=None, ylim=None):
1✔
3942
        """
3943
        Show the plot using matplotlib.
3944

3945
        Args:
3946
            xlim: Specifies the x-axis limits. Defaults to None for
3947
                automatic determination.
3948
            ylim: Specifies the y-axis limits. Defaults to None for
3949
                automatic determination.
3950
        """
3951
        plt = self.get_plot(xlim, ylim)
×
3952
        plt.show()
×
3953

3954

3955
@requires(mlab is not None, "MayAvi mlab not imported! Please install mayavi.")
1✔
3956
def plot_fermi_surface(
1✔
3957
    data,
3958
    structure,
3959
    cbm,
3960
    energy_levels=None,
3961
    multiple_figure=True,
3962
    mlab_figure=None,
3963
    kpoints_dict=None,
3964
    colors=None,
3965
    transparency_factor=None,
3966
    labels_scale_factor=0.05,
3967
    points_scale_factor=0.02,
3968
    interactive=True,
3969
):
3970
    """
3971
    Plot the Fermi surface at specific energy value using Boltztrap 1 FERMI
3972
    mode.
3973

3974
    The easiest way to use this plotter is:
3975

3976
        1. Run boltztrap in 'FERMI' mode using BoltztrapRunner,
3977
        2. Load BoltztrapAnalyzer using your method of choice (e.g., from_files)
3978
        3. Pass in your BoltztrapAnalyzer's fermi_surface_data as this
3979
            function's data argument.
3980

3981
    Args:
3982
        data: energy values in a 3D grid from a CUBE file via read_cube_file
3983
            function, or from a BoltztrapAnalyzer.fermi_surface_data
3984
        structure: structure object of the material
3985
        energy_levels ([float]): Energy values for plotting the fermi surface(s)
3986
            By default 0 eV correspond to the VBM, as in the plot of band
3987
            structure along symmetry line.
3988
            Default: One surface, with max energy value + 0.01 eV
3989
        cbm (bool): Boolean value to specify if the considered band is a
3990
            conduction band or not
3991
        multiple_figure (bool): If True a figure for each energy level will be
3992
            shown. If False all the surfaces will be shown in the same figure.
3993
            In this last case, tune the transparency factor.
3994
        mlab_figure (mayavi.mlab.figure): A previous figure to plot a new
3995
            surface on.
3996
        kpoints_dict (dict): dictionary of kpoints to label in the plot.
3997
            Example: {"K":[0.5,0.0,0.5]}, coords are fractional
3998
        colors ([tuple]): Iterable of 3-tuples (r,g,b) of integers to define
3999
            the colors of each surface (one per energy level).
4000
            Should be the same length as the number of surfaces being plotted.
4001
            Example (3 surfaces): colors=[(1,0,0), (0,1,0), (0,0,1)]
4002
            Example (2 surfaces): colors=[(0, 0.5, 0.5)]
4003
        transparency_factor [float]: Values in the range [0,1] to tune the
4004
            opacity of each surface. Should be one transparency_factor per
4005
            surface.
4006
        labels_scale_factor (float): factor to tune size of the kpoint labels
4007
        points_scale_factor (float): factor to tune size of the kpoint points
4008
        interactive (bool): if True an interactive figure will be shown.
4009
            If False a non interactive figure will be shown, but it is possible
4010
            to plot other surfaces on the same figure. To make it interactive,
4011
            run mlab.show().
4012

4013
    Returns:
4014
        ((mayavi.mlab.figure, mayavi.mlab)): The mlab plotter and an interactive
4015
            figure to control the plot.
4016

4017
    Note: Experimental.
4018
          Please, double check the surface shown by using some
4019
          other software and report issues.
4020
    """
4021
    bz = structure.lattice.reciprocal_lattice.get_wigner_seitz_cell()
×
4022
    cell = structure.lattice.reciprocal_lattice.matrix
×
4023

4024
    fact = 1 if not cbm else -1
×
4025
    data_1d = data.ravel()
×
4026
    en_min = np.min(fact * data_1d)
×
4027
    en_max = np.max(fact * data_1d)
×
4028

4029
    if energy_levels is None:
×
4030
        energy_levels = [en_min + 0.01] if cbm else [en_max - 0.01]
×
4031
        print("Energy level set to: " + str(energy_levels[0]) + " eV")
×
4032

4033
    else:
4034
        for e in energy_levels:
×
4035
            if e > en_max or e < en_min:
×
4036
                raise BoltztrapError(
×
4037
                    "energy level "
4038
                    + str(e)
4039
                    + " not in the range of possible energies: ["
4040
                    + str(en_min)
4041
                    + ", "
4042
                    + str(en_max)
4043
                    + "]"
4044
                )
4045

4046
    n_surfaces = len(energy_levels)
×
4047
    if colors is None:
×
4048
        colors = [(0, 0, 1)] * n_surfaces
×
4049

4050
    if transparency_factor is None:
×
4051
        transparency_factor = [1] * n_surfaces
×
4052

4053
    if mlab_figure:
×
4054
        fig = mlab_figure
×
4055

4056
    if kpoints_dict is None:
×
4057
        kpoints_dict = {}
×
4058

4059
    if mlab_figure is None and not multiple_figure:
×
4060
        fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1))
×
4061
        for iface in range(len(bz)):  # pylint: disable=C0200
×
4062
            for line in itertools.combinations(bz[iface], 2):
×
4063
                for jface in range(len(bz)):  # pylint: disable=C0200
×
4064
                    if (
×
4065
                        iface < jface
4066
                        and any(np.all(line[0] == x) for x in bz[jface])
4067
                        and any(np.all(line[1] == x) for x in bz[jface])
4068
                    ):
4069
                        mlab.plot3d(
×
4070
                            *zip(line[0], line[1]),
4071
                            color=(0, 0, 0),
4072
                            tube_radius=None,
4073
                            figure=fig,
4074
                        )
4075
        for label, coords in kpoints_dict.items():
×
4076
            label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords)
×
4077
            mlab.points3d(
×
4078
                *label_coords,
4079
                scale_factor=points_scale_factor,
4080
                color=(0, 0, 0),
4081
                figure=fig,
4082
            )
4083
            mlab.text3d(
×
4084
                *label_coords,
4085
                text=label,
4086
                scale=labels_scale_factor,
4087
                color=(0, 0, 0),
4088
                figure=fig,
4089
            )
4090

4091
    for i, isolevel in enumerate(energy_levels):
×
4092
        alpha = transparency_factor[i]
×
4093
        color = colors[i]
×
4094
        if multiple_figure:
×
4095
            fig = mlab.figure(size=(1024, 768), bgcolor=(1, 1, 1))
×
4096

4097
            for iface in range(len(bz)):  # pylint: disable=C0200
×
4098
                for line in itertools.combinations(bz[iface], 2):
×
4099
                    for jface in range(len(bz)):
×
4100
                        if (
×
4101
                            iface < jface
4102
                            and any(np.all(line[0] == x) for x in bz[jface])
4103
                            and any(np.all(line[1] == x) for x in bz[jface])
4104
                        ):
4105
                            mlab.plot3d(
×
4106
                                *zip(line[0], line[1]),
4107
                                color=(0, 0, 0),
4108
                                tube_radius=None,
4109
                                figure=fig,
4110
                            )
4111

4112
            for label, coords in kpoints_dict.items():
×
4113
                label_coords = structure.lattice.reciprocal_lattice.get_cartesian_coords(coords)
×
4114
                mlab.points3d(
×
4115
                    *label_coords,
4116
                    scale_factor=points_scale_factor,
4117
                    color=(0, 0, 0),
4118
                    figure=fig,
4119
                )
4120
                mlab.text3d(
×
4121
                    *label_coords,
4122
                    text=label,
4123
                    scale=labels_scale_factor,
4124
                    color=(0, 0, 0),
4125
                    figure=fig,
4126
                )
4127

4128
        cp = mlab.contour3d(
×
4129
            fact * data,
4130
            contours=[isolevel],
4131
            transparent=True,
4132
            colormap="hot",
4133
            color=color,
4134
            opacity=alpha,
4135
            figure=fig,
4136
        )
4137

4138
        polydata = cp.actor.actors[0].mapper.input
×
4139
        pts = np.array(polydata.points)  # - 1
×
4140
        polydata.points = np.dot(pts, cell / np.array(data.shape)[:, np.newaxis])
×
4141

4142
        cx, cy, cz = (np.mean(np.array(polydata.points)[:, i]) for i in range(3))
×
4143

4144
        polydata.points = (np.array(polydata.points) - [cx, cy, cz]) * 2
×
4145

4146
        # mlab.view(distance='auto')
4147
        fig.scene.isometric_view()
×
4148

4149
    if interactive:
×
4150
        mlab.show()
×
4151

4152
    return fig, mlab
×
4153

4154

4155
def plot_wigner_seitz(lattice, ax=None, **kwargs):
1✔
4156
    """
4157
    Adds the skeleton of the Wigner-Seitz cell of the lattice to a matplotlib Axes
4158

4159
    Args:
4160
        lattice: Lattice object
4161
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4162
        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to black
4163
            and linewidth to 1.
4164

4165
    Returns:
4166
        matplotlib figure and matplotlib ax
4167
    """
4168
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4169

4170
    if "color" not in kwargs:
1✔
4171
        kwargs["color"] = "k"
1✔
4172
    if "linewidth" not in kwargs:
1✔
4173
        kwargs["linewidth"] = 1
1✔
4174

4175
    bz = lattice.get_wigner_seitz_cell()
1✔
4176
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4177
    for iface in range(len(bz)):  # pylint: disable=C0200
1✔
4178
        for line in itertools.combinations(bz[iface], 2):
1✔
4179
            for jface in range(len(bz)):
1✔
4180
                if (
1✔
4181
                    iface < jface
4182
                    and any(np.all(line[0] == x) for x in bz[jface])
4183
                    and any(np.all(line[1] == x) for x in bz[jface])
4184
                ):
4185
                    ax.plot(*zip(line[0], line[1]), **kwargs)
1✔
4186

4187
    return fig, ax
1✔
4188

4189

4190
def plot_lattice_vectors(lattice, ax=None, **kwargs):
1✔
4191
    """
4192
    Adds the basis vectors of the lattice provided to a matplotlib Axes
4193

4194
    Args:
4195
        lattice: Lattice object
4196
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4197
        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to green
4198
            and linewidth to 3.
4199

4200
    Returns:
4201
        matplotlib figure and matplotlib ax
4202
    """
4203
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4204

4205
    if "color" not in kwargs:
1✔
4206
        kwargs["color"] = "g"
1✔
4207
    if "linewidth" not in kwargs:
1✔
4208
        kwargs["linewidth"] = 3
1✔
4209

4210
    vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
1✔
4211
    vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
1✔
4212
    ax.plot(*zip(vertex1, vertex2), **kwargs)
1✔
4213
    vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
1✔
4214
    ax.plot(*zip(vertex1, vertex2), **kwargs)
1✔
4215
    vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
1✔
4216
    ax.plot(*zip(vertex1, vertex2), **kwargs)
1✔
4217

4218
    return fig, ax
1✔
4219

4220

4221
def plot_path(line, lattice=None, coords_are_cartesian=False, ax=None, **kwargs):
1✔
4222
    """
4223
    Adds a line passing through the coordinates listed in 'line' to a matplotlib Axes
4224

4225
    Args:
4226
        line: list of coordinates.
4227
        lattice: Lattice object used to convert from reciprocal to Cartesian coordinates
4228
        coords_are_cartesian: Set to True if you are providing
4229
            coordinates in Cartesian coordinates. Defaults to False.
4230
            Requires lattice if False.
4231
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4232
        kwargs: kwargs passed to the matplotlib function 'plot'. Color defaults to red
4233
            and linewidth to 3.
4234

4235
    Returns:
4236
        matplotlib figure and matplotlib ax
4237
    """
4238
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4239

4240
    if "color" not in kwargs:
1✔
4241
        kwargs["color"] = "r"
1✔
4242
    if "linewidth" not in kwargs:
1✔
4243
        kwargs["linewidth"] = 3
1✔
4244

4245
    for k in range(1, len(line)):
1✔
4246
        vertex1 = line[k - 1]
1✔
4247
        vertex2 = line[k]
1✔
4248
        if not coords_are_cartesian:
1✔
4249
            if lattice is None:
1✔
4250
                raise ValueError("coords_are_cartesian False requires the lattice")
×
4251
            vertex1 = lattice.get_cartesian_coords(vertex1)
1✔
4252
            vertex2 = lattice.get_cartesian_coords(vertex2)
1✔
4253
        ax.plot(*zip(vertex1, vertex2), **kwargs)
1✔
4254

4255
    return fig, ax
1✔
4256

4257

4258
def plot_labels(labels, lattice=None, coords_are_cartesian=False, ax=None, **kwargs):
1✔
4259
    """
4260
    Adds labels to a matplotlib Axes
4261

4262
    Args:
4263
        labels: dict containing the label as a key and the coordinates as value.
4264
        lattice: Lattice object used to convert from reciprocal to Cartesian coordinates
4265
        coords_are_cartesian: Set to True if you are providing.
4266
            coordinates in Cartesian coordinates. Defaults to False.
4267
            Requires lattice if False.
4268
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4269
        kwargs: kwargs passed to the matplotlib function 'text'. Color defaults to blue
4270
            and size to 25.
4271

4272
    Returns:
4273
        matplotlib figure and matplotlib ax
4274
    """
4275
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4276

4277
    if "color" not in kwargs:
1✔
4278
        kwargs["color"] = "b"
1✔
4279
    if "size" not in kwargs:
1✔
4280
        kwargs["size"] = 25
1✔
4281

4282
    for k, coords in labels.items():
1✔
4283
        label = k
1✔
4284
        if k.startswith("\\") or k.find("_") != -1:
1✔
4285
            label = "$" + k + "$"
1✔
4286
        off = 0.01
1✔
4287
        if coords_are_cartesian:
1✔
4288
            coords = np.array(coords)
×
4289
        else:
4290
            if lattice is None:
1✔
4291
                raise ValueError("coords_are_cartesian False requires the lattice")
×
4292
            coords = lattice.get_cartesian_coords(coords)
1✔
4293
        ax.text(*(coords + off), s=label, **kwargs)
1✔
4294

4295
    return fig, ax
1✔
4296

4297

4298
def fold_point(p, lattice, coords_are_cartesian=False):
1✔
4299
    """
4300
    Folds a point with coordinates p inside the first Brillouin zone of the lattice.
4301

4302
    Args:
4303
        p: coordinates of one point
4304
        lattice: Lattice object used to convert from reciprocal to Cartesian coordinates
4305
        coords_are_cartesian: Set to True if you are providing
4306
            coordinates in Cartesian coordinates. Defaults to False.
4307

4308
    Returns:
4309
        The Cartesian coordinates folded inside the first Brillouin zone
4310
    """
4311
    if coords_are_cartesian:
1✔
4312
        p = lattice.get_fractional_coords(p)
×
4313
    else:
4314
        p = np.array(p)
1✔
4315

4316
    p = np.mod(p + 0.5 - 1e-10, 1) - 0.5 + 1e-10
1✔
4317
    p = lattice.get_cartesian_coords(p)
1✔
4318

4319
    closest_lattice_point = None
1✔
4320
    smallest_distance = 10000
1✔
4321
    for i in (-1, 0, 1):
1✔
4322
        for j in (-1, 0, 1):
1✔
4323
            for k in (-1, 0, 1):
1✔
4324
                lattice_point = np.dot((i, j, k), lattice.matrix)
1✔
4325
                dist = np.linalg.norm(p - lattice_point)
1✔
4326
                if closest_lattice_point is None or dist < smallest_distance:
1✔
4327
                    closest_lattice_point = lattice_point
1✔
4328
                    smallest_distance = dist
1✔
4329

4330
    if not np.allclose(closest_lattice_point, (0, 0, 0)):
1✔
4331
        p = p - closest_lattice_point
×
4332

4333
    return p
1✔
4334

4335

4336
def plot_points(points, lattice=None, coords_are_cartesian=False, fold=False, ax=None, **kwargs):
1✔
4337
    """
4338
    Adds Points to a matplotlib Axes
4339

4340
    Args:
4341
        points: list of coordinates
4342
        lattice: Lattice object used to convert from reciprocal to Cartesian coordinates
4343
        coords_are_cartesian: Set to True if you are providing
4344
            coordinates in Cartesian coordinates. Defaults to False.
4345
            Requires lattice if False.
4346
        fold: whether the points should be folded inside the first Brillouin Zone.
4347
            Defaults to False. Requires lattice if True.
4348
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4349
        kwargs: kwargs passed to the matplotlib function 'scatter'. Color defaults to blue
4350

4351
    Returns:
4352
        matplotlib figure and matplotlib ax
4353
    """
4354
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4355

4356
    if "color" not in kwargs:
1✔
4357
        kwargs["color"] = "b"
1✔
4358

4359
    if (not coords_are_cartesian or fold) and lattice is None:
1✔
4360
        raise ValueError("coords_are_cartesian False or fold True require the lattice")
×
4361

4362
    for p in points:
1✔
4363
        if fold:
1✔
4364
            p = fold_point(p, lattice, coords_are_cartesian=coords_are_cartesian)
×
4365

4366
        elif not coords_are_cartesian:
1✔
4367
            p = lattice.get_cartesian_coords(p)
1✔
4368

4369
        ax.scatter(*p, **kwargs)
1✔
4370

4371
    return fig, ax
1✔
4372

4373

4374
@add_fig_kwargs
1✔
4375
def plot_brillouin_zone_from_kpath(kpath, ax=None, **kwargs):
1✔
4376
    """
4377
    Gives the plot (as a matplotlib object) of the symmetry line path in
4378
        the Brillouin Zone.
4379

4380
    Args:
4381
        kpath (HighSymmKpath): a HighSymmKPath object
4382
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4383
        **kwargs: provided by add_fig_kwargs decorator
4384

4385
    Returns:
4386
        matplotlib figure
4387

4388
    """
4389
    lines = [[kpath.kpath["kpoints"][k] for k in p] for p in kpath.kpath["path"]]
×
4390
    return plot_brillouin_zone(
×
4391
        bz_lattice=kpath.prim_rec,
4392
        lines=lines,
4393
        ax=ax,
4394
        labels=kpath.kpath["kpoints"],
4395
        **kwargs,
4396
    )
4397

4398

4399
@add_fig_kwargs
1✔
4400
def plot_brillouin_zone(
1✔
4401
    bz_lattice,
4402
    lines=None,
4403
    labels=None,
4404
    kpoints=None,
4405
    fold=False,
4406
    coords_are_cartesian=False,
4407
    ax=None,
4408
    **kwargs,
4409
):
4410
    """
4411
    Plots a 3D representation of the Brillouin zone of the structure.
4412
    Can add to the plot paths, labels and kpoints
4413

4414
    Args:
4415
        bz_lattice: Lattice object of the Brillouin zone
4416
        lines: list of lists of coordinates. Each list represent a different path
4417
        labels: dict containing the label as a key and the coordinates as value.
4418
        kpoints: list of coordinates
4419
        fold: whether the points should be folded inside the first Brillouin Zone.
4420
            Defaults to False. Requires lattice if True.
4421
        coords_are_cartesian: Set to True if you are providing
4422
            coordinates in Cartesian coordinates. Defaults to False.
4423
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4424
        kwargs: provided by add_fig_kwargs decorator
4425

4426
    Returns:
4427
        matplotlib figure
4428
    """
4429
    fig, ax = plot_lattice_vectors(bz_lattice, ax=ax)
1✔
4430
    plot_wigner_seitz(bz_lattice, ax=ax)
1✔
4431
    if lines is not None:
1✔
4432
        for line in lines:
1✔
4433
            plot_path(line, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax)
1✔
4434

4435
    if labels is not None:
1✔
4436
        plot_labels(labels, bz_lattice, coords_are_cartesian=coords_are_cartesian, ax=ax)
1✔
4437
        plot_points(
1✔
4438
            labels.values(),
4439
            bz_lattice,
4440
            coords_are_cartesian=coords_are_cartesian,
4441
            fold=False,
4442
            ax=ax,
4443
        )
4444

4445
    if kpoints is not None:
1✔
4446
        plot_points(
1✔
4447
            kpoints,
4448
            bz_lattice,
4449
            coords_are_cartesian=coords_are_cartesian,
4450
            ax=ax,
4451
            fold=fold,
4452
        )
4453

4454
    ax.set_xlim3d(-1, 1)
1✔
4455
    ax.set_ylim3d(-1, 1)
1✔
4456
    ax.set_zlim3d(-1, 1)
1✔
4457

4458
    # ax.set_aspect('equal')
4459
    ax.axis("off")
1✔
4460

4461
    return fig
1✔
4462

4463

4464
def plot_ellipsoid(
1✔
4465
    hessian,
4466
    center,
4467
    lattice=None,
4468
    rescale=1.0,
4469
    ax=None,
4470
    coords_are_cartesian=False,
4471
    arrows=False,
4472
    **kwargs,
4473
):
4474
    """
4475
    Plots a 3D ellipsoid rappresenting the Hessian matrix in input.
4476
    Useful to get a graphical visualization of the effective mass
4477
    of a band in a single k-point.
4478

4479
    Args:
4480
        hessian: the Hessian matrix
4481
        center: the center of the ellipsoid in reciprocal coords (Default)
4482
        lattice: Lattice object of the Brillouin zone
4483
        rescale: factor for size scaling of the ellipsoid
4484
        ax: matplotlib :class:`Axes` or None if a new figure should be created.
4485
        coords_are_cartesian: Set to True if you are providing a center in
4486
                              Cartesian coordinates. Defaults to False.
4487
        kwargs: kwargs passed to the matplotlib function 'plot_wireframe'.
4488
                Color defaults to blue, rstride and cstride
4489
                default to 4, alpha defaults to 0.2.
4490

4491
    Returns:
4492
        matplotlib figure and matplotlib ax
4493

4494
    Example of use:
4495
        fig,ax=plot_wigner_seitz(struct.reciprocal_lattice)
4496
        plot_ellipsoid(hessian,[0.0,0.0,0.0], struct.reciprocal_lattice,ax=ax)
4497
    """
4498
    if (not coords_are_cartesian) and lattice is None:
1✔
4499
        raise ValueError("coords_are_cartesian False or fold True require the lattice")
×
4500

4501
    if not coords_are_cartesian:
1✔
4502
        center = lattice.get_cartesian_coords(center)
1✔
4503

4504
    if "color" not in kwargs:
1✔
4505
        kwargs["color"] = "b"
1✔
4506
    if "rstride" not in kwargs:
1✔
4507
        kwargs["rstride"] = 4
1✔
4508
    if "cstride" not in kwargs:
1✔
4509
        kwargs["cstride"] = 4
1✔
4510
    if "alpha" not in kwargs:
1✔
4511
        kwargs["alpha"] = 0.2
1✔
4512

4513
    # calculate the ellipsoid
4514
    # find the rotation matrix and radii of the axes
4515
    U, s, rotation = np.linalg.svd(hessian)
1✔
4516
    radii = 1.0 / np.sqrt(s)
1✔
4517

4518
    # from polar coordinates
4519
    u = np.linspace(0.0, 2.0 * np.pi, 100)
1✔
4520
    v = np.linspace(0.0, np.pi, 100)
1✔
4521
    x = radii[0] * np.outer(np.cos(u), np.sin(v))
1✔
4522
    y = radii[1] * np.outer(np.sin(u), np.sin(v))
1✔
4523
    z = radii[2] * np.outer(np.ones_like(u), np.cos(v))
1✔
4524
    for i in range(len(x)):
1✔
4525
        for j in range(len(x)):
1✔
4526
            [x[i, j], y[i, j], z[i, j]] = np.dot([x[i, j], y[i, j], z[i, j]], rotation) * rescale + center
1✔
4527

4528
    # add the ellipsoid to the current axes
4529
    ax, fig, plt = get_ax3d_fig_plt(ax)
1✔
4530
    ax.plot_wireframe(x, y, z, **kwargs)
1✔
4531

4532
    if arrows:
1✔
4533
        color = ("b", "g", "r")
×
4534
        em = np.zeros((3, 3))
×
4535
        for i in range(3):
×
4536
            em[i, :] = rotation[i, :] / np.linalg.norm(rotation[i, :])
×
4537
        for i in range(3):
×
4538
            ax.quiver3D(
×
4539
                center[0],
4540
                center[1],
4541
                center[2],
4542
                em[i, 0],
4543
                em[i, 1],
4544
                em[i, 2],
4545
                pivot="tail",
4546
                arrow_length_ratio=0.2,
4547
                length=radii[i] * rescale,
4548
                color=color[i],
4549
            )
4550

4551
    return fig, ax
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc