• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

45.87
/pymatgen/transformations/advanced_transformations.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module implements more advanced transformations.
6
"""
7

8
from __future__ import annotations
1✔
9

10
import logging
1✔
11
import math
1✔
12
import warnings
1✔
13
from fractions import Fraction
1✔
14
from itertools import groupby, product
1✔
15
from math import gcd
1✔
16
from string import ascii_lowercase
1✔
17
from typing import Callable, Iterable
1✔
18

19
import numpy as np
1✔
20
import tqdm
1✔
21
from monty.dev import requires
1✔
22
from monty.fractions import lcm
1✔
23
from monty.json import MSONable
1✔
24

25
from pymatgen.analysis.adsorption import AdsorbateSiteFinder
1✔
26
from pymatgen.analysis.bond_valence import BVAnalyzer
1✔
27
from pymatgen.analysis.energy_models import SymmetryModel
1✔
28
from pymatgen.analysis.ewald import EwaldSummation
1✔
29
from pymatgen.analysis.gb.grain import GrainBoundaryGenerator
1✔
30
from pymatgen.analysis.local_env import MinimumDistanceNN
1✔
31
from pymatgen.analysis.structure_matcher import SpinComparator, StructureMatcher
1✔
32
from pymatgen.analysis.structure_prediction.substitution_probability import (
1✔
33
    SubstitutionPredictor,
34
)
35
from pymatgen.command_line.enumlib_caller import EnumError, EnumlibAdaptor
1✔
36
from pymatgen.command_line.mcsqs_caller import run_mcsqs
1✔
37
from pymatgen.core.periodic_table import DummySpecies, Element, Species, get_el_sp
1✔
38
from pymatgen.core.structure import Structure
1✔
39
from pymatgen.core.surface import SlabGenerator
1✔
40
from pymatgen.electronic_structure.core import Spin
1✔
41
from pymatgen.io.ase import AseAtomsAdaptor
1✔
42
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1✔
43
from pymatgen.transformations.standard_transformations import (
1✔
44
    OrderDisorderedStructureTransformation,
45
    SubstitutionTransformation,
46
    SupercellTransformation,
47
)
48
from pymatgen.transformations.transformation_abc import AbstractTransformation
1✔
49

50
try:
1✔
51
    import hiphive
1✔
52
except ImportError:
1✔
53
    hiphive = None
1✔
54

55
__author__ = "Shyue Ping Ong, Stephen Dacek, Anubhav Jain, Matthew Horton, Alex Ganose"
1✔
56

57
logger = logging.getLogger(__name__)
1✔
58

59

60
class ChargeBalanceTransformation(AbstractTransformation):
1✔
61
    """
62
    This is a transformation that disorders a structure to make it charge
63
    balanced, given an oxidation state-decorated structure.
64
    """
65

66
    def __init__(self, charge_balance_sp):
1✔
67
        """
68
        Args:
69
            charge_balance_sp: specie to add or remove. Currently only removal
70
                is supported
71
        """
72
        self.charge_balance_sp = str(charge_balance_sp)
1✔
73

74
    def apply_transformation(self, structure):
1✔
75
        """
76
        Applies the transformation.
77

78
        Args:
79
            structure: Input Structure
80

81
        Returns:
82
            Charge balanced structure.
83
        """
84
        charge = structure.charge
1✔
85
        specie = get_el_sp(self.charge_balance_sp)
1✔
86
        num_to_remove = charge / specie.oxi_state
1✔
87
        num_in_structure = structure.composition[specie]
1✔
88
        removal_fraction = num_to_remove / num_in_structure
1✔
89
        if removal_fraction < 0:
1✔
90
            raise ValueError("addition of specie not yet supported by ChargeBalanceTransformation")
×
91
        trans = SubstitutionTransformation({self.charge_balance_sp: {self.charge_balance_sp: 1 - removal_fraction}})
1✔
92
        return trans.apply_transformation(structure)
1✔
93

94
    def __str__(self):
1✔
95
        return f"Charge Balance Transformation : Species to remove = {self.charge_balance_sp}"
×
96

97
    def __repr__(self):
1✔
98
        return str(self)
×
99

100
    @property
1✔
101
    def inverse(self):
1✔
102
        """Returns: None"""
103
        return None
×
104

105
    @property
1✔
106
    def is_one_to_many(self) -> bool:
1✔
107
        """Returns: False"""
108
        return False
×
109

110

111
class SuperTransformation(AbstractTransformation):
1✔
112
    """
113
    This is a transformation that is inherently one-to-many. It is constructed
114
    from a list of transformations and returns one structure for each
115
    transformation. The primary use for this class is extending a transmuter
116
    object.
117
    """
118

119
    def __init__(self, transformations, nstructures_per_trans=1):
1✔
120
        """
121
        Args:
122
            transformations ([transformations]): List of transformations to apply
123
                to a structure. One transformation is applied to each output
124
                structure.
125
            nstructures_per_trans (int): If the transformations are one-to-many and,
126
                nstructures_per_trans structures from each transformation are
127
                added to the full list. Defaults to 1, i.e., only best structure.
128
        """
129
        self._transformations = transformations
1✔
130
        self.nstructures_per_trans = nstructures_per_trans
1✔
131

132
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
133
        """
134
        Applies the transformation.
135

136
        Args:
137
            structure: Input Structure
138
            return_ranked_list: Number of structures to return.
139

140
        Returns:
141
            Structures with all transformations applied.
142
        """
143
        if not return_ranked_list:
1✔
144
            raise ValueError("SuperTransformation has no single best structure output. Must use return_ranked_list")
×
145
        structures = []
1✔
146
        for t in self._transformations:
1✔
147
            if t.is_one_to_many:
1✔
148
                for d in t.apply_transformation(structure, return_ranked_list=self.nstructures_per_trans):
×
149
                    d["transformation"] = t
×
150
                    structures.append(d)
×
151
            else:
152
                structures.append({"transformation": t, "structure": t.apply_transformation(structure)})
1✔
153
        return structures
1✔
154

155
    def __str__(self):
1✔
156
        return f"Super Transformation : Transformations = {' '.join(map(str, self._transformations))}"
×
157

158
    def __repr__(self):
1✔
159
        return str(self)
×
160

161
    @property
1✔
162
    def inverse(self):
1✔
163
        """Returns: None"""
164
        return None
×
165

166
    @property
1✔
167
    def is_one_to_many(self) -> bool:
1✔
168
        """Returns: True"""
169
        return True
1✔
170

171

172
class MultipleSubstitutionTransformation:
1✔
173
    """
174
    Performs multiple substitutions on a structure. For example, can do a
175
    fractional replacement of Ge in LiGePS with a list of species, creating one
176
    structure for each substitution. Ordering is done using a dummy element so
177
    only one ordering must be done per substitution oxidation state. Charge
178
    balancing of the structure is optionally performed.
179

180
    .. note::
181
        There are no checks to make sure that removal fractions are possible
182
        and rounding may occur. Currently charge balancing only works for
183
        removal of species.
184
    """
185

186
    def __init__(
1✔
187
        self,
188
        sp_to_replace,
189
        r_fraction,
190
        substitution_dict,
191
        charge_balance_species=None,
192
        order=True,
193
    ):
194
        """
195
        Performs multiple fractional substitutions on a transmuter.
196

197
        Args:
198
            sp_to_replace: species to be replaced
199
            r_fraction: fraction of that specie to replace
200
            substitution_dict: dictionary of the format
201
                {2: ["Mg", "Ti", "V", "As", "Cr", "Ta", "N", "Nb"],
202
                3: ["Ru", "Fe", "Co", "Ce", "As", "Cr", "Ta", "N", "Nb"],
203
                4: ["Ru", "V", "Cr", "Ta", "N", "Nb"],
204
                5: ["Ru", "W", "Mn"]
205
                }
206
                The number is the charge used for each of the list of elements
207
                (an element can be present in multiple lists)
208
            charge_balance_species: If specified, will balance the charge on
209
                the structure using that specie.
210
            order: Whether to order the structures.
211
        """
212
        self.sp_to_replace = sp_to_replace
1✔
213
        self.r_fraction = r_fraction
1✔
214
        self.substitution_dict = substitution_dict
1✔
215
        self.charge_balance_species = charge_balance_species
1✔
216
        self.order = order
1✔
217

218
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
219
        """
220
        Applies the transformation.
221

222
        Args:
223
            structure: Input Structure
224
            return_ranked_list: Number of structures to return.
225

226
        Returns:
227
            Structures with all substitutions applied.
228
        """
229
        if not return_ranked_list:
1✔
230
            raise ValueError(
×
231
                "MultipleSubstitutionTransformation has no single"
232
                " best structure output. Must use"
233
                " return_ranked_list."
234
            )
235
        outputs = []
1✔
236
        for charge, el_list in self.substitution_dict.items():
1✔
237
            if charge > 0:
1✔
238
                sign = "+"
1✔
239
            else:
240
                sign = "-"
×
241
            dummy_sp = f"X{charge}{sign}"
1✔
242
            mapping = {
1✔
243
                self.sp_to_replace: {
244
                    self.sp_to_replace: 1 - self.r_fraction,
245
                    dummy_sp: self.r_fraction,
246
                }
247
            }
248
            trans = SubstitutionTransformation(mapping)  # type: ignore
1✔
249
            dummy_structure = trans.apply_transformation(structure)
1✔
250
            if self.charge_balance_species is not None:
1✔
251
                cbt = ChargeBalanceTransformation(self.charge_balance_species)
×
252
                dummy_structure = cbt.apply_transformation(dummy_structure)
×
253
            if self.order:
1✔
254
                trans = OrderDisorderedStructureTransformation()
1✔
255
                dummy_structure = trans.apply_transformation(dummy_structure)
1✔
256

257
            for el in el_list:
1✔
258
                if charge > 0:
1✔
259
                    sign = "+"
1✔
260
                else:
261
                    sign = "-"
×
262
                st = SubstitutionTransformation({f"X{charge}+": f"{el}{charge}{sign}"})
1✔
263
                new_structure = st.apply_transformation(dummy_structure)
1✔
264
                outputs.append({"structure": new_structure})
1✔
265
        return outputs
1✔
266

267
    def __str__(self):
1✔
268
        return f"Multiple Substitution Transformation : Substitution on {self.sp_to_replace}"
×
269

270
    def __repr__(self):
1✔
271
        return str(self)
×
272

273
    @property
1✔
274
    def inverse(self):
1✔
275
        """Returns: None"""
276
        return None
×
277

278
    @property
1✔
279
    def is_one_to_many(self) -> bool:
1✔
280
        """Returns: True"""
281
        return True
×
282

283

284
class EnumerateStructureTransformation(AbstractTransformation):
1✔
285
    """
286
    Order a disordered structure using enumlib. For complete orderings, this
287
    generally produces fewer structures that the OrderDisorderedStructure
288
    transformation, and at a much faster speed.
289
    """
290

291
    def __init__(
1✔
292
        self,
293
        min_cell_size: int = 1,
294
        max_cell_size: int = 1,
295
        symm_prec: float = 0.1,
296
        refine_structure: bool = False,
297
        enum_precision_parameter: float = 0.001,
298
        check_ordered_symmetry: bool = True,
299
        max_disordered_sites=None,
300
        sort_criteria: str | Callable = "ewald",
301
        timeout=None,
302
    ):
303
        """
304
        Args:
305
            min_cell_size:
306
                The minimum cell size wanted. Must be an int. Defaults to 1.
307
            max_cell_size:
308
                The maximum cell size wanted. Must be an int. Defaults to 1.
309
            symm_prec:
310
                Tolerance to use for symmetry.
311
            refine_structure:
312
                This parameter has the same meaning as in enumlib_caller.
313
                If you are starting from a structure that has been relaxed via
314
                some electronic structure code, it is usually much better to
315
                start with symmetry determination and then obtain a refined
316
                structure. The refined structure have cell parameters and
317
                atomic positions shifted to the expected symmetry positions,
318
                which makes it much less sensitive precision issues in enumlib.
319
                If you are already starting from an experimental cif, refinement
320
                should have already been done and it is not necessary. Defaults
321
                to False.
322
            enum_precision_parameter (float): Finite precision parameter for
323
                enumlib. Default of 0.001 is usually ok, but you might need to
324
                tweak it for certain cells.
325
            check_ordered_symmetry (bool): Whether to check the symmetry of
326
                the ordered sites. If the symmetry of the ordered sites is
327
                lower, the lowest symmetry ordered sites is included in the
328
                enumeration. This is important if the ordered sites break
329
                symmetry in a way that is important getting possible
330
                structures. But sometimes including ordered sites
331
                slows down enumeration to the point that it cannot be
332
                completed. Switch to False in those cases. Defaults to True.
333
            max_disordered_sites (int):
334
                An alternate parameter to max_cell size. Will sequentially try
335
                larger and larger cell sizes until (i) getting a result or (ii)
336
                the number of disordered sites in the cell exceeds
337
                max_disordered_sites. Must set max_cell_size to None when using
338
                this parameter.
339
            sort_criteria (str or callable): Sort by Ewald energy ("ewald", must have oxidation states and slow) or
340
                M3GNet relaxed energy ("m3gnet_relax", which is the most accurate but most expensive and provides
341
                pre-relaxed structures - needs m3gnet package installed) or by M3GNet static energy ("m3gnet_static")
342
                or by number of sites ("nsites", the fastest, the default). The expense of m3gnet_relax or m3gnet_static
343
                can be worth it if it significantly reduces the number of structures to be considered. m3gnet_relax
344
                speeds up the subsequent DFT calculations. Alternatively, a callable can be supplied that returns a
345
                (Structure, energy) tuple.
346
            timeout (float): timeout in minutes to pass to EnumlibAdaptor
347
        """
348
        self.symm_prec = symm_prec
×
349
        self.min_cell_size = min_cell_size
×
350
        self.max_cell_size = max_cell_size
×
351
        self.refine_structure = refine_structure
×
352
        self.enum_precision_parameter = enum_precision_parameter
×
353
        self.check_ordered_symmetry = check_ordered_symmetry
×
354
        self.max_disordered_sites = max_disordered_sites
×
355
        self.sort_criteria = sort_criteria
×
356
        self.timeout = timeout
×
357

358
        if max_cell_size and max_disordered_sites:
×
359
            raise ValueError("Cannot set both max_cell_size and max_disordered_sites!")
×
360

361
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
362
        """
363
        Returns either a single ordered structure or a sequence of all ordered
364
        structures.
365

366
        Args:
367
            structure: Structure to order.
368
            return_ranked_list (bool): Whether or not multiple structures are
369
                returned. If return_ranked_list is a number, that number of
370
                structures is returned.
371

372
        Returns:
373
            Depending on returned_ranked list, either a transformed structure
374
            or a list of dictionaries, where each dictionary is of the form
375
            {"structure" = .... , "other_arguments"}
376

377
            The list of ordered structures is ranked by Ewald energy / atom, if
378
            the input structure is an oxidation state decorated structure.
379
            Otherwise, it is ranked by number of sites, with smallest number of
380
            sites first.
381
        """
382
        try:
×
383
            num_to_return = int(return_ranked_list)
×
384
        except ValueError:
×
385
            num_to_return = 1
×
386

387
        if self.refine_structure:
×
388
            finder = SpacegroupAnalyzer(structure, self.symm_prec)
×
389
            structure = finder.get_refined_structure()
×
390

391
        contains_oxidation_state = all(
×
392
            hasattr(sp, "oxi_state") and sp.oxi_state != 0 for sp in structure.composition.elements
393
        )
394

395
        structures = None
×
396

397
        if structure.is_ordered:
×
398
            warnings.warn(
×
399
                f"Enumeration skipped for structure with composition {structure.composition} because it is ordered"
400
            )
401
            structures = [structure.copy()]
×
402

403
        if self.max_disordered_sites:
×
404
            ndisordered = sum(1 for site in structure if not site.is_ordered)
×
405
            if ndisordered > self.max_disordered_sites:
×
406
                raise ValueError(f"Too many disordered sites! ({ndisordered} > {self.max_disordered_sites})")
×
407
            max_cell_sizes: Iterable[int] = range(
×
408
                self.min_cell_size,
409
                int(math.floor(self.max_disordered_sites / ndisordered)) + 1,
410
            )
411
        else:
412
            max_cell_sizes = [self.max_cell_size]
×
413

414
        for max_cell_size in max_cell_sizes:
×
415
            adaptor = EnumlibAdaptor(
×
416
                structure,
417
                min_cell_size=self.min_cell_size,
418
                max_cell_size=max_cell_size,
419
                symm_prec=self.symm_prec,
420
                refine_structure=False,
421
                enum_precision_parameter=self.enum_precision_parameter,
422
                check_ordered_symmetry=self.check_ordered_symmetry,
423
                timeout=self.timeout,
424
            )
425
            try:
×
426
                adaptor.run()
×
427
                structures = adaptor.structures
×
428
                if structures:
×
429
                    break
×
430
            except EnumError:
×
431
                warnings.warn(f"Unable to enumerate for {max_cell_size = }")
×
432

433
        if structures is None:
×
434
            raise ValueError("Unable to enumerate")
×
435

436
        original_latt = structure.lattice
×
437
        inv_latt = np.linalg.inv(original_latt.matrix)
×
438
        ewald_matrices = {}
×
439
        all_structures = []
×
440
        m3gnet_model = None
×
441
        for s in tqdm.tqdm(structures):
×
442
            new_latt = s.lattice
×
443
            transformation = np.dot(new_latt.matrix, inv_latt)
×
444
            transformation = tuple(tuple(int(round(cell)) for cell in row) for row in transformation)
×
445
            if callable(self.sort_criteria):
×
446
                s, energy = self.sort_criteria(s)
×
447
                all_structures.append(
×
448
                    {
449
                        "num_sites": len(s),
450
                        "energy": energy,
451
                        "structure": s,
452
                    }
453
                )
454
            elif contains_oxidation_state and self.sort_criteria == "ewald":
×
455
                if transformation not in ewald_matrices:
×
456
                    s_supercell = structure * transformation
×
457
                    ewald = EwaldSummation(s_supercell)
×
458
                    ewald_matrices[transformation] = ewald
×
459
                else:
460
                    ewald = ewald_matrices[transformation]
×
461
                energy = ewald.compute_sub_structure(s)
×
462
                all_structures.append({"num_sites": len(s), "energy": energy, "structure": s})
×
463
            elif self.sort_criteria.startswith("m3gnet"):
×
464
                if self.sort_criteria == "m3gnet_relax":
×
465
                    if m3gnet_model is None:
×
466
                        from m3gnet.models import Relaxer
×
467

468
                        m3gnet_model = Relaxer()
×
469
                    relax_results = m3gnet_model.relax(s)
×
470
                    energy = float(relax_results["trajectory"].energies[-1])
×
471
                    s = relax_results["final_structure"]
×
472
                else:
473
                    if m3gnet_model is None:
×
474
                        from m3gnet.models import M3GNet, M3GNetCalculator, Potential
×
475

476
                        potential = Potential(M3GNet.load())
×
477
                        m3gnet_model = M3GNetCalculator(potential=potential, stress_weight=0.01)
×
478
                    from pymatgen.io.ase import AseAtomsAdaptor
×
479

480
                    atoms = AseAtomsAdaptor().get_atoms(s)
×
481
                    m3gnet_model.calculate(atoms)
×
482
                    energy = float(m3gnet_model.results["energy"])
×
483

484
                all_structures.append(
×
485
                    {
486
                        "num_sites": len(s),
487
                        "energy": energy,
488
                        "structure": s,
489
                    }
490
                )
491
            else:
492
                all_structures.append({"num_sites": len(s), "structure": s})
×
493

494
        def sort_func(s):
×
495
            return (
×
496
                s["energy"] / s["num_sites"]
497
                if self.sort_criteria == "m3gnet" or (contains_oxidation_state and self.sort_criteria == "ewald")
498
                else s["num_sites"]
499
            )
500

501
        self._all_structures = sorted(all_structures, key=sort_func)
×
502

503
        if return_ranked_list:
×
504
            return self._all_structures[0:num_to_return]
×
505
        return self._all_structures[0]["structure"]
×
506

507
    def __str__(self):
1✔
508
        return "EnumerateStructureTransformation"
×
509

510
    def __repr__(self):
1✔
511
        return str(self)
×
512

513
    @property
1✔
514
    def inverse(self):
1✔
515
        """Returns: None"""
516
        return None
×
517

518
    @property
1✔
519
    def is_one_to_many(self) -> bool:
1✔
520
        """Returns: True"""
521
        return True
×
522

523

524
class SubstitutionPredictorTransformation(AbstractTransformation):
1✔
525
    """
526
    This transformation takes a structure and uses the structure
527
    prediction module to find likely site substitutions.
528
    """
529

530
    def __init__(self, threshold=1e-2, scale_volumes=True, **kwargs):
1✔
531
        """
532
        Args:
533
            threshold: Threshold for substitution.
534
            scale_volumes: Whether to scale volumes after substitution.
535
            **kwargs: Args for SubstitutionProbability class lambda_table, alpha
536
        """
537
        self.kwargs = kwargs
1✔
538
        self.threshold = threshold
1✔
539
        self.scale_volumes = scale_volumes
1✔
540
        self._substitutor = SubstitutionPredictor(threshold=threshold, **kwargs)
1✔
541

542
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
543
        """
544
        Applies the transformation.
545

546
        Args:
547
            structure: Input Structure
548
            return_ranked_list: Number of structures to return.
549

550
        Returns:
551
            Predicted Structures.
552
        """
553
        if not return_ranked_list:
1✔
554
            raise ValueError("SubstitutionPredictorTransformation doesn't support returning 1 structure")
×
555

556
        preds = self._substitutor.composition_prediction(structure.composition, to_this_composition=False)
1✔
557
        preds.sort(key=lambda x: x["probability"], reverse=True)
1✔
558

559
        outputs = []
1✔
560
        for pred in preds:
1✔
561
            st = SubstitutionTransformation(pred["substitutions"])
1✔
562
            output = {
1✔
563
                "structure": st.apply_transformation(structure),
564
                "probability": pred["probability"],
565
                "threshold": self.threshold,
566
                "substitutions": {},
567
            }
568

569
            # dictionary keys have to be converted to strings for JSON
570
            for key, value in pred["substitutions"].items():
1✔
571
                output["substitutions"][str(key)] = str(value)
1✔
572
            outputs.append(output)
1✔
573
        return outputs
1✔
574

575
    def __str__(self):
1✔
576
        return "SubstitutionPredictorTransformation"
×
577

578
    def __repr__(self):
1✔
579
        return str(self)
×
580

581
    @property
1✔
582
    def inverse(self):
1✔
583
        """Returns: None"""
584
        return None
×
585

586
    @property
1✔
587
    def is_one_to_many(self) -> bool:
1✔
588
        """Returns: True"""
589
        return True
×
590

591

592
class MagOrderParameterConstraint(MSONable):
1✔
593
    """
594
    This class can be used to supply MagOrderingTransformation
595
    to just a specific subset of species or sites that satisfy the
596
    provided constraints. This can be useful for setting an order
597
    parameters for, for example, ferrimagnetic structures which
598
    might order on certain motifs, with the global order parameter
599
    dependent on how many sites satisfy that motif.
600
    """
601

602
    def __init__(
1✔
603
        self,
604
        order_parameter,
605
        species_constraints=None,
606
        site_constraint_name=None,
607
        site_constraints=None,
608
    ):
609
        """
610
        :param order_parameter (float): any number from 0.0 to 1.0,
611
            typically 0.5 (antiferromagnetic) or 1.0 (ferromagnetic)
612
        :param species_constraint (list): str or list of strings
613
            of Species symbols that the constraint should apply to
614
        :param site_constraint_name (str): name of the site property
615
            that the constraint should apply to, e.g. "coordination_no"
616
        :param site_constraints (list): list of values of the site
617
            property that the constraints should apply to
618
        """
619
        # validation
620
        if site_constraints and site_constraints != [None] and not site_constraint_name:
×
621
            raise ValueError("Specify the name of the site constraint.")
×
622
        if not site_constraints and site_constraint_name:
×
623
            raise ValueError("Please specify some site constraints.")
×
624
        if not isinstance(species_constraints, list):
×
625
            species_constraints = [species_constraints]
×
626
        if not isinstance(site_constraints, list):
×
627
            site_constraints = [site_constraints]
×
628

629
        if order_parameter > 1 or order_parameter < 0:
×
630
            raise ValueError("Order parameter must lie between 0 and 1")
×
631
        if order_parameter != 0.5:
×
632
            warnings.warn(
×
633
                "Use care when using a non-standard order parameter, "
634
                "though it can be useful in some cases it can also "
635
                "lead to unintended behavior. Consult documentation."
636
            )
637

638
        self.order_parameter = order_parameter
×
639
        self.species_constraints = species_constraints
×
640
        self.site_constraint_name = site_constraint_name
×
641
        self.site_constraints = site_constraints
×
642

643
    def satisfies_constraint(self, site):
1✔
644
        """
645
        Checks if a periodic site satisfies the constraint.
646
        """
647
        if not site.is_ordered:
×
648
            return False
×
649

650
        satisfies_constraints = self.species_constraints and str(site.specie) in self.species_constraints
×
651

652
        if self.site_constraint_name and self.site_constraint_name in site.properties:
×
653
            prop = site.properties[self.site_constraint_name]
×
654
            satisfies_constraints = prop in self.site_constraints
×
655

656
        return satisfies_constraints
×
657

658

659
class MagOrderingTransformation(AbstractTransformation):
1✔
660
    """
661
    This transformation takes a structure and returns a list of collinear
662
    magnetic orderings. For disordered structures, make an ordered
663
    approximation first.
664
    """
665

666
    def __init__(self, mag_species_spin, order_parameter=0.5, energy_model=None, **kwargs):
1✔
667
        """
668
        :param mag_species_spin: A mapping of elements/species to their
669
            spin magnitudes, e.g. {"Fe3+": 5, "Mn3+": 4}
670
        :param order_parameter (float or list): if float, a specifies a
671
            global order parameter and can take values from 0.0 to 1.0
672
            (e.g. 0.5 for antiferromagnetic or 1.0 for ferromagnetic), if
673
            list has to be a list of
674
            :class:`pymatgen.transformations.advanced_transformations.MagOrderParameterConstraint`
675
            to specify more complicated orderings, see documentation for
676
            MagOrderParameterConstraint more details on usage
677
        :param energy_model: Energy model to rank the returned structures,
678
            see :mod: `pymatgen.analysis.energy_models` for more information (note
679
            that this is not necessarily a physical energy). By default, returned
680
            structures use SymmetryModel() which ranks structures from most
681
            symmetric to least.
682
        :param kwargs: Additional kwargs that are passed to
683
        :class:`EnumerateStructureTransformation` such as min_cell_size etc.
684
        """
685
        # checking for sensible order_parameter values
686
        if isinstance(order_parameter, float):
×
687
            # convert to constraint format
688
            order_parameter = [
×
689
                MagOrderParameterConstraint(
690
                    order_parameter=order_parameter,
691
                    species_constraints=list(mag_species_spin),
692
                )
693
            ]
694
        elif isinstance(order_parameter, list):
×
695
            ops = [isinstance(item, MagOrderParameterConstraint) for item in order_parameter]
×
696
            if not any(ops):
×
697
                raise ValueError("Order parameter not correctly defined.")
×
698
        else:
699
            raise ValueError("Order parameter not correctly defined.")
×
700

701
        self.mag_species_spin = mag_species_spin
×
702
        # store order parameter constraints as dicts to save implementing
703
        # to/from dict methods for MSONable compatibility
704
        self.order_parameter = [op.as_dict() for op in order_parameter]
×
705
        self.energy_model = energy_model or SymmetryModel()
×
706
        self.enum_kwargs = kwargs
×
707

708
    @staticmethod
1✔
709
    def determine_min_cell(disordered_structure):
1✔
710
        """
711
        Determine the smallest supercell that is able to enumerate
712
        the provided structure with the given order parameter
713
        """
714

715
        def lcm(n1, n2):
×
716
            """
717
            Find least common multiple of two numbers
718
            """
719
            return n1 * n2 / gcd(n1, n2)
×
720

721
        # assumes all order parameters for a given species are the same
722
        mag_species_order_parameter = {}
×
723
        mag_species_occurrences = {}
×
724
        for site in disordered_structure:
×
725
            if not site.is_ordered:
×
726
                # this very hacky bit of code only works because we know
727
                # that on disordered sites in this class, all species are the same
728
                # but have different spins, and this is comma-delimited
729
                sp = str(list(site.species)[0]).split(",", maxsplit=1)[0]
×
730
                if sp in mag_species_order_parameter:
×
731
                    mag_species_occurrences[sp] += 1
×
732
                else:
733
                    op = max(site.species.values())
×
734
                    mag_species_order_parameter[sp] = op
×
735
                    mag_species_occurrences[sp] = 1
×
736

737
        smallest_n = []
×
738

739
        for sp, order_parameter in mag_species_order_parameter.items():
×
740
            denom = Fraction(order_parameter).limit_denominator(100).denominator
×
741
            num_atom_per_specie = mag_species_occurrences[sp]
×
742
            n_gcd = gcd(denom, num_atom_per_specie)
×
743
            smallest_n.append(lcm(int(n_gcd), denom) / n_gcd)
×
744

745
        return max(smallest_n)
×
746

747
    @staticmethod
1✔
748
    def _add_dummy_species(structure, order_parameters):
1✔
749
        """
750
        :param structure: ordered Structure
751
        :param order_parameters: list of MagOrderParameterConstraints
752
        :return: A structure decorated with disordered
753
            DummySpecies on which to perform the enumeration.
754
            Note that the DummySpecies are super-imposed on
755
            to the original sites, to make it easier to
756
            retrieve the original site after enumeration is
757
            performed (this approach is preferred over a simple
758
            mapping since multiple species may have the same
759
            DummySpecies, depending on the constraints specified).
760
            This approach can also preserve site properties even after
761
            enumeration.
762
        """
763
        dummy_struct = structure.copy()
×
764

765
        def generate_dummy_specie():
×
766
            """
767
            Generator which returns DummySpecies symbols Mma, Mmb, etc.
768
            """
769
            subscript_length = 1
×
770
            while True:
771
                for subscript in product(ascii_lowercase, repeat=subscript_length):
×
772
                    yield "Mm" + "".join(subscript)
×
773
                subscript_length += 1
×
774

775
        dummy_species_gen = generate_dummy_specie()
×
776

777
        # one dummy species for each order parameter constraint
778
        dummy_species_symbols = [next(dummy_species_gen) for i in range(len(order_parameters))]
×
779
        dummy_species = [
×
780
            {
781
                DummySpecies(symbol, properties={"spin": Spin.up}): constraint.order_parameter,
782
                DummySpecies(symbol, properties={"spin": Spin.down}): 1 - constraint.order_parameter,
783
            }
784
            for symbol, constraint in zip(dummy_species_symbols, order_parameters)
785
        ]
786

787
        for site in dummy_struct:
×
788
            satisfies_constraints = [c.satisfies_constraint(site) for c in order_parameters]
×
789
            if satisfies_constraints.count(True) > 1:
×
790
                # site should either not satisfy any constraints, or satisfy
791
                # one constraint
792
                raise ValueError(f"Order parameter constraints conflict for site: {site.specie}, {site.properties}")
×
793
            if any(satisfies_constraints):
×
794
                dummy_specie_idx = satisfies_constraints.index(True)
×
795
                dummy_struct.append(dummy_species[dummy_specie_idx], site.coords, site.lattice)
×
796

797
        return dummy_struct
×
798

799
    @staticmethod
1✔
800
    def _remove_dummy_species(structure):
1✔
801
        """
802
        :return: Structure with dummy species removed, but
803
        their corresponding spin properties merged with the
804
        original sites. Used after performing enumeration.
805
        """
806
        if not structure.is_ordered:
×
807
            raise Exception("Something went wrong with enumeration.")
×
808

809
        sites_to_remove = []
×
810
        logger.debug(f"Dummy species structure:\n{structure}")
×
811
        for idx, site in enumerate(structure):
×
812
            if isinstance(site.specie, DummySpecies):
×
813
                sites_to_remove.append(idx)
×
814
                spin = site.specie._properties.get("spin", None)
×
815
                neighbors = structure.get_neighbors(
×
816
                    site,
817
                    0.05,  # arbitrary threshold, needs to be << any bond length
818
                    # but >> floating point precision issues
819
                    include_index=True,
820
                )
821
                if len(neighbors) != 1:
×
822
                    raise Exception(f"This shouldn't happen, found neighbors: {neighbors}")
×
823
                orig_site_idx = neighbors[0][2]
×
824
                orig_specie = structure[orig_site_idx].specie
×
825
                new_specie = Species(
×
826
                    orig_specie.symbol,
827
                    getattr(orig_specie, "oxi_state", None),
828
                    properties={"spin": spin},
829
                )
830
                structure.replace(
×
831
                    orig_site_idx,
832
                    new_specie,
833
                    properties=structure[orig_site_idx].properties,
834
                )
835
        structure.remove_sites(sites_to_remove)
×
836
        logger.debug(f"Structure with dummy species removed:\n{structure}")
×
837
        return structure
×
838

839
    def _add_spin_magnitudes(self, structure):
1✔
840
        """
841
        Replaces Spin.up/Spin.down with spin magnitudes specified
842
        by mag_species_spin.
843
        :param structure:
844
        :return:
845
        """
846
        for idx, site in enumerate(structure):
×
847
            if getattr(site.specie, "_properties", None):
×
848
                spin = site.specie._properties.get("spin", None)
×
849
                sign = int(spin) if spin else 0
×
850
                if spin:
×
851
                    new_properties = site.specie._properties.copy()
×
852
                    # this very hacky bit of code only works because we know
853
                    # that on disordered sites in this class, all species are the same
854
                    # but have different spins, and this is comma-delimited
855
                    sp = str(site.specie).split(",", maxsplit=1)[0]
×
856
                    new_properties.update({"spin": sign * self.mag_species_spin.get(sp, 0)})
×
857
                    new_specie = Species(
×
858
                        site.specie.symbol,
859
                        getattr(site.specie, "oxi_state", None),
860
                        new_properties,
861
                    )
862
                    structure.replace(idx, new_specie, properties=site.properties)
×
863
        logger.debug(f"Structure with spin magnitudes:\n{structure}")
×
864
        return structure
×
865

866
    def apply_transformation(
1✔
867
        self, structure: Structure, return_ranked_list: bool | int = False
868
    ) -> Structure | list[Structure]:
869
        """Apply MagOrderTransformation to an input structure.
870

871
        Args:
872
            structure (Structure): Any ordered structure.
873
            return_ranked_list (bool, optional): As in other Transformations. Defaults to False.
874

875
        Raises:
876
            ValueError: On disordered structures.
877

878
        Returns:
879
            Structure | list[Structure]: Structure(s) after MagOrderTransformation.
880
        """
881
        if not structure.is_ordered:
×
882
            raise ValueError("Create an ordered approximation of your  input structure first.")
×
883

884
        # retrieve order parameters
885
        order_parameters = [MagOrderParameterConstraint.from_dict(op_dict) for op_dict in self.order_parameter]
×
886
        # add dummy species on which to perform enumeration
887
        structure = self._add_dummy_species(structure, order_parameters)
×
888

889
        # trivial case
890
        if structure.is_ordered:
×
891
            structure = self._remove_dummy_species(structure)
×
892
            return [structure] if return_ranked_list > 1 else structure
×
893

894
        enum_kwargs = self.enum_kwargs.copy()
×
895

896
        enum_kwargs["min_cell_size"] = max(int(self.determine_min_cell(structure)), enum_kwargs.get("min_cell_size", 1))
×
897

898
        if enum_kwargs.get("max_cell_size", None):
×
899
            if enum_kwargs["min_cell_size"] > enum_kwargs["max_cell_size"]:
×
900
                warnings.warn(
×
901
                    f"Specified max cell size ({enum_kwargs['max_cell_size']}) is "
902
                    "smaller than the minimum enumerable cell size "
903
                    f"({enum_kwargs['min_cell_size']}), changing max cell size to "
904
                    f"{enum_kwargs['min_cell_size']}"
905
                )
906
                enum_kwargs["max_cell_size"] = enum_kwargs["min_cell_size"]
×
907
        else:
908
            enum_kwargs["max_cell_size"] = enum_kwargs["min_cell_size"]
×
909

910
        t = EnumerateStructureTransformation(**enum_kwargs)
×
911

912
        alls = t.apply_transformation(structure, return_ranked_list=return_ranked_list)
×
913

914
        # handle the fact that EnumerateStructureTransformation can either
915
        # return a single Structure or a list
916
        if isinstance(alls, Structure):
×
917
            # remove dummy species and replace Spin.up or Spin.down
918
            # with spin magnitudes given in mag_species_spin arg
919
            alls = self._remove_dummy_species(alls)
×
920
            alls = self._add_spin_magnitudes(alls)
×
921
        else:
922
            for idx, _ in enumerate(alls):
×
923
                alls[idx]["structure"] = self._remove_dummy_species(alls[idx]["structure"])
×
924
                alls[idx]["structure"] = self._add_spin_magnitudes(alls[idx]["structure"])
×
925

926
        try:
×
927
            num_to_return = int(return_ranked_list)
×
928
        except ValueError:
×
929
            num_to_return = 1
×
930

931
        if num_to_return == 1 or not return_ranked_list:
×
932
            return alls[0]["structure"] if num_to_return else alls
×
933

934
        # remove duplicate structures and group according to energy model
935
        m = StructureMatcher(comparator=SpinComparator())
×
936

937
        def key(x):
×
938
            return SpacegroupAnalyzer(x, 0.1).get_space_group_number()
×
939

940
        out = []
×
941
        for _, g in groupby(sorted((d["structure"] for d in alls), key=key), key):
×
942
            g = list(g)  # type: ignore
×
943
            grouped = m.group_structures(g)
×
944
            out.extend([{"structure": g[0], "energy": self.energy_model.get_energy(g[0])} for g in grouped])
×
945

946
        self._all_structures = sorted(out, key=lambda d: d["energy"])
×
947

948
        return self._all_structures[0:num_to_return]  # type: ignore
×
949

950
    def __str__(self):
1✔
951
        return "MagOrderingTransformation"
×
952

953
    def __repr__(self):
1✔
954
        return str(self)
×
955

956
    @property
1✔
957
    def inverse(self):
1✔
958
        """Returns: None"""
959
        return None
×
960

961
    @property
1✔
962
    def is_one_to_many(self) -> bool:
1✔
963
        """Returns: True"""
964
        return True
×
965

966

967
def _find_codopant(target, oxidation_state, allowed_elements=None):
1✔
968
    """
969
    Finds the element from "allowed elements" that (i) possesses the desired
970
    "oxidation state" and (ii) is closest in ionic radius to the target specie
971

972
    Args:
973
        target: (Species) provides target ionic radius.
974
        oxidation_state: (float) codopant oxidation state.
975
        allowed_elements: ([str]) List of allowed elements. If None,
976
            all elements are tried.
977

978
    Returns:
979
        (Species) with oxidation_state that has ionic radius closest to
980
        target.
981
    """
982
    ref_radius = target.ionic_radius
×
983
    candidates = []
×
984
    symbols = allowed_elements or [el.symbol for el in Element]
×
985
    for sym in symbols:
×
986
        try:
×
987
            with warnings.catch_warnings():
×
988
                warnings.simplefilter("ignore")
×
989
                sp = Species(sym, oxidation_state)
×
990
                r = sp.ionic_radius
×
991
                if r is not None:
×
992
                    candidates.append((r, sp))
×
993
        except Exception:
×
994
            pass
×
995
    return min(candidates, key=lambda l: abs(l[0] / ref_radius - 1))[1]
×
996

997

998
class DopingTransformation(AbstractTransformation):
1✔
999
    """
1000
    A transformation that performs doping of a structure.
1001
    """
1002

1003
    def __init__(
1✔
1004
        self,
1005
        dopant,
1006
        ionic_radius_tol=float("inf"),
1007
        min_length=10,
1008
        alio_tol=0,
1009
        codopant=False,
1010
        max_structures_per_enum=100,
1011
        allowed_doping_species=None,
1012
        **kwargs,
1013
    ):
1014
        """
1015
        Args:
1016
            dopant (Species-like): E.g., Al3+. Must have oxidation state.
1017
            ionic_radius_tol (float): E.g., Fractional allowable ionic radii
1018
                mismatch for dopant to fit into a site. Default of inf means
1019
                that any dopant with the right oxidation state is allowed.
1020
            min_Length (float): Min. lattice parameter between periodic
1021
                images of dopant. Defaults to 10A for now.
1022
            alio_tol (int): If this is not 0, attempt will be made to dope
1023
                sites with oxidation_states +- alio_tol of the dopant. E.g.,
1024
                1 means that the ions like Ca2+ and Ti4+ are considered as
1025
                potential doping sites for Al3+.
1026
            codopant (bool): If True, doping will be carried out with a
1027
                codopant to maintain charge neutrality. Otherwise, vacancies
1028
                will be used.
1029
            max_structures_per_enum (float): Maximum number of structures to
1030
                return per enumeration. Note that there can be more than one
1031
                candidate doping site, and each site enumeration will return at
1032
                max max_structures_per_enum structures. Defaults to 100.
1033
            allowed_doping_species (list): Species that are allowed to be
1034
                doping sites. This is an inclusionary list. If specified,
1035
                any sites which are not
1036
            **kwargs:
1037
                Same keyword args as :class:`EnumerateStructureTransformation`,
1038
                i.e., min_cell_size, etc.
1039
        """
1040
        self.dopant = get_el_sp(dopant)
×
1041
        self.ionic_radius_tol = ionic_radius_tol
×
1042
        self.min_length = min_length
×
1043
        self.alio_tol = alio_tol
×
1044
        self.codopant = codopant
×
1045
        self.max_structures_per_enum = max_structures_per_enum
×
1046
        self.allowed_doping_species = allowed_doping_species
×
1047
        self.kwargs = kwargs
×
1048

1049
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
1050
        """
1051
        Args:
1052
            structure (Structure): Input structure to dope
1053

1054
        Returns:
1055
            [{"structure": Structure, "energy": float}]
1056
        """
1057
        comp = structure.composition
×
1058
        logger.info(f"Composition: {comp}")
×
1059

1060
        for sp in comp:
×
1061
            try:
×
1062
                sp.oxi_state
×
1063
            except AttributeError:
×
1064
                analyzer = BVAnalyzer()
×
1065
                structure = analyzer.get_oxi_state_decorated_structure(structure)
×
1066
                comp = structure.composition
×
1067
                break
×
1068

1069
        ox = self.dopant.oxi_state
×
1070
        radius = self.dopant.ionic_radius
×
1071

1072
        compatible_species = [
×
1073
            sp for sp in comp if sp.oxi_state == ox and abs(sp.ionic_radius / radius - 1) < self.ionic_radius_tol
1074
        ]
1075

1076
        if (not compatible_species) and self.alio_tol:
×
1077
            # We only consider aliovalent doping if there are no compatible
1078
            # isovalent species.
1079
            compatible_species = [
×
1080
                sp
1081
                for sp in comp
1082
                if abs(sp.oxi_state - ox) <= self.alio_tol
1083
                and abs(sp.ionic_radius / radius - 1) < self.ionic_radius_tol
1084
                and sp.oxi_state * ox >= 0
1085
            ]
1086

1087
        if self.allowed_doping_species is not None:
×
1088
            # Only keep allowed doping species.
1089
            compatible_species = [
×
1090
                sp for sp in compatible_species if sp in [get_el_sp(s) for s in self.allowed_doping_species]
1091
            ]
1092

1093
        logger.info(f"Compatible species: {compatible_species}")
×
1094

1095
        lengths = structure.lattice.abc
×
1096
        scaling = [max(1, int(round(math.ceil(self.min_length / x)))) for x in lengths]
×
1097
        logger.info(f"Lengths are {str(lengths)}")
×
1098
        logger.info(f"Scaling = {str(scaling)}")
×
1099

1100
        all_structures = []
×
1101
        t = EnumerateStructureTransformation(**self.kwargs)
×
1102

1103
        for sp in compatible_species:
×
1104
            supercell = structure * scaling
×
1105
            nsp = supercell.composition[sp]
×
1106
            if sp.oxi_state == ox:
×
1107
                supercell.replace_species({sp: {sp: (nsp - 1) / nsp, self.dopant: 1 / nsp}})  # type: ignore
×
1108
                logger.info(f"Doping {sp} for {self.dopant} at level {1 / nsp:.3f}")
×
1109
            elif self.codopant:
×
1110
                codopant = _find_codopant(sp, 2 * sp.oxi_state - ox)  # type: ignore
×
1111
                supercell.replace_species(
×
1112
                    {sp: {sp: (nsp - 2) / nsp, self.dopant: 1 / nsp, codopant: 1 / nsp}}
1113
                )  # type: ignore
1114
                logger.info(f"Doping {sp} for {self.dopant} + {codopant} at level {1 / nsp:.3f}")
×
1115
            elif abs(sp.oxi_state) < abs(ox):  # type: ignore
×
1116
                # Strategy: replace the target species with a
1117
                # combination of dopant and vacancy.
1118
                # We will choose the lowest oxidation state species as a
1119
                # vacancy compensation species as it is likely to be lower in
1120
                # energy
1121
                sp_to_remove = min(
×
1122
                    (s for s in comp if s.oxi_state * ox > 0),
1123
                    key=lambda ss: abs(ss.oxi_state),  # type: ignore
1124
                )
1125

1126
                if sp_to_remove == sp:
×
1127
                    common_charge = lcm(int(abs(sp.oxi_state)), int(abs(ox)))  # type: ignore
×
1128
                    ndopant = common_charge / abs(ox)
×
1129
                    nsp_to_remove = common_charge / abs(sp.oxi_state)  # type: ignore
×
1130
                    logger.info(f"Doping {nsp_to_remove} {sp} with {ndopant} {self.dopant}.")
×
1131
                    supercell.replace_species(
×
1132
                        {sp: {sp: (nsp - nsp_to_remove) / nsp, self.dopant: ndopant / nsp}}  # type: ignore
1133
                    )
1134
                else:
1135
                    ox_diff = int(abs(round(sp.oxi_state - ox)))
×
1136
                    vac_ox = int(abs(sp_to_remove.oxi_state)) * ox_diff  # type: ignore
×
1137
                    common_charge = lcm(vac_ox, ox_diff)
×
1138
                    ndopant = common_charge / ox_diff
×
1139
                    nx_to_remove = common_charge / vac_ox
×
1140
                    nx = supercell.composition[sp_to_remove]
×
1141
                    logger.info(f"Doping {ndopant} {sp} with {self.dopant} and removing {nx_to_remove} {sp_to_remove}.")
×
1142
                    supercell.replace_species(
×
1143
                        {
1144
                            sp: {sp: (nsp - ndopant) / nsp, self.dopant: ndopant / nsp},  # type: ignore
1145
                            sp_to_remove: {sp_to_remove: (nx - nx_to_remove) / nx},  # type: ignore
1146
                        }
1147
                    )
1148
            elif abs(sp.oxi_state) > abs(ox):  # type: ignore
×
1149
                # Strategy: replace the target species with dopant and also
1150
                # remove some opposite charged species for charge neutrality
1151
                if ox > 0:
×
1152
                    sp_to_remove = max(supercell.composition, key=lambda el: el.X)
×
1153
                else:
1154
                    sp_to_remove = min(supercell.composition, key=lambda el: el.X)
×
1155
                # Confirm species are of opposite oxidation states.
1156
                assert sp_to_remove.oxi_state * sp.oxi_state < 0  # type: ignore
×
1157

1158
                ox_diff = int(abs(round(sp.oxi_state - ox)))
×
1159
                anion_ox = int(abs(sp_to_remove.oxi_state))  # type: ignore
×
1160
                nx = supercell.composition[sp_to_remove]
×
1161
                common_charge = lcm(anion_ox, ox_diff)
×
1162
                ndopant = common_charge / ox_diff
×
1163
                nx_to_remove = common_charge / anion_ox
×
1164
                logger.info(f"Doping {ndopant} {sp} with {self.dopant} and removing {nx_to_remove} {sp_to_remove}.")
×
1165
                supercell.replace_species(
×
1166
                    {
1167
                        sp: {sp: (nsp - ndopant) / nsp, self.dopant: ndopant / nsp},  # type: ignore
1168
                        sp_to_remove: {sp_to_remove: (nx - nx_to_remove) / nx},  # type: ignore
1169
                    }
1170
                )
1171

1172
            ss = t.apply_transformation(supercell, return_ranked_list=self.max_structures_per_enum)
×
1173
            logger.info(f"{len(ss)} distinct structures")
×
1174
            all_structures.extend(ss)
×
1175

1176
        logger.info(f"Total {len(all_structures)} doped structures")
×
1177
        if return_ranked_list:
×
1178
            return all_structures[:return_ranked_list]
×
1179

1180
        return all_structures[0]["structure"]
×
1181

1182
    @property
1✔
1183
    def inverse(self):
1✔
1184
        """Returns: None"""
1185
        return None
×
1186

1187
    @property
1✔
1188
    def is_one_to_many(self) -> bool:
1✔
1189
        """Returns: True"""
1190
        return True
×
1191

1192

1193
class SlabTransformation(AbstractTransformation):
1✔
1194
    """
1195
    A transformation that creates a slab from a structure.
1196
    """
1197

1198
    def __init__(
1✔
1199
        self,
1200
        miller_index,
1201
        min_slab_size,
1202
        min_vacuum_size,
1203
        lll_reduce=False,
1204
        center_slab=False,
1205
        in_unit_planes=False,
1206
        primitive=True,
1207
        max_normal_search=None,
1208
        shift=0,
1209
        tol=0.1,
1210
    ):
1211
        """
1212
        Args:
1213
            miller_index (3-tuple or list): miller index of slab
1214
            min_slab_size (float): minimum slab size in angstroms
1215
            min_vacuum_size (float): minimum size of vacuum
1216
            lll_reduce (bool): whether to apply LLL reduction
1217
            center_slab (bool): whether to center the slab
1218
            primitive (bool): whether to reduce slabs to most primitive cell
1219
            max_normal_search (int): maximum index to include in linear
1220
                combinations of indices to find c lattice vector orthogonal
1221
                to slab surface
1222
            shift (float): shift to get termination
1223
            tol (float): tolerance for primitive cell finding
1224
        """
1225
        self.miller_index = miller_index
1✔
1226
        self.min_slab_size = min_slab_size
1✔
1227
        self.min_vacuum_size = min_vacuum_size
1✔
1228
        self.lll_reduce = lll_reduce
1✔
1229
        self.center_slab = center_slab
1✔
1230
        self.in_unit_planes = in_unit_planes
1✔
1231
        self.primitive = primitive
1✔
1232
        self.max_normal_search = max_normal_search
1✔
1233
        self.shift = shift
1✔
1234
        self.tol = tol
1✔
1235

1236
    def apply_transformation(self, structure):
1✔
1237
        """
1238
        Applies the transformation.
1239

1240
        Args:
1241
            structure: Input Structure
1242

1243
        Returns:
1244
            Slab Structures.
1245
        """
1246
        sg = SlabGenerator(
1✔
1247
            structure,
1248
            self.miller_index,
1249
            self.min_slab_size,
1250
            self.min_vacuum_size,
1251
            self.lll_reduce,
1252
            self.center_slab,
1253
            self.in_unit_planes,
1254
            self.primitive,
1255
            self.max_normal_search,
1256
        )
1257
        slab = sg.get_slab(self.shift, self.tol)
1✔
1258
        return slab
1✔
1259

1260
    @property
1✔
1261
    def inverse(self):
1✔
1262
        """Returns: None"""
1263
        return None
×
1264

1265
    @property
1✔
1266
    def is_one_to_many(self) -> bool:
1✔
1267
        """Returns: False"""
1268
        return False
×
1269

1270

1271
class DisorderOrderedTransformation(AbstractTransformation):
1✔
1272
    """
1273
    Not to be confused with OrderDisorderedTransformation,
1274
    this transformation attempts to obtain a
1275
    *disordered* structure from an input ordered structure.
1276
    This may or may not be physically plausible, further
1277
    inspection of the returned structures is advised.
1278
    The main purpose for this transformation is for structure
1279
    matching to crystal prototypes for structures that have
1280
    been derived from a parent prototype structure by
1281
    substitutions or alloying additions.
1282
    """
1283

1284
    def __init__(self, max_sites_to_merge=2):
1✔
1285
        """
1286
        Args:
1287
            max_sites_to_merge: only merge this number of sites together
1288
        """
1289
        self.max_sites_to_merge = max_sites_to_merge
1✔
1290

1291
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
1292
        """
1293
        Args:
1294
            structure: ordered structure
1295
            return_ranked_list: as in other pymatgen Transformations
1296

1297
        Returns:
1298
            Transformed disordered structure(s)
1299
        """
1300
        if not structure.is_ordered:
1✔
1301
            raise ValueError("This transformation is for disordered structures only.")
×
1302

1303
        partitions = self._partition_species(structure.composition, max_components=self.max_sites_to_merge)
1✔
1304
        disorder_mappings = self._get_disorder_mappings(structure.composition, partitions)
1✔
1305

1306
        disordered_structures = []
1✔
1307
        for mapping in disorder_mappings:
1✔
1308
            disordered_structure = structure.copy()
1✔
1309
            disordered_structure.replace_species(mapping)
1✔
1310
            disordered_structures.append({"structure": disordered_structure, "mapping": mapping})
1✔
1311

1312
        if len(disordered_structures) == 0:
1✔
1313
            return None
×
1314
        if not return_ranked_list:
1✔
1315
            return disordered_structures[0]["structure"]
1✔
1316
        if len(disordered_structures) > return_ranked_list:
×
1317
            disordered_structures = disordered_structures[0:return_ranked_list]
×
1318
        return disordered_structures
×
1319

1320
    @property
1✔
1321
    def inverse(self):
1✔
1322
        """Returns: None"""
1323
        return None
×
1324

1325
    @property
1✔
1326
    def is_one_to_many(self) -> bool:
1✔
1327
        """Returns: True"""
1328
        return True
×
1329

1330
    @staticmethod
1✔
1331
    def _partition_species(composition, max_components=2):
1✔
1332
        """
1333
        Private method to split a list of species into
1334
        various partitions.
1335
        """
1336

1337
        def _partition(collection):
1✔
1338
            # thanks https://stackoverflow.com/a/30134039
1339

1340
            if len(collection) == 1:
1✔
1341
                yield [collection]
1✔
1342
                return
1✔
1343

1344
            first = collection[0]
1✔
1345
            for smaller in _partition(collection[1:]):
1✔
1346
                # insert `first` in each of the subpartition's subsets
1347
                for n, subset in enumerate(smaller):
1✔
1348
                    yield smaller[:n] + [[first] + subset] + smaller[n + 1 :]
1✔
1349
                # put `first` in its own subset
1350
                yield [[first]] + smaller
1✔
1351

1352
        def _sort_partitions(partitions_to_sort):
1✔
1353
            """
1354
            Sort partitions by those we want to check first
1355
            (typically, merging two sites into one is the
1356
            one to try first).
1357
            """
1358
            partition_indices = [(idx, [len(p) for p in partition]) for idx, partition in enumerate(partitions_to_sort)]
1✔
1359

1360
            # sort by maximum length of partition first (try smallest maximums first)
1361
            # and secondarily by number of partitions (most partitions first, i.e.
1362
            # create the 'least disordered' structures first)
1363
            partition_indices = sorted(partition_indices, key=lambda x: (max(x[1]), -len(x[1])))
1✔
1364

1365
            # merge at most max_component sites,
1366
            # e.g. merge at most 2 species into 1 disordered site
1367
            partition_indices = [x for x in partition_indices if max(x[1]) <= max_components]
1✔
1368

1369
            partition_indices.pop(0)  # this is just the input structure
1✔
1370

1371
            sorted_partitions = [partitions_to_sort[x[0]] for x in partition_indices]
1✔
1372

1373
            return sorted_partitions
1✔
1374

1375
        collection = list(composition)
1✔
1376
        partitions = list(_partition(collection))
1✔
1377
        partitions = _sort_partitions(partitions)
1✔
1378

1379
        return partitions
1✔
1380

1381
    @staticmethod
1✔
1382
    def _get_disorder_mappings(composition, partitions):
1✔
1383
        """
1384
        Private method to obtain the mapping to create
1385
        a disordered structure from a given partition.
1386
        """
1387

1388
        def _get_replacement_dict_from_partition(partition):
1✔
1389
            d = {}  # to be passed to Structure.replace_species()
1✔
1390
            for sp_list in partition:
1✔
1391
                if len(sp_list) > 1:
1✔
1392
                    total_occ = sum(composition[sp] for sp in sp_list)
1✔
1393
                    merged_comp = {sp: composition[sp] / total_occ for sp in sp_list}
1✔
1394
                    for sp in sp_list:
1✔
1395
                        d[sp] = merged_comp
1✔
1396
            return d
1✔
1397

1398
        disorder_mapping = [_get_replacement_dict_from_partition(p) for p in partitions]
1✔
1399

1400
        return disorder_mapping
1✔
1401

1402

1403
class GrainBoundaryTransformation(AbstractTransformation):
1✔
1404
    """
1405
    A transformation that creates a gb from a bulk structure.
1406
    """
1407

1408
    def __init__(
1✔
1409
        self,
1410
        rotation_axis,
1411
        rotation_angle,
1412
        expand_times=4,
1413
        vacuum_thickness=0.0,
1414
        ab_shift=None,
1415
        normal=False,
1416
        ratio=True,
1417
        plane=None,
1418
        max_search=20,
1419
        tol_coi=1.0e-8,
1420
        rm_ratio=0.7,
1421
        quick_gen=False,
1422
    ):
1423
        """
1424
        Args:
1425
            rotation_axis (list): Rotation axis of GB in the form of a list of integer
1426
                e.g.: [1, 1, 0]
1427
            rotation_angle (float, in unit of degree): rotation angle used to generate GB.
1428
                Make sure the angle is accurate enough. You can use the enum* functions
1429
                in this class to extract the accurate angle.
1430
                e.g.: The rotation angle of sigma 3 twist GB with the rotation axis
1431
                [1, 1, 1] and GB plane (1, 1, 1) can be 60.000000000 degree.
1432
                If you do not know the rotation angle, but know the sigma value, we have
1433
                provide the function get_rotation_angle_from_sigma which is able to return
1434
                all the rotation angles of sigma value you provided.
1435
            expand_times (int): The multiple times used to expand one unit grain to larger grain.
1436
                This is used to tune the grain length of GB to warrant that the two GBs in one
1437
                cell do not interact with each other. Default set to 4.
1438
            vacuum_thickness (float): The thickness of vacuum that you want to insert between
1439
                two grains of the GB. Default to 0.
1440
            ab_shift (list of float, in unit of a, b vectors of Gb): in plane shift of two grains
1441
            normal (logic):
1442
                determine if need to require the c axis of top grain (first transformation matrix)
1443
                perpendicular to the surface or not.
1444
                default to false.
1445
            ratio (list of integers): lattice axial ratio.
1446
                If True, will try to determine automatically from structure.
1447
                For cubic system, ratio is not needed and can be set to None.
1448
                For tetragonal system, ratio = [mu, mv], list of two integers,
1449
                that is, mu/mv = c2/a2. If it is irrational, set it to None.
1450
                For orthorhombic system, ratio = [mu, lam, mv], list of three integers,
1451
                    that is, mu:lam:mv = c2:b2:a2. If irrational for one axis, set it to None.
1452
                e.g. mu:lam:mv = c2,None,a2, means b2 is irrational.
1453
                For rhombohedral system, ratio = [mu, mv], list of two integers,
1454
                that is, mu/mv is the ratio of (1+2*cos(alpha))/cos(alpha).
1455
                If irrational, set it to None.
1456
                For hexagonal system, ratio = [mu, mv], list of two integers,
1457
                that is, mu/mv = c2/a2. If it is irrational, set it to none.
1458
            plane (list): Grain boundary plane in the form of a list of integers
1459
                e.g.: [1, 2, 3]. If none, we set it as twist GB. The plane will be perpendicular
1460
                to the rotation axis.
1461
            max_search (int): max search for the GB lattice vectors that give the smallest GB
1462
                lattice. If normal is true, also max search the GB c vector that perpendicular
1463
                to the plane. For complex GB, if you want to speed up, you can reduce this value.
1464
                But too small of this value may lead to error.
1465
            tol_coi (float): tolerance to find the coincidence sites. When making approximations to
1466
                the ratio needed to generate the GB, you probably need to increase this tolerance to
1467
                obtain the correct number of coincidence sites. To check the number of coincidence
1468
                sites are correct or not, you can compare the generated Gb object's sigma with enum*
1469
                sigma values (what user expected by input).
1470
            rm_ratio (float): the criteria to remove the atoms which are too close with each other.
1471
                rm_ratio * bond_length of bulk system is the criteria of bond length, below which the atom
1472
                will be removed. Default to 0.7.
1473
            quick_gen (bool): whether to quickly generate a supercell, if set to true, no need to
1474
                find the smallest cell.
1475

1476
        Returns:
1477
           Grain boundary structure (gb (Structure) object).
1478
        """
1479
        self.rotation_axis = rotation_axis
1✔
1480
        self.rotation_angle = rotation_angle
1✔
1481
        self.expand_times = expand_times
1✔
1482
        self.vacuum_thickness = vacuum_thickness
1✔
1483
        self.ab_shift = ab_shift or [0, 0]
1✔
1484
        self.normal = normal
1✔
1485
        self.ratio = ratio
1✔
1486
        self.plane = plane
1✔
1487
        self.max_search = max_search
1✔
1488
        self.tol_coi = tol_coi
1✔
1489
        self.rm_ratio = rm_ratio
1✔
1490
        self.quick_gen = quick_gen
1✔
1491

1492
    def apply_transformation(self, structure):
1✔
1493
        """
1494
        Applies the transformation.
1495

1496
        Args:
1497
            structure: Input Structure
1498
            return_ranked_list: Number of structures to return.
1499

1500
        Returns:
1501
            Grain boundary Structures.
1502
        """
1503
        gbg = GrainBoundaryGenerator(structure)
1✔
1504
        gb_struct = gbg.gb_from_parameters(
1✔
1505
            self.rotation_axis,
1506
            self.rotation_angle,
1507
            expand_times=self.expand_times,
1508
            vacuum_thickness=self.vacuum_thickness,
1509
            ab_shift=self.ab_shift,
1510
            normal=self.normal,
1511
            ratio=gbg.get_ratio() if self.ratio is True else self.ratio,
1512
            plane=self.plane,
1513
            max_search=self.max_search,
1514
            tol_coi=self.tol_coi,
1515
            rm_ratio=self.rm_ratio,
1516
            quick_gen=self.quick_gen,
1517
        )
1518
        return gb_struct
1✔
1519

1520
    @property
1✔
1521
    def inverse(self):
1✔
1522
        """Returns: None"""
1523
        return None
×
1524

1525
    @property
1✔
1526
    def is_one_to_many(self) -> bool:
1✔
1527
        """Returns: False"""
1528
        return False
×
1529

1530

1531
class CubicSupercellTransformation(AbstractTransformation):
1✔
1532
    """
1533
    A transformation that aims to generate a nearly cubic supercell structure
1534
    from a structure.
1535

1536
    The algorithm solves for a transformation matrix that makes the supercell
1537
    cubic. The matrix must have integer entries, so entries are rounded (in such
1538
    a way that forces the matrix to be nonsingular). From the supercell
1539
    resulting from this transformation matrix, vector projections are used to
1540
    determine the side length of the largest cube that can fit inside the
1541
    supercell. The algorithm will iteratively increase the size of the supercell
1542
    until the largest inscribed cube's side length is at least 'min_length'
1543
    and the number of atoms in the supercell falls in the range
1544
    ``min_atoms < n < max_atoms``.
1545
    """
1546

1547
    def __init__(
1✔
1548
        self,
1549
        min_atoms: int | None = None,
1550
        max_atoms: int | None = None,
1551
        min_length: float = 15.0,
1552
        force_diagonal: bool = False,
1553
        force_90_degrees: bool = False,
1554
        angle_tolerance: float = 1e-3,
1555
    ):
1556
        """
1557
        Args:
1558
            max_atoms: Maximum number of atoms allowed in the supercell.
1559
            min_atoms: Minimum number of atoms allowed in the supercell.
1560
            min_length: Minimum length of the smallest supercell lattice vector.
1561
            force_diagonal: If True, return a transformation with a diagonal
1562
                transformation matrix.
1563
            force_90_degrees: If True, return a transformation for a supercell
1564
                with 90 degree angles (if possible). To avoid long run times,
1565
                please use max_atoms
1566
            angle_tolerance: tolerance to determine the 90 degree angles
1567
        """
1568
        self.min_atoms = min_atoms or -np.Inf
1✔
1569
        self.max_atoms = max_atoms or np.Inf
1✔
1570
        self.min_length = min_length
1✔
1571
        self.force_diagonal = force_diagonal
1✔
1572
        self.force_90_degrees = force_90_degrees
1✔
1573
        self.angle_tolerance = angle_tolerance
1✔
1574
        self.transformation_matrix = None
1✔
1575

1576
    def apply_transformation(self, structure: Structure) -> Structure:
1✔
1577
        """
1578
        The algorithm solves for a transformation matrix that makes the
1579
        supercell cubic. The matrix must have integer entries, so entries are
1580
        rounded (in such a way that forces the matrix to be nonsingular). From
1581
        the supercell resulting from this transformation matrix, vector
1582
        projections are used to determine the side length of the largest cube
1583
        that can fit inside the supercell. The algorithm will iteratively
1584
        increase the size of the supercell until the largest inscribed cube's
1585
        side length is at least 'num_nn_dists' times the nearest neighbor
1586
        distance and the number of atoms in the supercell falls in the range
1587
        defined by min_atoms and max_atoms.
1588

1589
        Returns:
1590
            supercell: Transformed supercell.
1591
        """
1592
        lat_vecs = structure.lattice.matrix
1✔
1593

1594
        # boolean for if a sufficiently large supercell has been created
1595
        sc_not_found = True
1✔
1596

1597
        if self.force_diagonal:
1✔
1598
            scale = self.min_length / np.array(structure.lattice.abc)
1✔
1599
            self.transformation_matrix = np.diag(np.ceil(scale).astype(int))  # type: ignore
1✔
1600
            st = SupercellTransformation(self.transformation_matrix)
1✔
1601
            return st.apply_transformation(structure)
1✔
1602

1603
        # target_threshold is used as the desired cubic side lengths
1604
        target_sc_size = self.min_length
1✔
1605
        while sc_not_found:
1✔
1606
            target_sc_lat_vecs = np.eye(3, 3) * target_sc_size
1✔
1607
            self.transformation_matrix = target_sc_lat_vecs @ np.linalg.inv(lat_vecs)  # type: ignore
1✔
1608

1609
            # round the entries of T and force T to be nonsingular
1610
            self.transformation_matrix = _round_and_make_arr_singular(self.transformation_matrix)  # type: ignore
1✔
1611

1612
            proposed_sc_lat_vecs = self.transformation_matrix @ lat_vecs
1✔
1613

1614
            # Find the shortest dimension length and direction
1615
            a = proposed_sc_lat_vecs[0]
1✔
1616
            b = proposed_sc_lat_vecs[1]
1✔
1617
            c = proposed_sc_lat_vecs[2]
1✔
1618

1619
            length1_vec = c - _proj(c, a)  # a-c plane
1✔
1620
            length2_vec = a - _proj(a, c)
1✔
1621
            length3_vec = b - _proj(b, a)  # b-a plane
1✔
1622
            length4_vec = a - _proj(a, b)
1✔
1623
            length5_vec = b - _proj(b, c)  # b-c plane
1✔
1624
            length6_vec = c - _proj(c, b)
1✔
1625
            length_vecs = np.array(
1✔
1626
                [
1627
                    length1_vec,
1628
                    length2_vec,
1629
                    length3_vec,
1630
                    length4_vec,
1631
                    length5_vec,
1632
                    length6_vec,
1633
                ]
1634
            )
1635

1636
            # Get number of atoms
1637
            st = SupercellTransformation(self.transformation_matrix)
1✔
1638
            superstructure = st.apply_transformation(structure)
1✔
1639
            num_at = superstructure.num_sites
1✔
1640

1641
            # Check if constraints are satisfied
1642
            if (
1✔
1643
                np.min(np.linalg.norm(length_vecs, axis=1)) >= self.min_length
1644
                and self.min_atoms <= num_at <= self.max_atoms
1645
            ):
1646
                if not self.force_90_degrees:
1✔
1647
                    return superstructure
1✔
1648
                else:
1649
                    if np.all(
1✔
1650
                        np.absolute(np.array(superstructure.lattice.angles) - np.array([90.0, 90.0, 90.0]))
1651
                        < self.angle_tolerance
1652
                    ):
1653
                        return superstructure
1✔
1654

1655
            # Increase threshold until proposed supercell meets requirements
1656
            target_sc_size += 0.1
1✔
1657
            if num_at > self.max_atoms:
1✔
1658
                raise AttributeError(
×
1659
                    "While trying to solve for the supercell, the max "
1660
                    "number of atoms was exceeded. Try lowering the number"
1661
                    "of nearest neighbor distances."
1662
                )
1663
        raise AttributeError("Unable to find cubic supercell")
×
1664

1665
    @property
1✔
1666
    def inverse(self):
1✔
1667
        """
1668
        Returns:
1669
            None
1670
        """
1671
        return None
×
1672

1673
    @property
1✔
1674
    def is_one_to_many(self) -> bool:
1✔
1675
        """
1676
        Returns:
1677
            False
1678
        """
1679
        return False
×
1680

1681

1682
class AddAdsorbateTransformation(AbstractTransformation):
1✔
1683
    """
1684
    Create absorbate structures.
1685
    """
1686

1687
    def __init__(
1✔
1688
        self,
1689
        adsorbate,
1690
        selective_dynamics=False,
1691
        height=0.9,
1692
        mi_vec=None,
1693
        repeat=None,
1694
        min_lw=5.0,
1695
        translate=True,
1696
        reorient=True,
1697
        find_args=None,
1698
    ):
1699
        """
1700
        Use AdsorbateSiteFinder to add an absorbate to a slab.
1701

1702
        Args:
1703
            adsorbate (Molecule): molecule to add as adsorbate
1704
            selective_dynamics (bool): flag for whether to assign
1705
                non-surface sites as fixed for selective dynamics
1706
            height (float): height criteria for selection of surface sites
1707
            mi_vec : vector corresponding to the vector
1708
                concurrent with the miller index, this enables use with
1709
                slabs that have been reoriented, but the miller vector
1710
                must be supplied manually
1711
            repeat (3-tuple or list): repeat argument for supercell generation
1712
            min_lw (float): minimum length and width of the slab, only used
1713
                if repeat is None
1714
            translate (bool): flag on whether to translate the molecule so
1715
                that its CoM is at the origin prior to adding it to the surface
1716
            reorient (bool): flag on whether or not to reorient adsorbate
1717
                along the miller index
1718
            find_args (dict): dictionary of arguments to be passed to the
1719
                call to self.find_adsorption_sites, e.g. {"distance":2.0}
1720
        """
1721
        self.adsorbate = adsorbate
1✔
1722
        self.selective_dynamics = selective_dynamics
1✔
1723
        self.height = height
1✔
1724
        self.mi_vec = mi_vec
1✔
1725
        self.repeat = repeat
1✔
1726
        self.min_lw = min_lw
1✔
1727
        self.translate = translate
1✔
1728
        self.reorient = reorient
1✔
1729
        self.find_args = find_args
1✔
1730

1731
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
1732
        """
1733
        Args:
1734
            structure: Must be a Slab structure
1735
            return_ranked_list:  Whether or not multiple structures are
1736
                returned. If return_ranked_list is a number, up to that number of
1737
                structures is returned.
1738

1739
        Returns: Slab with adsorbate
1740
        """
1741
        sitefinder = AdsorbateSiteFinder(
1✔
1742
            structure,
1743
            selective_dynamics=self.selective_dynamics,
1744
            height=self.height,
1745
            mi_vec=self.mi_vec,
1746
        )
1747

1748
        structures = sitefinder.generate_adsorption_structures(
1✔
1749
            self.adsorbate,
1750
            repeat=self.repeat,
1751
            min_lw=self.min_lw,
1752
            translate=self.translate,
1753
            reorient=self.reorient,
1754
            find_args=self.find_args,
1755
        )
1756

1757
        if not return_ranked_list:
1✔
1758
            return structures[0]
1✔
1759
        return [{"structure": structure} for structure in structures[:return_ranked_list]]
×
1760

1761
    @property
1✔
1762
    def inverse(self):
1✔
1763
        """Returns: None"""
1764
        return None
×
1765

1766
    @property
1✔
1767
    def is_one_to_many(self) -> bool:
1✔
1768
        """Returns: True"""
1769
        return True
×
1770

1771

1772
def _round_and_make_arr_singular(arr: np.ndarray) -> np.ndarray:
1✔
1773
    """
1774
    This function rounds all elements of a matrix to the nearest integer,
1775
    unless the rounding scheme causes the matrix to be singular, in which
1776
    case elements of zero rows or columns in the rounded matrix with the
1777
    largest absolute valued magnitude in the unrounded matrix will be
1778
    rounded to the next integer away from zero rather than to the
1779
    nearest integer.
1780

1781
    The transformation is as follows. First, all entries in 'arr' will be
1782
    rounded to the nearest integer to yield 'arr_rounded'. If 'arr_rounded'
1783
    has any zero rows, then one element in each zero row of 'arr_rounded'
1784
    corresponding to the element in 'arr' of that row with the largest
1785
    absolute valued magnitude will be rounded to the next integer away from
1786
    zero (see the '_round_away_from_zero(x)' function) rather than the
1787
    nearest integer. This process is then repeated for zero columns. Also
1788
    note that if 'arr' already has zero rows or columns, then this function
1789
    will not change those rows/columns.
1790

1791
    Args:
1792
        arr: Input matrix
1793

1794
    Returns:
1795
        Transformed matrix.
1796
    """
1797

1798
    def round_away_from_zero(x):
1✔
1799
        """
1800
        Returns 'x' rounded to the next integer away from 0.
1801
        If 'x' is zero, then returns zero.
1802
        E.g. -1.2 rounds to -2.0. 1.2 rounds to 2.0.
1803
        """
1804
        abs_x = abs(x)
1✔
1805
        return math.ceil(abs_x) * (abs_x / x) if x != 0 else 0
1✔
1806

1807
    arr_rounded = np.around(arr)
1✔
1808

1809
    # Zero rows in 'arr_rounded' make the array singular, so force zero rows to
1810
    # be nonzero
1811
    if (~arr_rounded.any(axis=1)).any():
1✔
1812
        # Check for zero rows in T_rounded
1813

1814
        # indices of zero rows
1815
        zero_row_idxs = np.where(~arr_rounded.any(axis=1))[0]
1✔
1816

1817
        for zero_row_idx in zero_row_idxs:  # loop over zero rows
1✔
1818
            zero_row = arr[zero_row_idx, :]
1✔
1819

1820
            # Find the element of the zero row with the largest absolute
1821
            # magnitude in the original (non-rounded) array (i.e. 'arr')
1822
            matches = np.absolute(zero_row) == np.amax(np.absolute(zero_row))
1✔
1823
            col_idx_to_fix = np.where(matches)[0]
1✔
1824

1825
            # Break ties for the largest absolute magnitude
1826
            r_idx = np.random.randint(len(col_idx_to_fix))
1✔
1827
            col_idx_to_fix = col_idx_to_fix[r_idx]
1✔
1828

1829
            # Round the chosen element away from zero
1830
            arr_rounded[zero_row_idx, col_idx_to_fix] = round_away_from_zero(arr[zero_row_idx, col_idx_to_fix])
1✔
1831

1832
    # Repeat process for zero columns
1833
    if (~arr_rounded.any(axis=0)).any():
1✔
1834
        # Check for zero columns in T_rounded
1835
        zero_col_idxs = np.where(~arr_rounded.any(axis=0))[0]
×
1836
        for zero_col_idx in zero_col_idxs:
×
1837
            zero_col = arr[:, zero_col_idx]
×
1838
            matches = np.absolute(zero_col) == np.amax(np.absolute(zero_col))
×
1839
            row_idx_to_fix = np.where(matches)[0]
×
1840

1841
            for i in row_idx_to_fix:
×
1842
                arr_rounded[i, zero_col_idx] = round_away_from_zero(arr[i, zero_col_idx])
×
1843
    return arr_rounded.astype(int)
1✔
1844

1845

1846
class SubstituteSurfaceSiteTransformation(AbstractTransformation):
1✔
1847
    """
1848
    Use AdsorptionSiteFinder to perform substitution-type doping on the surface
1849
    and returns all possible configurations where one dopant is substituted
1850
    per surface. Can substitute one surface or both.
1851
    """
1852

1853
    def __init__(
1✔
1854
        self,
1855
        atom,
1856
        selective_dynamics=False,
1857
        height=0.9,
1858
        mi_vec=None,
1859
        target_species=None,
1860
        sub_both_sides=False,
1861
        range_tol=1e-2,
1862
        dist_from_surf=0,
1863
    ):
1864
        """
1865
        Args:
1866
            atom (str): atom corresponding to substitutional dopant
1867
            selective_dynamics (bool): flag for whether to assign
1868
                non-surface sites as fixed for selective dynamics
1869
            height (float): height criteria for selection of surface sites
1870
            mi_vec : vector corresponding to the vector
1871
                concurrent with the miller index, this enables use with
1872
                slabs that have been reoriented, but the miller vector
1873
                must be supplied manually
1874
            target_species:  List of specific species to substitute
1875
            sub_both_sides (bool): If true, substitute an equivalent
1876
                site on the other surface
1877
            range_tol (float): Find viable substitution sites at a specific
1878
                distance from the surface +- this tolerance
1879
            dist_from_surf (float): Distance from the surface to find viable
1880
                substitution sites, defaults to 0 to substitute at the surface
1881
        """
1882
        self.atom = atom
1✔
1883
        self.selective_dynamics = selective_dynamics
1✔
1884
        self.height = height
1✔
1885
        self.mi_vec = mi_vec
1✔
1886
        self.target_species = target_species
1✔
1887
        self.sub_both_sides = sub_both_sides
1✔
1888
        self.range_tol = range_tol
1✔
1889
        self.dist_from_surf = dist_from_surf
1✔
1890

1891
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
1892
        """
1893
        Args:
1894
            structure: Must be a Slab structure
1895
            return_ranked_list:  Whether or not multiple structures are
1896
                returned. If return_ranked_list is a number, up to that number of
1897
                structures is returned.
1898

1899
        Returns: Slab with sites substituted
1900
        """
1901
        sitefinder = AdsorbateSiteFinder(
1✔
1902
            structure,
1903
            selective_dynamics=self.selective_dynamics,
1904
            height=self.height,
1905
            mi_vec=self.mi_vec,
1906
        )
1907

1908
        structures = sitefinder.generate_substitution_structures(
1✔
1909
            self.atom,
1910
            target_species=self.target_species,
1911
            sub_both_sides=self.sub_both_sides,
1912
            range_tol=self.range_tol,
1913
            dist_from_surf=self.dist_from_surf,
1914
        )
1915

1916
        if not return_ranked_list:
1✔
1917
            return structures[0]
1✔
1918
        return [{"structure": structure} for structure in structures[:return_ranked_list]]
×
1919

1920
    @property
1✔
1921
    def inverse(self):
1✔
1922
        """Returns: None"""
1923
        return None
×
1924

1925
    @property
1✔
1926
    def is_one_to_many(self) -> bool:
1✔
1927
        """Returns: True"""
1928
        return True
×
1929

1930

1931
def _proj(b, a):
1✔
1932
    """
1933
    Returns vector projection (np.ndarray) of vector b (np.ndarray)
1934
    onto vector a (np.ndarray)
1935
    """
1936
    return (b.T @ (a / np.linalg.norm(a))) * (a / np.linalg.norm(a))
1✔
1937

1938

1939
class SQSTransformation(AbstractTransformation):
1✔
1940
    """
1941
    A transformation that creates a special quasirandom structure (SQS) from a structure with partial occupancies.
1942
    """
1943

1944
    def __init__(
1✔
1945
        self,
1946
        scaling,
1947
        cluster_size_and_shell=None,
1948
        search_time=60,
1949
        directory=None,
1950
        instances=None,
1951
        temperature=1,
1952
        wr=1,
1953
        wn=1,
1954
        wd=0.5,
1955
        tol=1e-3,
1956
        best_only=True,
1957
        remove_duplicate_structures=True,
1958
        reduction_algo="LLL",
1959
    ):
1960
        """
1961
        Args:
1962
            structure (Structure): Disordered pymatgen Structure object
1963
            scaling (int or list): Scaling factor to determine supercell. Two options are possible:
1964
                    a. (preferred) Scales number of atoms, e.g., for a structure with 8 atoms,
1965
                       scaling=4 would lead to a 32 atom supercell
1966
                    b. A sequence of three scaling factors, e.g., [2, 1, 1], which
1967
                       specifies that the supercell should have dimensions 2a x b x c
1968
            cluster_size_and_shell (Optional[Dict[int, int]]): Dictionary of cluster interactions with entries in
1969
                the form number of atoms: nearest neighbor shell
1970
        Keyword Args:
1971
            search_time (float): Time spent looking for the ideal SQS in minutes (default: 60)
1972
            directory (str): Directory to run mcsqs calculation and store files (default: None
1973
                runs calculations in a temp directory)
1974
            instances (int): Specifies the number of parallel instances of mcsqs to run
1975
                (default: number of cpu cores detected by Python)
1976
            temperature (int or float): Monte Carlo temperature (default: 1), "T" in atat code
1977
            wr (int or float): Weight assigned to range of perfect correlation match in objective
1978
                function (default = 1)
1979
            wn (int or float): Multiplicative decrease in weight per additional point in cluster (default: 1)
1980
            wd (int or float): Exponent of decay in weight as function of cluster diameter (default: 0)
1981
            tol (int or float): Tolerance for matching correlations (default: 1e-3)
1982
            best_only (bool): only return structures with lowest objective function
1983
            remove_duplicate_structures (bool): only return unique structures
1984
            reduction_algo (str): The lattice reduction algorithm to use.
1985
                Currently supported options are "niggli" or "LLL".
1986
                "False" does not reduce structure.
1987
        """
1988
        self.scaling = scaling
×
1989
        self.search_time = search_time
×
1990
        self.cluster_size_and_shell = cluster_size_and_shell
×
1991
        self.directory = directory
×
1992
        self.instances = instances
×
1993
        self.temperature = temperature
×
1994
        self.wr = wr
×
1995
        self.wn = wn
×
1996
        self.wd = wd
×
1997
        self.tol = tol
×
1998
        self.best_only = best_only
×
1999
        self.remove_duplicate_structures = remove_duplicate_structures
×
2000
        self.reduction_algo = reduction_algo
×
2001

2002
    @staticmethod
1✔
2003
    def _get_max_neighbor_distance(struct, shell):
1✔
2004
        """
2005
        Calculate maximum nearest neighbor distance
2006
        Args:
2007
            struct: pymatgen Structure object
2008
            shell: nearest neighbor shell, such that shell=1 is the first nearest
2009
                neighbor, etc.
2010

2011
        Returns:
2012
            maximum nearest neighbor distance, in angstroms
2013
        """
2014
        mdnn = MinimumDistanceNN()
×
2015
        distances = []
×
2016

2017
        for site_num, site in enumerate(struct):
×
2018
            shell_info = mdnn.get_nn_shell_info(struct, site_num, shell)
×
2019
            for entry in shell_info:
×
2020
                image = entry["image"]
×
2021
                distance = site.distance(struct[entry["site_index"]], jimage=image)
×
2022
                distances.append(distance)
×
2023

2024
        return max(distances)
×
2025

2026
    @staticmethod
1✔
2027
    def _get_disordered_substructure(struc_disordered):
1✔
2028
        """
2029
        Converts disordered structure into a substructure consisting of only disordered sites
2030
        Args:
2031
            struc_disordered: pymatgen disordered Structure object
2032
        Returns:
2033
            pymatgen Structure object representing a substructure of disordered sites
2034
        """
2035
        disordered_substructure = struc_disordered.copy()
×
2036

2037
        idx_to_remove = []
×
2038
        for idx, site in enumerate(disordered_substructure.sites):
×
2039
            if site.is_ordered:
×
2040
                idx_to_remove.append(idx)
×
2041
        disordered_substructure.remove_sites(idx_to_remove)
×
2042

2043
        return disordered_substructure
×
2044

2045
    @staticmethod
1✔
2046
    def _sqs_cluster_estimate(struc_disordered, cluster_size_and_shell: dict[int, int] | None = None):
1✔
2047
        """
2048
        Set up an ATAT cluster.out file for a given structure and set of constraints
2049
        Args:
2050
            struc_disordered: disordered pymatgen Structure object
2051
            cluster_size_and_shell: dict of integers {cluster: shell}
2052

2053
        Returns:
2054
            dict of {cluster size: distance in angstroms} for mcsqs calculation
2055
        """
2056
        cluster_size_and_shell = cluster_size_and_shell or {2: 3, 3: 2, 4: 1}
×
2057

2058
        disordered_substructure = SQSTransformation._get_disordered_substructure(struc_disordered)
×
2059

2060
        clusters = {}
×
2061
        for cluster_size, shell in cluster_size_and_shell.items():
×
2062
            max_distance = SQSTransformation._get_max_neighbor_distance(disordered_substructure, shell)
×
2063
            clusters[cluster_size] = max_distance + 0.01  # add small tolerance
×
2064

2065
        return clusters
×
2066

2067
    def apply_transformation(self, structure: Structure, return_ranked_list=False):
1✔
2068
        """
2069
        Applies SQS transformation
2070
        Args:
2071
            structure (pymatgen Structure): pymatgen Structure with partial occupancies
2072
            return_ranked_list (bool): number of structures to return
2073
        Returns:
2074
            pymatgen Structure which is an SQS of the input structure
2075
        """
2076
        if return_ranked_list and self.instances is None:
×
2077
            raise ValueError("mcsqs has no instances, so cannot return a ranked list")
×
2078
        if (
×
2079
            isinstance(return_ranked_list, int)
2080
            and isinstance(self.instances, int)
2081
            and return_ranked_list > self.instances
2082
        ):
2083
            raise ValueError("return_ranked_list cannot be less that number of instances")
×
2084

2085
        clusters = self._sqs_cluster_estimate(structure, self.cluster_size_and_shell)
×
2086

2087
        # useful for debugging and understanding
2088
        self._last_used_clusters = clusters
×
2089

2090
        sqs = run_mcsqs(
×
2091
            structure=structure,
2092
            clusters=clusters,
2093
            scaling=self.scaling,
2094
            search_time=self.search_time,
2095
            directory=self.directory,
2096
            instances=self.instances,
2097
            temperature=self.temperature,
2098
            wr=self.wr,
2099
            wn=self.wn,
2100
            wd=self.wd,
2101
            tol=self.tol,
2102
        )
2103

2104
        return self._get_unique_bestsqs_strucs(
×
2105
            sqs,
2106
            best_only=self.best_only,
2107
            return_ranked_list=return_ranked_list,
2108
            remove_duplicate_structures=self.remove_duplicate_structures,
2109
            reduction_algo=self.reduction_algo,
2110
        )
2111

2112
    @staticmethod
1✔
2113
    def _get_unique_bestsqs_strucs(sqs, best_only, return_ranked_list, remove_duplicate_structures, reduction_algo):
1✔
2114
        """
2115
        Gets unique sqs structures with lowest objective function. Requires an mcsqs output that has been run
2116
            in parallel, otherwise returns Sqs.bestsqs
2117
        Args:
2118
            sqs (Sqs): Sqs class object.
2119
            best_only (bool): only return structures with lowest objective function.
2120
            return_ranked_list (bool): Number of structures to return.
2121
            remove_duplicate_structures (bool): only return unique structures.
2122
            reduction_algo (str): The lattice reduction algorithm to use.
2123
                Currently supported options are "niggli" or "LLL".
2124
                "False" does not reduce structure.
2125

2126
        Returns:
2127
            list of dicts of the form {'structure': Structure, 'objective_function': ...}, unless run in serial
2128
                (returns a single structure Sqs.bestsqs)
2129
        """
2130
        if not return_ranked_list:
×
2131
            return_struc = sqs.bestsqs
×
2132

2133
            # reduce structure
2134
            if reduction_algo:
×
2135
                return_struc = return_struc.get_reduced_structure(reduction_algo=reduction_algo)
×
2136

2137
            # return just the structure
2138
            return return_struc
×
2139

2140
        strucs = []
×
2141
        for d in sqs.allsqs:
×
2142
            # filter for best structures only if enabled, else use full sqs.all_sqs list
2143
            if (not best_only) or (best_only and d["objective_function"] == sqs.objective_function):
×
2144
                struct = d["structure"]
×
2145
                # add temporary objective_function attribute to access objective_function after grouping
2146
                struct.objective_function = d["objective_function"]
×
2147
                strucs.append(struct)
×
2148

2149
        if remove_duplicate_structures:
×
2150
            matcher = StructureMatcher()
×
2151
            # sort by unique structures ... can take a while for a long list of strucs
2152
            unique_strucs_grouped = matcher.group_structures(strucs)
×
2153
            # get unique structures only
2154
            strucs = [group[0] for group in unique_strucs_grouped]
×
2155

2156
        # sort structures by objective function
2157
        strucs.sort(key=lambda x: x.objective_function if isinstance(x.objective_function, float) else -np.inf)
×
2158

2159
        to_return = [{"structure": struct, "objective_function": struct.objective_function} for struct in strucs]
×
2160

2161
        for d in to_return:
×
2162
            # delete temporary objective_function attribute
2163
            del d["structure"].objective_function
×
2164

2165
            # reduce structure
2166
            if reduction_algo:
×
2167
                d["structure"] = d["structure"].get_reduced_structure(reduction_algo=reduction_algo)
×
2168

2169
        if isinstance(return_ranked_list, int):
×
2170
            return to_return[:return_ranked_list]
×
2171
        return to_return
×
2172

2173
    @property
1✔
2174
    def inverse(self):
1✔
2175
        """Returns: None"""
2176
        return None
×
2177

2178
    @property
1✔
2179
    def is_one_to_many(self) -> bool:
1✔
2180
        """Returns: True"""
2181
        return True
×
2182

2183

2184
class MonteCarloRattleTransformation(AbstractTransformation):
1✔
2185
    r"""
2186
    Uses a Monte Carlo rattle procedure to randomly perturb the sites in a
2187
    structure.
2188

2189
    This class requires the hiPhive package to be installed.
2190

2191
    Rattling atom `i` is carried out as a Monte Carlo move that is accepted with
2192
    a probability determined from the minimum interatomic distance
2193
    :math:`d_{ij}`. If :math:`\\min(d_{ij})` is smaller than :math:`d_{min}`
2194
    the move is only accepted with a low probability.
2195

2196
    This process is repeated for each atom a number of times meaning
2197
    the magnitude of the final displacements is not *directly*
2198
    connected to `rattle_std`.
2199
    """
2200

2201
    @requires(hiphive, "hiphive is required for MonteCarloRattleTransformation")
1✔
2202
    def __init__(self, rattle_std: float, min_distance: float, seed: int | None = None, **kwargs):
1✔
2203
        """
2204
        Args:
2205
            rattle_std: Rattle amplitude (standard deviation in normal
2206
                distribution). Note: this value is not *directly* connected to the
2207
                final average displacement for the structures
2208
            min_distance: Interatomic distance used for computing the probability
2209
                for each rattle move.
2210
            seed: Seed for setting up NumPy random state from which random numbers
2211
                are generated. If ``None``, a random seed will be generated
2212
                (default). This option allows the output of this transformation
2213
                to be deterministic.
2214
            **kwargs: Additional keyword arguments to be passed to the hiPhive
2215
                mc_rattle function.
2216
        """
2217
        self.rattle_std = rattle_std
×
2218
        self.min_distance = min_distance
×
2219
        self.seed = seed
×
2220

2221
        if not seed:
×
2222
            # if seed is None, use a random RandomState seed but make sure
2223
            # we store that the original seed was None
2224
            seed = np.random.randint(1, 1000000000)
×
2225

2226
        self.random_state = np.random.RandomState(seed)  # pylint: disable=E1101
×
2227
        self.kwargs = kwargs
×
2228

2229
    def apply_transformation(self, structure: Structure) -> Structure:
1✔
2230
        """
2231
        Apply the transformation.
2232

2233
        Args:
2234
            structure: Input Structure
2235

2236
        Returns:
2237
            Structure with sites perturbed.
2238
        """
2239
        from hiphive.structure_generation.rattle import mc_rattle
×
2240

2241
        atoms = AseAtomsAdaptor.get_atoms(structure)
×
2242
        seed = self.random_state.randint(1, 1000000000)
×
2243
        displacements = mc_rattle(atoms, self.rattle_std, self.min_distance, seed=seed, **self.kwargs)
×
2244

2245
        transformed_structure = Structure(
×
2246
            structure.lattice,
2247
            structure.species,
2248
            structure.cart_coords + displacements,
2249
            coords_are_cartesian=True,
2250
        )
2251

2252
        return transformed_structure
×
2253

2254
    def __str__(self):
1✔
2255
        return f"{__name__} : rattle_std = {self.rattle_std}"
×
2256

2257
    def __repr__(self):
1✔
2258
        return str(self)
×
2259

2260
    @property
1✔
2261
    def inverse(self):
1✔
2262
        """
2263
        Returns: None
2264
        """
2265
        return None
×
2266

2267
    @property
1✔
2268
    def is_one_to_many(self) -> bool:
1✔
2269
        """
2270
        Returns: False
2271
        """
2272
        return False
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc