• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.26
/pymatgen/analysis/phase_diagram.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module defines tools to generate and analyze phase diagrams.
6
"""
7

8
from __future__ import annotations
1✔
9

10
import collections
1✔
11
import itertools
1✔
12
import json
1✔
13
import logging
1✔
14
import math
1✔
15
import os
1✔
16
import re
1✔
17
import warnings
1✔
18
from functools import lru_cache
1✔
19
from io import StringIO
1✔
20
from typing import Any, Iterator, Literal, Sequence
1✔
21

22
import numpy as np
1✔
23
import plotly.graph_objs as go
1✔
24
from monty.json import MontyDecoder, MSONable
1✔
25
from scipy.optimize import minimize
1✔
26
from scipy.spatial import ConvexHull
1✔
27
from tqdm import tqdm
1✔
28

29
from pymatgen.analysis.reaction_calculator import Reaction, ReactionError
1✔
30
from pymatgen.core.composition import Composition
1✔
31
from pymatgen.core.periodic_table import DummySpecies, Element, get_el_sp
1✔
32
from pymatgen.entries import Entry
1✔
33
from pymatgen.util.coord import Simplex, in_coord_list
1✔
34
from pymatgen.util.plotting import pretty_plot
1✔
35
from pymatgen.util.string import htmlify, latexify
1✔
36
from pymatgen.util.typing import ArrayLike
1✔
37

38
logger = logging.getLogger(__name__)
1✔
39

40
with open(os.path.join(os.path.dirname(__file__), "..", "util", "plotly_pd_layouts.json")) as f:
1✔
41
    plotly_layouts = json.load(f)
1✔
42

43

44
class PDEntry(Entry):
1✔
45
    """
46
    An object encompassing all relevant data for phase diagrams.
47

48
    Attributes:
49
        composition (Composition): The composition associated with the PDEntry.
50
        energy (float): The energy associated with the entry.
51
        name (str):  A name for the entry. This is the string shown in the phase diagrams.
52
            By default, this is the reduced formula for the composition, but can be
53
            set to some other string for display purposes.
54
        attribute (MSONable): A arbitrary attribute. Can be used to specify that the
55
            entry is a newly found compound, or to specify a particular label for
56
            the entry, etc. An attribute can be anything but must be MSONable.
57
    """
58

59
    def __init__(
1✔
60
        self,
61
        composition: Composition,
62
        energy: float,
63
        name: str | None = None,
64
        attribute: object = None,
65
    ):
66
        """
67
        Args:
68
            composition (Composition): Composition
69
            energy (float): Energy for composition.
70
            name (str): Optional parameter to name the entry. Defaults
71
                to the reduced chemical formula.
72
            attribute: Optional attribute of the entry. Must be MSONable.
73
        """
74
        super().__init__(composition, energy)
1✔
75
        self.name = name or self.composition.reduced_formula
1✔
76
        self.attribute = attribute
1✔
77

78
    def __repr__(self):
1✔
79
        name = ""
1✔
80
        if self.name != self.composition.reduced_formula:
1✔
81
            name = f" ({self.name})"
1✔
82
        return f"{type(self).__name__} : {self.composition}{name} with energy = {self.energy:.4f}"
1✔
83

84
    @property
1✔
85
    def energy(self) -> float:
1✔
86
        """
87
        Returns:
88
            the energy of the entry.
89
        """
90
        return self._energy
1✔
91

92
    def as_dict(self):
1✔
93
        """
94
        Returns:
95
            MSONable dictionary representation of PDEntry
96
        """
97
        return_dict = super().as_dict()
1✔
98
        return_dict.update({"name": self.name, "attribute": self.attribute})
1✔
99
        return return_dict
1✔
100

101
    @classmethod
1✔
102
    def from_dict(cls, d):
1✔
103
        """
104
        Args:
105
            d (dict): dictionary representation of PDEntry
106

107
        Returns:
108
            PDEntry
109
        """
110
        return cls(
1✔
111
            Composition(d["composition"]),
112
            d["energy"],
113
            d["name"] if "name" in d else None,
114
            d["attribute"] if "attribute" in d else None,
115
        )
116

117

118
class GrandPotPDEntry(PDEntry):
1✔
119
    """
120
    A grand potential pd entry object encompassing all relevant data for phase
121
    diagrams. Chemical potentials are given as a element-chemical potential
122
    dict.
123
    """
124

125
    def __init__(self, entry, chempots, name=None):
1✔
126
        """
127
        Args:
128
            entry: A PDEntry-like object.
129
            chempots: Chemical potential specification as {Element: float}.
130
            name: Optional parameter to name the entry. Defaults to the reduced
131
                chemical formula of the original entry.
132
        """
133
        super().__init__(
1✔
134
            entry.composition,
135
            entry.energy,
136
            name or entry.name,
137
            entry.attribute if hasattr(entry, "attribute") else None,
138
        )
139
        # NOTE if we init GrandPotPDEntry from ComputedEntry _energy is the
140
        # corrected energy of the ComputedEntry hence the need to keep
141
        # the original entry to not lose data.
142
        self.original_entry = entry
1✔
143
        self.original_comp = self._composition
1✔
144
        self.chempots = chempots
1✔
145

146
    @property
1✔
147
    def composition(self) -> Composition:
1✔
148
        """The composition after removing free species
149

150
        Returns:
151
            Composition
152
        """
153
        return Composition({el: self._composition[el] for el in self._composition.elements if el not in self.chempots})
1✔
154

155
    @property
1✔
156
    def chemical_energy(self):
1✔
157
        """The chemical energy term mu*N in the grand potential
158

159
        Returns:
160
            The chemical energy term mu*N in the grand potential
161
        """
162
        return sum(self._composition[el] * pot for el, pot in self.chempots.items())
1✔
163

164
    @property
1✔
165
    def energy(self):
1✔
166
        """
167
        Returns:
168
            The grand potential energy
169
        """
170
        return self._energy - self.chemical_energy
1✔
171

172
    def __repr__(self):
1✔
173
        output = [
×
174
            f"GrandPotPDEntry with original composition {self.original_entry.composition}, "
175
            f"energy = {self.original_entry.energy:.4f}, ",
176
            "chempots = " + ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items()),
177
        ]
178
        return "".join(output)
×
179

180
    def as_dict(self):
1✔
181
        """
182
        Returns:
183
            MSONable dictionary representation of GrandPotPDEntry
184
        """
185
        return {
1✔
186
            "@module": type(self).__module__,
187
            "@class": type(self).__name__,
188
            "entry": self.original_entry.as_dict(),
189
            "chempots": {el.symbol: u for el, u in self.chempots.items()},
190
            "name": self.name,
191
        }
192

193
    @classmethod
1✔
194
    def from_dict(cls, d):
1✔
195
        """
196
        Args:
197
            d (dict): dictionary representation of GrandPotPDEntry
198

199
        Returns:
200
            GrandPotPDEntry
201
        """
202
        chempots = {Element(symbol): u for symbol, u in d["chempots"].items()}
1✔
203
        entry = MontyDecoder().process_decoded(d["entry"])
1✔
204
        return cls(entry, chempots, d["name"])
1✔
205

206

207
class TransformedPDEntry(PDEntry):
1✔
208
    """
209
    This class represents a TransformedPDEntry, which allows for a PDEntry to be
210
    transformed to a different composition coordinate space. It is used in the
211
    construction of phase diagrams that do not have elements as the terminal
212
    compositions.
213
    """
214

215
    # Tolerance for determining if amount of a composition is positive.
216
    amount_tol = 1e-5
1✔
217

218
    def __init__(self, entry, sp_mapping, name=None):
1✔
219
        """
220
        Args:
221
            entry (PDEntry): Original entry to be transformed.
222
            sp_mapping ({Composition: DummySpecies}): dictionary
223
                mapping Terminal Compositions to Dummy Species
224
        """
225
        super().__init__(
1✔
226
            entry.composition,
227
            entry.energy,
228
            name or entry.name,
229
            entry.attribute if hasattr(entry, "attribute") else None,
230
        )
231
        self.original_entry = entry
1✔
232
        self.sp_mapping = sp_mapping
1✔
233

234
        self.rxn = Reaction(list(self.sp_mapping), [self._composition])
1✔
235
        self.rxn.normalize_to(self.original_entry.composition)
1✔
236

237
        # NOTE We only allow reactions that have positive amounts of reactants.
238
        if not all(self.rxn.get_coeff(comp) <= TransformedPDEntry.amount_tol for comp in self.sp_mapping):
1✔
239
            raise TransformedPDEntryError("Only reactions with positive amounts of reactants allowed")
×
240

241
    @property
1✔
242
    def composition(self) -> Composition:
1✔
243
        """The composition in the dummy species space
244

245
        Returns:
246
            Composition
247
        """
248
        # NOTE this is not infallible as the original entry is mutable and an
249
        # end user could choose to normalize or change the original entry.
250
        # However, the risk of this seems low.
251
        factor = self._composition.num_atoms / self.original_entry.composition.num_atoms
1✔
252

253
        trans_comp = {self.sp_mapping[comp]: -self.rxn.get_coeff(comp) for comp in self.sp_mapping}
1✔
254

255
        trans_comp = {k: v * factor for k, v in trans_comp.items() if v > TransformedPDEntry.amount_tol}
1✔
256

257
        return Composition(trans_comp)
1✔
258

259
    def __repr__(self):
1✔
260
        output = [
1✔
261
            f"TransformedPDEntry {self.composition}",
262
            f" with original composition {self.original_entry.composition}",
263
            f", energy = {self.original_entry.energy:.4f}",
264
        ]
265
        return "".join(output)
1✔
266

267
    def as_dict(self):
1✔
268
        """
269
        Returns:
270
            MSONable dictionary representation of TransformedPDEntry
271
        """
272
        d = {
1✔
273
            "@module": type(self).__module__,
274
            "@class": type(self).__name__,
275
            "sp_mapping": self.sp_mapping,
276
        }
277
        d.update(self.original_entry.as_dict())
1✔
278
        return d
1✔
279

280
    @classmethod
1✔
281
    def from_dict(cls, d):
1✔
282
        """
283
        Args:
284
            d (dict): dictionary representation of TransformedPDEntry
285

286
        Returns:
287
            TransformedPDEntry
288
        """
289
        sp_mapping = d["sp_mapping"]
1✔
290
        del d["sp_mapping"]
1✔
291
        entry = MontyDecoder().process_decoded(d)
1✔
292
        return cls(entry, sp_mapping)
1✔
293

294

295
class TransformedPDEntryError(Exception):
1✔
296
    """
297
    An exception class for TransformedPDEntry.
298
    """
299

300

301
class PhaseDiagram(MSONable):
1✔
302
    """
303
    Simple phase diagram class taking in elements and entries as inputs.
304
    The algorithm is based on the work in the following papers:
305

306
    1. S. P. Ong, L. Wang, B. Kang, and G. Ceder, Li-Fe-P-O2 Phase Diagram from
307
        First Principles Calculations. Chem. Mater., 2008, 20(5), 1798-1807.
308
        doi:10.1021/cm702327g
309

310
    2. S. P. Ong, A. Jain, G. Hautier, B. Kang, G. Ceder, Thermal stabilities
311
        of delithiated olivine MPO4 (M=Fe, Mn) cathodes investigated using first
312
        principles calculations. Electrochem. Comm., 2010, 12(3), 427-430.
313
        doi:10.1016/j.elecom.2010.01.010
314

315
    Attributes:
316
        dim (int): The dimensionality of the phase diagram.
317
        elements: Elements in the phase diagram.
318
        el_refs: List of elemental references for the phase diagrams. These are
319
            entries corresponding to the lowest energy element entries for simple
320
            compositional phase diagrams.
321
        all_entries: All entries provided for Phase Diagram construction. Note that this
322
            does not mean that all these entries are actually used in the phase
323
            diagram. For example, this includes the positive formation energy
324
            entries that are filtered out before Phase Diagram construction.
325
        qhull_entries: Actual entries used in convex hull. Excludes all positive formation
326
            energy entries.
327
        qhull_data: Data used in the convex hull operation. This is essentially a matrix of
328
            composition data and energy per atom values created from qhull_entries.
329
        facets: Facets of the phase diagram in the form of  [[1,2,3],[4,5,6]...].
330
            For a ternary, it is the indices (references to qhull_entries and
331
            qhull_data) for the vertices of the phase triangles. Similarly
332
            extended to higher D simplices for higher dimensions.
333
        simplices: The simplices of the phase diagram as a list of np.ndarray, i.e.,
334
            the list of stable compositional coordinates in the phase diagram.
335
    """
336

337
    # Tolerance for determining if formation energy is positive.
338
    formation_energy_tol = 1e-11
1✔
339
    numerical_tol = 1e-8
1✔
340

341
    def __init__(
1✔
342
        self,
343
        entries: Sequence[PDEntry] | set[PDEntry],
344
        elements: Sequence[Element] = (),
345
        *,
346
        computed_data: dict[str, Any] | None = None,
347
    ) -> None:
348
        """
349
        Args:
350
            entries (list[PDEntry]): A list of PDEntry-like objects having an
351
                energy, energy_per_atom and composition.
352
            elements (list[Element]): Optional list of elements in the phase
353
                diagram. If set to None, the elements are determined from
354
                the entries themselves and are sorted alphabetically.
355
                If specified, element ordering (e.g. for pd coordinates)
356
                is preserved.
357
            computed_data (dict): A dict containing pre-computed data. This allows
358
                PhaseDiagram object to be reconstituted without performing the
359
                expensive convex hull computation. The dict is the output from the
360
                PhaseDiagram._compute() method and is stored in PhaseDiagram.computed_data
361
                when generated for the first time.
362
        """
363
        if not entries:
1✔
364
            raise ValueError("Unable to build phase diagram without entries.")
1✔
365

366
        self.elements = elements
1✔
367
        self.entries = entries
1✔
368
        if computed_data is None:
1✔
369
            computed_data = self._compute()
1✔
370
        else:
371
            computed_data = MontyDecoder().process_decoded(computed_data)
1✔
372
            assert isinstance(computed_data, dict)
1✔
373
            # update keys to be Element objects in case they are strings in pre-computed data
374
            computed_data["el_refs"] = [(Element(el_str), entry) for el_str, entry in computed_data["el_refs"]]
1✔
375
        self.computed_data = computed_data
1✔
376
        self.facets = computed_data["facets"]
1✔
377
        self.simplexes = computed_data["simplexes"]
1✔
378
        self.all_entries = computed_data["all_entries"]
1✔
379
        self.qhull_data = computed_data["qhull_data"]
1✔
380
        self.dim = computed_data["dim"]
1✔
381
        self.el_refs = dict(computed_data["el_refs"])
1✔
382
        self.qhull_entries = tuple(computed_data["qhull_entries"])
1✔
383
        self._qhull_spaces = tuple(frozenset(e.composition.elements) for e in self.qhull_entries)
1✔
384
        self._stable_entries = tuple({self.qhull_entries[i] for i in set(itertools.chain(*self.facets))})
1✔
385
        self._stable_spaces = tuple(frozenset(e.composition.elements) for e in self._stable_entries)
1✔
386

387
    def as_dict(self):
1✔
388
        """
389
        Returns:
390
            MSONable dictionary representation of PhaseDiagram
391
        """
392
        return {
1✔
393
            "@module": type(self).__module__,
394
            "@class": type(self).__name__,
395
            "all_entries": [e.as_dict() for e in self.all_entries],
396
            "elements": [e.as_dict() for e in self.elements],
397
            "computed_data": self.computed_data,
398
        }
399

400
    @classmethod
1✔
401
    def from_dict(cls, d: dict[str, Any]) -> PhaseDiagram:
1✔
402
        """
403
        Args:
404
            d (dict): dictionary representation of PhaseDiagram
405

406
        Returns:
407
            PhaseDiagram
408
        """
409
        entries = [MontyDecoder().process_decoded(dd) for dd in d["all_entries"]]
1✔
410
        elements = [Element.from_dict(dd) for dd in d["elements"]]
1✔
411
        computed_data = d.get("computed_data")
1✔
412
        return cls(entries, elements, computed_data=computed_data)
1✔
413

414
    def _compute(self) -> dict[str, Any]:
1✔
415
        if self.elements == ():
1✔
416
            self.elements = sorted({els for e in self.entries for els in e.composition.elements})
1✔
417

418
        elements = list(self.elements)
1✔
419
        dim = len(elements)
1✔
420

421
        entries = sorted(self.entries, key=lambda e: e.composition.reduced_composition)
1✔
422

423
        el_refs: dict[Element, PDEntry] = {}
1✔
424
        min_entries: list[PDEntry] = []
1✔
425
        all_entries: list[PDEntry] = []
1✔
426
        for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
1✔
427
            group = list(group_iter)
1✔
428
            min_entry = min(group, key=lambda e: e.energy_per_atom)
1✔
429
            if composition.is_element:
1✔
430
                el_refs[composition.elements[0]] = min_entry
1✔
431
            min_entries.append(min_entry)
1✔
432
            all_entries.extend(group)
1✔
433

434
        if missing := set(elements) - set(el_refs):
1✔
435
            raise ValueError(f"Missing terminal entries for elements {sorted(map(str, missing))}")
1✔
436
        if extra := set(el_refs) - set(elements):
1✔
437
            raise ValueError(f"There are more terminal elements than dimensions: {sorted(map(str, extra))}")
×
438

439
        data = np.array(
1✔
440
            [[e.composition.get_atomic_fraction(el) for el in elements] + [e.energy_per_atom] for e in min_entries]
441
        )
442

443
        # Use only entries with negative formation energy
444
        vec = [el_refs[el].energy_per_atom for el in elements] + [-1]
1✔
445
        form_e = -np.dot(data, vec)
1✔
446
        idx = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()
1✔
447

448
        # Add the elemental references
449
        idx.extend([min_entries.index(el) for el in el_refs.values()])
1✔
450

451
        qhull_entries = [min_entries[i] for i in idx]
1✔
452
        qhull_data = data[idx][:, 1:]
1✔
453

454
        # Add an extra point to enforce full dimensionality.
455
        # This point will be present in all upper hull facets.
456
        extra_point = np.zeros(dim) + 1 / dim
1✔
457
        extra_point[-1] = np.max(qhull_data) + 1
1✔
458
        qhull_data = np.concatenate([qhull_data, [extra_point]], axis=0)
1✔
459

460
        if dim == 1:
1✔
461
            facets = [qhull_data.argmin(axis=0)]
1✔
462
        else:
463
            facets = get_facets(qhull_data)
1✔
464
            final_facets = []
1✔
465
            for facet in facets:
1✔
466
                # Skip facets that include the extra point
467
                if max(facet) == len(qhull_data) - 1:
1✔
468
                    continue
1✔
469
                m = qhull_data[facet]
1✔
470
                m[:, -1] = 1
1✔
471
                if abs(np.linalg.det(m)) > 1e-14:
1✔
472
                    final_facets.append(facet)
1✔
473
            facets = final_facets
1✔
474

475
        simplexes = [Simplex(qhull_data[f, :-1]) for f in facets]
1✔
476
        self.elements = elements
1✔
477
        return dict(
1✔
478
            facets=facets,
479
            simplexes=simplexes,
480
            all_entries=all_entries,
481
            qhull_data=qhull_data,
482
            dim=dim,
483
            # Dictionary with Element keys is not JSON-serializable
484
            el_refs=list(el_refs.items()),
485
            qhull_entries=qhull_entries,
486
        )
487

488
    def pd_coords(self, comp: Composition) -> np.ndarray:
1✔
489
        """
490
        The phase diagram is generated in a reduced dimensional space
491
        (n_elements - 1). This function returns the coordinates in that space.
492
        These coordinates are compatible with the stored simplex objects.
493

494
        Args:
495
            comp (Composition): A composition
496

497
        Returns:
498
            The coordinates for a given composition in the PhaseDiagram's basis
499
        """
500
        if set(comp.elements) - set(self.elements):
1✔
501
            raise ValueError(f"{comp} has elements not in the phase diagram {self.elements}")
1✔
502
        return np.array([comp.get_atomic_fraction(el) for el in self.elements[1:]])
1✔
503

504
    @property
1✔
505
    def all_entries_hulldata(self):
1✔
506
        """
507
        Returns:
508
            The actual ndarray used to construct the convex hull.
509
        """
510
        data = [
1✔
511
            [e.composition.get_atomic_fraction(el) for el in self.elements] + [e.energy_per_atom]
512
            for e in self.all_entries
513
        ]
514
        return np.array(data)[:, 1:]
1✔
515

516
    @property
1✔
517
    def unstable_entries(self) -> set[Entry]:
1✔
518
        """
519
        Returns:
520
            set[Entry]: unstable entries in the phase diagram. Includes positive formation energy entries.
521
        """
522
        return {e for e in self.all_entries if e not in self.stable_entries}
1✔
523

524
    @property
1✔
525
    def stable_entries(self) -> set[Entry]:
1✔
526
        """
527
        Returns:
528
            set[Entry]: of stable entries in the phase diagram.
529
        """
530
        return set(self._stable_entries)
1✔
531

532
    @lru_cache(1)
1✔
533
    def _get_stable_entries_in_space(self, space) -> list[Entry]:
1✔
534
        """
535
        Args:
536
            space (set[Element]): set of Element objects
537

538
        Returns:
539
            list[Entry]: stable entries in the space.
540
        """
541
        return [e for e, s in zip(self._stable_entries, self._stable_spaces) if space.issuperset(s)]
1✔
542

543
    def get_reference_energy_per_atom(self, comp: Composition) -> float:
1✔
544
        """
545
        Args:
546
            comp (Composition): Input composition
547

548
        Returns:
549
            Reference energy of the terminal species at a given composition.
550
        """
551
        return sum(comp[el] * self.el_refs[el].energy_per_atom for el in comp.elements) / comp.num_atoms
×
552

553
    def get_form_energy(self, entry: PDEntry) -> float:
1✔
554
        """
555
        Returns the formation energy for an entry (NOT normalized) from the
556
        elemental references.
557

558
        Args:
559
            entry (PDEntry): A PDEntry-like object.
560

561
        Returns:
562
            float: Formation energy from the elemental references.
563
        """
564
        comp = entry.composition
1✔
565
        return entry.energy - sum(comp[el] * self.el_refs[el].energy_per_atom for el in comp.elements)
1✔
566

567
    def get_form_energy_per_atom(self, entry: PDEntry) -> float:
1✔
568
        """
569
        Returns the formation energy per atom for an entry from the
570
        elemental references.
571

572
        Args:
573
            entry (PDEntry): An PDEntry-like object
574

575
        Returns:
576
            Formation energy **per atom** from the elemental references.
577
        """
578
        return self.get_form_energy(entry) / entry.composition.num_atoms
1✔
579

580
    def __repr__(self) -> str:
1✔
581
        symbols = [el.symbol for el in self.elements]
1✔
582
        output = [
1✔
583
            f"{'-'.join(symbols)} phase diagram",
584
            f"{len(self.stable_entries)} stable phases: ",
585
            ", ".join(entry.name for entry in sorted(self.stable_entries, key=str)),
586
        ]
587
        return "\n".join(output)
1✔
588

589
    @lru_cache(1)
1✔
590
    def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]:
1✔
591
        """
592
        Get any facet that a composition falls into. Cached so successive
593
        calls at same composition are fast.
594

595
        Args:
596
            comp (Composition): A composition
597
        """
598
        c = self.pd_coords(comp)
1✔
599
        for f, s in zip(self.facets, self.simplexes):
1✔
600
            if s.in_simplex(c, PhaseDiagram.numerical_tol / 10):
1✔
601
                return f, s
1✔
602

603
        raise RuntimeError(f"No facet found for {comp = }")
×
604

605
    def _get_all_facets_and_simplexes(self, comp):
1✔
606
        """
607
        Get all facets that a composition falls into.
608

609
        Args:
610
            comp (Composition): A composition
611
        """
612
        c = self.pd_coords(comp)
1✔
613

614
        all_facets = [
1✔
615
            f for f, s in zip(self.facets, self.simplexes) if s.in_simplex(c, PhaseDiagram.numerical_tol / 10)
616
        ]
617

618
        if not all_facets:
1✔
619
            raise RuntimeError(f"No facets found for {comp = }")
×
620

621
        return all_facets
1✔
622

623
    def _get_facet_chempots(self, facet):
1✔
624
        """
625
        Calculates the chemical potentials for each element within a facet.
626

627
        Args:
628
            facet: Facet of the phase diagram.
629

630
        Returns:
631
            {element: chempot} for all elements in the phase diagram.
632
        """
633
        comp_list = [self.qhull_entries[i].composition for i in facet]
1✔
634
        energy_list = [self.qhull_entries[i].energy_per_atom for i in facet]
1✔
635
        m = [[c.get_atomic_fraction(e) for e in self.elements] for c in comp_list]
1✔
636
        chempots = np.linalg.solve(m, energy_list)
1✔
637

638
        return dict(zip(self.elements, chempots))
1✔
639

640
    def _get_simplex_intersections(self, c1, c2):
1✔
641
        """
642
        Returns coordinates of the intersection of the tie line between two compositions
643
        and the simplexes of the PhaseDiagram.
644

645
        Args:
646
            c1: Reduced dimension coordinates of first composition
647
            c2: Reduced dimension coordinates of second composition
648

649
        Returns:
650
            Array of the intersections between the tie line and the simplexes of
651
            the PhaseDiagram
652
        """
653
        intersections = [c1, c2]
1✔
654
        for sc in self.simplexes:
1✔
655
            intersections.extend(sc.line_intersection(c1, c2))
1✔
656

657
        return np.array(intersections)
1✔
658

659
    def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]:
1✔
660
        """
661
        Provides the decomposition at a particular composition.
662

663
        Args:
664
            comp (Composition): A composition
665

666
        Returns:
667
            Decomposition as a dict of {PDEntry: amount} where amount
668
            is the amount of the fractional composition.
669
        """
670
        facet, simplex = self._get_facet_and_simplex(comp)
1✔
671
        decomp_amts = simplex.bary_coords(self.pd_coords(comp))
1✔
672
        return {
1✔
673
            self.qhull_entries[f]: amt for f, amt in zip(facet, decomp_amts) if abs(amt) > PhaseDiagram.numerical_tol
674
        }
675

676
    def get_decomp_and_hull_energy_per_atom(self, comp: Composition) -> tuple[dict[PDEntry, float], float]:
1✔
677
        """
678
        Args:
679
            comp (Composition): Input composition
680

681
        Returns:
682
            Energy of lowest energy equilibrium at desired composition per atom
683
        """
684
        decomp = self.get_decomposition(comp)
1✔
685
        return decomp, sum(e.energy_per_atom * n for e, n in decomp.items())
1✔
686

687
    def get_hull_energy_per_atom(self, comp: Composition, **kwargs) -> float:
1✔
688
        """
689
        Args:
690
            comp (Composition): Input composition
691

692
        Returns:
693
            Energy of lowest energy equilibrium at desired composition.
694
        """
695
        return self.get_decomp_and_hull_energy_per_atom(comp, **kwargs)[1]
1✔
696

697
    def get_hull_energy(self, comp: Composition) -> float:
1✔
698
        """
699
        Args:
700
            comp (Composition): Input composition
701

702
        Returns:
703
            Energy of lowest energy equilibrium at desired composition. Not
704
                normalized by atoms, i.e. E(Li4O2) = 2 * E(Li2O)
705
        """
706
        return comp.num_atoms * self.get_hull_energy_per_atom(comp)
1✔
707

708
    def get_decomp_and_e_above_hull(
1✔
709
        self,
710
        entry: PDEntry,
711
        allow_negative: bool = False,
712
        check_stable: bool = True,
713
        on_error: Literal["raise", "warn", "ignore"] = "raise",
714
    ) -> tuple[dict[PDEntry, float], float] | tuple[None, None]:
715
        """
716
        Provides the decomposition and energy above convex hull for an entry.
717
        Due to caching, can be much faster if entries with the same composition
718
        are processed together.
719

720
        Args:
721
            entry (PDEntry): A PDEntry like object
722
            allow_negative (bool): Whether to allow negative e_above_hulls. Used to
723
                calculate equilibrium reaction energies. Defaults to False.
724
            check_stable (bool): Whether to first check whether an entry is stable.
725
                In normal circumstances, this is the faster option since checking for
726
                stable entries is relatively fast. However, if you have a huge proportion
727
                of unstable entries, then this check can slow things down. You should then
728
                set this to False.
729
            on_error ('raise' | 'warn' | 'ignore'): What to do if no valid decomposition was
730
                found. 'raise' will throw ValueError. 'warn' will print return (None, None).
731
                'ignore' just returns (None, None). Defaults to 'raise'.
732

733
        Raises:
734
            ValueError: If no valid decomposition exists in this phase diagram for given entry.
735

736
        Returns:
737
            (decomp, energy_above_hull). The decomposition is provided
738
                as a dict of {PDEntry: amount} where amount is the amount of the
739
                fractional composition. Stable entries should have energy above
740
                convex hull of 0. The energy is given per atom.
741
        """
742
        # Avoid computation for stable_entries.
743
        # NOTE scaled duplicates of stable_entries will not be caught.
744
        if check_stable and entry in self.stable_entries:
1✔
745
            return {entry: 1.0}, 0.0
1✔
746

747
        try:
1✔
748
            decomp, hull_energy = self.get_decomp_and_hull_energy_per_atom(entry.composition)
1✔
749
        except Exception as exc:
1✔
750
            if on_error == "raise":
1✔
751
                raise ValueError(f"Unable to get decomposition for {entry}") from exc
1✔
752
            elif on_error == "warn":
1✔
753
                warnings.warn(f"Unable to get decomposition for {entry}, encountered {exc}")
×
754
                return None, None
×
755
            else:
756
                return None, None
1✔
757
        e_above_hull = entry.energy_per_atom - hull_energy
1✔
758

759
        if allow_negative or e_above_hull >= -PhaseDiagram.numerical_tol:
1✔
760
            return decomp, e_above_hull
1✔
761

762
        msg = f"No valid decomposition found for {entry}! (e_h: {e_above_hull})"
1✔
763
        if on_error == "raise":
1✔
764
            raise ValueError(msg)
1✔
765
        elif on_error == "warn":
1✔
766
            warnings.warn(msg)
1✔
767
        return None, None  # 'ignore' and 'warn' case
1✔
768

769
    def get_e_above_hull(self, entry: PDEntry, **kwargs: Any) -> float | None:
1✔
770
        """
771
        Provides the energy above convex hull for an entry
772

773
        Args:
774
            entry (PDEntry): A PDEntry like object
775

776
        Returns:
777
            float | None: Energy above convex hull of entry. Stable entries should have
778
                energy above hull of 0. The energy is given per atom.
779
        """
780
        return self.get_decomp_and_e_above_hull(entry, **kwargs)[1]
1✔
781

782
    def get_equilibrium_reaction_energy(self, entry: PDEntry) -> float | None:
1✔
783
        """
784
        Provides the reaction energy of a stable entry from the neighboring
785
        equilibrium stable entries (also known as the inverse distance to
786
        hull).
787

788
        Args:
789
            entry (PDEntry): A PDEntry like object
790

791
        Returns:
792
            float | None: Equilibrium reaction energy of entry. Stable entries should have
793
                equilibrium reaction energy <= 0. The energy is given per atom.
794
        """
795
        elem_space = entry.composition.elements
1✔
796

797
        # NOTE scaled duplicates of stable_entries will not be caught.
798
        if entry not in self._get_stable_entries_in_space(frozenset(elem_space)):
1✔
799
            raise ValueError(
×
800
                f"{entry} is unstable, the equilibrium reaction energy is available only for stable entries."
801
            )
802

803
        if entry.is_element:
1✔
804
            return 0
1✔
805

806
        entries = [e for e in self._get_stable_entries_in_space(frozenset(elem_space)) if e != entry]
1✔
807
        modpd = PhaseDiagram(entries, elements=elem_space)
1✔
808

809
        return modpd.get_decomp_and_e_above_hull(entry, allow_negative=True)[1]
1✔
810

811
    def get_decomp_and_phase_separation_energy(
1✔
812
        self,
813
        entry: PDEntry,
814
        space_limit: int = 200,
815
        stable_only: bool = False,
816
        tols: Sequence[float] = (1e-8,),
817
        maxiter: int = 1000,
818
        **kwargs: Any,
819
    ) -> tuple[dict[PDEntry, float], float] | tuple[None, None]:
820
        """
821
        Provides the combination of entries in the PhaseDiagram that gives the
822
        lowest formation enthalpy with the same composition as the given entry
823
        excluding entries with the same composition and the energy difference
824
        per atom between the given entry and the energy of the combination found.
825

826
        For unstable entries that are not polymorphs of stable entries (or completely
827
        novel entries) this is simply the energy above (or below) the convex hull.
828

829
        For entries with the same composition as one of the stable entries in the
830
        phase diagram setting `stable_only` to `False` (Default) allows for entries
831
        not previously on the convex hull to be considered in the combination.
832
        In this case the energy returned is what is referred to as the decomposition
833
        enthalpy in:
834

835
        1. Bartel, C., Trewartha, A., Wang, Q., Dunn, A., Jain, A., Ceder, G.,
836
            A critical examination of compound stability predictions from
837
            machine-learned formation energies, npj Computational Materials 6, 97 (2020)
838

839
        For stable entries setting `stable_only` to `True` returns the same energy
840
        as `get_equilibrium_reaction_energy`. This function is based on a constrained
841
        optimization rather than recalculation of the convex hull making it
842
        algorithmically cheaper. However, if `tol` is too loose there is potential
843
        for this algorithm to converge to a different solution.
844

845
        Args:
846
            entry (PDEntry): A PDEntry like object.
847
            space_limit (int): The maximum number of competing entries to consider
848
                before calculating a second convex hull to reducing the complexity
849
                of the optimization.
850
            stable_only (bool): Only use stable materials as competing entries.
851
            tols (list[float]): Tolerances for convergence of the SLSQP optimization
852
                when finding the equilibrium reaction. Tighter tolerances tested first.
853
            maxiter (int): The maximum number of iterations of the SLSQP optimizer
854
                when finding the equilibrium reaction.
855

856
        Returns:
857
            (decomp, energy). The decomposition  is given as a dict of {PDEntry, amount}
858
            for all entries in the decomp reaction where amount is the amount of the
859
            fractional composition. The phase separation energy is given per atom.
860
        """
861
        entry_frac = entry.composition.fractional_composition
1✔
862
        entry_elems = frozenset(entry_frac.elements)
1✔
863

864
        # Handle elemental materials
865
        if entry.is_element:
1✔
866
            return self.get_decomp_and_e_above_hull(entry, allow_negative=True, **kwargs)
1✔
867

868
        # Select space to compare against
869
        if stable_only:
1✔
870
            compare_entries = self._get_stable_entries_in_space(entry_elems)
1✔
871
        else:
872
            compare_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if entry_elems.issuperset(s)]
1✔
873

874
        # get memory ids of entries with the same composition.
875
        same_comp_mem_ids = [
1✔
876
            id(c)
877
            for c in compare_entries
878
            if (  # NOTE use this construction to avoid calls to fractional_composition
879
                len(entry_frac) == len(c.composition)
880
                and all(
881
                    abs(v - c.composition.get_atomic_fraction(el)) <= Composition.amount_tolerance
882
                    for el, v in entry_frac.items()
883
                )
884
            )
885
        ]
886

887
        if not any(id(e) in same_comp_mem_ids for e in self._get_stable_entries_in_space(entry_elems)):
1✔
888
            return self.get_decomp_and_e_above_hull(entry, allow_negative=True, **kwargs)
1✔
889

890
        # take entries with negative e_form and different compositions as competing entries
891
        competing_entries = {c for c in compare_entries if id(c) not in same_comp_mem_ids}
1✔
892

893
        # NOTE SLSQP optimizer doesn't scale well for > 300 competing entries.
894
        if len(competing_entries) > space_limit and not stable_only:
1✔
895
            warnings.warn(
×
896
                f"There are {len(competing_entries)} competing entries "
897
                f"for {entry.composition} - Calculating inner hull to discard "
898
                "additional unstable entries"
899
            )
900

901
            reduced_space = competing_entries - {*self._get_stable_entries_in_space(entry_elems)} | {
×
902
                *self.el_refs.values()
903
            }
904

905
            # NOTE calling PhaseDiagram is only reasonable if the composition has fewer than 5 elements
906
            # TODO can we call PatchedPhaseDiagram here?
907
            inner_hull = PhaseDiagram(reduced_space)
×
908

909
            competing_entries = inner_hull.stable_entries | {*self._get_stable_entries_in_space(entry_elems)}
×
910
            competing_entries = {c for c in compare_entries if id(c) not in same_comp_mem_ids}
×
911

912
        if len(competing_entries) > space_limit:
1✔
913
            warnings.warn(
×
914
                f"There are {len(competing_entries)} competing entries "
915
                f"for {entry.composition} - Using SLSQP to find "
916
                "decomposition likely to be slow"
917
            )
918

919
        decomp = _get_slsqp_decomp(entry.composition, competing_entries, tols, maxiter)
1✔
920

921
        # find the minimum alternative formation energy for the decomposition
922
        decomp_enthalpy = np.sum([c.energy_per_atom * amt for c, amt in decomp.items()])
1✔
923

924
        decomp_enthalpy = entry.energy_per_atom - decomp_enthalpy
1✔
925

926
        return decomp, decomp_enthalpy
1✔
927

928
    def get_phase_separation_energy(self, entry, **kwargs):
1✔
929
        """
930
        Provides the energy to the convex hull for the given entry. For stable entries
931
        already in the phase diagram the algorithm provides the phase separation energy
932
        which is referred to as the decomposition enthalpy in:
933

934
        1. Bartel, C., Trewartha, A., Wang, Q., Dunn, A., Jain, A., Ceder, G.,
935
            A critical examination of compound stability predictions from
936
            machine-learned formation energies, npj Computational Materials 6, 97 (2020)
937

938
        Args:
939
            entry (PDEntry): A PDEntry like object
940
            **kwargs: Keyword args passed to `get_decomp_and_decomp_energy`
941
                space_limit (int): The maximum number of competing entries to consider.
942
                stable_only (bool): Only use stable materials as competing entries
943
                tol (float): The tolerance for convergence of the SLSQP optimization
944
                    when finding the equilibrium reaction.
945
                maxiter (int): The maximum number of iterations of the SLSQP optimizer
946
                    when finding the equilibrium reaction.
947

948
        Returns:
949
            phase separation energy per atom of entry. Stable entries should have
950
            energies <= 0, Stable elemental entries should have energies = 0 and
951
            unstable entries should have energies > 0. Entries that have the same
952
            composition as a stable energy may have positive or negative phase
953
            separation energies depending on their own energy.
954
        """
955
        return self.get_decomp_and_phase_separation_energy(entry, **kwargs)[1]
1✔
956

957
    def get_composition_chempots(self, comp):
1✔
958
        """
959
        Get the chemical potentials for all elements at a given composition.
960

961
        Args:
962
            comp (Composition): Composition
963

964
        Returns:
965
            Dictionary of chemical potentials.
966
        """
967
        facet = self._get_facet_and_simplex(comp)[0]
1✔
968
        return self._get_facet_chempots(facet)
1✔
969

970
    def get_all_chempots(self, comp):
1✔
971
        """
972
        Get chemical potentials at a given composition.
973

974
        Args:
975
            comp (Composition): Composition
976

977
        Returns:
978
            Chemical potentials.
979
        """
980
        all_facets = self._get_all_facets_and_simplexes(comp)
1✔
981

982
        chempots = {}
1✔
983
        for facet in all_facets:
1✔
984
            facet_name = "-".join(self.qhull_entries[j].name for j in facet)
1✔
985
            chempots[facet_name] = self._get_facet_chempots(facet)
1✔
986

987
        return chempots
1✔
988

989
    def get_transition_chempots(self, element):
1✔
990
        """
991
        Get the critical chemical potentials for an element in the Phase
992
        Diagram.
993

994
        Args:
995
            element: An element. Has to be in the PD in the first place.
996

997
        Returns:
998
            A sorted sequence of critical chemical potentials, from less
999
            negative to more negative.
1000
        """
1001
        if element not in self.elements:
1✔
1002
            raise ValueError("get_transition_chempots can only be called with elements in the phase diagram.")
×
1003

1004
        critical_chempots = []
1✔
1005
        for facet in self.facets:
1✔
1006
            chempots = self._get_facet_chempots(facet)
1✔
1007
            critical_chempots.append(chempots[element])
1✔
1008

1009
        clean_pots = []
1✔
1010
        for c in sorted(critical_chempots):
1✔
1011
            if len(clean_pots) == 0:
1✔
1012
                clean_pots.append(c)
1✔
1013
            else:
1014
                if abs(c - clean_pots[-1]) > PhaseDiagram.numerical_tol:
1✔
1015
                    clean_pots.append(c)
1✔
1016
        clean_pots.reverse()
1✔
1017
        return tuple(clean_pots)
1✔
1018

1019
    def get_critical_compositions(self, comp1, comp2):
1✔
1020
        """
1021
        Get the critical compositions along the tieline between two
1022
        compositions. I.e. where the decomposition products change.
1023
        The endpoints are also returned.
1024

1025
        Args:
1026
            comp1, comp2 (Composition): compositions that define the tieline
1027

1028
        Returns:
1029
            [(Composition)]: list of critical compositions. All are of
1030
                the form x * comp1 + (1-x) * comp2
1031
        """
1032
        n1 = comp1.num_atoms
1✔
1033
        n2 = comp2.num_atoms
1✔
1034
        pd_els = self.elements
1✔
1035

1036
        # NOTE the reduced dimensionality Simplexes don't use the
1037
        # first element in the PD
1038
        c1 = self.pd_coords(comp1)
1✔
1039
        c2 = self.pd_coords(comp2)
1✔
1040

1041
        # NOTE none of the projections work if c1 == c2, so just
1042
        # return *copies* of the inputs
1043
        if np.all(c1 == c2):
1✔
1044
            return [comp1.copy(), comp2.copy()]
1✔
1045

1046
        # NOTE made into method to facilitate inheritance of this method
1047
        # in PatchedPhaseDiagram if approximate solution can be found.
1048
        intersections = self._get_simplex_intersections(c1, c2)
1✔
1049

1050
        # find position along line
1051
        l = c2 - c1
1✔
1052
        l /= np.sum(l**2) ** 0.5
1✔
1053
        proj = np.dot(intersections - c1, l)
1✔
1054

1055
        # only take compositions between endpoints
1056
        proj = proj[
1✔
1057
            np.logical_and(proj > -self.numerical_tol, proj < proj[1] + self.numerical_tol)  # proj[1] is |c2-c1|
1058
        ]
1059
        proj.sort()
1✔
1060

1061
        # only unique compositions
1062
        valid = np.ones(len(proj), dtype=bool)
1✔
1063
        valid[1:] = proj[1:] > proj[:-1] + self.numerical_tol
1✔
1064
        proj = proj[valid]
1✔
1065

1066
        ints = c1 + l * proj[:, None]
1✔
1067

1068
        # reconstruct full-dimensional composition array
1069
        cs = np.concatenate([np.array([1 - np.sum(ints, axis=-1)]).T, ints], axis=-1)
1✔
1070

1071
        # mixing fraction when compositions are normalized
1072
        x = proj / np.dot(c2 - c1, l)
1✔
1073

1074
        # mixing fraction when compositions are not normalized
1075
        x_unnormalized = x * n1 / (n2 + x * (n1 - n2))
1✔
1076
        num_atoms = n1 + (n2 - n1) * x_unnormalized
1✔
1077
        cs *= num_atoms[:, None]
1✔
1078

1079
        return [Composition((c, v) for c, v in zip(pd_els, m)) for m in cs]
1✔
1080

1081
    def get_element_profile(self, element, comp, comp_tol=1e-5):
1✔
1082
        """
1083
        Provides the element evolution data for a composition.
1084
        For example, can be used to analyze Li conversion voltages by varying
1085
        uLi and looking at the phases formed. Also can be used to analyze O2
1086
        evolution by varying uO2.
1087

1088
        Args:
1089
            element: An element. Must be in the phase diagram.
1090
            comp: A Composition
1091
            comp_tol: The tolerance to use when calculating decompositions.
1092
                Phases with amounts less than this tolerance are excluded.
1093
                Defaults to 1e-5.
1094

1095
        Returns:
1096
            Evolution data as a list of dictionaries of the following format:
1097
            [ {'chempot': -10.487582010000001, 'evolution': -2.0,
1098
            'reaction': Reaction Object], ...]
1099
        """
1100
        element = get_el_sp(element)
1✔
1101

1102
        if element not in self.elements:
1✔
1103
            raise ValueError("get_transition_chempots can only be called with elements in the phase diagram.")
×
1104

1105
        gc_comp = Composition({el: amt for el, amt in comp.items() if el != element})
1✔
1106
        el_ref = self.el_refs[element]
1✔
1107
        el_comp = Composition(element.symbol)
1✔
1108
        evolution = []
1✔
1109

1110
        for cc in self.get_critical_compositions(el_comp, gc_comp)[1:]:
1✔
1111
            decomp_entries = list(self.get_decomposition(cc))
1✔
1112
            decomp = [k.composition for k in decomp_entries]
1✔
1113
            rxn = Reaction([comp], decomp + [el_comp])
1✔
1114
            rxn.normalize_to(comp)
1✔
1115
            c = self.get_composition_chempots(cc + el_comp * 1e-5)[element]
1✔
1116
            amt = -rxn.coeffs[rxn.all_comp.index(el_comp)]
1✔
1117
            evolution.append(
1✔
1118
                {
1119
                    "chempot": c,
1120
                    "evolution": amt,
1121
                    "element_reference": el_ref,
1122
                    "reaction": rxn,
1123
                    "entries": decomp_entries,
1124
                    "critical_composition": cc,
1125
                }
1126
            )
1127
        return evolution
1✔
1128

1129
    def get_chempot_range_map(
1✔
1130
        self, elements: Sequence[Element], referenced: bool = True, joggle: bool = True
1131
    ) -> dict[Element, list[Simplex]]:
1132
        """
1133
        Returns a chemical potential range map for each stable entry.
1134

1135
        Args:
1136
            elements: Sequence of elements to be considered as independent
1137
                variables. E.g., if you want to show the stability ranges
1138
                of all Li-Co-O phases wrt to uLi and uO, you will supply
1139
                [Element("Li"), Element("O")]
1140
            referenced: If True, gives the results with a reference being the
1141
                energy of the elemental phase. If False, gives absolute values.
1142
            joggle (bool): Whether to joggle the input to avoid precision
1143
                errors.
1144

1145
        Returns:
1146
            Returns a dict of the form {entry: [simplices]}. The list of
1147
            simplices are the sides of the N-1 dim polytope bounding the
1148
            allowable chemical potential range of each entry.
1149
        """
1150
        all_chempots = []
1✔
1151
        for facet in self.facets:
1✔
1152
            chempots = self._get_facet_chempots(facet)
1✔
1153
            all_chempots.append([chempots[el] for el in self.elements])
1✔
1154

1155
        inds = [self.elements.index(el) for el in elements]
1✔
1156

1157
        if referenced:
1✔
1158
            el_energies = {el: self.el_refs[el].energy_per_atom for el in elements}
1✔
1159
        else:
1160
            el_energies = {el: 0.0 for el in elements}
×
1161

1162
        chempot_ranges = collections.defaultdict(list)
1✔
1163
        vertices = [list(range(len(self.elements)))]
1✔
1164

1165
        if len(all_chempots) > len(self.elements):
1✔
1166
            vertices = get_facets(all_chempots, joggle=joggle)
1✔
1167

1168
        for ufacet in vertices:
1✔
1169
            for combi in itertools.combinations(ufacet, 2):
1✔
1170
                data1 = self.facets[combi[0]]
1✔
1171
                data2 = self.facets[combi[1]]
1✔
1172
                common_ent_ind = set(data1).intersection(set(data2))
1✔
1173
                if len(common_ent_ind) == len(elements):
1✔
1174
                    common_entries = [self.qhull_entries[i] for i in common_ent_ind]
1✔
1175
                    data = np.array([[all_chempots[i][j] - el_energies[self.elements[j]] for j in inds] for i in combi])
1✔
1176
                    sim = Simplex(data)
1✔
1177
                    for entry in common_entries:
1✔
1178
                        chempot_ranges[entry].append(sim)
1✔
1179

1180
        return chempot_ranges
1✔
1181

1182
    def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
1✔
1183
        """
1184
        Returns a set of chemical potentials corresponding to the vertices of
1185
        the simplex in the chemical potential phase diagram.
1186
        The simplex is built using all elements in the target_composition
1187
        except dep_elt.
1188
        The chemical potential of dep_elt is computed from the target
1189
        composition energy.
1190
        This method is useful to get the limiting conditions for
1191
        defects computations for instance.
1192

1193
        Args:
1194
            target_comp: A Composition object
1195
            dep_elt: the element for which the chemical potential is computed
1196
                from the energy of the stable phase at the target composition
1197
            tol_en: a tolerance on the energy to set
1198

1199
        Returns:
1200
             [{Element: mu}]: An array of conditions on simplex vertices for
1201
             which each element has a chemical potential set to a given
1202
             value. "absolute" values (i.e., not referenced to element energies)
1203
        """
1204
        mu_ref = np.array([self.el_refs[e].energy_per_atom for e in self.elements if e != dep_elt])
1✔
1205
        chempot_ranges = self.get_chempot_range_map([e for e in self.elements if e != dep_elt])
1✔
1206

1207
        for e in self.elements:
1✔
1208
            if e not in target_comp.elements:
1✔
1209
                target_comp = target_comp + Composition({e: 0.0})
×
1210

1211
        coeff = [-target_comp[e] for e in self.elements if e != dep_elt]
1✔
1212

1213
        for e, chempots in chempot_ranges.items():
1✔
1214
            if e.composition.reduced_composition == target_comp.reduced_composition:
1✔
1215
                multiplier = e.composition[dep_elt] / target_comp[dep_elt]
1✔
1216
                ef = e.energy / multiplier
1✔
1217
                all_coords = []
1✔
1218
                for s in chempots:
1✔
1219
                    for v in s._coords:
1✔
1220
                        elements = [e for e in self.elements if e != dep_elt]
1✔
1221
                        res = {}
1✔
1222
                        for i, el in enumerate(elements):
1✔
1223
                            res[el] = v[i] + mu_ref[i]
1✔
1224
                        res[dep_elt] = (np.dot(v + mu_ref, coeff) + ef) / target_comp[dep_elt]
1✔
1225
                        already_in = False
1✔
1226
                        for di in all_coords:
1✔
1227
                            dict_equals = True
1✔
1228
                            for k in di:
1✔
1229
                                if abs(di[k] - res[k]) > tol_en:
1✔
1230
                                    dict_equals = False
1✔
1231
                                    break
1✔
1232
                            if dict_equals:
1✔
1233
                                already_in = True
1✔
1234
                                break
1✔
1235
                        if not already_in:
1✔
1236
                            all_coords.append(res)
1✔
1237
        return all_coords
1✔
1238

1239
    def get_chempot_range_stability_phase(self, target_comp, open_elt):
1✔
1240
        """
1241
        Returns a set of chemical potentials corresponding to the max and min
1242
        chemical potential of the open element for a given composition. It is
1243
        quite common to have for instance a ternary oxide (e.g., ABO3) for
1244
        which you want to know what are the A and B chemical potential leading
1245
        to the highest and lowest oxygen chemical potential (reducing and
1246
        oxidizing conditions). This is useful for defect computations.
1247

1248
        Args:
1249
            target_comp: A Composition object
1250
            open_elt: Element that you want to constrain to be max or min
1251

1252
        Returns:
1253
             {Element: (mu_min, mu_max)}: Chemical potentials are given in
1254
             "absolute" values (i.e., not referenced to 0)
1255
        """
1256
        muref = np.array([self.el_refs[e].energy_per_atom for e in self.elements if e != open_elt])
1✔
1257
        chempot_ranges = self.get_chempot_range_map([e for e in self.elements if e != open_elt])
1✔
1258
        for e in self.elements:
1✔
1259
            if e not in target_comp.elements:
1✔
1260
                target_comp = target_comp + Composition({e: 0.0})
×
1261

1262
        coeff = [-target_comp[e] for e in self.elements if e != open_elt]
1✔
1263
        max_open = -float("inf")
1✔
1264
        min_open = float("inf")
1✔
1265
        max_mus = None
1✔
1266
        min_mus = None
1✔
1267

1268
        for e, chempots in chempot_ranges.items():
1✔
1269
            if e.composition.reduced_composition == target_comp.reduced_composition:
1✔
1270
                multiplicator = e.composition[open_elt] / target_comp[open_elt]
1✔
1271
                ef = e.energy / multiplicator
1✔
1272
                all_coords = []
1✔
1273
                for s in chempots:
1✔
1274
                    for v in s._coords:
1✔
1275
                        all_coords.append(v)
1✔
1276
                        test_open = (np.dot(v + muref, coeff) + ef) / target_comp[open_elt]
1✔
1277
                        if test_open > max_open:
1✔
1278
                            max_open = test_open
1✔
1279
                            max_mus = v
1✔
1280
                        if test_open < min_open:
1✔
1281
                            min_open = test_open
1✔
1282
                            min_mus = v
1✔
1283

1284
        elts = [e for e in self.elements if e != open_elt]
1✔
1285
        res = {}
1✔
1286

1287
        for i, el in enumerate(elts):
1✔
1288
            res[el] = (min_mus[i] + muref[i], max_mus[i] + muref[i])
1✔
1289

1290
        res[open_elt] = (min_open, max_open)
1✔
1291
        return res
1✔
1292

1293

1294
class GrandPotentialPhaseDiagram(PhaseDiagram):
1✔
1295
    """
1296
    A class representing a Grand potential phase diagram. Grand potential phase
1297
    diagrams are essentially phase diagrams that are open to one or more
1298
    components. To construct such phase diagrams, the relevant free energy is
1299
    the grand potential, which can be written as the Legendre transform of the
1300
    Gibbs free energy as follows
1301

1302
    Grand potential = G - u_X N_X
1303

1304
    The algorithm is based on the work in the following papers:
1305

1306
    1. S. P. Ong, L. Wang, B. Kang, and G. Ceder, Li-Fe-P-O2 Phase Diagram from
1307
       First Principles Calculations. Chem. Mater., 2008, 20(5), 1798-1807.
1308
       doi:10.1021/cm702327g
1309

1310
    2. S. P. Ong, A. Jain, G. Hautier, B. Kang, G. Ceder, Thermal stabilities
1311
       of delithiated olivine MPO4 (M=Fe, Mn) cathodes investigated using first
1312
       principles calculations. Electrochem. Comm., 2010, 12(3), 427-430.
1313
       doi:10.1016/j.elecom.2010.01.010
1314
    """
1315

1316
    def __init__(self, entries, chempots, elements=None, *, computed_data=None):
1✔
1317
        """
1318
        Standard constructor for grand potential phase diagram.
1319

1320
        Args:
1321
            entries ([PDEntry]): A list of PDEntry-like objects having an
1322
                energy, energy_per_atom and composition.
1323
            chempots ({Element: float}): Specify the chemical potentials
1324
                of the open elements.
1325
            elements ([Element]): Optional list of elements in the phase
1326
                diagram. If set to None, the elements are determined from
1327
                the entries themselves.
1328
            computed_data (dict): A dict containing pre-computed data. This allows
1329
                PhaseDiagram object to be reconstituted without performing the
1330
                expensive convex hull computation. The dict is the output from the
1331
                PhaseDiagram._compute() method and is stored in PhaseDiagram.computed_data
1332
                when generated for the first time.
1333
        """
1334
        if elements is None:
1✔
1335
            elements = {els for e in entries for els in e.composition.elements}
1✔
1336

1337
        self.chempots = {get_el_sp(el): u for el, u in chempots.items()}
1✔
1338
        elements = set(elements) - set(self.chempots)
1✔
1339

1340
        all_entries = [
1✔
1341
            GrandPotPDEntry(e, self.chempots) for e in entries if len(elements.intersection(e.composition.elements)) > 0
1342
        ]
1343

1344
        super().__init__(all_entries, elements, computed_data=None)
1✔
1345

1346
    def __repr__(self):
1✔
1347
        chemsys = "-".join(el.symbol for el in self.elements)
1✔
1348
        chempots = ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items())
1✔
1349

1350
        output = [
1✔
1351
            f"{chemsys} GrandPotentialPhaseDiagram with {chempots = }",
1352
            f"{len(self.stable_entries)} stable phases: ",
1353
            ", ".join(entry.name for entry in self.stable_entries),
1354
        ]
1355
        return "".join(output)
1✔
1356

1357
    def as_dict(self):
1✔
1358
        """
1359
        Returns:
1360
            MSONable dictionary representation of GrandPotentialPhaseDiagram
1361
        """
1362
        return {
×
1363
            "@module": type(self).__module__,
1364
            "@class": type(self).__name__,
1365
            "all_entries": [e.as_dict() for e in self.all_entries],
1366
            "chempots": self.chempots,
1367
            "elements": [e.as_dict() for e in self.elements],
1368
        }
1369

1370
    @classmethod
1✔
1371
    def from_dict(cls, d):
1✔
1372
        """
1373
        Args:
1374
            d (dict): dictionary representation of GrandPotentialPhaseDiagram
1375

1376
        Returns:
1377
            GrandPotentialPhaseDiagram
1378
        """
1379
        entries = MontyDecoder().process_decoded(d["all_entries"])
×
1380
        elements = MontyDecoder().process_decoded(d["elements"])
×
1381
        return cls(entries, d["chempots"], elements)
×
1382

1383

1384
class CompoundPhaseDiagram(PhaseDiagram):
1✔
1385
    """
1386
    Generates phase diagrams from compounds as terminations instead of
1387
    elements.
1388
    """
1389

1390
    # Tolerance for determining if amount of a composition is positive.
1391
    amount_tol = 1e-5
1✔
1392

1393
    def __init__(self, entries, terminal_compositions, normalize_terminal_compositions=True):
1✔
1394
        """
1395
        Initializes a CompoundPhaseDiagram.
1396

1397
        Args:
1398
            entries ([PDEntry]): Sequence of input entries. For example,
1399
               if you want a Li2O-P2O5 phase diagram, you might have all
1400
               Li-P-O entries as an input.
1401
            terminal_compositions ([Composition]): Terminal compositions of
1402
                phase space. In the Li2O-P2O5 example, these will be the
1403
                Li2O and P2O5 compositions.
1404
            normalize_terminal_compositions (bool): Whether to normalize the
1405
                terminal compositions to a per atom basis. If normalized,
1406
                the energy above hulls will be consistent
1407
                for comparison across systems. Non-normalized terminals are
1408
                more intuitive in terms of compositional breakdowns.
1409
        """
1410
        self.original_entries = entries
1✔
1411
        self.terminal_compositions = terminal_compositions
1✔
1412
        self.normalize_terminals = normalize_terminal_compositions
1✔
1413
        (pentries, species_mapping) = self.transform_entries(entries, terminal_compositions)
1✔
1414
        self.species_mapping = species_mapping
1✔
1415
        super().__init__(pentries, elements=species_mapping.values())
1✔
1416

1417
    def transform_entries(self, entries, terminal_compositions):
1✔
1418
        """
1419
        Method to transform all entries to the composition coordinate in the
1420
        terminal compositions. If the entry does not fall within the space
1421
        defined by the terminal compositions, they are excluded. For example,
1422
        Li3PO4 is mapped into a Li2O:1.5, P2O5:0.5 composition. The terminal
1423
        compositions are represented by DummySpecies.
1424

1425
        Args:
1426
            entries: Sequence of all input entries
1427
            terminal_compositions: Terminal compositions of phase space.
1428

1429
        Returns:
1430
            Sequence of TransformedPDEntries falling within the phase space.
1431
        """
1432
        new_entries = []
1✔
1433
        if self.normalize_terminals:
1✔
1434
            terminal_compositions = [c.fractional_composition for c in terminal_compositions]
1✔
1435

1436
        # Map terminal compositions to unique dummy species.
1437
        sp_mapping = {}
1✔
1438
        for i, comp in enumerate(terminal_compositions):
1✔
1439
            sp_mapping[comp] = DummySpecies("X" + chr(102 + i))
1✔
1440

1441
        for entry in entries:
1✔
1442
            if getattr(entry, "attribute", None) is None:
1✔
1443
                entry.attribute = getattr(entry, "entry_id", None)
1✔
1444

1445
            try:
1✔
1446
                transformed_entry = TransformedPDEntry(entry, sp_mapping)
1✔
1447
                new_entries.append(transformed_entry)
1✔
1448
            except ReactionError:
1✔
1449
                # If the reaction can't be balanced, the entry does not fall
1450
                # into the phase space. We ignore them.
1451
                pass
1✔
1452
            except TransformedPDEntryError:
×
1453
                # If the reaction has negative amounts for reactants the
1454
                # entry does not fall into the phase space.
1455
                pass
×
1456

1457
        return new_entries, sp_mapping
1✔
1458

1459
    def as_dict(self):
1✔
1460
        """
1461
        Returns:
1462
            MSONable dictionary representation of CompoundPhaseDiagram
1463
        """
1464
        return {
×
1465
            "@module": type(self).__module__,
1466
            "@class": type(self).__name__,
1467
            "original_entries": [e.as_dict() for e in self.original_entries],
1468
            "terminal_compositions": [c.as_dict() for c in self.terminal_compositions],
1469
            "normalize_terminal_compositions": self.normalize_terminals,
1470
        }
1471

1472
    @classmethod
1✔
1473
    def from_dict(cls, d):
1✔
1474
        """
1475
        Args:
1476
            d (dict): dictionary representation of CompoundPhaseDiagram
1477

1478
        Returns:
1479
            CompoundPhaseDiagram
1480
        """
1481
        dec = MontyDecoder()
×
1482
        entries = dec.process_decoded(d["original_entries"])
×
1483
        terminal_compositions = dec.process_decoded(d["terminal_compositions"])
×
1484
        return cls(entries, terminal_compositions, d["normalize_terminal_compositions"])
×
1485

1486

1487
class PatchedPhaseDiagram(PhaseDiagram):
1✔
1488
    """
1489
    Computing the Convex Hull of a large set of data in multiple dimensions is
1490
    highly expensive. This class acts to breakdown large chemical spaces into
1491
    smaller chemical spaces which can be computed much more quickly due to having
1492
    both reduced dimensionality and data set sizes.
1493

1494
    Attributes:
1495
        subspaces ({str: {Element, }}): Dictionary of the sets of elements for each of the
1496
            PhaseDiagrams within the PatchedPhaseDiagram.
1497
        pds ({str: PhaseDiagram}): Dictionary of PhaseDiagrams within the
1498
            PatchedPhaseDiagram.
1499
        all_entries (list[PDEntry]): All entries provided for Phase Diagram construction.
1500
            Note that this does not mean that all these entries are actually used in
1501
            the phase diagram. For example, this includes the positive formation energy
1502
            entries that are filtered out before Phase Diagram construction.
1503
        min_entries (list[PDEntry]): List of the  lowest energy entries for each composition
1504
            in the data provided for Phase Diagram construction.
1505
        el_refs (list[PDEntry]): List of elemental references for the phase diagrams.
1506
            These are entries corresponding to the lowest energy element entries for
1507
            simple compositional phase diagrams.
1508
        elements (list[Element]): List of elements in the phase diagram.
1509
    """
1510

1511
    def __init__(
1✔
1512
        self,
1513
        entries: Sequence[PDEntry] | set[PDEntry],
1514
        elements: Sequence[Element] | None = None,
1515
        keep_all_spaces: bool = False,
1516
        verbose: bool = False,
1517
    ) -> None:
1518
        """
1519
        Args:
1520
            entries (list[PDEntry]): A list of PDEntry-like objects having an
1521
                energy, energy_per_atom and composition.
1522
            elements (list[Element], optional): Optional list of elements in the phase
1523
                diagram. If set to None, the elements are determined from
1524
                the entries themselves and are sorted alphabetically.
1525
                If specified, element ordering (e.g. for pd coordinates)
1526
                is preserved.
1527
            keep_all_spaces (bool): Boolean control on whether to keep chemical spaces
1528
                that are subspaces of other spaces.
1529
            verbose (bool): Whether to show progress bar during convex hull construction.
1530
        """
1531
        if elements is None:
1✔
1532
            elements = sorted({els for e in entries for els in e.composition.elements})
1✔
1533

1534
        self.dim = len(elements)
1✔
1535

1536
        entries = sorted(entries, key=lambda e: e.composition.reduced_composition)
1✔
1537

1538
        el_refs: dict[Element, PDEntry] = {}
1✔
1539
        min_entries = []
1✔
1540
        all_entries: list[PDEntry] = []
1✔
1541
        for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
1✔
1542
            group = list(group_iter)
1✔
1543
            min_entry = min(group, key=lambda e: e.energy_per_atom)
1✔
1544
            if composition.is_element:
1✔
1545
                el_refs[composition.elements[0]] = min_entry
1✔
1546
            min_entries.append(min_entry)
1✔
1547
            all_entries.extend(group)
1✔
1548

1549
        if len(el_refs) < self.dim:
1✔
1550
            missing = set(elements) - set(el_refs)
1✔
1551
            raise ValueError(f"Missing terminal entries for elements {sorted(map(str, missing))}")
1✔
1552
        if len(el_refs) > self.dim:
1✔
1553
            extra = set(el_refs) - set(elements)
×
1554
            raise ValueError(f"There are more terminal elements than dimensions: {extra}")
×
1555

1556
        data = np.array(
1✔
1557
            [[e.composition.get_atomic_fraction(el) for el in elements] + [e.energy_per_atom] for e in min_entries]
1558
        )
1559

1560
        # Use only entries with negative formation energy
1561
        vec = [el_refs[el].energy_per_atom for el in elements] + [-1]
1✔
1562
        form_e = -np.dot(data, vec)
1✔
1563
        inds = np.where(form_e < -PhaseDiagram.formation_energy_tol)[0].tolist()
1✔
1564

1565
        # Add the elemental references
1566
        inds.extend([min_entries.index(el) for el in el_refs.values()])
1✔
1567

1568
        self.qhull_entries = tuple(min_entries[i] for i in inds)
1✔
1569
        # make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable
1570
        # prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically)
1571
        self._qhull_spaces = tuple(frozenset(e.composition.elements) for e in self.qhull_entries)
1✔
1572

1573
        # Get all unique chemical spaces
1574
        spaces = {s for s in self._qhull_spaces if len(s) > 1}
1✔
1575

1576
        # Remove redundant chemical spaces
1577
        if not keep_all_spaces and len(spaces) > 1:
1✔
1578
            max_size = max(len(s) for s in spaces)
1✔
1579

1580
            systems = set()
1✔
1581
            # NOTE reduce the number of comparisons by only comparing to larger sets
1582
            for i in range(2, max_size + 1):
1✔
1583
                test = (s for s in spaces if len(s) == i)
1✔
1584
                refer = (s for s in spaces if len(s) > i)
1✔
1585
                systems |= {t for t in test if not any(t.issubset(r) for r in refer)}
1✔
1586

1587
            spaces = systems
1✔
1588

1589
        # TODO comprhys: refactor to have self._compute method to allow serialisation
1590
        self.spaces = sorted(spaces, key=len, reverse=False)  # Calculate pds for smaller dimension spaces first
1✔
1591
        self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=not verbose))
1✔
1592
        self.all_entries = all_entries
1✔
1593
        self.el_refs = el_refs
1✔
1594
        self.elements = elements
1✔
1595

1596
        # Add terminal elements as we may not have PD patches including them
1597
        # NOTE add el_refs in case no multielement entries are present for el
1598
        _stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries}
1✔
1599
        self._stable_entries = tuple(_stable_entries | {*self.el_refs.values()})
1✔
1600
        self._stable_spaces = tuple(frozenset(e.composition.elements) for e in self._stable_entries)
1✔
1601

1602
    def __repr__(self):
1✔
1603
        return f"{type(self).__name__} covering {len(self.spaces)} sub-spaces"
1✔
1604

1605
    def __len__(self):
1✔
1606
        return len(self.spaces)
1✔
1607

1608
    def __getitem__(self, item: frozenset[Element]) -> PhaseDiagram:
1✔
1609
        return self.pds[item]
1✔
1610

1611
    def __setitem__(self, key: frozenset[Element], value: PhaseDiagram) -> None:
1✔
1612
        self.pds[key] = value
1✔
1613

1614
    def __delitem__(self, key: frozenset[Element]) -> None:
1✔
1615
        del self.pds[key]
1✔
1616

1617
    def __iter__(self) -> Iterator[PhaseDiagram]:
1✔
1618
        return iter(self.pds.values())
1✔
1619

1620
    def __contains__(self, item: frozenset[Element]) -> bool:
1✔
1621
        return item in self.pds
1✔
1622

1623
    def as_dict(self) -> dict[str, Any]:
1✔
1624
        """
1625
        Returns:
1626
            dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram
1627
        """
1628
        return {
1✔
1629
            "@module": type(self).__module__,
1630
            "@class": type(self).__name__,
1631
            "all_entries": [e.as_dict() for e in self.all_entries],
1632
            "elements": [e.as_dict() for e in self.elements],
1633
        }
1634

1635
    @classmethod
1✔
1636
    def from_dict(cls, d):
1✔
1637
        """
1638
        Args:
1639
            d (dict): dictionary representation of PatchedPhaseDiagram
1640

1641
        Returns:
1642
            PatchedPhaseDiagram
1643
        """
1644
        entries = [MontyDecoder().process_decoded(dd) for dd in d["all_entries"]]
1✔
1645
        elements = [Element.from_dict(dd) for dd in d["elements"]]
1✔
1646
        return cls(entries, elements)
1✔
1647

1648
    # NOTE the following could be inherited unchanged from PhaseDiagram:
1649
    #     __repr__,
1650
    #     as_dict,
1651
    #     all_entries_hulldata,
1652
    #     unstable_entries,
1653
    #     stable_entries,
1654
    #     get_form_energy(),
1655
    #     get_form_energy_per_atom(),
1656
    #     get_hull_energy(),
1657
    #     get_e_above_hull(),
1658
    #     get_decomp_and_e_above_hull(),
1659
    #     get_decomp_and_phase_separation_energy(),
1660
    #     get_phase_separation_energy()
1661

1662
    def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram:
1✔
1663
        """
1664
        Get the possible phase diagrams for an entry
1665

1666
        Args:
1667
            entry (PDEntry | Composition): A PDEntry or Composition-like object
1668

1669
        Returns:
1670
            PhaseDiagram: phase diagram that the entry is part of
1671
        """
1672
        if isinstance(entry, Composition):
1✔
1673
            entry_space = frozenset(entry.elements)
1✔
1674
        else:
1675
            entry_space = frozenset(entry.composition.elements)
1✔
1676

1677
        try:
1✔
1678
            return self.pds[entry_space]
1✔
1679
        except KeyError:
1✔
1680
            for space, pd in self.pds.items():
1✔
1681
                if space.issuperset(entry_space):
1✔
1682
                    return pd
1✔
1683

1684
        raise ValueError(f"No suitable PhaseDiagrams found for {entry}.")
1✔
1685

1686
    def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]:
1✔
1687
        """
1688
        See PhaseDiagram
1689

1690
        Args:
1691
            comp (Composition): A composition
1692

1693
        Returns:
1694
            Decomposition as a dict of {PDEntry: amount} where amount
1695
            is the amount of the fractional composition.
1696
        """
1697
        try:
1✔
1698
            pd = self.get_pd_for_entry(comp)
1✔
1699
            return pd.get_decomposition(comp)
×
1700
        except ValueError as e:
1✔
1701
            # NOTE warn when stitching across pds is being used
1702
            warnings.warn(str(e) + " Using SLSQP to find decomposition")
1✔
1703
            competing_entries = self._get_stable_entries_in_space(frozenset(comp.elements))
1✔
1704
            return _get_slsqp_decomp(comp, competing_entries)
1✔
1705

1706
    def get_equilibrium_reaction_energy(self, entry: Entry) -> float:
1✔
1707
        """
1708
        See PhaseDiagram
1709

1710
        NOTE this is only approximately the same as the what we would get
1711
        from `PhaseDiagram` as we make use of the slsqp approach inside
1712
        get_phase_separation_energy().
1713

1714
        Args:
1715
            entry (PDEntry): A PDEntry like object
1716

1717
        Returns:
1718
            Equilibrium reaction energy of entry. Stable entries should have
1719
            equilibrium reaction energy <= 0. The energy is given per atom.
1720
        """
1721
        return self.get_phase_separation_energy(entry, stable_only=True)
1✔
1722

1723
    # NOTE the following functions are not implemented for PatchedPhaseDiagram
1724

1725
    def _get_facet_and_simplex(self):
1✔
1726
        """
1727
        Not Implemented - See PhaseDiagram
1728
        """
1729
        raise NotImplementedError("`_get_facet_and_simplex` not implemented for `PatchedPhaseDiagram`")
×
1730

1731
    def _get_all_facets_and_simplexes(self):
1✔
1732
        """
1733
        Not Implemented - See PhaseDiagram
1734
        """
1735
        raise NotImplementedError("`_get_all_facets_and_simplexes` not implemented for `PatchedPhaseDiagram`")
×
1736

1737
    def _get_facet_chempots(self):
1✔
1738
        """
1739
        Not Implemented - See PhaseDiagram
1740
        """
1741
        raise NotImplementedError("`_get_facet_chempots` not implemented for `PatchedPhaseDiagram`")
×
1742

1743
    def _get_simplex_intersections(self):
1✔
1744
        """
1745
        Not Implemented - See PhaseDiagram
1746
        """
1747
        raise NotImplementedError("`_get_simplex_intersections` not implemented for `PatchedPhaseDiagram`")
×
1748

1749
    def get_composition_chempots(self):
1✔
1750
        """
1751
        Not Implemented - See PhaseDiagram
1752
        """
1753
        raise NotImplementedError("`get_composition_chempots` not implemented for `PatchedPhaseDiagram`")
×
1754

1755
    def get_all_chempots(self):
1✔
1756
        """
1757
        Not Implemented - See PhaseDiagram
1758
        """
1759
        raise NotImplementedError("`get_all_chempots` not implemented for `PatchedPhaseDiagram`")
×
1760

1761
    def get_transition_chempots(self):
1✔
1762
        """
1763
        Not Implemented - See PhaseDiagram
1764
        """
1765
        raise NotImplementedError("`get_transition_chempots` not implemented for `PatchedPhaseDiagram`")
×
1766

1767
    def get_critical_compositions(self):
1✔
1768
        """
1769
        Not Implemented - See PhaseDiagram
1770
        """
1771
        raise NotImplementedError("`get_critical_compositions` not implemented for `PatchedPhaseDiagram`")
×
1772

1773
    def get_element_profile(self):
1✔
1774
        """
1775
        Not Implemented - See PhaseDiagram
1776
        """
1777
        raise NotImplementedError("`get_element_profile` not implemented for `PatchedPhaseDiagram`")
×
1778

1779
    def get_chempot_range_map(self):
1✔
1780
        """
1781
        Not Implemented - See PhaseDiagram
1782
        """
1783
        raise NotImplementedError("`get_chempot_range_map` not implemented for `PatchedPhaseDiagram`")
×
1784

1785
    def getmu_vertices_stability_phase(self):
1✔
1786
        """
1787
        Not Implemented - See PhaseDiagram
1788
        """
1789
        raise NotImplementedError("`getmu_vertices_stability_phase` not implemented for `PatchedPhaseDiagram`")
×
1790

1791
    def get_chempot_range_stability_phase(self):
1✔
1792
        """
1793
        Not Implemented - See PhaseDiagram
1794
        """
1795
        raise NotImplementedError("`get_chempot_range_stability_phase` not implemented for `PatchedPhaseDiagram`")
×
1796

1797
    def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
1✔
1798
        """
1799
        Args:
1800
            space (frozenset[Element]): chemical space of the form A-B-X
1801

1802
        Returns:
1803
            space, PhaseDiagram for the given chemical space
1804
        """
1805
        space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)]
1✔
1806

1807
        return space, PhaseDiagram(space_entries)
1✔
1808

1809

1810
class ReactionDiagram:
1✔
1811
    """
1812
    Analyzes the possible reactions between a pair of compounds, e.g.,
1813
    an electrolyte and an electrode.
1814
    """
1815

1816
    def __init__(self, entry1, entry2, all_entries, tol: float = 1e-4, float_fmt="%.4f"):
1✔
1817
        """
1818
        Args:
1819
            entry1 (ComputedEntry): Entry for 1st component. Note that
1820
                corrections, if any, must already be pre-applied. This is to
1821
                give flexibility for different kinds of corrections, e.g.,
1822
                if a particular entry is fitted to an experimental data (such
1823
                as EC molecule).
1824
            entry2 (ComputedEntry): Entry for 2nd component. Note that
1825
                corrections must already be pre-applied. This is to
1826
                give flexibility for different kinds of corrections, e.g.,
1827
                if a particular entry is fitted to an experimental data (such
1828
                as EC molecule).
1829
            all_entries ([ComputedEntry]): All other entries to be
1830
                considered in the analysis. Note that corrections, if any,
1831
                must already be pre-applied.
1832
            tol (float): Tolerance to be used to determine validity of reaction.
1833
            float_fmt (str): Formatting string to be applied to all floats.
1834
                Determines number of decimal places in reaction string.
1835
        """
1836
        elem_set = set()
1✔
1837
        for e in [entry1, entry2]:
1✔
1838
            elem_set.update([el.symbol for el in e.composition.elements])
1✔
1839

1840
        elements = tuple(elem_set)  # Fix elements to ensure order.
1✔
1841

1842
        comp_vec1 = np.array([entry1.composition.get_atomic_fraction(el) for el in elements])
1✔
1843
        comp_vec2 = np.array([entry2.composition.get_atomic_fraction(el) for el in elements])
1✔
1844
        r1 = entry1.composition.reduced_composition
1✔
1845
        r2 = entry2.composition.reduced_composition
1✔
1846

1847
        logger.debug(f"{len(all_entries)} total entries.")
1✔
1848

1849
        pd = PhaseDiagram(all_entries + [entry1, entry2])
1✔
1850
        terminal_formulas = [
1✔
1851
            entry1.composition.reduced_formula,
1852
            entry2.composition.reduced_formula,
1853
        ]
1854

1855
        logger.debug(f"{len(pd.stable_entries)} stable entries")
1✔
1856
        logger.debug(f"{len(pd.facets)} facets")
1✔
1857
        logger.debug(f"{len(pd.qhull_entries)} qhull_entries")
1✔
1858

1859
        rxn_entries = []
1✔
1860
        done: list[tuple[float, float]] = []
1✔
1861

1862
        def fmt(fl):
1✔
1863
            return float_fmt % fl
1✔
1864

1865
        for facet in pd.facets:
1✔
1866
            for face in itertools.combinations(facet, len(facet) - 1):
1✔
1867
                face_entries = [pd.qhull_entries[i] for i in face]
1✔
1868

1869
                if any(e.composition.reduced_formula in terminal_formulas for e in face_entries):
1✔
1870
                    continue
1✔
1871

1872
                try:
1✔
1873
                    mat = []
1✔
1874
                    for e in face_entries:
1✔
1875
                        mat.append([e.composition.get_atomic_fraction(el) for el in elements])
1✔
1876
                    mat.append(comp_vec2 - comp_vec1)
1✔
1877
                    matrix = np.array(mat).T
1✔
1878
                    coeffs = np.linalg.solve(matrix, comp_vec2)
1✔
1879

1880
                    x = coeffs[-1]
1✔
1881
                    # pylint: disable=R1716
1882
                    if all(c >= -tol for c in coeffs) and (abs(sum(coeffs[:-1]) - 1) < tol) and (tol < x < 1 - tol):
1✔
1883
                        c1 = x / r1.num_atoms
1✔
1884
                        c2 = (1 - x) / r2.num_atoms
1✔
1885
                        factor = 1 / (c1 + c2)
1✔
1886

1887
                        c1 *= factor
1✔
1888
                        c2 *= factor
1✔
1889

1890
                        # Avoid duplicate reactions.
1891
                        if any(np.allclose([c1, c2], cc) for cc in done):
1✔
1892
                            continue
1✔
1893

1894
                        done.append((c1, c2))
1✔
1895

1896
                        rxn_str = f"{fmt(c1)} {r1.reduced_formula} + {fmt(c2)} {r2.reduced_formula} -> "
1✔
1897
                        products = []
1✔
1898
                        product_entries = []
1✔
1899

1900
                        energy = -(x * entry1.energy_per_atom + (1 - x) * entry2.energy_per_atom)
1✔
1901

1902
                        for c, e in zip(coeffs[:-1], face_entries):
1✔
1903
                            if c > tol:
1✔
1904
                                r = e.composition.reduced_composition
1✔
1905
                                products.append(f"{fmt(c / r.num_atoms * factor)} {r.reduced_formula}")
1✔
1906
                                product_entries.append((c, e))
1✔
1907
                                energy += c * e.energy_per_atom
1✔
1908

1909
                        rxn_str += " + ".join(products)
1✔
1910
                        comp = x * comp_vec1 + (1 - x) * comp_vec2
1✔
1911
                        entry = PDEntry(
1✔
1912
                            Composition(dict(zip(elements, comp))),
1913
                            energy=energy,
1914
                            attribute=rxn_str,
1915
                        )
1916
                        entry.decomposition = product_entries
1✔
1917
                        rxn_entries.append(entry)
1✔
1918
                except np.linalg.LinAlgError:
1✔
1919
                    logger.debug(
1✔
1920
                        "Reactants = "
1921
                        + ", ".join(
1922
                            [
1923
                                entry1.composition.reduced_formula,
1924
                                entry2.composition.reduced_formula,
1925
                            ]
1926
                        )
1927
                    )
1928
                    logger.debug(f"Products = {', '.join([e.composition.reduced_formula for e in face_entries])}")
1✔
1929

1930
        rxn_entries = sorted(rxn_entries, key=lambda e: e.name, reverse=True)
1✔
1931

1932
        self.entry1 = entry1
1✔
1933
        self.entry2 = entry2
1✔
1934
        self.rxn_entries = rxn_entries
1✔
1935
        self.labels = {}
1✔
1936
        for i, e in enumerate(rxn_entries):
1✔
1937
            self.labels[str(i + 1)] = e.attribute
1✔
1938
            e.name = str(i + 1)
1✔
1939
        self.all_entries = all_entries
1✔
1940
        self.pd = pd
1✔
1941

1942
    def get_compound_pd(self):
1✔
1943
        """
1944
        Get the CompoundPhaseDiagram object, which can then be used for
1945
        plotting.
1946

1947
        Returns:
1948
            CompoundPhaseDiagram
1949
        """
1950
        # For this plot, since the reactions are reported in formation
1951
        # energies, we need to set the energies of the terminal compositions
1952
        # to 0. So we make create copies with 0 energy.
1953
        entry1 = PDEntry(self.entry1.composition, 0)
1✔
1954
        entry2 = PDEntry(self.entry2.composition, 0)
1✔
1955

1956
        cpd = CompoundPhaseDiagram(
1✔
1957
            self.rxn_entries + [entry1, entry2],
1958
            [
1959
                Composition(entry1.composition.reduced_formula),
1960
                Composition(entry2.composition.reduced_formula),
1961
            ],
1962
            normalize_terminal_compositions=False,
1963
        )
1964
        return cpd
1✔
1965

1966

1967
class PhaseDiagramError(Exception):
1✔
1968
    """
1969
    An exception class for Phase Diagram generation.
1970
    """
1971

1972

1973
def get_facets(qhull_data: ArrayLike, joggle: bool = False) -> ConvexHull:
1✔
1974
    """
1975
    Get the simplex facets for the Convex hull.
1976

1977
    Args:
1978
        qhull_data (np.ndarray): The data from which to construct the convex
1979
            hull as a Nxd array (N being number of data points and d being the
1980
            dimension)
1981
        joggle (bool): Whether to joggle the input to avoid precision
1982
            errors.
1983

1984
    Returns:
1985
        List of simplices of the Convex Hull.
1986
    """
1987
    if joggle:
1✔
1988
        return ConvexHull(qhull_data, qhull_options="QJ i").simplices
1✔
1989
    return ConvexHull(qhull_data, qhull_options="Qt i").simplices
1✔
1990

1991

1992
def _get_slsqp_decomp(
1✔
1993
    comp,
1994
    competing_entries,
1995
    tols=(1e-8,),
1996
    maxiter=1000,
1997
):
1998
    """
1999
    Finds the amounts of competing compositions that minimize the energy of a
2000
    given composition
2001

2002
    The algorithm is based on the work in the following paper:
2003

2004
    1. Bartel, C., Trewartha, A., Wang, Q., Dunn, A., Jain, A., Ceder, G.,
2005
        A critical examination of compound stability predictions from
2006
        machine-learned formation energies, npj Computational Materials 6, 97 (2020)
2007

2008
    Args:
2009
        comp (Composition): A Composition to analyze
2010
        competing_entries ([PDEntry]): List of entries to consider for decomposition
2011
        tols (list): tolerances to try for SLSQP convergence. Issues observed for
2012
            tol > 1e-7 in the fractional composition (default 1e-8)
2013
        maxiter (int): maximum number of SLSQP iterations
2014

2015
    Returns:
2016
            decomposition as a dict of {PDEntry: amount} where amount
2017
            is the amount of the fractional composition.
2018
    """
2019
    # Elemental amount present in given entry
2020
    amts = comp.get_el_amt_dict()
1✔
2021
    chemical_space = tuple(amts)
1✔
2022
    b = np.array([amts[el] for el in chemical_space])
1✔
2023

2024
    # Elemental amounts present in competing entries
2025
    A_transpose = np.zeros((len(chemical_space), len(competing_entries)))
1✔
2026
    for j, comp_entry in enumerate(competing_entries):
1✔
2027
        amts = comp_entry.composition.get_el_amt_dict()
1✔
2028
        for i, el in enumerate(chemical_space):
1✔
2029
            A_transpose[i, j] = amts[el]
1✔
2030

2031
    # NOTE normalize arrays to avoid calls to fractional_composition
2032
    b = b / np.sum(b)
1✔
2033
    A_transpose = A_transpose / np.sum(A_transpose, axis=0)
1✔
2034

2035
    # Energies of competing entries
2036
    Es = np.array([comp_entry.energy_per_atom for comp_entry in competing_entries])
1✔
2037

2038
    molar_constraint = {"type": "eq", "fun": lambda x: np.dot(A_transpose, x) - b, "jac": lambda x: A_transpose}
1✔
2039

2040
    options = {"maxiter": maxiter, "disp": False}
1✔
2041

2042
    # NOTE max_bound needs to be larger than 1
2043
    max_bound = comp.num_atoms
1✔
2044
    bounds = [(0, max_bound)] * len(competing_entries)
1✔
2045
    x0 = [1 / len(competing_entries)] * len(competing_entries)
1✔
2046

2047
    # NOTE the tolerance needs to be tight to stop the optimization
2048
    # from exiting before convergence is reached. Issues observed for
2049
    # tol > 1e-7 in the fractional composition (default 1e-8).
2050
    for tol in sorted(tols):
1✔
2051
        solution = minimize(
1✔
2052
            fun=lambda x: np.dot(x, Es),
2053
            x0=x0,
2054
            method="SLSQP",
2055
            jac=lambda x: Es,
2056
            bounds=bounds,
2057
            constraints=[molar_constraint],
2058
            tol=tol,
2059
            options=options,
2060
        )
2061

2062
        if solution.success:
1✔
2063
            decomp_amts = solution.x
1✔
2064
            return {
1✔
2065
                c: amt  # NOTE this is the amount of the fractional composition.
2066
                for c, amt in zip(competing_entries, decomp_amts)
2067
                if amt > PhaseDiagram.numerical_tol
2068
            }
2069

2070
    raise ValueError(f"No valid decomp found for {comp}!")
×
2071

2072

2073
class PDPlotter:
1✔
2074
    """
2075
    A plotter class for compositional phase diagrams.
2076
    """
2077

2078
    def __init__(
1✔
2079
        self,
2080
        phasediagram: PhaseDiagram,
2081
        show_unstable: float = 0.2,
2082
        backend: Literal["plotly", "matplotlib"] = "plotly",
2083
        **plotkwargs,
2084
    ):
2085
        """
2086
        Args:
2087
            phasediagram (PhaseDiagram): PhaseDiagram object.
2088
            show_unstable (float): Whether unstable (above the hull) phases will be
2089
                plotted. If a number > 0 is entered, all phases with
2090
                e_hull < show_unstable (eV/atom) will be shown.
2091
            backend ("plotly" | "matplotlib"): Python package used for plotting. Defaults to "plotly".
2092
            **plotkwargs (dict): Keyword args passed to matplotlib.pyplot.plot. Can
2093
                be used to customize markers etc. If not set, the default is
2094
                {
2095
                    "markerfacecolor": (0.2157, 0.4941, 0.7216),
2096
                    "markersize": 10,
2097
                    "linewidth": 3
2098
                }
2099
        """
2100
        # note: palettable imports matplotlib
2101
        from palettable.colorbrewer.qualitative import Set1_3
1✔
2102

2103
        self._pd = phasediagram
1✔
2104
        self._dim = len(self._pd.elements)  # type: ignore
1✔
2105
        if self._dim > 4:
1✔
2106
            raise ValueError("Only 1-4 components supported!")
×
2107
        self.lines = uniquelines(self._pd.facets) if self._dim > 1 else [[self._pd.facets[0][0], self._pd.facets[0][0]]]
1✔
2108
        self.show_unstable = show_unstable
1✔
2109
        self.backend = backend
1✔
2110
        self._min_energy = min(self._pd.get_form_energy_per_atom(e) for e in self._pd.stable_entries)
1✔
2111
        colors = Set1_3.mpl_colors
1✔
2112
        self.plotkwargs = plotkwargs or {
1✔
2113
            "markerfacecolor": colors[2],
2114
            "markersize": 10,
2115
            "linewidth": 3,
2116
        }
2117

2118
    @property  # type: ignore
1✔
2119
    @lru_cache(1)
1✔
2120
    def pd_plot_data(self):
1✔
2121
        """
2122
        Plotting data for phase diagram. Cached for repetitive calls.
2123
        2-comp - Full hull with energies
2124
        3/4-comp - Projection into 2D or 3D Gibbs triangle.
2125

2126
        Returns:
2127
            (lines, stable_entries, unstable_entries):
2128
            - lines is a list of list of coordinates for lines in the PD.
2129
            - stable_entries is a dict of {coordinates : entry} for each stable node
2130
                in the phase diagram. (Each coordinate can only have one
2131
                stable phase)
2132
            - unstable_entries is a dict of {entry: coordinates} for all unstable
2133
                nodes in the phase diagram.
2134
        """
2135
        pd = self._pd
1✔
2136
        entries = pd.qhull_entries
1✔
2137
        data = np.array(pd.qhull_data)
1✔
2138
        lines = []
1✔
2139
        stable_entries = {}
1✔
2140
        for line in self.lines:
1✔
2141
            entry1 = entries[line[0]]
1✔
2142
            entry2 = entries[line[1]]
1✔
2143
            if self._dim < 3:
1✔
2144
                x = [data[line[0]][0], data[line[1]][0]]
1✔
2145
                y = [
1✔
2146
                    pd.get_form_energy_per_atom(entry1),
2147
                    pd.get_form_energy_per_atom(entry2),
2148
                ]
2149
                coord = [x, y]
1✔
2150
            elif self._dim == 3:
1✔
2151
                coord = triangular_coord(data[line, 0:2])
1✔
2152
            else:
2153
                coord = tet_coord(data[line, 0:3])
1✔
2154
            lines.append(coord)
1✔
2155
            labelcoord = list(zip(*coord))
1✔
2156
            stable_entries[labelcoord[0]] = entry1
1✔
2157
            stable_entries[labelcoord[1]] = entry2
1✔
2158

2159
        all_entries = pd.all_entries
1✔
2160
        all_data = np.array(pd.all_entries_hulldata)
1✔
2161
        unstable_entries = {}
1✔
2162
        stable = pd.stable_entries
1✔
2163
        for i, entry in enumerate(all_entries):
1✔
2164
            if entry not in stable:
1✔
2165
                if self._dim < 3:
1✔
2166
                    x = [all_data[i][0], all_data[i][0]]
1✔
2167
                    y = [
1✔
2168
                        pd.get_form_energy_per_atom(entry),
2169
                        pd.get_form_energy_per_atom(entry),
2170
                    ]
2171
                    coord = [x, y]
1✔
2172
                elif self._dim == 3:
1✔
2173
                    coord = triangular_coord([all_data[i, 0:2], all_data[i, 0:2]])
1✔
2174
                else:
2175
                    coord = tet_coord([all_data[i, 0:3], all_data[i, 0:3], all_data[i, 0:3]])
1✔
2176
                labelcoord = list(zip(*coord))
1✔
2177
                unstable_entries[entry] = labelcoord[0]
1✔
2178

2179
        return lines, stable_entries, unstable_entries
1✔
2180

2181
    def get_plot(
1✔
2182
        self,
2183
        label_stable=True,
2184
        label_unstable=True,
2185
        ordering=None,
2186
        energy_colormap=None,
2187
        process_attributes=False,
2188
        plt=None,
2189
        label_uncertainties=False,
2190
    ):
2191
        """
2192
        Args:
2193
            label_stable: Whether to label stable compounds.
2194
            label_unstable: Whether to label unstable compounds.
2195
            ordering: Ordering of vertices (matplotlib backend only).
2196
            energy_colormap: Colormap for coloring energy (matplotlib backend only).
2197
            process_attributes: Whether to process the attributes (matplotlib
2198
                backend only).
2199
            plt: Existing plt object if plotting multiple phase diagrams (
2200
                matplotlib backend only).
2201
            label_uncertainties: Whether to add error bars to the hull (plotly
2202
                backend only). For binaries, this also shades the hull with the
2203
                uncertainty window.
2204

2205
        Returns:
2206
            go.Figure (plotly) or matplotlib.pyplot (matplotlib)
2207
        """
2208
        fig = None
1✔
2209

2210
        if self.backend == "plotly":
1✔
2211
            data = [self._create_plotly_lines()]
1✔
2212

2213
            if self._dim == 3:
1✔
2214
                data.append(self._create_plotly_ternary_support_lines())
1✔
2215
                data.append(self._create_plotly_ternary_hull())
1✔
2216

2217
            stable_labels_plot = self._create_plotly_stable_labels(label_stable)
1✔
2218
            stable_marker_plot, unstable_marker_plot = self._create_plotly_markers(label_uncertainties)
1✔
2219

2220
            if self._dim == 2 and label_uncertainties:
1✔
2221
                data.append(self._create_plotly_uncertainty_shading(stable_marker_plot))
×
2222

2223
            data.append(stable_labels_plot)
1✔
2224
            data.append(unstable_marker_plot)
1✔
2225
            data.append(stable_marker_plot)
1✔
2226

2227
            fig = go.Figure(data=data)
1✔
2228
            fig.layout = self._create_plotly_figure_layout()
1✔
2229

2230
        elif self.backend == "matplotlib":
1✔
2231
            if self._dim <= 3:
1✔
2232
                fig = self._get_2d_plot(
1✔
2233
                    label_stable,
2234
                    label_unstable,
2235
                    ordering,
2236
                    energy_colormap,
2237
                    plt=plt,
2238
                    process_attributes=process_attributes,
2239
                )
2240
            elif self._dim == 4:
1✔
2241
                fig = self._get_3d_plot(label_stable)
1✔
2242

2243
        return fig
1✔
2244

2245
    def plot_element_profile(self, element, comp, show_label_index=None, xlim=5):
1✔
2246
        """
2247
        Draw the element profile plot for a composition varying different
2248
        chemical potential of an element.
2249
        X value is the negative value of the chemical potential reference to
2250
        elemental chemical potential. For example, if choose Element("Li"),
2251
        X= -(µLi-µLi0), which corresponds to the voltage versus metal anode.
2252
        Y values represent for the number of element uptake in this composition
2253
        (unit: per atom). All reactions are printed to help choosing the
2254
        profile steps you want to show label in the plot.
2255

2256
        Args:
2257
         element (Element): An element of which the chemical potential is
2258
            considered. It also must be in the phase diagram.
2259
         comp (Composition): A composition.
2260
         show_label_index (list of integers): The labels for reaction products
2261
            you want to show in the plot. Default to None (not showing any
2262
            annotation for reaction products). For the profile steps you want
2263
            to show the labels, just add it to the show_label_index. The
2264
            profile step counts from zero. For example, you can set
2265
            show_label_index=[0, 2, 5] to label profile step 0,2,5.
2266
         xlim (float): The max x value. x value is from 0 to xlim. Default to
2267
            5 eV.
2268

2269
        Returns:
2270
            Plot of element profile evolution by varying the chemical potential
2271
            of an element.
2272
        """
2273
        plt = pretty_plot(12, 8)
1✔
2274
        pd = self._pd
1✔
2275
        evolution = pd.get_element_profile(element, comp)
1✔
2276
        num_atoms = evolution[0]["reaction"].reactants[0].num_atoms
1✔
2277
        element_energy = evolution[0]["chempot"]
1✔
2278
        x1, x2, y1 = None, None, None
1✔
2279
        for i, d in enumerate(evolution):
1✔
2280
            v = -(d["chempot"] - element_energy)
1✔
2281
            if i != 0:
1✔
2282
                plt.plot([x2, x2], [y1, d["evolution"] / num_atoms], "k", linewidth=2.5)
1✔
2283
            x1 = v
1✔
2284
            y1 = d["evolution"] / num_atoms
1✔
2285

2286
            if i != len(evolution) - 1:
1✔
2287
                x2 = -(evolution[i + 1]["chempot"] - element_energy)
1✔
2288
            else:
2289
                x2 = 5.0
1✔
2290
            if show_label_index is not None and i in show_label_index:
1✔
2291
                products = [
×
2292
                    re.sub(r"(\d+)", r"$_{\1}$", p.reduced_formula)
2293
                    for p in d["reaction"].products
2294
                    if p.reduced_formula != element.symbol
2295
                ]
2296
                plt.annotate(
×
2297
                    ", ".join(products),
2298
                    xy=(v + 0.05, y1 + 0.05),
2299
                    fontsize=24,
2300
                    color="r",
2301
                )
2302
                plt.plot([x1, x2], [y1, y1], "r", linewidth=3)
×
2303
            else:
2304
                plt.plot([x1, x2], [y1, y1], "k", linewidth=2.5)
1✔
2305

2306
        plt.xlim((0, xlim))
1✔
2307
        plt.xlabel("-$\\Delta{\\mu}$ (eV)")
1✔
2308
        plt.ylabel("Uptake per atom")
1✔
2309

2310
        return plt
1✔
2311

2312
    def show(self, *args, **kwargs):
1✔
2313
        """
2314
        Draw the phase diagram using Plotly (or Matplotlib) and show it.
2315

2316
        Args:
2317
            *args: Passed to get_plot.
2318
            **kwargs: Passed to get_plot.
2319
        """
2320
        self.get_plot(*args, **kwargs).show()
×
2321

2322
    def _get_2d_plot(
1✔
2323
        self,
2324
        label_stable=True,
2325
        label_unstable=True,
2326
        ordering=None,
2327
        energy_colormap=None,
2328
        vmin_mev=-60.0,
2329
        vmax_mev=60.0,
2330
        show_colorbar=True,
2331
        process_attributes=False,
2332
        plt=None,
2333
    ):
2334
        """
2335
        Shows the plot using pylab. Contains import statements since matplotlib is a
2336
        fairly extensive library to load.
2337
        """
2338
        if plt is None:
1✔
2339
            plt = pretty_plot(8, 6)
1✔
2340
        from matplotlib.font_manager import FontProperties
1✔
2341

2342
        if ordering is None:
1✔
2343
            (lines, labels, unstable) = self.pd_plot_data
1✔
2344
        else:
2345
            (_lines, _labels, _unstable) = self.pd_plot_data
×
2346
            (lines, labels, unstable) = order_phase_diagram(_lines, _labels, _unstable, ordering)
×
2347
        if energy_colormap is None:
1✔
2348
            if process_attributes:
1✔
2349
                for x, y in lines:
×
2350
                    plt.plot(x, y, "k-", linewidth=3, markeredgecolor="k")
×
2351
                # One should think about a clever way to have "complex"
2352
                # attributes with complex processing options but with a clear
2353
                # logic. At this moment, I just use the attributes to know
2354
                # whether an entry is a new compound or an existing (from the
2355
                #  ICSD or from the MP) one.
2356
                for x, y in labels:
×
2357
                    if labels[(x, y)].attribute is None or labels[(x, y)].attribute == "existing":
×
2358
                        plt.plot(x, y, "ko", **self.plotkwargs)
×
2359
                    else:
2360
                        plt.plot(x, y, "k*", **self.plotkwargs)
×
2361
            else:
2362
                for x, y in lines:
1✔
2363
                    plt.plot(x, y, "ko-", **self.plotkwargs)
1✔
2364
        else:
2365
            from matplotlib.cm import ScalarMappable
×
2366
            from matplotlib.colors import LinearSegmentedColormap, Normalize
×
2367

2368
            for x, y in lines:
×
2369
                plt.plot(x, y, "k-", markeredgecolor="k")
×
2370
            vmin = vmin_mev / 1000.0
×
2371
            vmax = vmax_mev / 1000.0
×
2372
            if energy_colormap == "default":
×
2373
                mid = -vmin / (vmax - vmin)
×
2374
                cmap = LinearSegmentedColormap.from_list(
×
2375
                    "my_colormap",
2376
                    [
2377
                        (0.0, "#005500"),
2378
                        (mid, "#55FF55"),
2379
                        (mid, "#FFAAAA"),
2380
                        (1.0, "#FF0000"),
2381
                    ],
2382
                )
2383
            else:
2384
                cmap = energy_colormap
×
2385
            norm = Normalize(vmin=vmin, vmax=vmax)
×
2386
            _map = ScalarMappable(norm=norm, cmap=cmap)
×
2387
            _energies = [self._pd.get_equilibrium_reaction_energy(entry) for coord, entry in labels.items()]
×
2388
            energies = [en if en < 0.0 else -0.00000001 for en in _energies]
×
2389
            vals_stable = _map.to_rgba(energies)
×
2390
            ii = 0
×
2391
            if process_attributes:
×
2392
                for x, y in labels:
×
2393
                    if labels[(x, y)].attribute is None or labels[(x, y)].attribute == "existing":
×
2394
                        plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=12)
×
2395
                    else:
2396
                        plt.plot(x, y, "*", markerfacecolor=vals_stable[ii], markersize=18)
×
2397
                    ii += 1
×
2398
            else:
2399
                for x, y in labels:
×
2400
                    plt.plot(x, y, "o", markerfacecolor=vals_stable[ii], markersize=15)
×
2401
                    ii += 1
×
2402

2403
        font = FontProperties()
1✔
2404
        font.set_weight("bold")
1✔
2405
        font.set_size(24)
1✔
2406

2407
        # Sets a nice layout depending on the type of PD. Also defines a
2408
        # "center" for the PD, which then allows the annotations to be spread
2409
        # out in a nice manner.
2410
        if len(self._pd.elements) == 3:
1✔
2411
            plt.axis("equal")
1✔
2412
            plt.xlim((-0.1, 1.2))
1✔
2413
            plt.ylim((-0.1, 1.0))
1✔
2414
            plt.axis("off")
1✔
2415
            center = (0.5, math.sqrt(3) / 6)
1✔
2416
        else:
2417
            miny = min(c[1] for c in labels)
1✔
2418
            ybuffer = max(abs(miny) * 0.1, 0.1)
1✔
2419
            plt.xlim((-0.1, 1.1))
1✔
2420
            plt.ylim((miny - ybuffer, ybuffer))
1✔
2421
            center = (0.5, miny / 2)
1✔
2422
            plt.xlabel("Fraction", fontsize=28, fontweight="bold")
1✔
2423
            plt.ylabel("Formation energy (eV/atom)", fontsize=28, fontweight="bold")
1✔
2424

2425
        for coords in sorted(labels, key=lambda x: -x[1]):
1✔
2426
            entry = labels[coords]
1✔
2427
            label = entry.name
1✔
2428

2429
            # The follow defines an offset for the annotation text emanating
2430
            # from the center of the PD. Results in fairly nice layouts for the
2431
            # most part.
2432
            vec = np.array(coords) - center
1✔
2433
            vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec
1✔
2434
            valign = "bottom" if vec[1] > 0 else "top"
1✔
2435
            if vec[0] < -0.01:
1✔
2436
                halign = "right"
1✔
2437
            elif vec[0] > 0.01:
1✔
2438
                halign = "left"
1✔
2439
            else:
2440
                halign = "center"
1✔
2441
            if label_stable:
1✔
2442
                if process_attributes and entry.attribute == "new":
1✔
2443
                    plt.annotate(
×
2444
                        latexify(label),
2445
                        coords,
2446
                        xytext=vec,
2447
                        textcoords="offset points",
2448
                        horizontalalignment=halign,
2449
                        verticalalignment=valign,
2450
                        fontproperties=font,
2451
                        color="g",
2452
                    )
2453
                else:
2454
                    plt.annotate(
1✔
2455
                        latexify(label),
2456
                        coords,
2457
                        xytext=vec,
2458
                        textcoords="offset points",
2459
                        horizontalalignment=halign,
2460
                        verticalalignment=valign,
2461
                        fontproperties=font,
2462
                    )
2463

2464
        if self.show_unstable:
1✔
2465
            font = FontProperties()
1✔
2466
            font.set_size(16)
1✔
2467
            energies_unstable = [self._pd.get_e_above_hull(entry) for entry, coord in unstable.items()]
1✔
2468
            if energy_colormap is not None:
1✔
2469
                energies.extend(energies_unstable)
×
2470
                vals_unstable = _map.to_rgba(energies_unstable)
×
2471
            ii = 0
1✔
2472
            for entry, coords in unstable.items():
1✔
2473
                ehull = self._pd.get_e_above_hull(entry)
1✔
2474
                if ehull < self.show_unstable:
1✔
2475
                    vec = np.array(coords) - center
1✔
2476
                    vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec
1✔
2477
                    label = entry.name
1✔
2478
                    if energy_colormap is None:
1✔
2479
                        plt.plot(
1✔
2480
                            coords[0],
2481
                            coords[1],
2482
                            "ks",
2483
                            linewidth=3,
2484
                            markeredgecolor="k",
2485
                            markerfacecolor="r",
2486
                            markersize=8,
2487
                        )
2488
                    else:
2489
                        plt.plot(
×
2490
                            coords[0],
2491
                            coords[1],
2492
                            "s",
2493
                            linewidth=3,
2494
                            markeredgecolor="k",
2495
                            markerfacecolor=vals_unstable[ii],
2496
                            markersize=8,
2497
                        )
2498
                    if label_unstable:
1✔
2499
                        plt.annotate(
1✔
2500
                            latexify(label),
2501
                            coords,
2502
                            xytext=vec,
2503
                            textcoords="offset points",
2504
                            horizontalalignment=halign,
2505
                            color="b",
2506
                            verticalalignment=valign,
2507
                            fontproperties=font,
2508
                        )
2509
                    ii += 1
1✔
2510
        if energy_colormap is not None and show_colorbar:
1✔
2511
            _map.set_array(energies)
×
2512
            cbar = plt.colorbar(_map)
×
2513
            cbar.set_label(
×
2514
                "Energy [meV/at] above hull (positive values)\nInverse energy [meV/at] above hull (negative values)",
2515
                rotation=-90,
2516
                ha="center",
2517
                va="bottom",
2518
            )
2519
        f = plt.gcf()
1✔
2520
        f.set_size_inches((8, 6))
1✔
2521
        plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07)
1✔
2522
        return plt
1✔
2523

2524
    def _get_3d_plot(self, label_stable=True):
1✔
2525
        """
2526
        Shows the plot using pylab. Usually I won"t do imports in methods,
2527
        but since plotting is a fairly expensive library to load and not all
2528
        machines have matplotlib installed, I have done it this way.
2529
        """
2530
        import matplotlib.pyplot as plt
1✔
2531
        from matplotlib.font_manager import FontProperties
1✔
2532

2533
        fig = plt.figure()
1✔
2534
        ax = fig.add_subplot(111, projection="3d")
1✔
2535
        font = FontProperties(weight="bold", size=13)
1✔
2536
        (lines, labels, unstable) = self.pd_plot_data
1✔
2537
        count = 1
1✔
2538
        newlabels = []
1✔
2539
        for x, y, z in lines:
1✔
2540
            ax.plot(
1✔
2541
                x,
2542
                y,
2543
                z,
2544
                "bo-",
2545
                linewidth=3,
2546
                markeredgecolor="b",
2547
                markerfacecolor="r",
2548
                markersize=10,
2549
            )
2550
        for coords in sorted(labels):
1✔
2551
            entry = labels[coords]
1✔
2552
            label = entry.name
1✔
2553
            if label_stable:
1✔
2554
                if len(entry.composition.elements) == 1:
1✔
2555
                    ax.text(coords[0], coords[1], coords[2], label, fontproperties=font)
1✔
2556
                else:
2557
                    ax.text(coords[0], coords[1], coords[2], str(count), fontsize=12)
1✔
2558
                    newlabels.append(f"{count} : {latexify(label)}")
1✔
2559
                    count += 1
1✔
2560
        plt.figtext(0.01, 0.01, "\n".join(newlabels), fontproperties=font)
1✔
2561
        ax.axis("off")
1✔
2562
        ax.set_xlim(-0.1, 0.72)
1✔
2563
        ax.set_ylim(0, 0.66)
1✔
2564
        ax.set_zlim(0, 0.56)  # pylint: disable=E1101
1✔
2565
        return plt
1✔
2566

2567
    def write_image(self, stream: str | StringIO, image_format: str = "svg", **kwargs) -> None:
1✔
2568
        """
2569
        Writes the phase diagram to an image in a stream.
2570

2571
        Args:
2572
            stream (str | StringIO): stream to write to. Can be a file stream or a StringIO stream.
2573
            image_format (str): format for image. Can be any of matplotlib supported formats.
2574
                Defaults to 'svg' for best results for vector graphics.
2575
            **kwargs: Pass through to get_plot function.
2576
        """
2577
        plt = self.get_plot(**kwargs)
×
2578

2579
        f = plt.gcf()
×
2580
        f.set_size_inches((12, 10))
×
2581

2582
        plt.savefig(stream, format=image_format)
×
2583

2584
    def plot_chempot_range_map(self, elements, referenced=True):
1✔
2585
        """
2586
        Plot the chemical potential range _map. Currently works only for
2587
        3-component PDs.
2588

2589
        Args:
2590
            elements: Sequence of elements to be considered as independent
2591
                variables. E.g., if you want to show the stability ranges of
2592
                all Li-Co-O phases wrt to uLi and uO, you will supply
2593
                [Element("Li"), Element("O")]
2594
            referenced: if True, gives the results with a reference being the
2595
                        energy of the elemental phase. If False, gives absolute values.
2596
        """
2597
        self.get_chempot_range_map_plot(elements, referenced=referenced).show()
×
2598

2599
    def get_chempot_range_map_plot(self, elements, referenced=True):
1✔
2600
        """
2601
        Returns a plot of the chemical potential range _map. Currently works
2602
        only for 3-component PDs.
2603

2604
        Args:
2605
            elements: Sequence of elements to be considered as independent
2606
                variables. E.g., if you want to show the stability ranges of
2607
                all Li-Co-O phases wrt to uLi and uO, you will supply
2608
                [Element("Li"), Element("O")]
2609
            referenced: if True, gives the results with a reference being the
2610
                        energy of the elemental phase. If False, gives absolute values.
2611

2612
        Returns:
2613
            A matplotlib plot object.
2614
        """
2615
        plt = pretty_plot(12, 8)
1✔
2616
        chempot_ranges = self._pd.get_chempot_range_map(elements, referenced=referenced)
1✔
2617
        missing_lines = {}
1✔
2618
        excluded_region = []
1✔
2619
        for entry, lines in chempot_ranges.items():
1✔
2620
            comp = entry.composition
1✔
2621
            center_x = 0
1✔
2622
            center_y = 0
1✔
2623
            coords = []
1✔
2624
            contain_zero = any(comp.get_atomic_fraction(el) == 0 for el in elements)
1✔
2625
            is_boundary = (not contain_zero) and sum(comp.get_atomic_fraction(el) for el in elements) == 1
1✔
2626
            for line in lines:
1✔
2627
                (x, y) = line.coords.transpose()
1✔
2628
                plt.plot(x, y, "k-")
1✔
2629

2630
                for coord in line.coords:
1✔
2631
                    if not in_coord_list(coords, coord):
1✔
2632
                        coords.append(coord.tolist())
1✔
2633
                        center_x += coord[0]
1✔
2634
                        center_y += coord[1]
1✔
2635
                if is_boundary:
1✔
2636
                    excluded_region.extend(line.coords)
1✔
2637

2638
            if coords and contain_zero:
1✔
2639
                missing_lines[entry] = coords
1✔
2640
            else:
2641
                xy = (center_x / len(coords), center_y / len(coords))
1✔
2642
                plt.annotate(latexify(entry.name), xy, fontsize=22)
1✔
2643

2644
        ax = plt.gca()
1✔
2645
        xlim = ax.get_xlim()
1✔
2646
        ylim = ax.get_ylim()
1✔
2647

2648
        # Shade the forbidden chemical potential regions.
2649
        excluded_region.append([xlim[1], ylim[1]])
1✔
2650
        excluded_region = sorted(excluded_region, key=lambda c: c[0])
1✔
2651
        (x, y) = np.transpose(excluded_region)
1✔
2652
        plt.fill(x, y, "0.80")
1✔
2653

2654
        # The hull does not generate the missing horizontal and vertical lines.
2655
        # The following code fixes this.
2656
        el0 = elements[0]
1✔
2657
        el1 = elements[1]
1✔
2658
        for entry, coords in missing_lines.items():
1✔
2659
            center_x = sum(c[0] for c in coords)
1✔
2660
            center_y = sum(c[1] for c in coords)
1✔
2661
            comp = entry.composition
1✔
2662
            is_x = comp.get_atomic_fraction(el0) < 0.01
1✔
2663
            is_y = comp.get_atomic_fraction(el1) < 0.01
1✔
2664
            n = len(coords)
1✔
2665
            if not (is_x and is_y):
1✔
2666
                if is_x:
1✔
2667
                    coords = sorted(coords, key=lambda c: c[1])
1✔
2668
                    for i in [0, -1]:
1✔
2669
                        x = [min(xlim), coords[i][0]]
1✔
2670
                        y = [coords[i][1], coords[i][1]]
1✔
2671
                        plt.plot(x, y, "k")
1✔
2672
                        center_x += min(xlim)
1✔
2673
                        center_y += coords[i][1]
1✔
2674
                elif is_y:
×
2675
                    coords = sorted(coords, key=lambda c: c[0])
×
2676
                    for i in [0, -1]:
×
2677
                        x = [coords[i][0], coords[i][0]]
×
2678
                        y = [coords[i][1], min(ylim)]
×
2679
                        plt.plot(x, y, "k")
×
2680
                        center_x += coords[i][0]
×
2681
                        center_y += min(ylim)
×
2682
                xy = (center_x / (n + 2), center_y / (n + 2))
1✔
2683
            else:
2684
                center_x = sum(coord[0] for coord in coords) + xlim[0]
1✔
2685
                center_y = sum(coord[1] for coord in coords) + ylim[0]
1✔
2686
                xy = (center_x / (n + 1), center_y / (n + 1))
1✔
2687

2688
            plt.annotate(
1✔
2689
                latexify(entry.name),
2690
                xy,
2691
                horizontalalignment="center",
2692
                verticalalignment="center",
2693
                fontsize=22,
2694
            )
2695

2696
        plt.xlabel(f"$\\mu_{{{el0.symbol}}} - \\mu_{{{el0.symbol}}}^0$ (eV)")
1✔
2697
        plt.ylabel(f"$\\mu_{{{el1.symbol}}} - \\mu_{{{el1.symbol}}}^0$ (eV)")
1✔
2698
        plt.tight_layout()
1✔
2699
        return plt
1✔
2700

2701
    def get_contour_pd_plot(self):
1✔
2702
        """
2703
        Plot a contour phase diagram plot, where phase triangles are colored
2704
        according to degree of instability by interpolation. Currently only
2705
        works for 3-component phase diagrams.
2706

2707
        Returns:
2708
            A matplotlib plot object.
2709
        """
2710
        from matplotlib import cm
1✔
2711
        from scipy import interpolate
1✔
2712

2713
        pd = self._pd
1✔
2714
        entries = pd.qhull_entries
1✔
2715
        data = np.array(pd.qhull_data)
1✔
2716

2717
        plt = self._get_2d_plot()
1✔
2718
        data[:, 0:2] = triangular_coord(data[:, 0:2]).transpose()
1✔
2719
        for i, e in enumerate(entries):
1✔
2720
            data[i, 2] = self._pd.get_e_above_hull(e)
1✔
2721

2722
        gridsize = 0.005
1✔
2723
        xnew = np.arange(0, 1.0, gridsize)
1✔
2724
        ynew = np.arange(0, 1, gridsize)
1✔
2725

2726
        f = interpolate.LinearNDInterpolator(data[:, 0:2], data[:, 2])
1✔
2727
        znew = np.zeros((len(ynew), len(xnew)))
1✔
2728
        for i, xval in enumerate(xnew):
1✔
2729
            for j, yval in enumerate(ynew):
1✔
2730
                znew[j, i] = f(xval, yval)
1✔
2731

2732
        # pylint: disable=E1101
2733
        plt.contourf(xnew, ynew, znew, 1000, cmap=cm.autumn_r)
1✔
2734

2735
        plt.colorbar()
1✔
2736
        return plt
1✔
2737

2738
    def _create_plotly_lines(self):
1✔
2739
        """
2740
        Creates Plotly scatter (line) plots for all phase diagram facets.
2741

2742
        Returns:
2743
            go.Scatter (or go.Scatter3d) plot
2744
        """
2745
        line_plot = None
1✔
2746
        x, y, z, energies = [], [], [], []
1✔
2747

2748
        for line in self.pd_plot_data[0]:
1✔
2749
            x.extend(list(line[0]) + [None])
1✔
2750
            y.extend(list(line[1]) + [None])
1✔
2751

2752
            if self._dim == 3:
1✔
2753
                z.extend(
1✔
2754
                    [self._pd.get_form_energy_per_atom(self.pd_plot_data[1][coord]) for coord in zip(line[0], line[1])]
2755
                    + [None]
2756
                )
2757

2758
            elif self._dim == 4:
1✔
2759
                energies.extend(
1✔
2760
                    [
2761
                        self._pd.get_form_energy_per_atom(self.pd_plot_data[1][coord])
2762
                        for coord in zip(line[0], line[1], line[2])
2763
                    ]
2764
                    + [None]
2765
                )
2766
                z.extend(list(line[2]) + [None])
1✔
2767

2768
        plot_args = dict(
1✔
2769
            mode="lines",
2770
            hoverinfo="none",
2771
            line={"color": "rgba(0,0,0,1.0)", "width": 7.0},
2772
            showlegend=False,
2773
        )
2774

2775
        if self._dim == 2:
1✔
2776
            line_plot = go.Scatter(x=x, y=y, **plot_args)
1✔
2777
        elif self._dim == 3:
1✔
2778
            line_plot = go.Scatter3d(x=y, y=x, z=z, **plot_args)
1✔
2779
        elif self._dim == 4:
1✔
2780
            line_plot = go.Scatter3d(x=x, y=y, z=z, **plot_args)
1✔
2781

2782
        return line_plot
1✔
2783

2784
    def _create_plotly_stable_labels(self, label_stable=True):
1✔
2785
        """
2786
        Creates a (hidable) scatter trace containing labels of stable phases.
2787
        Contains some functionality for creating sensible label positions.
2788

2789
        Returns:
2790
            go.Scatter (or go.Scatter3d) plot
2791
        """
2792
        x, y, z, text, textpositions = [], [], [], [], []
1✔
2793
        stable_labels_plot = None
1✔
2794
        min_energy_x = None
1✔
2795
        offset_2d = 0.005  # extra distance to offset label position for clarity
1✔
2796
        offset_3d = 0.01
1✔
2797

2798
        energy_offset = -0.1 * self._min_energy
1✔
2799

2800
        if self._dim == 2:
1✔
2801
            min_energy_x = min(list(self.pd_plot_data[1]), key=lambda c: c[1])[0]
1✔
2802

2803
        for coords, entry in self.pd_plot_data[1].items():
1✔
2804
            if entry.composition.is_element:  # taken care of by other function
1✔
2805
                continue
1✔
2806
            x_coord = coords[0]
1✔
2807
            y_coord = coords[1]
1✔
2808
            textposition = None
1✔
2809

2810
            if self._dim == 2:
1✔
2811
                textposition = "bottom left"
1✔
2812
                if x_coord >= min_energy_x:
1✔
2813
                    textposition = "bottom right"
1✔
2814
                    x_coord += offset_2d
1✔
2815
                else:
2816
                    x_coord -= offset_2d
×
2817
                y_coord -= offset_2d
1✔
2818
            elif self._dim == 3:
1✔
2819
                textposition = "middle center"
1✔
2820
                if coords[0] > 0.5:
1✔
2821
                    x_coord += offset_3d
1✔
2822
                else:
2823
                    x_coord -= offset_3d
1✔
2824
                if coords[1] > 0.866 / 2:
1✔
2825
                    y_coord -= offset_3d
1✔
2826
                else:
2827
                    y_coord += offset_3d
1✔
2828

2829
                z.append(self._pd.get_form_energy_per_atom(entry) + energy_offset)
1✔
2830

2831
            elif self._dim == 4:
1✔
2832
                x_coord = x_coord - offset_3d
1✔
2833
                y_coord = y_coord - offset_3d
1✔
2834
                textposition = "bottom right"
1✔
2835
                z.append(coords[2])
1✔
2836

2837
            x.append(x_coord)
1✔
2838
            y.append(y_coord)
1✔
2839
            textpositions.append(textposition)
1✔
2840

2841
            comp = entry.composition
1✔
2842
            if hasattr(entry, "original_entry"):
1✔
2843
                comp = entry.original_entry.composition
×
2844

2845
            formula = comp.reduced_formula
1✔
2846
            text.append(htmlify(formula))
1✔
2847

2848
        visible = True
1✔
2849
        if not label_stable or self._dim == 4:
1✔
2850
            visible = "legendonly"
1✔
2851

2852
        plot_args = dict(
1✔
2853
            text=text,
2854
            textposition=textpositions,
2855
            mode="text",
2856
            name="Labels (stable)",
2857
            hoverinfo="skip",
2858
            opacity=1.0,
2859
            visible=visible,
2860
            showlegend=True,
2861
        )
2862

2863
        if self._dim == 2:
1✔
2864
            stable_labels_plot = go.Scatter(x=x, y=y, **plot_args)
1✔
2865
        elif self._dim == 3:
1✔
2866
            stable_labels_plot = go.Scatter3d(x=y, y=x, z=z, **plot_args)
1✔
2867
        elif self._dim == 4:
1✔
2868
            stable_labels_plot = go.Scatter3d(x=x, y=y, z=z, **plot_args)
1✔
2869

2870
        return stable_labels_plot
1✔
2871

2872
    def _create_plotly_element_annotations(self):
1✔
2873
        """
2874
        Creates terminal element annotations for Plotly phase diagrams.
2875

2876
        Returns:
2877
            list of annotation dicts.
2878
        """
2879
        annotations_list = []
1✔
2880
        x, y, z = None, None, None
1✔
2881

2882
        for coords, entry in self.pd_plot_data[1].items():
1✔
2883
            if not entry.composition.is_element:
1✔
2884
                continue
1✔
2885

2886
            x, y = coords[0], coords[1]
1✔
2887

2888
            if self._dim == 3:
1✔
2889
                z = self._pd.get_form_energy_per_atom(entry)
1✔
2890
            elif self._dim == 4:
1✔
2891
                z = coords[2]
1✔
2892

2893
            if entry.composition.is_element:
1✔
2894
                clean_formula = str(entry.composition.elements[0])
1✔
2895
                if hasattr(entry, "original_entry"):
1✔
2896
                    orig_comp = entry.original_entry.composition
×
2897
                    clean_formula = htmlify(orig_comp.reduced_formula)
×
2898

2899
                font_dict = {"color": "#000000", "size": 24.0}
1✔
2900
                opacity = 1.0
1✔
2901

2902
            annotation = plotly_layouts["default_annotation_layout"].copy()
1✔
2903
            annotation.update(
1✔
2904
                {
2905
                    "x": x,
2906
                    "y": y,
2907
                    "font": font_dict,
2908
                    "text": clean_formula,
2909
                    "opacity": opacity,
2910
                }
2911
            )
2912

2913
            if self._dim in (3, 4):
1✔
2914
                for d in ["xref", "yref"]:
1✔
2915
                    annotation.pop(d)  # Scatter3d cannot contain xref, yref
1✔
2916
                    if self._dim == 3:
1✔
2917
                        annotation.update({"x": y, "y": x})
1✔
2918
                        if entry.composition.is_element:
1✔
2919
                            z = 0.9 * self._min_energy  # place label 10% above base
1✔
2920

2921
                annotation.update({"z": z})
1✔
2922

2923
            annotations_list.append(annotation)
1✔
2924

2925
        # extra point ensures equilateral triangular scaling is displayed
2926
        if self._dim == 3:
1✔
2927
            annotations_list.append(dict(x=1, y=1, z=0, opacity=0, text=""))
1✔
2928

2929
        return annotations_list
1✔
2930

2931
    def _create_plotly_figure_layout(self, label_stable=True):
1✔
2932
        """
2933
        Creates layout for plotly phase diagram figure and updates with
2934
        figure annotations.
2935

2936
        Args:
2937
            label_stable (bool): Whether to label stable compounds
2938

2939
        Returns:
2940
            Dictionary with Plotly figure layout settings.
2941
        """
2942
        annotations_list = None
1✔
2943
        layout = {}
1✔
2944

2945
        if label_stable:
1✔
2946
            annotations_list = self._create_plotly_element_annotations()
1✔
2947

2948
        if self._dim == 2:
1✔
2949
            layout = plotly_layouts["default_binary_layout"].copy()
1✔
2950
            layout["annotations"] = annotations_list
1✔
2951
        elif self._dim == 3:
1✔
2952
            layout = plotly_layouts["default_ternary_layout"].copy()
1✔
2953
            layout["scene"].update({"annotations": annotations_list})
1✔
2954
        elif self._dim == 4:
1✔
2955
            layout = plotly_layouts["default_quaternary_layout"].copy()
1✔
2956
            layout["scene"].update({"annotations": annotations_list})
1✔
2957

2958
        return layout
1✔
2959

2960
    def _create_plotly_markers(self, label_uncertainties=False):
1✔
2961
        """
2962
        Creates stable and unstable marker plots for overlaying on the phase diagram.
2963

2964
        Returns:
2965
            Tuple of Plotly go.Scatter (or go.Scatter3d) objects in order:
2966
            (stable markers, unstable markers)
2967
        """
2968

2969
        def get_marker_props(coords, entries, stable=True):
1✔
2970
            """Method for getting marker locations, hovertext, and error bars
2971
            from pd_plot_data
2972
            """
2973
            x, y, z, texts, energies, uncertainties = [], [], [], [], [], []
1✔
2974

2975
            for coord, entry in zip(coords, entries):
1✔
2976
                energy = round(self._pd.get_form_energy_per_atom(entry), 3)
1✔
2977

2978
                entry_id = getattr(entry, "entry_id", "no ID")
1✔
2979
                comp = entry.composition
1✔
2980

2981
                if hasattr(entry, "original_entry"):
1✔
2982
                    comp = entry.original_entry.composition
×
2983
                    entry_id = getattr(entry, "attribute", "no ID")
×
2984

2985
                formula = comp.reduced_formula
1✔
2986
                clean_formula = htmlify(formula)
1✔
2987
                label = f"{clean_formula} ({entry_id}) <br> {energy} eV/atom"
1✔
2988

2989
                if not stable:
1✔
2990
                    e_above_hull = round(self._pd.get_e_above_hull(entry), 3)
1✔
2991
                    if e_above_hull > self.show_unstable:
1✔
2992
                        continue
1✔
2993
                    label += f" (+{e_above_hull} eV/atom)"
1✔
2994
                    energies.append(e_above_hull)
1✔
2995
                else:
2996
                    uncertainty = 0
1✔
2997
                    if hasattr(entry, "correction_uncertainty_per_atom") and label_uncertainties:
1✔
2998
                        uncertainty = round(entry.correction_uncertainty_per_atom, 4)
×
2999
                        label += f"<br> (Error: +/- {uncertainty} eV/atom)"
×
3000

3001
                    uncertainties.append(uncertainty)
1✔
3002
                    energies.append(energy)
1✔
3003

3004
                texts.append(label)
1✔
3005

3006
                x.append(coord[0])
1✔
3007
                y.append(coord[1])
1✔
3008

3009
                if self._dim == 3:
1✔
3010
                    z.append(energy)
1✔
3011
                elif self._dim == 4:
1✔
3012
                    z.append(coord[2])
1✔
3013

3014
            return {
1✔
3015
                "x": x,
3016
                "y": y,
3017
                "z": z,
3018
                "texts": texts,
3019
                "energies": energies,
3020
                "uncertainties": uncertainties,
3021
            }
3022

3023
        stable_coords, stable_entries = zip(*self.pd_plot_data[1].items())
1✔
3024
        unstable_entries, unstable_coords = zip(*self.pd_plot_data[2].items())
1✔
3025

3026
        stable_props = get_marker_props(stable_coords, stable_entries)
1✔
3027

3028
        unstable_props = get_marker_props(unstable_coords, unstable_entries, stable=False)
1✔
3029

3030
        stable_markers, unstable_markers = {}, {}
1✔
3031

3032
        if self._dim == 2:
1✔
3033
            stable_markers = plotly_layouts["default_binary_marker_settings"].copy()
1✔
3034
            stable_markers.update(
1✔
3035
                dict(
3036
                    x=list(stable_props["x"]),
3037
                    y=list(stable_props["y"]),
3038
                    name="Stable",
3039
                    marker=dict(color="darkgreen", size=11, line=dict(color="black", width=2)),
3040
                    opacity=0.9,
3041
                    hovertext=stable_props["texts"],
3042
                    error_y=dict(
3043
                        array=list(stable_props["uncertainties"]),
3044
                        type="data",
3045
                        color="gray",
3046
                        thickness=2.5,
3047
                        width=5,
3048
                    ),
3049
                )
3050
            )
3051

3052
            unstable_markers = plotly_layouts["default_binary_marker_settings"].copy()
1✔
3053
            unstable_markers.update(
1✔
3054
                dict(
3055
                    x=list(unstable_props["x"]),
3056
                    y=list(unstable_props["y"]),
3057
                    name="Above Hull",
3058
                    marker=dict(
3059
                        color=unstable_props["energies"],
3060
                        colorscale=plotly_layouts["unstable_colorscale"],
3061
                        size=6,
3062
                        symbol="diamond",
3063
                    ),
3064
                    hovertext=unstable_props["texts"],
3065
                )
3066
            )
3067

3068
        elif self._dim == 3:
1✔
3069
            stable_markers = plotly_layouts["default_ternary_marker_settings"].copy()
1✔
3070
            stable_markers.update(
1✔
3071
                dict(
3072
                    x=list(stable_props["y"]),
3073
                    y=list(stable_props["x"]),
3074
                    z=list(stable_props["z"]),
3075
                    name="Stable",
3076
                    marker=dict(
3077
                        color="black",
3078
                        size=12,
3079
                        opacity=0.8,
3080
                        line=dict(color="black", width=3),
3081
                    ),
3082
                    hovertext=stable_props["texts"],
3083
                    error_z=dict(
3084
                        array=list(stable_props["uncertainties"]),
3085
                        type="data",
3086
                        color="darkgray",
3087
                        width=10,
3088
                        thickness=5,
3089
                    ),
3090
                )
3091
            )
3092

3093
            unstable_markers = plotly_layouts["default_ternary_marker_settings"].copy()
1✔
3094
            unstable_markers.update(
1✔
3095
                dict(
3096
                    x=unstable_props["y"],
3097
                    y=unstable_props["x"],
3098
                    z=unstable_props["z"],
3099
                    name="Above Hull",
3100
                    marker=dict(
3101
                        color=unstable_props["energies"],
3102
                        colorscale=plotly_layouts["unstable_colorscale"],
3103
                        size=6,
3104
                        symbol="diamond",
3105
                        colorbar=dict(title="Energy Above Hull<br>(eV/atom)", x=0.05, len=0.75),
3106
                    ),
3107
                    hovertext=unstable_props["texts"],
3108
                )
3109
            )
3110

3111
        elif self._dim == 4:
1✔
3112
            stable_markers = plotly_layouts["default_quaternary_marker_settings"].copy()
1✔
3113
            stable_markers.update(
1✔
3114
                dict(
3115
                    x=stable_props["x"],
3116
                    y=stable_props["y"],
3117
                    z=stable_props["z"],
3118
                    name="Stable",
3119
                    marker=dict(
3120
                        color=stable_props["energies"],
3121
                        colorscale=plotly_layouts["stable_markers_colorscale"],
3122
                        size=8,
3123
                        opacity=0.9,
3124
                    ),
3125
                    hovertext=stable_props["texts"],
3126
                )
3127
            )
3128

3129
            unstable_markers = plotly_layouts["default_quaternary_marker_settings"].copy()
1✔
3130
            unstable_markers.update(
1✔
3131
                dict(
3132
                    x=unstable_props["x"],
3133
                    y=unstable_props["y"],
3134
                    z=unstable_props["z"],
3135
                    name="Above Hull",
3136
                    marker=dict(
3137
                        color=unstable_props["energies"],
3138
                        colorscale=plotly_layouts["unstable_colorscale"],
3139
                        size=5,
3140
                        symbol="diamond",
3141
                        colorbar=dict(title="Energy Above Hull<br>(eV/atom)", x=0.05, len=0.75),
3142
                    ),
3143
                    hovertext=unstable_props["texts"],
3144
                    visible="legendonly",
3145
                )
3146
            )
3147

3148
        stable_marker_plot = go.Scatter(**stable_markers) if self._dim == 2 else go.Scatter3d(**stable_markers)
1✔
3149
        unstable_marker_plot = go.Scatter(**unstable_markers) if self._dim == 2 else go.Scatter3d(**unstable_markers)
1✔
3150

3151
        return stable_marker_plot, unstable_marker_plot
1✔
3152

3153
    def _create_plotly_uncertainty_shading(self, stable_marker_plot):
1✔
3154
        """
3155
        Creates shaded uncertainty region for stable entries. Currently only works
3156
        for binary (dim=2) phase diagrams.
3157

3158
        Args:
3159
            stable_marker_plot: go.Scatter object with stable markers and their
3160
            error bars.
3161

3162
        Returns:
3163
            Plotly go.Scatter object with uncertainty window shading.
3164
        """
3165
        uncertainty_plot = None
×
3166

3167
        x = stable_marker_plot.x
×
3168
        y = stable_marker_plot.y
×
3169

3170
        transformed = False
×
3171
        if hasattr(self._pd, "original_entries") or hasattr(self._pd, "chempots"):
×
3172
            transformed = True
×
3173

3174
        if self._dim == 2:
×
3175
            error = stable_marker_plot.error_y["array"]
×
3176

3177
            points = np.append(x, [y, error]).reshape(3, -1).T
×
3178
            points = points[points[:, 0].argsort()]  # sort by composition  # pylint: disable=E1136
×
3179

3180
            # these steps trace out the boundary pts of the uncertainty window
3181
            outline = points[:, :2].copy()
×
3182
            outline[:, 1] = outline[:, 1] + points[:, 2]
×
3183

3184
            last = -1
×
3185
            if transformed:
×
3186
                last = None  # allows for uncertainty in terminal compounds
×
3187

3188
            flipped_points = np.flip(points[:last, :].copy(), axis=0)
×
3189
            flipped_points[:, 1] = flipped_points[:, 1] - flipped_points[:, 2]
×
3190
            outline = np.vstack((outline, flipped_points[:, :2]))
×
3191

3192
            uncertainty_plot = go.Scatter(
×
3193
                x=outline[:, 0],
3194
                y=outline[:, 1],
3195
                name="Uncertainty (window)",
3196
                fill="toself",
3197
                mode="lines",
3198
                line=dict(width=0),
3199
                fillcolor="lightblue",
3200
                hoverinfo="skip",
3201
                opacity=0.4,
3202
            )
3203

3204
        return uncertainty_plot
×
3205

3206
    def _create_plotly_ternary_support_lines(self):
1✔
3207
        """
3208
        Creates support lines which aid in seeing the ternary hull in three
3209
        dimensions.
3210

3211
        Returns:
3212
            go.Scatter3d plot of support lines for ternary phase diagram.
3213
        """
3214
        stable_entry_coords = dict(map(reversed, self.pd_plot_data[1].items()))
1✔
3215

3216
        elem_coords = [stable_entry_coords[e] for e in self._pd.el_refs.values()]
1✔
3217

3218
        # add top and bottom triangle guidelines
3219
        x, y, z = [], [], []
1✔
3220
        for line in itertools.combinations(elem_coords, 2):
1✔
3221
            x.extend([line[0][0], line[1][0], None] * 2)
1✔
3222
            y.extend([line[0][1], line[1][1], None] * 2)
1✔
3223
            z.extend([0, 0, None, self._min_energy, self._min_energy, None])
1✔
3224

3225
        # add vertical guidelines
3226
        for elem in elem_coords:
1✔
3227
            x.extend([elem[0], elem[0], None])
1✔
3228
            y.extend([elem[1], elem[1], None])
1✔
3229
            z.extend([0, self._min_energy, None])
1✔
3230

3231
        return go.Scatter3d(
1✔
3232
            x=list(y),
3233
            y=list(x),
3234
            z=list(z),
3235
            mode="lines",
3236
            hoverinfo="none",
3237
            line=dict(color="rgba (0, 0, 0, 0.4)", dash="solid", width=1.0),
3238
            showlegend=False,
3239
        )
3240

3241
    def _create_plotly_ternary_hull(self):
1✔
3242
        """
3243
        Creates shaded mesh plot for coloring the ternary hull by formation energy.
3244

3245
        Returns:
3246
            go.Mesh3d plot
3247
        """
3248
        facets = np.array(self._pd.facets)
1✔
3249
        coords = np.array([triangular_coord(c) for c in zip(self._pd.qhull_data[:-1, 0], self._pd.qhull_data[:-1, 1])])
1✔
3250
        energies = np.array([self._pd.get_form_energy_per_atom(e) for e in self._pd.qhull_entries])
1✔
3251

3252
        return go.Mesh3d(
1✔
3253
            x=list(coords[:, 1]),
3254
            y=list(coords[:, 0]),
3255
            z=list(energies),
3256
            i=list(facets[:, 1]),
3257
            j=list(facets[:, 0]),
3258
            k=list(facets[:, 2]),
3259
            opacity=0.8,
3260
            intensity=list(energies),
3261
            colorscale=plotly_layouts["stable_colorscale"],
3262
            colorbar=dict(title="Formation energy<br>(eV/atom)", x=0.9, len=0.75),
3263
            hoverinfo="none",
3264
            lighting=dict(diffuse=0.0, ambient=1.0),
3265
            name="Convex Hull (shading)",
3266
            flatshading=True,
3267
            showlegend=True,
3268
        )
3269

3270

3271
def uniquelines(q):
1✔
3272
    """
3273
    Given all the facets, convert it into a set of unique lines. Specifically
3274
    used for converting convex hull facets into line pairs of coordinates.
3275

3276
    Args:
3277
        q: A 2-dim sequence, where each row represents a facet. E.g.,
3278
            [[1,2,3],[3,6,7],...]
3279

3280
    Returns:
3281
        setoflines:
3282
            A set of tuple of lines. E.g., ((1,2), (1,3), (2,3), ....)
3283
    """
3284
    setoflines = set()
1✔
3285
    for facets in q:
1✔
3286
        for line in itertools.combinations(facets, 2):
1✔
3287
            setoflines.add(tuple(sorted(line)))
1✔
3288
    return setoflines
1✔
3289

3290

3291
def triangular_coord(coord):
1✔
3292
    """
3293
    Convert a 2D coordinate into a triangle-based coordinate system for a
3294
    prettier phase diagram.
3295

3296
    Args:
3297
        coord: coordinate used in the convex hull computation.
3298

3299
    Returns:
3300
        coordinates in a triangular-based coordinate system.
3301
    """
3302
    unitvec = np.array([[1, 0], [0.5, math.sqrt(3) / 2]])
1✔
3303

3304
    result = np.dot(np.array(coord), unitvec)
1✔
3305
    return result.transpose()
1✔
3306

3307

3308
def tet_coord(coord):
1✔
3309
    """
3310
    Convert a 3D coordinate into a tetrahedron based coordinate system for a
3311
    prettier phase diagram.
3312

3313
    Args:
3314
        coord: coordinate used in the convex hull computation.
3315

3316
    Returns:
3317
        coordinates in a tetrahedron-based coordinate system.
3318
    """
3319
    unitvec = np.array(
1✔
3320
        [
3321
            [1, 0, 0],
3322
            [0.5, math.sqrt(3) / 2, 0],
3323
            [0.5, 1.0 / 3.0 * math.sqrt(3) / 2, math.sqrt(6) / 3],
3324
        ]
3325
    )
3326
    result = np.dot(np.array(coord), unitvec)
1✔
3327
    return result.transpose()
1✔
3328

3329

3330
def order_phase_diagram(lines, stable_entries, unstable_entries, ordering):
1✔
3331
    """
3332
    Orders the entries (their coordinates) in a phase diagram plot according
3333
    to the user specified ordering.
3334
    Ordering should be given as ['Up', 'Left', 'Right'], where Up,
3335
    Left and Right are the names of the entries in the upper, left and right
3336
    corners of the triangle respectively.
3337

3338
    Args:
3339
        lines: list of list of coordinates for lines in the PD.
3340
        stable_entries: {coordinate : entry} for each stable node in the
3341
            phase diagram. (Each coordinate can only have one stable phase)
3342
        unstable_entries: {entry: coordinates} for all unstable nodes in the
3343
            phase diagram.
3344
        ordering: Ordering of the phase diagram, given as a list ['Up',
3345
            'Left','Right']
3346

3347
    Returns:
3348
        (newlines, newstable_entries, newunstable_entries):
3349
        - newlines is a list of list of coordinates for lines in the PD.
3350
        - newstable_entries is a {coordinate : entry} for each stable node
3351
        in the phase diagram. (Each coordinate can only have one
3352
        stable phase)
3353
        - newunstable_entries is a {entry: coordinates} for all unstable
3354
        nodes in the phase diagram.
3355
    """
3356
    yup = -1000.0
×
3357
    xleft = 1000.0
×
3358
    xright = -1000.0
×
3359

3360
    for coord in stable_entries:
×
3361
        if coord[0] > xright:
×
3362
            xright = coord[0]
×
3363
            nameright = stable_entries[coord].name
×
3364
        if coord[0] < xleft:
×
3365
            xleft = coord[0]
×
3366
            nameleft = stable_entries[coord].name
×
3367
        if coord[1] > yup:
×
3368
            yup = coord[1]
×
3369
            nameup = stable_entries[coord].name
×
3370

3371
    if (nameup not in ordering) or (nameright not in ordering) or (nameleft not in ordering):
×
3372
        raise ValueError(
×
3373
            "Error in ordering_phase_diagram :\n"
3374
            f"{nameup!r}, {nameleft!r} and {nameright!r} should be in ordering : {ordering}"
3375
        )
3376

3377
    cc = np.array([0.5, np.sqrt(3.0) / 6.0], np.float_)
×
3378

3379
    if nameup == ordering[0]:
×
3380
        if nameleft == ordering[1]:
×
3381
            # The coordinates were already in the user ordering
3382
            return lines, stable_entries, unstable_entries
×
3383

3384
        newlines = [[np.array(1.0 - x), y] for x, y in lines]
×
3385
        newstable_entries = {(1.0 - c[0], c[1]): entry for c, entry in stable_entries.items()}
×
3386
        newunstable_entries = {entry: (1.0 - c[0], c[1]) for entry, c in unstable_entries.items()}
×
3387
        return newlines, newstable_entries, newunstable_entries
×
3388
    if nameup == ordering[1]:
×
3389
        if nameleft == ordering[2]:
×
3390
            c120 = np.cos(2.0 * np.pi / 3.0)
×
3391
            s120 = np.sin(2.0 * np.pi / 3.0)
×
3392
            newlines = []
×
3393
            for x, y in lines:
×
3394
                newx = np.zeros_like(x)
×
3395
                newy = np.zeros_like(y)
×
3396
                for ii, xx in enumerate(x):
×
3397
                    newx[ii] = c120 * (xx - cc[0]) - s120 * (y[ii] - cc[1]) + cc[0]
×
3398
                    newy[ii] = s120 * (xx - cc[0]) + c120 * (y[ii] - cc[1]) + cc[1]
×
3399
                newlines.append([newx, newy])
×
3400
            newstable_entries = {
×
3401
                (
3402
                    c120 * (c[0] - cc[0]) - s120 * (c[1] - cc[1]) + cc[0],
3403
                    s120 * (c[0] - cc[0]) + c120 * (c[1] - cc[1]) + cc[1],
3404
                ): entry
3405
                for c, entry in stable_entries.items()
3406
            }
3407
            newunstable_entries = {
×
3408
                entry: (
3409
                    c120 * (c[0] - cc[0]) - s120 * (c[1] - cc[1]) + cc[0],
3410
                    s120 * (c[0] - cc[0]) + c120 * (c[1] - cc[1]) + cc[1],
3411
                )
3412
                for entry, c in unstable_entries.items()
3413
            }
3414
            return newlines, newstable_entries, newunstable_entries
×
3415
        c120 = np.cos(2.0 * np.pi / 3.0)
×
3416
        s120 = np.sin(2.0 * np.pi / 3.0)
×
3417
        newlines = []
×
3418
        for x, y in lines:
×
3419
            newx = np.zeros_like(x)
×
3420
            newy = np.zeros_like(y)
×
3421
            for ii, xx in enumerate(x):
×
3422
                newx[ii] = -c120 * (xx - 1.0) - s120 * y[ii] + 1.0
×
3423
                newy[ii] = -s120 * (xx - 1.0) + c120 * y[ii]
×
3424
            newlines.append([newx, newy])
×
3425
        newstable_entries = {
×
3426
            (
3427
                -c120 * (c[0] - 1.0) - s120 * c[1] + 1.0,
3428
                -s120 * (c[0] - 1.0) + c120 * c[1],
3429
            ): entry
3430
            for c, entry in stable_entries.items()
3431
        }
3432
        newunstable_entries = {
×
3433
            entry: (
3434
                -c120 * (c[0] - 1.0) - s120 * c[1] + 1.0,
3435
                -s120 * (c[0] - 1.0) + c120 * c[1],
3436
            )
3437
            for entry, c in unstable_entries.items()
3438
        }
3439
        return newlines, newstable_entries, newunstable_entries
×
3440
    if nameup == ordering[2]:
×
3441
        if nameleft == ordering[0]:
×
3442
            c240 = np.cos(4.0 * np.pi / 3.0)
×
3443
            s240 = np.sin(4.0 * np.pi / 3.0)
×
3444
            newlines = []
×
3445
            for x, y in lines:
×
3446
                newx = np.zeros_like(x)
×
3447
                newy = np.zeros_like(y)
×
3448
                for ii, xx in enumerate(x):
×
3449
                    newx[ii] = c240 * (xx - cc[0]) - s240 * (y[ii] - cc[1]) + cc[0]
×
3450
                    newy[ii] = s240 * (xx - cc[0]) + c240 * (y[ii] - cc[1]) + cc[1]
×
3451
                newlines.append([newx, newy])
×
3452
            newstable_entries = {
×
3453
                (
3454
                    c240 * (c[0] - cc[0]) - s240 * (c[1] - cc[1]) + cc[0],
3455
                    s240 * (c[0] - cc[0]) + c240 * (c[1] - cc[1]) + cc[1],
3456
                ): entry
3457
                for c, entry in stable_entries.items()
3458
            }
3459
            newunstable_entries = {
×
3460
                entry: (
3461
                    c240 * (c[0] - cc[0]) - s240 * (c[1] - cc[1]) + cc[0],
3462
                    s240 * (c[0] - cc[0]) + c240 * (c[1] - cc[1]) + cc[1],
3463
                )
3464
                for entry, c in unstable_entries.items()
3465
            }
3466
            return newlines, newstable_entries, newunstable_entries
×
3467
        c240 = np.cos(4.0 * np.pi / 3.0)
×
3468
        s240 = np.sin(4.0 * np.pi / 3.0)
×
3469
        newlines = []
×
3470
        for x, y in lines:
×
3471
            newx = np.zeros_like(x)
×
3472
            newy = np.zeros_like(y)
×
3473
            for ii, xx in enumerate(x):
×
3474
                newx[ii] = -c240 * xx - s240 * y[ii]
×
3475
                newy[ii] = -s240 * xx + c240 * y[ii]
×
3476
            newlines.append([newx, newy])
×
3477
        newstable_entries = {
×
3478
            (-c240 * c[0] - s240 * c[1], -s240 * c[0] + c240 * c[1]): entry for c, entry in stable_entries.items()
3479
        }
3480
        newunstable_entries = {
×
3481
            entry: (-c240 * c[0] - s240 * c[1], -s240 * c[0] + c240 * c[1]) for entry, c in unstable_entries.items()
3482
        }
3483
        return newlines, newstable_entries, newunstable_entries
×
3484
    raise ValueError("Invalid ordering.")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc