• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.87
/pymatgen/analysis/chemenv/coordination_environments/structure_environments.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module contains objects that are used to describe the environments in a structure. The most detailed object
6
(StructureEnvironments) contains a very thorough analysis of the environments of a given atom but is difficult to
7
used as such. The LightStructureEnvironments object is a lighter version that is obtained by applying a "strategy"
8
on the StructureEnvironments object. Basically, the LightStructureEnvironments provides the coordination environment(s)
9
and possibly some fraction corresponding to these.
10
"""
11

12
from __future__ import annotations
1✔
13

14
import numpy as np
1✔
15
from monty.json import MontyDecoder, MSONable, jsanitize
1✔
16

17
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometries import (
1✔
18
    AllCoordinationGeometries,
19
)
20
from pymatgen.analysis.chemenv.coordination_environments.voronoi import (
1✔
21
    DetailedVoronoiContainer,
22
)
23
from pymatgen.analysis.chemenv.utils.chemenv_errors import ChemenvError
1✔
24
from pymatgen.analysis.chemenv.utils.defs_utils import AdditionalConditions
1✔
25
from pymatgen.core.periodic_table import Element, Species
1✔
26
from pymatgen.core.sites import PeriodicSite
1✔
27
from pymatgen.core.structure import PeriodicNeighbor, Structure
1✔
28

29
__author__ = "David Waroquiers"
1✔
30
__copyright__ = "Copyright 2012, The Materials Project"
1✔
31
__credits__ = "Geoffroy Hautier"
1✔
32
__version__ = "2.0"
1✔
33
__maintainer__ = "David Waroquiers"
1✔
34
__email__ = "david.waroquiers@gmail.com"
1✔
35
__date__ = "Feb 20, 2016"
1✔
36

37

38
allcg = AllCoordinationGeometries()
1✔
39
symbol_cn_mapping = allcg.get_symbol_cn_mapping()
1✔
40

41

42
class StructureEnvironments(MSONable):
1✔
43
    """
44
    Class used to store the chemical environments of a given structure.
45
    """
46

47
    AC = AdditionalConditions()
1✔
48

49
    class NeighborsSet:
1✔
50
        """
51
        Class used to store a given set of neighbors of a given site (based on the detailed_voronoi).
52
        """
53

54
        def __init__(self, structure: Structure, isite, detailed_voronoi, site_voronoi_indices, sources=None):
1✔
55
            """
56
            Constructor for NeighborsSet.
57

58
            Args:
59
                structure: Structure object.
60
                isite: Index of the site for which neighbors are stored in this NeighborsSet.
61
                detailed_voronoi: Corresponding DetailedVoronoiContainer object containing all the possible
62
                    neighbors of the give site.
63
                site_voronoi_indices: Indices of the voronoi sites in the DetailedVoronoiContainer object that
64
                    make up this NeighborsSet.
65
                sources: Sources for this NeighborsSet, i.e. how this NeighborsSet was generated.
66
            """
67
            self.structure = structure
1✔
68
            self.isite = isite
1✔
69
            self.detailed_voronoi = detailed_voronoi
1✔
70
            self.voronoi = detailed_voronoi.voronoi_list2[isite]
1✔
71
            myset = set(site_voronoi_indices)
1✔
72
            if len(myset) != len(site_voronoi_indices):
1✔
73
                raise ValueError("Set of neighbors contains duplicates !")
×
74
            self.site_voronoi_indices = sorted(myset)
1✔
75
            if sources is None:
1✔
76
                self.sources = [{"origin": "UNKNOWN"}]
×
77
            elif isinstance(sources, list):
1✔
78
                self.sources = sources
1✔
79
            else:
80
                self.sources = [sources]
1✔
81

82
        def get_neighb_voronoi_indices(self, permutation):
1✔
83
            """
84
            Return the indices in the detailed_voronoi corresponding to the current permutation.
85

86
            Args:
87
                permutation: Current permutation for which the indices in the detailed_voronoi are needed.
88

89
            Returns: List of indices in the detailed_voronoi.
90
            """
91
            return [self.site_voronoi_indices[ii] for ii in permutation]
1✔
92

93
        @property
1✔
94
        def neighb_coords(self):
1✔
95
            """
96
            Coordinates of neighbors for this NeighborsSet.
97
            """
98
            return [self.voronoi[inb]["site"].coords for inb in self.site_voronoi_indices]
1✔
99

100
        @property
1✔
101
        def neighb_coordsOpt(self):
1✔
102
            """
103
            Optimized access to the coordinates of neighbors for this NeighborsSet.
104
            """
105
            return self.detailed_voronoi.voronoi_list_coords[self.isite].take(self.site_voronoi_indices, axis=0)
1✔
106

107
        @property
1✔
108
        def neighb_sites(self):
1✔
109
            """
110
            Neighbors for this NeighborsSet as pymatgen Sites.
111
            """
112
            return [self.voronoi[inb]["site"] for inb in self.site_voronoi_indices]
1✔
113

114
        @property
1✔
115
        def neighb_sites_and_indices(self):
1✔
116
            """
117
            List of neighbors for this NeighborsSet as pymatgen Sites and their index in the original structure.
118
            """
119
            return [
1✔
120
                {"site": self.voronoi[inb]["site"], "index": self.voronoi[inb]["index"]}
121
                for inb in self.site_voronoi_indices
122
            ]
123

124
        @property
1✔
125
        def coords(self):
1✔
126
            """
127
            Coordinates of the current central atom and its neighbors for this NeighborsSet.
128
            """
129
            coords = [self.structure[self.isite].coords]
1✔
130
            coords.extend(self.neighb_coords)
1✔
131
            return coords
1✔
132

133
        @property
1✔
134
        def normalized_distances(self):
1✔
135
            """
136
            Normalized distances to each neighbor in this NeighborsSet.
137
            """
138
            return [self.voronoi[inb]["normalized_distance"] for inb in self.site_voronoi_indices]
1✔
139

140
        @property
1✔
141
        def normalized_angles(self):
1✔
142
            """
143
            Normalized angles for each neighbor in this NeighborsSet.
144
            """
145
            return [self.voronoi[inb]["normalized_angle"] for inb in self.site_voronoi_indices]
1✔
146

147
        @property
1✔
148
        def distances(self):
1✔
149
            """
150
            Distances to each neighbor in this NeighborsSet.
151
            """
152
            return [self.voronoi[inb]["distance"] for inb in self.site_voronoi_indices]
1✔
153

154
        @property
1✔
155
        def angles(self):
1✔
156
            """
157
            Angles for each neighbor in this NeighborsSet.
158
            """
159
            return [self.voronoi[inb]["angle"] for inb in self.site_voronoi_indices]
1✔
160

161
        # @property
162
        # def sphere_fraction_angles(self):
163
        #     return [0.25 * self.voronoi[inb]['angle'] / np.pi for inb in self.site_voronoi_indices]
164

165
        @property
1✔
166
        def info(self):
1✔
167
            """
168
            Summarized information about this NeighborsSet.
169
            """
170
            was = self.normalized_angles
1✔
171
            wds = self.normalized_distances
1✔
172
            angles = self.angles
1✔
173
            distances = self.distances
1✔
174
            return {
1✔
175
                "normalized_angles": was,
176
                "normalized_distances": wds,
177
                "normalized_angles_sum": np.sum(was),
178
                "normalized_angles_mean": np.mean(was),
179
                "normalized_angles_std": np.std(was),
180
                "normalized_angles_min": np.min(was),
181
                "normalized_angles_max": np.max(was),
182
                "normalized_distances_mean": np.mean(wds),
183
                "normalized_distances_std": np.std(wds),
184
                "normalized_distances_min": np.min(wds),
185
                "normalized_distances_max": np.max(wds),
186
                "angles": angles,
187
                "distances": distances,
188
                "angles_sum": np.sum(angles),
189
                "angles_mean": np.mean(angles),
190
                "angles_std": np.std(angles),
191
                "angles_min": np.min(angles),
192
                "angles_max": np.max(angles),
193
                "distances_mean": np.mean(distances),
194
                "distances_std": np.std(distances),
195
                "distances_min": np.min(distances),
196
                "distances_max": np.max(distances),
197
            }
198

199
        def distance_plateau(self):
1✔
200
            """
201
            Returns the distances plateau's for this NeighborsSet.
202
            """
203
            all_nbs_normalized_distances_sorted = sorted(
×
204
                (nb["normalized_distance"] for nb in self.voronoi), reverse=True
205
            )
206
            maxdist = np.max(self.normalized_distances)
×
207
            plateau = None
×
208
            for idist, dist in enumerate(all_nbs_normalized_distances_sorted):
×
209
                if np.isclose(
×
210
                    dist,
211
                    maxdist,
212
                    rtol=0.0,
213
                    atol=self.detailed_voronoi.normalized_distance_tolerance,
214
                ):
215
                    if idist == 0:
×
216
                        plateau = np.inf
×
217
                    else:
218
                        plateau = all_nbs_normalized_distances_sorted[idist - 1] - maxdist
×
219
                    break
×
220
            if plateau is None:
×
221
                raise ValueError("Plateau not found ...")
×
222
            return plateau
×
223

224
        def angle_plateau(self):
1✔
225
            """
226
            Returns the angles plateau's for this NeighborsSet.
227
            """
228
            all_nbs_normalized_angles_sorted = sorted(nb["normalized_angle"] for nb in self.voronoi)
×
229
            minang = np.min(self.normalized_angles)
×
230
            for nb in self.voronoi:
×
231
                print(nb)
×
232
            plateau = None
×
233
            for iang, ang in enumerate(all_nbs_normalized_angles_sorted):
×
234
                if np.isclose(
×
235
                    ang,
236
                    minang,
237
                    rtol=0.0,
238
                    atol=self.detailed_voronoi.normalized_angle_tolerance,
239
                ):
240
                    if iang == 0:
×
241
                        plateau = minang
×
242
                    else:
243
                        plateau = minang - all_nbs_normalized_angles_sorted[iang - 1]
×
244
                    break
×
245
            if plateau is None:
×
246
                raise ValueError("Plateau not found ...")
×
247
            return plateau
×
248

249
        def voronoi_grid_surface_points(self, additional_condition=1, other_origins="DO_NOTHING"):
1✔
250
            """
251
            Get the surface points in the Voronoi grid for this neighbor from the sources.
252
            The general shape of the points should look like a staircase such as in the following figure :
253

254
               ^
255
            0.0|
256
               |
257
               |      B----C
258
               |      |    |
259
               |      |    |
260
            a  |      k    D-------E
261
            n  |      |            |
262
            g  |      |            |
263
            l  |      |            |
264
            e  |      j            F----n---------G
265
               |      |                           |
266
               |      |                           |
267
               |      A----g-------h----i---------H
268
               |
269
               |
270
            1.0+------------------------------------------------->
271
              1.0              distance              2.0   ->+Inf
272

273
            Args:
274
                additional_condition: Additional condition for the neighbors.
275
                other_origins: What to do with sources that do not come from the Voronoi grid (e.g. "from hints").
276
            """
277
            mysrc = []
1✔
278
            for src in self.sources:
1✔
279
                if src["origin"] == "dist_ang_ac_voronoi":
1✔
280
                    if src["ac"] != additional_condition:
1✔
281
                        continue
1✔
282
                    mysrc.append(src)
1✔
283
                else:
284
                    if other_origins == "DO_NOTHING":
×
285
                        continue
×
286
                    raise NotImplementedError("Nothing implemented for other sources ...")
×
287
            if len(mysrc) == 0:
1✔
288
                return None
×
289

290
            dists = [src["dp_dict"]["min"] for src in mysrc]
1✔
291
            angs = [src["ap_dict"]["max"] for src in mysrc]
1✔
292
            next_dists = [src["dp_dict"]["next"] for src in mysrc]
1✔
293
            next_angs = [src["ap_dict"]["next"] for src in mysrc]
1✔
294

295
            points_dict = {}
1✔
296

297
            pdists = []
1✔
298
            pangs = []
1✔
299

300
            for isrc in range(len(mysrc)):
1✔
301
                if not any(np.isclose(pdists, dists[isrc])):
1✔
302
                    pdists.append(dists[isrc])
1✔
303
                if not any(np.isclose(pdists, next_dists[isrc])):
1✔
304
                    pdists.append(next_dists[isrc])
1✔
305
                if not any(np.isclose(pangs, angs[isrc])):
1✔
306
                    pangs.append(angs[isrc])
1✔
307
                if not any(np.isclose(pangs, next_angs[isrc])):
1✔
308
                    pangs.append(next_angs[isrc])
1✔
309
                d1_indices = np.argwhere(np.isclose(pdists, dists[isrc])).flatten()
1✔
310
                if len(d1_indices) != 1:
1✔
311
                    raise ValueError("Distance parameter not found ...")
×
312
                d2_indices = np.argwhere(np.isclose(pdists, next_dists[isrc])).flatten()
1✔
313
                if len(d2_indices) != 1:
1✔
314
                    raise ValueError("Distance parameter not found ...")
×
315
                a1_indices = np.argwhere(np.isclose(pangs, angs[isrc])).flatten()
1✔
316
                if len(a1_indices) != 1:
1✔
317
                    raise ValueError("Angle parameter not found ...")
×
318
                a2_indices = np.argwhere(np.isclose(pangs, next_angs[isrc])).flatten()
1✔
319
                if len(a2_indices) != 1:
1✔
320
                    raise ValueError("Angle parameter not found ...")
×
321
                id1 = d1_indices[0]
1✔
322
                id2 = d2_indices[0]
1✔
323
                ia1 = a1_indices[0]
1✔
324
                ia2 = a2_indices[0]
1✔
325
                for id_ia in [(id1, ia1), (id1, ia2), (id2, ia1), (id2, ia2)]:
1✔
326
                    if id_ia not in points_dict:
1✔
327
                        points_dict[id_ia] = 0
1✔
328
                    points_dict[id_ia] += 1
1✔
329

330
            new_pts = []
1✔
331
            for pt, pt_nb in points_dict.items():
1✔
332
                if pt_nb % 2 == 1:
1✔
333
                    new_pts.append(pt)
1✔
334

335
            sorted_points = [(0, 0)]
1✔
336
            move_ap_index = True
1✔
337
            while True:
338
                last_pt = sorted_points[-1]
1✔
339
                if move_ap_index:  # "Move" the angle parameter
1✔
340
                    idp = last_pt[0]
1✔
341
                    iap = None
1✔
342
                    for pt in new_pts:
1✔
343
                        if pt[0] == idp and pt != last_pt:
1✔
344
                            iap = pt[1]
1✔
345
                            break
1✔
346
                else:  # "Move" the distance parameter
347
                    idp = None
1✔
348
                    iap = last_pt[1]
1✔
349
                    for pt in new_pts:
1✔
350
                        if pt[1] == iap and pt != last_pt:
1✔
351
                            idp = pt[0]
1✔
352
                            break
1✔
353
                if (idp, iap) == (0, 0):
1✔
354
                    break
1✔
355
                if (idp, iap) in sorted_points:
1✔
356
                    raise ValueError("Error sorting points ...")
×
357
                sorted_points.append((idp, iap))
1✔
358
                move_ap_index = not move_ap_index
1✔
359

360
            points = [(pdists[idp], pangs[iap]) for (idp, iap) in sorted_points]
1✔
361
            return points
1✔
362

363
        @property
1✔
364
        def source(self):
1✔
365
            """
366
            Returns the source of this NeighborsSet (how it was generated, e.g. from which Voronoi cut-offs, or from
367
            hints).
368
            """
369
            if len(self.sources) != 1:
1✔
370
                raise RuntimeError("Number of sources different from 1 !")
×
371
            return self.sources[0]
1✔
372

373
        def add_source(self, source):
1✔
374
            """
375
            Add a source to this NeighborsSet.
376

377
            Args:
378
                source: Information about the generation of this NeighborsSet.
379
            """
380
            if source not in self.sources:
1✔
381
                self.sources.append(source)
1✔
382

383
        def __len__(self):
1✔
384
            return len(self.site_voronoi_indices)
1✔
385

386
        def __hash__(self):
1✔
387
            return len(self.site_voronoi_indices)
1✔
388

389
        def __eq__(self, other: object) -> bool:
1✔
390
            needed_attrs = ("isite", "site_voronoi_indices")
1✔
391
            if not all(hasattr(other, attr) for attr in needed_attrs):
1✔
392
                return NotImplemented
×
393
            return all(getattr(self, attr) == getattr(other, attr) for attr in needed_attrs)
1✔
394

395
        def __str__(self):
1✔
396
            out = f"Neighbors Set for site #{self.isite:d} :\n"
1✔
397
            out += f" - Coordination number : {len(self):d}\n"
1✔
398
            voro_indices = ", ".join(f"{site_voronoi_index:d}" for site_voronoi_index in self.site_voronoi_indices)
1✔
399
            out += f" - Voronoi indices : {voro_indices}\n"
1✔
400
            return out
1✔
401

402
        def as_dict(self):
1✔
403
            """
404
            A JSON-serializable dict representation of the NeighborsSet.
405
            """
406
            return {
1✔
407
                "isite": self.isite,
408
                "site_voronoi_indices": self.site_voronoi_indices,
409
                "sources": self.sources,
410
            }
411

412
        @classmethod
1✔
413
        def from_dict(cls, dd, structure: Structure, detailed_voronoi):
1✔
414
            """
415
            Reconstructs the NeighborsSet algorithm from its JSON-serializable dict representation, together with
416
            the structure and the DetailedVoronoiContainer.
417

418
            As an inner (nested) class, the NeighborsSet is not supposed to be used anywhere else that inside the
419
            StructureEnvironments. The from_dict method is thus using the structure and  detailed_voronoi when
420
            reconstructing itself. These two are both in the StructureEnvironments object.
421

422
            Args:
423
                dd: a JSON-serializable dict representation of a NeighborsSet.
424
                structure: The structure.
425
                detailed_voronoi: The Voronoi object containing all the neighboring atoms from which the subset of
426
                    neighbors for this NeighborsSet is extracted.
427

428
            Returns: a NeighborsSet.
429
            """
430
            return cls(
1✔
431
                structure=structure,
432
                isite=dd["isite"],
433
                detailed_voronoi=detailed_voronoi,
434
                site_voronoi_indices=dd["site_voronoi_indices"],
435
                sources=dd["sources"],
436
            )
437

438
    def __init__(
1✔
439
        self,
440
        voronoi,
441
        valences,
442
        sites_map,
443
        equivalent_sites,
444
        ce_list,
445
        structure,
446
        neighbors_sets=None,
447
        info=None,
448
    ):
449
        """
450
        Constructor for the StructureEnvironments object.
451

452
        Args:
453
            voronoi: VoronoiContainer object for the structure.
454
            valences: Valences provided.
455
            sites_map: Mapping of equivalent sites to the unequivalent sites that have been computed.
456
            equivalent_sites: List of list of equivalent sites of the structure.
457
            ce_list: List of chemical environments.
458
            structure: Structure object.
459
            neighbors_sets: List of neighbors sets.
460
            info: Additional information for this StructureEnvironments object.
461
        """
462
        self.voronoi = voronoi
1✔
463
        self.valences = valences
1✔
464
        self.sites_map = sites_map
1✔
465
        self.equivalent_sites = equivalent_sites
1✔
466
        # self.struct_sites_to_irreducible_site_list_map = struct_sites_to_irreducible_site_list_map
467
        self.ce_list = ce_list
1✔
468
        self.structure = structure
1✔
469
        if neighbors_sets is None:
1✔
470
            self.neighbors_sets = [None] * len(self.structure)
1✔
471
        else:
472
            self.neighbors_sets = neighbors_sets
1✔
473
        self.info = info
1✔
474

475
    def init_neighbors_sets(self, isite, additional_conditions=None, valences=None):
1✔
476
        """
477
        Initialize the list of neighbors sets for the current site.
478

479
        Args:
480
            isite: Index of the site under consideration.
481
            additional_conditions: Additional conditions to be used for the initialization of the list of
482
                neighbors sets, e.g. "Only anion-cation bonds", ...
483
            valences: List of valences for each site in the structure (needed if an additional condition based on the
484
                valence is used, e.g. only anion-cation bonds).
485
        """
486
        site_voronoi = self.voronoi.voronoi_list2[isite]
1✔
487
        if site_voronoi is None:
1✔
488
            return
×
489
        if additional_conditions is None:
1✔
490
            additional_conditions = self.AC.ALL
1✔
491
        if (self.AC.ONLY_ACB in additional_conditions or self.AC.ONLY_ACB_AND_NO_E2SEB) and valences is None:
1✔
492
            raise ChemenvError(
×
493
                "StructureEnvironments",
494
                "init_neighbors_sets",
495
                "Valences are not given while only_anion_cation_bonds are allowed. Cannot continue",
496
            )
497
        site_distance_parameters = self.voronoi.neighbors_normalized_distances[isite]
1✔
498
        site_angle_parameters = self.voronoi.neighbors_normalized_angles[isite]
1✔
499
        # Precompute distance conditions
500
        distance_conditions = []
1✔
501
        for idp, dp_dict in enumerate(site_distance_parameters):
1✔
502
            distance_conditions.append([])
1✔
503
            for inb, _ in enumerate(site_voronoi):
1✔
504
                cond = inb in dp_dict["nb_indices"]
1✔
505
                distance_conditions[idp].append(cond)
1✔
506
        # Precompute angle conditions
507
        angle_conditions = []
1✔
508
        for iap, ap_dict in enumerate(site_angle_parameters):
1✔
509
            angle_conditions.append([])
1✔
510
            for inb, _ in enumerate(site_voronoi):
1✔
511
                cond = inb in ap_dict["nb_indices"]
1✔
512
                angle_conditions[iap].append(cond)
1✔
513
        # Precompute additional conditions
514
        precomputed_additional_conditions = {ac: [] for ac in additional_conditions}
1✔
515
        for voro_nb_dict in site_voronoi:
1✔
516
            for ac in additional_conditions:
1✔
517
                cond = self.AC.check_condition(
1✔
518
                    condition=ac,
519
                    structure=self.structure,
520
                    parameters={
521
                        "valences": valences,
522
                        "neighbor_index": voro_nb_dict["index"],
523
                        "site_index": isite,
524
                    },
525
                )
526
                precomputed_additional_conditions[ac].append(cond)
1✔
527
        # Add the neighbors sets based on the distance/angle/additional parameters
528
        for idp, dp_dict in enumerate(site_distance_parameters):
1✔
529
            for iap, ap_dict in enumerate(site_angle_parameters):
1✔
530
                for iac, ac in enumerate(additional_conditions):
1✔
531
                    src = {
1✔
532
                        "origin": "dist_ang_ac_voronoi",
533
                        "idp": idp,
534
                        "iap": iap,
535
                        "dp_dict": dp_dict,
536
                        "ap_dict": ap_dict,
537
                        "iac": iac,
538
                        "ac": ac,
539
                        "ac_name": self.AC.CONDITION_DESCRIPTION[ac],
540
                    }
541
                    site_voronoi_indices = [
1✔
542
                        inb
543
                        for inb, voro_nb_dict in enumerate(site_voronoi)
544
                        if (
545
                            distance_conditions[idp][inb]
546
                            and angle_conditions[iap][inb]
547
                            and precomputed_additional_conditions[ac][inb]
548
                        )
549
                    ]
550
                    nb_set = self.NeighborsSet(
1✔
551
                        structure=self.structure,
552
                        isite=isite,
553
                        detailed_voronoi=self.voronoi,
554
                        site_voronoi_indices=site_voronoi_indices,
555
                        sources=src,
556
                    )
557
                    self.add_neighbors_set(isite=isite, nb_set=nb_set)
1✔
558

559
    def add_neighbors_set(self, isite, nb_set):
1✔
560
        """
561
        Adds a neighbor set to the list of neighbors sets for this site.
562

563
        Args:
564
            isite: Index of the site under consideration.
565
            nb_set: NeighborsSet to be added.
566
        """
567
        if self.neighbors_sets[isite] is None:
1✔
568
            self.neighbors_sets[isite] = {}
1✔
569
            self.ce_list[isite] = {}
1✔
570
        cn = len(nb_set)
1✔
571
        if cn not in self.neighbors_sets[isite]:
1✔
572
            self.neighbors_sets[isite][cn] = []
1✔
573
            self.ce_list[isite][cn] = []
1✔
574
        try:
1✔
575
            nb_set_index = self.neighbors_sets[isite][cn].index(nb_set)
1✔
576
            self.neighbors_sets[isite][cn][nb_set_index].add_source(nb_set.source)
1✔
577
        except ValueError:
1✔
578
            self.neighbors_sets[isite][cn].append(nb_set)
1✔
579
            self.ce_list[isite][cn].append(None)
1✔
580

581
    def update_coordination_environments(self, isite, cn, nb_set, ce):
1✔
582
        """
583
        Updates the coordination environment for this site, coordination and neighbor set.
584

585
        Args:
586
            isite: Index of the site to be updated.
587
            cn: Coordination to be updated.
588
            nb_set: Neighbors set to be updated.
589
            ce: ChemicalEnvironments object for this neighbors set.
590
        """
591
        if self.ce_list[isite] is None:
1✔
592
            self.ce_list[isite] = {}
×
593
        if cn not in self.ce_list[isite]:
1✔
594
            self.ce_list[isite][cn] = []
×
595
        try:
1✔
596
            nb_set_index = self.neighbors_sets[isite][cn].index(nb_set)
1✔
597
        except ValueError:
×
598
            raise ValueError("Neighbors set not found in the structure environments")
×
599
        if nb_set_index == len(self.ce_list[isite][cn]):
1✔
600
            self.ce_list[isite][cn].append(ce)
×
601
        elif nb_set_index < len(self.ce_list[isite][cn]):
1✔
602
            self.ce_list[isite][cn][nb_set_index] = ce
1✔
603
        else:
604
            raise ValueError("Neighbors set not yet in ce_list !")
×
605

606
    def update_site_info(self, isite, info_dict):
1✔
607
        """
608
        Update information about this site.
609

610
        Args:
611
            isite: Index of the site for which info has to be updated.
612
            info_dict: Dictionary of information to be added for this site.
613
        """
614
        if "sites_info" not in self.info:
1✔
615
            self.info["sites_info"] = [{} for _ in range(len(self.structure))]
1✔
616
        self.info["sites_info"][isite].update(info_dict)
1✔
617

618
    def get_coordination_environments(self, isite, cn, nb_set):
1✔
619
        """
620
        Get the ChemicalEnvironments for a given site, coordination and neighbors set.
621

622
        Args:
623
            isite: Index of the site for which the ChemicalEnvironments is looked for.
624
            cn: Coordination for which the ChemicalEnvironments is looked for.
625
            nb_set: Neighbors set for which the ChemicalEnvironments is looked for.
626

627
        Returns: a ChemicalEnvironments object.
628
        """
629
        if self.ce_list[isite] is None:
1✔
630
            return None
×
631
        if cn not in self.ce_list[isite]:
1✔
632
            return None
×
633
        try:
1✔
634
            nb_set_index = self.neighbors_sets[isite][cn].index(nb_set)
1✔
635
        except ValueError:
×
636
            return None
×
637
        return self.ce_list[isite][cn][nb_set_index]
1✔
638

639
    def get_csm(self, isite, mp_symbol):
1✔
640
        """
641
        Get the continuous symmetry measure for a given site in the given coordination environment.
642

643
        Args:
644
            isite: Index of the site.
645
            mp_symbol: Symbol of the coordination environment for which we want the continuous symmetry measure.
646

647
        Returns: Continuous symmetry measure of the given site in the given environment.
648
        """
649
        csms = self.get_csms(isite, mp_symbol)
1✔
650
        if len(csms) != 1:
1✔
651
            raise ChemenvError(
×
652
                "StructureEnvironments",
653
                "get_csm",
654
                f"Number of csms for site #{str(isite)} with mp_symbol {mp_symbol!r} = {str(len(csms))}",
655
            )
656
        return csms[0]
1✔
657

658
    def get_csms(self, isite, mp_symbol):
1✔
659
        """
660
        Returns the continuous symmetry measure(s) of site with index isite with respect to the
661
         perfect coordination environment with mp_symbol. For some environments, a given mp_symbol might not
662
         be available (if there is no voronoi parameters leading to a number of neighbors corresponding to
663
         the coordination number of environment mp_symbol). For some environments, a given mp_symbol might
664
         lead to more than one csm (when two or more different voronoi parameters lead to different neighbors
665
         but with same number of neighbors).
666

667
        Args:
668
            isite: Index of the site.
669
            mp_symbol: MP symbol of the perfect environment for which the csm has to be given.
670

671
        Returns:
672
            List of csms for site isite with respect to geometry mp_symbol
673
        """
674
        cn = symbol_cn_mapping[mp_symbol]
1✔
675
        if cn not in self.ce_list[isite]:
1✔
676
            return []
×
677
        return [envs[mp_symbol] for envs in self.ce_list[isite][cn]]
1✔
678

679
    def plot_csm_and_maps(self, isite, max_csm=8.0):
1✔
680
        """
681
        Plotting of the coordination numbers of a given site for all the distfactor/angfactor parameters. If the
682
        chemical environments are given, a color map is added to the plot, with the lowest continuous symmetry measure
683
        as the value for the color of that distfactor/angfactor set.
684

685
        Args:
686
            isite: Index of the site for which the plot has to be done
687
            max_csm: Maximum continuous symmetry measure to be shown.
688
        """
689
        try:
×
690
            import matplotlib.pyplot as plt
×
691
        except ImportError:
×
692
            print('Plotting Chemical Environments requires matplotlib ... exiting "plot" function')
×
693
            return None
×
694
        fig = self.get_csm_and_maps(isite=isite, max_csm=max_csm)
×
695
        if fig is None:
×
696
            return None
×
697
        plt.show()
×
698
        return None
×
699

700
    def get_csm_and_maps(self, isite, max_csm=8.0, figsize=None, symmetry_measure_type=None):
1✔
701
        """
702
        Plotting of the coordination numbers of a given site for all the distfactor/angfactor parameters. If the
703
        chemical environments are given, a color map is added to the plot, with the lowest continuous symmetry measure
704
        as the value for the color of that distfactor/angfactor set.
705

706
        Args:
707
            isite: Index of the site for which the plot has to be done.
708
            max_csm: Maximum continuous symmetry measure to be shown.
709
            figsize: Size of the figure.
710
            symmetry_measure_type: Type of continuous symmetry measure to be used.
711

712
        Returns:
713
            Matplotlib figure and axes representing the csm and maps.
714
        """
715
        try:
1✔
716
            import matplotlib.pyplot as plt
1✔
717
            from matplotlib.gridspec import GridSpec
1✔
718
        except ImportError:
×
719
            print('Plotting Chemical Environments requires matplotlib ... exiting "plot" function')
×
720
            return None
×
721

722
        if symmetry_measure_type is None:
1✔
723
            symmetry_measure_type = "csm_wcs_ctwcc"
1✔
724
        # Initializes the figure
725
        if figsize is None:
1✔
726
            fig = plt.figure()
1✔
727
        else:
728
            fig = plt.figure(figsize=figsize)
×
729
        gs = GridSpec(2, 1, hspace=0.0, wspace=0.0)
1✔
730
        subplot = fig.add_subplot(gs[:])
1✔
731
        subplot_distang = subplot.twinx()
1✔
732

733
        ix = 0
1✔
734
        cn_maps = []
1✔
735
        all_wds = []
1✔
736
        all_was = []
1✔
737
        max_wd = 0.0
1✔
738
        for cn, nb_sets in self.neighbors_sets[isite].items():
1✔
739
            for inb_set, nb_set in enumerate(nb_sets):
1✔
740
                ce = self.ce_list[isite][cn][inb_set]
1✔
741
                if ce is None:
1✔
742
                    continue
1✔
743
                mingeoms = ce.minimum_geometries(max_csm=max_csm)
1✔
744
                if len(mingeoms) == 0:
1✔
745
                    continue
1✔
746
                wds = nb_set.normalized_distances
1✔
747
                max_wd = max(max_wd, max(wds))
1✔
748
                all_wds.append(wds)
1✔
749
                all_was.append(nb_set.normalized_angles)
1✔
750
                for mp_symbol, cg_dict in mingeoms:
1✔
751
                    csm = cg_dict["other_symmetry_measures"][symmetry_measure_type]
1✔
752
                    subplot.plot(ix, csm, "ob")
1✔
753
                    subplot.annotate(mp_symbol, xy=(ix, csm))
1✔
754
                cn_maps.append((cn, inb_set))
1✔
755
                ix += 1
1✔
756

757
        if max_wd < 1.225:
1✔
758
            ymax_wd = 1.25
1✔
759
            yticks_wd = np.linspace(1.0, ymax_wd, 6)
1✔
760
        elif max_wd < 1.36:
×
761
            ymax_wd = 1.4
×
762
            yticks_wd = np.linspace(1.0, ymax_wd, 5)
×
763
        elif max_wd < 1.45:
×
764
            ymax_wd = 1.5
×
765
            yticks_wd = np.linspace(1.0, ymax_wd, 6)
×
766
        elif max_wd < 1.55:
×
767
            ymax_wd = 1.6
×
768
            yticks_wd = np.linspace(1.0, ymax_wd, 7)
×
769
        elif max_wd < 1.75:
×
770
            ymax_wd = 1.8
×
771
            yticks_wd = np.linspace(1.0, ymax_wd, 5)
×
772
        elif max_wd < 1.95:
×
773
            ymax_wd = 2.0
×
774
            yticks_wd = np.linspace(1.0, ymax_wd, 6)
×
775
        elif max_wd < 2.35:
×
776
            ymax_wd = 2.5
×
777
            yticks_wd = np.linspace(1.0, ymax_wd, 7)
×
778
        else:
779
            ymax_wd = np.ceil(1.1 * max_wd)
×
780
            yticks_wd = np.linspace(1.0, ymax_wd, 6)
×
781

782
        yticks_wa = np.linspace(0.0, 1.0, 6)
1✔
783

784
        frac_bottom = 0.05
1✔
785
        frac_top = 0.05
1✔
786
        frac_middle = 0.1
1✔
787
        yamin = frac_bottom
1✔
788
        yamax = 0.5 - frac_middle / 2
1✔
789
        ydmin = 0.5 + frac_middle / 2
1✔
790
        ydmax = 1.0 - frac_top
1✔
791

792
        def yang(wa):
1✔
793
            return (yamax - yamin) * np.array(wa) + yamin
1✔
794

795
        def ydist(wd):
1✔
796
            return (np.array(wd) - 1.0) / (ymax_wd - 1.0) * (ydmax - ydmin) + ydmin
1✔
797

798
        for ix, was in enumerate(all_was):
1✔
799
            subplot_distang.plot(0.2 + ix * np.ones_like(was), yang(was), "<g")
1✔
800
            if np.mod(ix, 2) == 0:
1✔
801
                alpha = 0.3
1✔
802
            else:
803
                alpha = 0.1
1✔
804
            subplot_distang.fill_between(
1✔
805
                [-0.5 + ix, 0.5 + ix],
806
                [1.0, 1.0],
807
                0.0,
808
                facecolor="k",
809
                alpha=alpha,
810
                zorder=-1000,
811
            )
812
        for ix, wds in enumerate(all_wds):
1✔
813
            subplot_distang.plot(0.2 + ix * np.ones_like(wds), ydist(wds), "sm")
1✔
814

815
        subplot_distang.plot([-0.5, len(cn_maps)], [0.5, 0.5], "k--", alpha=0.5)
1✔
816

817
        yticks = yang(yticks_wa).tolist()
1✔
818
        yticks.extend(ydist(yticks_wd).tolist())
1✔
819
        yticklabels = yticks_wa.tolist()
1✔
820
        yticklabels.extend(yticks_wd.tolist())
1✔
821
        subplot_distang.set_yticks(yticks)
1✔
822
        subplot_distang.set_yticklabels(yticklabels)
1✔
823

824
        fake_subplot_ang = fig.add_subplot(gs[1], frame_on=False)
1✔
825
        fake_subplot_dist = fig.add_subplot(gs[0], frame_on=False)
1✔
826
        fake_subplot_ang.set_yticks([])
1✔
827
        fake_subplot_dist.set_yticks([])
1✔
828
        fake_subplot_ang.set_xticks([])
1✔
829
        fake_subplot_dist.set_xticks([])
1✔
830
        fake_subplot_ang.set_ylabel("Angle parameter", labelpad=45, rotation=-90)
1✔
831
        fake_subplot_dist.set_ylabel("Distance parameter", labelpad=45, rotation=-90)
1✔
832
        fake_subplot_ang.yaxis.set_label_position("right")
1✔
833
        fake_subplot_dist.yaxis.set_label_position("right")
1✔
834

835
        subplot_distang.set_ylim([0.0, 1.0])
1✔
836
        subplot.set_xticks(range(len(cn_maps)))
1✔
837
        subplot.set_ylabel("Continuous symmetry measure")
1✔
838
        subplot.set_xlim([-0.5, len(cn_maps) - 0.5])
1✔
839
        subplot_distang.set_xlim([-0.5, len(cn_maps) - 0.5])
1✔
840
        subplot.set_xticklabels([str(cn_map) for cn_map in cn_maps])
1✔
841

842
        return fig, subplot
1✔
843

844
    def get_environments_figure(
1✔
845
        self,
846
        isite,
847
        plot_type=None,
848
        title="Coordination numbers",
849
        max_dist=2.0,
850
        colormap=None,
851
        figsize=None,
852
        strategy=None,
853
    ):
854
        """
855
        Plotting of the coordination environments of a given site for all the distfactor/angfactor regions. The
856
        chemical environments with the lowest continuous symmetry measure is shown for each distfactor/angfactor
857
        region as the value for the color of that distfactor/angfactor region (using a colormap).
858

859
        Args:
860
            isite: Index of the site for which the plot has to be done.
861
            plot_type: How to plot the coordinations.
862
            title: Title for the figure.
863
            max_dist: Maximum distance to be plotted when the plotting of the distance is set to 'initial_normalized'
864
                or 'initial_real' (Warning: this is not the same meaning in both cases! In the first case, the
865
                closest atom lies at a "normalized" distance of 1.0 so that 2.0 means refers to this normalized
866
                distance while in the second case, the real distance is used).
867
            colormap: Color map to be used for the continuous symmetry measure.
868
            figsize: Size of the figure.
869
            strategy: Whether to plot information about one of the Chemenv Strategies.
870

871
        Returns:
872
            Matplotlib figure and axes representing the environments.
873
        """
874
        try:
1✔
875
            import matplotlib.pyplot as mpl
1✔
876
            from matplotlib import cm
1✔
877
            from matplotlib.colors import Normalize
1✔
878
            from matplotlib.patches import Polygon
1✔
879
        except ImportError:
×
880
            print('Plotting Chemical Environments requires matplotlib ... exiting "plot" function')
×
881
            return None
×
882

883
        # Initializes the figure
884
        if figsize is None:
1✔
885
            fig = mpl.figure()
1✔
886
        else:
887
            fig = mpl.figure(figsize=figsize)
×
888
        subplot = fig.add_subplot(111)
1✔
889

890
        # Initializes the distance and angle parameters
891
        if plot_type is None:
1✔
892
            plot_type = {
1✔
893
                "distance_parameter": ("initial_normalized", None),
894
                "angle_parameter": ("initial_normalized_inverted", None),
895
            }
896
        if colormap is None:
1✔
897
            mycm = cm.jet  # pylint: disable=E1101
1✔
898
        else:
899
            mycm = colormap
×
900
        mymin = 0.0
1✔
901
        mymax = 10.0
1✔
902
        norm = Normalize(vmin=mymin, vmax=mymax)
1✔
903
        scalarmap = cm.ScalarMappable(norm=norm, cmap=mycm)
1✔
904
        dist_limits = [1.0, max_dist]
1✔
905
        ang_limits = [0.0, 1.0]
1✔
906
        if plot_type["distance_parameter"][0] == "one_minus_inverse_alpha_power_n":
1✔
907
            if plot_type["distance_parameter"][1] is None:
×
908
                exponent = 3
×
909
            else:
910
                exponent = plot_type["distance_parameter"][1]["exponent"]
×
911
            xlabel = f"Distance parameter : $1.0-\\frac{{1.0}}{{\\alpha^{{{exponent:d}}}}}$"
×
912

913
            def dp_func(dp):
×
914
                return 1.0 - 1.0 / np.power(dp, exponent)
×
915

916
        elif plot_type["distance_parameter"][0] == "initial_normalized":
1✔
917
            xlabel = "Distance parameter : $\\alpha$"
1✔
918

919
            def dp_func(dp):
1✔
920
                return dp
1✔
921

922
        else:
923
            raise ValueError(f"Wrong value for distance parameter plot type \"{plot_type['distance_parameter'][0]}\"")
×
924

925
        if plot_type["angle_parameter"][0] == "one_minus_gamma":
1✔
926
            ylabel = "Angle parameter : $1.0-\\gamma$"
×
927

928
            def ap_func(ap):
×
929
                return 1.0 - ap
×
930

931
        elif plot_type["angle_parameter"][0] in [
1✔
932
            "initial_normalized_inverted",
933
            "initial_normalized",
934
        ]:
935
            ylabel = "Angle parameter : $\\gamma$"
1✔
936

937
            def ap_func(ap):
1✔
938
                return ap
1✔
939

940
        else:
941
            raise ValueError(f"Wrong value for angle parameter plot type \"{plot_type['angle_parameter'][0]}\"")
×
942
        dist_limits = [dp_func(dp) for dp in dist_limits]
1✔
943
        ang_limits = [ap_func(ap) for ap in ang_limits]
1✔
944

945
        for cn, cn_nb_sets in self.neighbors_sets[isite].items():
1✔
946
            for inb_set, nb_set in enumerate(cn_nb_sets):
1✔
947
                nb_set_surface_pts = nb_set.voronoi_grid_surface_points()
1✔
948
                if nb_set_surface_pts is None:
1✔
949
                    continue
×
950
                ce = self.ce_list[isite][cn][inb_set]
1✔
951
                if ce is None:
1✔
952
                    mycolor = "w"
1✔
953
                    myinvcolor = "k"
1✔
954
                    mytext = f"{cn:d}"
1✔
955
                else:
956
                    mingeom = ce.minimum_geometry()
1✔
957
                    if mingeom is not None:
1✔
958
                        mp_symbol = mingeom[0]
1✔
959
                        csm = mingeom[1]["symmetry_measure"]
1✔
960
                        mycolor = scalarmap.to_rgba(csm)
1✔
961
                        myinvcolor = [
1✔
962
                            1.0 - mycolor[0],
963
                            1.0 - mycolor[1],
964
                            1.0 - mycolor[2],
965
                            1.0,
966
                        ]
967
                        mytext = f"{mp_symbol}"
1✔
968
                    else:
969
                        mycolor = "w"
×
970
                        myinvcolor = "k"
×
971
                        mytext = f"{cn:d}"
×
972
                nb_set_surface_pts = [(dp_func(pt[0]), ap_func(pt[1])) for pt in nb_set_surface_pts]
1✔
973
                polygon = Polygon(
1✔
974
                    nb_set_surface_pts,
975
                    closed=True,
976
                    edgecolor="k",
977
                    facecolor=mycolor,
978
                    linewidth=1.2,
979
                )
980
                subplot.add_patch(polygon)
1✔
981
                myipt = len(nb_set_surface_pts) / 2
1✔
982
                ipt = int(myipt)
1✔
983
                if myipt != ipt:
1✔
984
                    raise RuntimeError("Number of surface points not even")
×
985
                patch_center = (
1✔
986
                    (nb_set_surface_pts[0][0] + min(nb_set_surface_pts[ipt][0], dist_limits[1])) / 2,
987
                    (nb_set_surface_pts[0][1] + nb_set_surface_pts[ipt][1]) / 2,
988
                )
989

990
                if (
1✔
991
                    np.abs(nb_set_surface_pts[-1][1] - nb_set_surface_pts[-2][1]) > 0.06
992
                    and np.abs(min(nb_set_surface_pts[-1][0], dist_limits[1]) - nb_set_surface_pts[0][0]) > 0.125
993
                ):
994
                    xytext = (
1✔
995
                        (min(nb_set_surface_pts[-1][0], dist_limits[1]) + nb_set_surface_pts[0][0]) / 2,
996
                        (nb_set_surface_pts[-1][1] + nb_set_surface_pts[-2][1]) / 2,
997
                    )
998
                    subplot.annotate(
1✔
999
                        mytext,
1000
                        xy=xytext,
1001
                        ha="center",
1002
                        va="center",
1003
                        color=myinvcolor,
1004
                        fontsize="x-small",
1005
                    )
1006
                elif (
1✔
1007
                    np.abs(nb_set_surface_pts[ipt][1] - nb_set_surface_pts[0][1]) > 0.1
1008
                    and np.abs(min(nb_set_surface_pts[ipt][0], dist_limits[1]) - nb_set_surface_pts[0][0]) > 0.125
1009
                ):
1010
                    xytext = patch_center
×
1011
                    subplot.annotate(
×
1012
                        mytext,
1013
                        xy=xytext,
1014
                        ha="center",
1015
                        va="center",
1016
                        color=myinvcolor,
1017
                        fontsize="x-small",
1018
                    )
1019

1020
        subplot.set_title(title)
1✔
1021
        subplot.set_xlabel(xlabel)
1✔
1022
        subplot.set_ylabel(ylabel)
1✔
1023

1024
        dist_limits.sort()
1✔
1025
        ang_limits.sort()
1✔
1026
        subplot.set_xlim(dist_limits)
1✔
1027
        subplot.set_ylim(ang_limits)
1✔
1028
        if strategy is not None:
1✔
1029
            try:
×
1030
                strategy.add_strategy_visualization_to_subplot(subplot=subplot)
×
1031
            except Exception:
×
1032
                pass
×
1033
        if plot_type["angle_parameter"][0] == "initial_normalized_inverted":
1✔
1034
            subplot.axes.invert_yaxis()
1✔
1035

1036
        scalarmap.set_array([mymin, mymax])
1✔
1037
        cb = fig.colorbar(scalarmap, ax=subplot, extend="max")
1✔
1038
        cb.set_label("Continuous symmetry measure")
1✔
1039
        return fig, subplot
1✔
1040

1041
    def plot_environments(
1✔
1042
        self,
1043
        isite,
1044
        plot_type=None,
1045
        title="Coordination numbers",
1046
        max_dist=2.0,
1047
        figsize=None,
1048
        strategy=None,
1049
    ):
1050
        """
1051
        Plotting of the coordination numbers of a given site for all the distfactor/angfactor parameters. If the
1052
        chemical environments are given, a color map is added to the plot, with the lowest continuous symmetry measure
1053
        as the value for the color of that distfactor/angfactor set.
1054

1055
        Args:
1056
            isite: Index of the site for which the plot has to be done.
1057
            plot_type: How to plot the coordinations.
1058
            title: Title for the figure.
1059
            max_dist: Maximum distance to be plotted when the plotting of the distance is set to 'initial_normalized'
1060
                or 'initial_real' (Warning: this is not the same meaning in both cases! In the first case, the
1061
                closest atom lies at a "normalized" distance of 1.0 so that 2.0 means refers to this normalized
1062
                distance while in the second case, the real distance is used).
1063
            figsize: Size of the figure.
1064
            strategy: Whether to plot information about one of the Chemenv Strategies.
1065
        """
1066
        fig, subplot = self.get_environments_figure(
×
1067
            isite=isite,
1068
            plot_type=plot_type,
1069
            title=title,
1070
            max_dist=max_dist,
1071
            figsize=figsize,
1072
            strategy=strategy,
1073
        )
1074
        if fig is None:
×
1075
            return
×
1076
        fig.show()
×
1077

1078
    def save_environments_figure(
1✔
1079
        self,
1080
        isite,
1081
        imagename="image.png",
1082
        plot_type=None,
1083
        title="Coordination numbers",
1084
        max_dist=2.0,
1085
        figsize=None,
1086
    ):
1087
        """
1088
        Saves the environments figure to a given file.
1089

1090
        Args:
1091
            isite: Index of the site for which the plot has to be done.
1092
            imagename: Name of the file to which the figure has to be saved.
1093
            plot_type: How to plot the coordinations.
1094
            title: Title for the figure.
1095
            max_dist: Maximum distance to be plotted when the plotting of the distance is set to 'initial_normalized'
1096
                or 'initial_real' (Warning: this is not the same meaning in both cases! In the first case, the
1097
                closest atom lies at a "normalized" distance of 1.0 so that 2.0 means refers to this normalized
1098
                distance while in the second case, the real distance is used).
1099
            figsize: Size of the figure.
1100
        """
1101
        fig, subplot = self.get_environments_figure(
1✔
1102
            isite=isite,
1103
            plot_type=plot_type,
1104
            title=title,
1105
            max_dist=max_dist,
1106
            figsize=figsize,
1107
        )
1108
        if fig is None:
1✔
1109
            return
×
1110
        fig.savefig(imagename)
1✔
1111

1112
    def differences_wrt(self, other):
1✔
1113
        """
1114
        Return differences found in the current StructureEnvironments with respect to another StructureEnvironments.
1115

1116
        Args:
1117
            other: A StructureEnvironments object.
1118

1119
        Returns:
1120
            List of differences between the two StructureEnvironments objects.
1121
        """
1122
        differences = []
1✔
1123
        if self.structure != other.structure:
1✔
1124
            differences.append(
×
1125
                {
1126
                    "difference": "structure",
1127
                    "comparison": "__eq__",
1128
                    "self": self.structure,
1129
                    "other": other.structure,
1130
                }
1131
            )
1132
            differences.append(
×
1133
                {
1134
                    "difference": "PREVIOUS DIFFERENCE IS DISMISSIVE",
1135
                    "comparison": "differences_wrt",
1136
                }
1137
            )
1138
            return differences
×
1139
        if self.valences != other.valences:
1✔
1140
            differences.append(
×
1141
                {
1142
                    "difference": "valences",
1143
                    "comparison": "__eq__",
1144
                    "self": self.valences,
1145
                    "other": other.valences,
1146
                }
1147
            )
1148
        if self.info != other.info:
1✔
1149
            differences.append(
×
1150
                {
1151
                    "difference": "info",
1152
                    "comparison": "__eq__",
1153
                    "self": self.info,
1154
                    "other": other.info,
1155
                }
1156
            )
1157
        if self.voronoi != other.voronoi:
1✔
1158
            if self.voronoi.is_close_to(other.voronoi):
×
1159
                differences.append(
×
1160
                    {
1161
                        "difference": "voronoi",
1162
                        "comparison": "__eq__",
1163
                        "self": self.voronoi,
1164
                        "other": other.voronoi,
1165
                    }
1166
                )
1167
                differences.append(
×
1168
                    {
1169
                        "difference": "PREVIOUS DIFFERENCE IS DISMISSIVE",
1170
                        "comparison": "differences_wrt",
1171
                    }
1172
                )
1173
                return differences
×
1174

1175
            differences.append(
×
1176
                {
1177
                    "difference": "voronoi",
1178
                    "comparison": "is_close_to",
1179
                    "self": self.voronoi,
1180
                    "other": other.voronoi,
1181
                }
1182
            )
1183
            # TODO: make it possible to have "close" voronoi's
1184
            differences.append(
×
1185
                {
1186
                    "difference": "PREVIOUS DIFFERENCE IS DISMISSIVE",
1187
                    "comparison": "differences_wrt",
1188
                }
1189
            )
1190
            return differences
×
1191
        for isite, self_site_nb_sets in enumerate(self.neighbors_sets):
1✔
1192
            other_site_nb_sets = other.neighbors_sets[isite]
1✔
1193
            if self_site_nb_sets is None:
1✔
1194
                if other_site_nb_sets is None:
1✔
1195
                    continue
1✔
1196
                differences.append(
×
1197
                    {
1198
                        "difference": f"neighbors_sets[{isite=:d}]",
1199
                        "comparison": "has_neighbors",
1200
                        "self": "None",
1201
                        "other": set(other_site_nb_sets),
1202
                    }
1203
                )
1204
                continue
×
1205
            if other_site_nb_sets is None:
1✔
1206
                differences.append(
×
1207
                    {
1208
                        "difference": f"neighbors_sets[{isite=:d}]",
1209
                        "comparison": "has_neighbors",
1210
                        "self": set(self_site_nb_sets),
1211
                        "other": "None",
1212
                    }
1213
                )
1214
                continue
×
1215
            self_site_cns = set(self_site_nb_sets)
1✔
1216
            other_site_cns = set(other_site_nb_sets)
1✔
1217
            if self_site_cns != other_site_cns:
1✔
1218
                differences.append(
×
1219
                    {
1220
                        "difference": f"neighbors_sets[{isite=:d}]",
1221
                        "comparison": "coordination_numbers",
1222
                        "self": self_site_cns,
1223
                        "other": other_site_cns,
1224
                    }
1225
                )
1226
            common_cns = self_site_cns.intersection(other_site_cns)
1✔
1227
            for cn in common_cns:
1✔
1228
                other_site_cn_nb_sets = other_site_nb_sets[cn]
1✔
1229
                self_site_cn_nb_sets = self_site_nb_sets[cn]
1✔
1230
                set_self_site_cn_nb_sets = set(self_site_cn_nb_sets)
1✔
1231
                set_other_site_cn_nb_sets = set(other_site_cn_nb_sets)
1✔
1232
                if set_self_site_cn_nb_sets != set_other_site_cn_nb_sets:
1✔
1233
                    differences.append(
×
1234
                        {
1235
                            "difference": f"neighbors_sets[{isite=:d}][{cn=:d}]",
1236
                            "comparison": "neighbors_sets",
1237
                            "self": self_site_cn_nb_sets,
1238
                            "other": other_site_cn_nb_sets,
1239
                        }
1240
                    )
1241
                common_nb_sets = set_self_site_cn_nb_sets.intersection(set_other_site_cn_nb_sets)
1✔
1242
                for nb_set in common_nb_sets:
1✔
1243
                    inb_set_self = self_site_cn_nb_sets.index(nb_set)
1✔
1244
                    inb_set_other = other_site_cn_nb_sets.index(nb_set)
1✔
1245
                    self_ce = self.ce_list[isite][cn][inb_set_self]
1✔
1246
                    other_ce = other.ce_list[isite][cn][inb_set_other]
1✔
1247
                    if self_ce != other_ce:
1✔
1248
                        if self_ce.is_close_to(other_ce):
×
1249
                            differences.append(
×
1250
                                {
1251
                                    "difference": f"ce_list[{isite=}][{cn=}][inb_set={inb_set_self}]",
1252
                                    "comparison": "__eq__",
1253
                                    "self": self_ce,
1254
                                    "other": other_ce,
1255
                                }
1256
                            )
1257
                        else:
1258
                            differences.append(
×
1259
                                {
1260
                                    "difference": f"ce_list[{isite=}][{cn=}][inb_set={inb_set_self}]",
1261
                                    "comparison": "is_close_to",
1262
                                    "self": self_ce,
1263
                                    "other": other_ce,
1264
                                }
1265
                            )
1266
        return differences
1✔
1267

1268
    def __eq__(self, other: object) -> bool:
1✔
1269
        if not isinstance(other, StructureEnvironments):
1✔
1270
            return NotImplemented
×
1271

1272
        if len(self.ce_list) != len(other.ce_list):
1✔
1273
            return False
×
1274
        if self.voronoi != other.voronoi:
1✔
1275
            return False
×
1276
        if len(self.valences) != len(other.valences):
1✔
1277
            return False
×
1278
        if self.sites_map != other.sites_map:
1✔
1279
            return False
×
1280
        if self.equivalent_sites != other.equivalent_sites:
1✔
1281
            return False
×
1282
        if self.structure != other.structure:
1✔
1283
            return False
×
1284
        if self.info != other.info:
1✔
1285
            return False
×
1286
        for isite, site_ces in enumerate(self.ce_list):
1✔
1287
            site_nb_sets_self = self.neighbors_sets[isite]
1✔
1288
            site_nb_sets_other = other.neighbors_sets[isite]
1✔
1289
            if site_nb_sets_self != site_nb_sets_other:
1✔
1290
                return False
×
1291
            if site_ces != other.ce_list[isite]:
1✔
1292
                return False
×
1293
        return True
1✔
1294

1295
    def as_dict(self):
1✔
1296
        """
1297
        Bson-serializable dict representation of the StructureEnvironments object.
1298

1299
        Returns:
1300
            Bson-serializable dict representation of the StructureEnvironments object.
1301
        """
1302
        ce_list_dict = [
1✔
1303
            {str(cn): [ce.as_dict() if ce is not None else None for ce in ce_dict[cn]] for cn in ce_dict}
1304
            if ce_dict is not None
1305
            else None
1306
            for ce_dict in self.ce_list
1307
        ]
1308
        nbs_sets_dict = [
1✔
1309
            {str(cn): [nb_set.as_dict() for nb_set in nb_sets] for cn, nb_sets in site_nbs_sets.items()}
1310
            if site_nbs_sets is not None
1311
            else None
1312
            for site_nbs_sets in self.neighbors_sets
1313
        ]
1314
        info_dict = {key: val for key, val in self.info.items() if key not in ["sites_info"]}
1✔
1315
        info_dict["sites_info"] = [
1✔
1316
            {
1317
                "nb_sets_info": {
1318
                    str(cn): {str(inb_set): nb_set_info for inb_set, nb_set_info in cn_sets.items()}
1319
                    for cn, cn_sets in site_info["nb_sets_info"].items()
1320
                },
1321
                "time": site_info["time"],
1322
            }
1323
            if "nb_sets_info" in site_info
1324
            else {}
1325
            for site_info in self.info["sites_info"]
1326
        ]
1327

1328
        return {
1✔
1329
            "@module": type(self).__module__,
1330
            "@class": type(self).__name__,
1331
            "voronoi": self.voronoi.as_dict(),
1332
            "valences": self.valences,
1333
            "sites_map": self.sites_map,
1334
            "equivalent_sites": [[ps.as_dict() for ps in psl] for psl in self.equivalent_sites],
1335
            "ce_list": ce_list_dict,
1336
            "structure": self.structure.as_dict(),
1337
            "neighbors_sets": nbs_sets_dict,
1338
            "info": info_dict,
1339
        }
1340

1341
    @classmethod
1✔
1342
    def from_dict(cls, d):
1✔
1343
        """
1344
        Reconstructs the StructureEnvironments object from a dict representation of the StructureEnvironments created
1345
        using the as_dict method.
1346

1347
        Args:
1348
            d: dict representation of the StructureEnvironments object.
1349

1350
        Returns:
1351
            StructureEnvironments object.
1352
        """
1353
        ce_list = [
1✔
1354
            None
1355
            if (ce_dict == "None" or ce_dict is None)
1356
            else {
1357
                int(cn): [
1358
                    None if (ced is None or ced == "None") else ChemicalEnvironments.from_dict(ced)
1359
                    for ced in ce_dict[cn]
1360
                ]
1361
                for cn in ce_dict
1362
            }
1363
            for ce_dict in d["ce_list"]
1364
        ]
1365
        voronoi = DetailedVoronoiContainer.from_dict(d["voronoi"])
1✔
1366
        structure = Structure.from_dict(d["structure"])
1✔
1367
        neighbors_sets = [
1✔
1368
            {
1369
                int(cn): [
1370
                    cls.NeighborsSet.from_dict(dd=nb_set_dict, structure=structure, detailed_voronoi=voronoi)
1371
                    for nb_set_dict in nb_sets
1372
                ]
1373
                for cn, nb_sets in site_nbs_sets_dict.items()
1374
            }
1375
            if site_nbs_sets_dict is not None
1376
            else None
1377
            for site_nbs_sets_dict in d["neighbors_sets"]
1378
        ]
1379
        info = {key: val for key, val in d["info"].items() if key not in ["sites_info"]}
1✔
1380
        if "sites_info" in d["info"]:
1✔
1381
            info["sites_info"] = [
1✔
1382
                {
1383
                    "nb_sets_info": {
1384
                        int(cn): {int(inb_set): nb_set_info for inb_set, nb_set_info in cn_sets.items()}
1385
                        for cn, cn_sets in site_info["nb_sets_info"].items()
1386
                    },
1387
                    "time": site_info["time"],
1388
                }
1389
                if "nb_sets_info" in site_info
1390
                else {}
1391
                for site_info in d["info"]["sites_info"]
1392
            ]
1393
        return cls(
1✔
1394
            voronoi=voronoi,
1395
            valences=d["valences"],
1396
            sites_map=d["sites_map"],
1397
            equivalent_sites=[[PeriodicSite.from_dict(psd) for psd in psl] for psl in d["equivalent_sites"]],
1398
            ce_list=ce_list,
1399
            structure=structure,
1400
            neighbors_sets=neighbors_sets,
1401
            info=info,
1402
        )
1403

1404

1405
class LightStructureEnvironments(MSONable):
1✔
1406
    """
1407
    Class used to store the chemical environments of a given structure obtained from a given ChemenvStrategy. Currently,
1408
    only strategies leading to the determination of a unique environment for each site is allowed
1409
    This class does not store all the information contained in the StructureEnvironments object, only the coordination
1410
    environment found.
1411
    """
1412

1413
    DELTA_MAX_OXIDATION_STATE = 0.1
1✔
1414
    DEFAULT_STATISTICS_FIELDS = [
1✔
1415
        "anion_list",
1416
        "anion_atom_list",
1417
        "cation_list",
1418
        "cation_atom_list",
1419
        "neutral_list",
1420
        "neutral_atom_list",
1421
        "atom_coordination_environments_present",
1422
        "ion_coordination_environments_present",
1423
        "fraction_atom_coordination_environments_present",
1424
        "fraction_ion_coordination_environments_present",
1425
        "coordination_environments_atom_present",
1426
        "coordination_environments_ion_present",
1427
    ]
1428

1429
    class NeighborsSet:
1✔
1430
        """
1431
        Class used to store a given set of neighbors of a given site (based on a list of sites, the voronoi
1432
        container is not part of the LightStructureEnvironments object).
1433
        """
1434

1435
        def __init__(self, structure: Structure, isite, all_nbs_sites, all_nbs_sites_indices):
1✔
1436
            """
1437
            Constructor for NeighborsSet.
1438

1439
            Args:
1440
                structure: Structure object.
1441
                isite: Index of the site for which neighbors are stored in this NeighborsSet.
1442
                all_nbs_sites: All the possible neighbors for this site.
1443
                all_nbs_sites_indices: Indices of the sites in all_nbs_sites that make up this NeighborsSet.
1444
            """
1445
            self.structure = structure
1✔
1446
            self.isite = isite
1✔
1447
            self.all_nbs_sites = all_nbs_sites
1✔
1448
            myset = set(all_nbs_sites_indices)
1✔
1449
            if len(myset) != len(all_nbs_sites_indices):
1✔
1450
                raise ValueError("Set of neighbors contains duplicates !")
×
1451
            self.all_nbs_sites_indices = sorted(myset)
1✔
1452
            self.all_nbs_sites_indices_unsorted = all_nbs_sites_indices
1✔
1453

1454
        @property
1✔
1455
        def neighb_coords(self):
1✔
1456
            """
1457
            Coordinates of neighbors for this NeighborsSet.
1458
            """
1459
            return [self.all_nbs_sites[inb]["site"].coords for inb in self.all_nbs_sites_indices_unsorted]
1✔
1460

1461
        @property
1✔
1462
        def neighb_sites(self):
1✔
1463
            """
1464
            Neighbors for this NeighborsSet as pymatgen Sites.
1465
            """
1466
            return [self.all_nbs_sites[inb]["site"] for inb in self.all_nbs_sites_indices_unsorted]
1✔
1467

1468
        @property
1✔
1469
        def neighb_sites_and_indices(self):
1✔
1470
            """
1471
            List of neighbors for this NeighborsSet as pymatgen Sites and their index in the original structure.
1472
            """
1473
            return [
1✔
1474
                {
1475
                    "site": self.all_nbs_sites[inb]["site"],
1476
                    "index": self.all_nbs_sites[inb]["index"],
1477
                }
1478
                for inb in self.all_nbs_sites_indices_unsorted
1479
            ]
1480

1481
        @property
1✔
1482
        def neighb_indices_and_images(self) -> list[dict[str, int]]:
1✔
1483
            """
1484
            List of indices and images with respect to the original unit cell sites for this NeighborsSet.
1485
            """
1486
            return [
1✔
1487
                {
1488
                    "index": self.all_nbs_sites[inb]["index"],
1489
                    "image_cell": self.all_nbs_sites[inb]["image_cell"],
1490
                }
1491
                for inb in self.all_nbs_sites_indices_unsorted
1492
            ]
1493

1494
        def __len__(self) -> int:
1✔
1495
            return len(self.all_nbs_sites_indices)
1✔
1496

1497
        def __hash__(self) -> int:
1✔
1498
            return len(self.all_nbs_sites_indices)
1✔
1499

1500
        def __eq__(self, other: object) -> bool:
1✔
1501
            needed_attrs = ("isite", "all_nbs_sites_indices")
1✔
1502

1503
            if not all(hasattr(other, attr) for attr in needed_attrs):
1✔
1504
                return NotImplemented
×
1505

1506
            return all(getattr(self, attr) == getattr(other, attr) for attr in needed_attrs)
1✔
1507

1508
        def __str__(self):
1✔
1509
            return (
1✔
1510
                f"Neighbors Set for site #{self.isite:d} :\n"
1511
                f" - Coordination number : {len(self):d}\n"
1512
                f" - Neighbors sites indices : {', '.join(f'{nb_idxs:d}' for nb_idxs in self.all_nbs_sites_indices)}\n"
1513
            )
1514

1515
        def as_dict(self):
1✔
1516
            """
1517
            A JSON-serializable dict representation of the NeighborsSet.
1518
            """
1519
            return {
1✔
1520
                "isite": self.isite,
1521
                "all_nbs_sites_indices": self.all_nbs_sites_indices_unsorted,
1522
            }
1523

1524
        @classmethod
1✔
1525
        def from_dict(cls, dd, structure: Structure, all_nbs_sites):
1✔
1526
            """
1527
            Reconstructs the NeighborsSet algorithm from its JSON-serializable dict representation, together with
1528
            the structure and all the possible neighbors sites.
1529

1530
            As an inner (nested) class, the NeighborsSet is not supposed to be used anywhere else that inside the
1531
            LightStructureEnvironments. The from_dict method is thus using the structure and all_nbs_sites when
1532
            reconstructing itself. These two are both in the LightStructureEnvironments object.
1533

1534
            Args:
1535
                dd: a JSON-serializable dict representation of a NeighborsSet.
1536
                structure: The structure.
1537
                all_nbs_sites: The list of all the possible neighbors for a given site.
1538

1539
            Returns: a NeighborsSet.
1540
            """
1541
            return cls(
1✔
1542
                structure=structure,
1543
                isite=dd["isite"],
1544
                all_nbs_sites=all_nbs_sites,
1545
                all_nbs_sites_indices=dd["all_nbs_sites_indices"],
1546
            )
1547

1548
    def __init__(
1✔
1549
        self,
1550
        strategy,
1551
        coordination_environments=None,
1552
        all_nbs_sites=None,
1553
        neighbors_sets=None,
1554
        structure=None,
1555
        valences=None,
1556
        valences_origin=None,
1557
    ):
1558
        """
1559
        Constructor for the LightStructureEnvironments object.
1560

1561
        Args:
1562
            strategy: ChemEnv strategy used to get the environments.
1563
            coordination_environments: The coordination environments identified.
1564
            all_nbs_sites: All the possible neighbors for each site in the structure.
1565
            neighbors_sets: The neighbors sets of each site in the structure.
1566
            structure: The structure.
1567
            valences: The valences used to get the environments (if needed).
1568
            valences_origin: How the valences were obtained (e.g. from the Bond-valence analysis or from the original
1569
                structure).
1570
        """
1571
        self.strategy = strategy
1✔
1572
        self.statistics_dict = None
1✔
1573
        self.coordination_environments = coordination_environments
1✔
1574
        self._all_nbs_sites = all_nbs_sites
1✔
1575
        self.neighbors_sets = neighbors_sets
1✔
1576
        self.structure = structure
1✔
1577
        self.valences = valences
1✔
1578
        self.valences_origin = valences_origin
1✔
1579

1580
    @classmethod
1✔
1581
    def from_structure_environments(cls, strategy, structure_environments, valences=None, valences_origin=None):
1✔
1582
        """
1583
        Construct a LightStructureEnvironments object from a strategy and a StructureEnvironments object.
1584

1585
        Args:
1586
            strategy: ChemEnv strategy used.
1587
            structure_environments: StructureEnvironments object from which to construct the LightStructureEnvironments.
1588
            valences: The valences of each site in the structure.
1589
            valences_origin: How the valences were obtained (e.g. from the Bond-valence analysis or from the original
1590
                structure).
1591

1592
        Returns: a LightStructureEnvironments object.
1593
        """
1594
        structure = structure_environments.structure
1✔
1595
        strategy.set_structure_environments(structure_environments=structure_environments)
1✔
1596
        coordination_environments = [None] * len(structure)
1✔
1597
        neighbors_sets = [None] * len(structure)
1✔
1598
        _all_nbs_sites = []
1✔
1599
        my_all_nbs_sites = []
1✔
1600
        if valences is None:
1✔
1601
            valences = structure_environments.valences
1✔
1602
            if valences_origin is None:
1✔
1603
                valences_origin = "from_structure_environments"
1✔
1604
        else:
1605
            if valences_origin is None:
1✔
1606
                valences_origin = "user-specified"
1✔
1607

1608
        for isite, site in enumerate(structure):
1✔
1609
            site_ces_and_nbs_list = strategy.get_site_ce_fractions_and_neighbors(site, strategy_info=True)
1✔
1610
            if site_ces_and_nbs_list is None:
1✔
1611
                continue
1✔
1612
            coordination_environments[isite] = []
1✔
1613
            neighbors_sets[isite] = []
1✔
1614
            site_ces = []
1✔
1615
            site_nbs_sets = []
1✔
1616
            for ce_and_neighbors in site_ces_and_nbs_list:
1✔
1617
                _all_nbs_sites_indices = []
1✔
1618
                # Coordination environment
1619
                ce_dict = {
1✔
1620
                    "ce_symbol": ce_and_neighbors["ce_symbol"],
1621
                    "ce_fraction": ce_and_neighbors["ce_fraction"],
1622
                }
1623
                if ce_and_neighbors["ce_dict"] is not None:
1✔
1624
                    csm = ce_and_neighbors["ce_dict"]["other_symmetry_measures"][strategy.symmetry_measure_type]
1✔
1625
                else:
1626
                    csm = None
×
1627
                ce_dict["csm"] = csm
1✔
1628
                ce_dict["permutation"] = ce_and_neighbors["ce_dict"]["permutation"]
1✔
1629
                site_ces.append(ce_dict)
1✔
1630
                # Neighbors
1631
                neighbors = ce_and_neighbors["neighbors"]
1✔
1632
                for nb_site_and_index in neighbors:
1✔
1633
                    nb_site = nb_site_and_index["site"]
1✔
1634
                    try:
1✔
1635
                        nb_allnbs_sites_index = my_all_nbs_sites.index(nb_site)
1✔
1636
                    except ValueError:
1✔
1637
                        nb_index_unitcell = nb_site_and_index["index"]
1✔
1638
                        diff = nb_site.frac_coords - structure[nb_index_unitcell].frac_coords
1✔
1639
                        rounddiff = np.round(diff)
1✔
1640
                        if not np.allclose(diff, rounddiff):
1✔
1641
                            raise ValueError(
×
1642
                                "Weird, differences between one site in a periodic image cell is not integer ..."
1643
                            )
1644
                        nb_image_cell = np.array(rounddiff, int)
1✔
1645
                        nb_allnbs_sites_index = len(_all_nbs_sites)
1✔
1646
                        _all_nbs_sites.append(
1✔
1647
                            {
1648
                                "site": nb_site,
1649
                                "index": nb_index_unitcell,
1650
                                "image_cell": nb_image_cell,
1651
                            }
1652
                        )
1653
                        my_all_nbs_sites.append(nb_site)
1✔
1654
                    _all_nbs_sites_indices.append(nb_allnbs_sites_index)
1✔
1655

1656
                nb_set = cls.NeighborsSet(
1✔
1657
                    structure=structure,
1658
                    isite=isite,
1659
                    all_nbs_sites=_all_nbs_sites,
1660
                    all_nbs_sites_indices=_all_nbs_sites_indices,
1661
                )
1662
                site_nbs_sets.append(nb_set)
1✔
1663
            coordination_environments[isite] = site_ces
1✔
1664
            neighbors_sets[isite] = site_nbs_sets
1✔
1665
        return cls(
1✔
1666
            strategy=strategy,
1667
            coordination_environments=coordination_environments,
1668
            all_nbs_sites=_all_nbs_sites,
1669
            neighbors_sets=neighbors_sets,
1670
            structure=structure,
1671
            valences=valences,
1672
            valences_origin=valences_origin,
1673
        )
1674

1675
    def setup_statistic_lists(self):
1✔
1676
        """
1677
        Set up the statistics of environments for this LightStructureEnvironments.
1678
        """
1679
        self.statistics_dict = {
1✔
1680
            "valences_origin": self.valences_origin,
1681
            "anion_list": {},  # OK
1682
            "anion_number": None,  # OK
1683
            "anion_atom_list": {},  # OK
1684
            "anion_atom_number": None,  # OK
1685
            "cation_list": {},  # OK
1686
            "cation_number": None,  # OK
1687
            "cation_atom_list": {},  # OK
1688
            "cation_atom_number": None,  # OK
1689
            "neutral_list": {},  # OK
1690
            "neutral_number": None,  # OK
1691
            "neutral_atom_list": {},  # OK
1692
            "neutral_atom_number": None,  # OK
1693
            "atom_coordination_environments_present": {},  # OK
1694
            "ion_coordination_environments_present": {},  # OK
1695
            "coordination_environments_ion_present": {},  # OK
1696
            "coordination_environments_atom_present": {},  # OK
1697
            "fraction_ion_coordination_environments_present": {},  # OK
1698
            "fraction_atom_coordination_environments_present": {},  # OK
1699
            "fraction_coordination_environments_ion_present": {},  # OK
1700
            "fraction_coordination_environments_atom_present": {},  # OK
1701
            "count_ion_present": {},  # OK
1702
            "count_atom_present": {},  # OK
1703
            "count_coordination_environments_present": {},
1704
        }
1705
        atom_stat = self.statistics_dict["atom_coordination_environments_present"]
1✔
1706
        ce_atom_stat = self.statistics_dict["coordination_environments_atom_present"]
1✔
1707
        fraction_atom_stat = self.statistics_dict["fraction_atom_coordination_environments_present"]
1✔
1708
        fraction_ce_atom_stat = self.statistics_dict["fraction_coordination_environments_atom_present"]
1✔
1709
        count_atoms = self.statistics_dict["count_atom_present"]
1✔
1710
        count_ce = self.statistics_dict["count_coordination_environments_present"]
1✔
1711
        for isite, site in enumerate(self.structure):
1✔
1712
            # Building anion and cation list
1713
            site_species = []
1✔
1714
            if self.valences != "undefined":
1✔
1715
                for sp, occ in site.species.items():
1✔
1716
                    valence = self.valences[isite]
1✔
1717
                    strspecie = str(Species(sp.symbol, valence))
1✔
1718
                    if valence < 0:
1✔
1719
                        specielist = self.statistics_dict["anion_list"]
1✔
1720
                        atomlist = self.statistics_dict["anion_atom_list"]
1✔
1721
                    elif valence > 0:
1✔
1722
                        specielist = self.statistics_dict["cation_list"]
1✔
1723
                        atomlist = self.statistics_dict["cation_atom_list"]
1✔
1724
                    else:
1725
                        specielist = self.statistics_dict["neutral_list"]
×
1726
                        atomlist = self.statistics_dict["neutral_atom_list"]
×
1727
                    if strspecie not in specielist:
1✔
1728
                        specielist[strspecie] = occ
1✔
1729
                    else:
1730
                        specielist[strspecie] += occ
1✔
1731
                    if sp.symbol not in atomlist:
1✔
1732
                        atomlist[sp.symbol] = occ
1✔
1733
                    else:
1734
                        atomlist[sp.symbol] += occ
1✔
1735
                    site_species.append((sp.symbol, valence, occ))
1✔
1736
            # Building environments lists
1737
            if self.coordination_environments[isite] is not None:
1✔
1738
                site_envs = [
1✔
1739
                    (ce_piece_dict["ce_symbol"], ce_piece_dict["ce_fraction"])
1740
                    for ce_piece_dict in self.coordination_environments[isite]
1741
                ]
1742
                for ce_symbol, fraction in site_envs:
1✔
1743
                    if fraction is None:
1✔
1744
                        continue
×
1745
                    if ce_symbol not in count_ce:
1✔
1746
                        count_ce[ce_symbol] = 0.0
1✔
1747
                    count_ce[ce_symbol] += fraction
1✔
1748
                for sp, occ in site.species.items():
1✔
1749
                    elmt = sp.symbol
1✔
1750
                    if elmt not in atom_stat:
1✔
1751
                        atom_stat[elmt] = {}
1✔
1752
                        count_atoms[elmt] = 0.0
1✔
1753
                    count_atoms[elmt] += occ
1✔
1754
                    for ce_symbol, fraction in site_envs:
1✔
1755
                        if fraction is None:
1✔
1756
                            continue
×
1757
                        if ce_symbol not in atom_stat[elmt]:
1✔
1758
                            atom_stat[elmt][ce_symbol] = 0.0
1✔
1759

1760
                        atom_stat[elmt][ce_symbol] += occ * fraction
1✔
1761
                        if ce_symbol not in ce_atom_stat:
1✔
1762
                            ce_atom_stat[ce_symbol] = {}
1✔
1763
                        if elmt not in ce_atom_stat[ce_symbol]:
1✔
1764
                            ce_atom_stat[ce_symbol][elmt] = 0.0
1✔
1765
                        ce_atom_stat[ce_symbol][elmt] += occ * fraction
1✔
1766

1767
                if self.valences != "undefined":
1✔
1768
                    ion_stat = self.statistics_dict["ion_coordination_environments_present"]
1✔
1769
                    ce_ion_stat = self.statistics_dict["coordination_environments_ion_present"]
1✔
1770
                    count_ions = self.statistics_dict["count_ion_present"]
1✔
1771
                    for elmt, oxi_state, occ in site_species:
1✔
1772
                        if elmt not in ion_stat:
1✔
1773
                            ion_stat[elmt] = {}
1✔
1774
                            count_ions[elmt] = {}
1✔
1775
                        if oxi_state not in ion_stat[elmt]:
1✔
1776
                            ion_stat[elmt][oxi_state] = {}
1✔
1777
                            count_ions[elmt][oxi_state] = 0.0
1✔
1778
                        count_ions[elmt][oxi_state] += occ
1✔
1779
                        for ce_symbol, fraction in site_envs:
1✔
1780
                            if fraction is None:
1✔
1781
                                continue
×
1782
                            if ce_symbol not in ion_stat[elmt][oxi_state]:
1✔
1783
                                ion_stat[elmt][oxi_state][ce_symbol] = 0.0
1✔
1784
                            ion_stat[elmt][oxi_state][ce_symbol] += occ * fraction
1✔
1785
                            if ce_symbol not in ce_ion_stat:
1✔
1786
                                ce_ion_stat[ce_symbol] = {}
1✔
1787
                            if elmt not in ce_ion_stat[ce_symbol]:
1✔
1788
                                ce_ion_stat[ce_symbol][elmt] = {}
1✔
1789
                            if oxi_state not in ce_ion_stat[ce_symbol][elmt]:
1✔
1790
                                ce_ion_stat[ce_symbol][elmt][oxi_state] = 0.0
1✔
1791
                            ce_ion_stat[ce_symbol][elmt][oxi_state] += occ * fraction
1✔
1792
        self.statistics_dict["anion_number"] = len(self.statistics_dict["anion_list"])
1✔
1793
        self.statistics_dict["anion_atom_number"] = len(self.statistics_dict["anion_atom_list"])
1✔
1794
        self.statistics_dict["cation_number"] = len(self.statistics_dict["cation_list"])
1✔
1795
        self.statistics_dict["cation_atom_number"] = len(self.statistics_dict["cation_atom_list"])
1✔
1796
        self.statistics_dict["neutral_number"] = len(self.statistics_dict["neutral_list"])
1✔
1797
        self.statistics_dict["neutral_atom_number"] = len(self.statistics_dict["neutral_atom_list"])
1✔
1798

1799
        for elmt, envs in atom_stat.items():
1✔
1800
            sumelement = count_atoms[elmt]
1✔
1801
            fraction_atom_stat[elmt] = {env: fraction / sumelement for env, fraction in envs.items()}
1✔
1802
        for ce_symbol, atoms in ce_atom_stat.items():
1✔
1803
            sumsymbol = count_ce[ce_symbol]
1✔
1804
            fraction_ce_atom_stat[ce_symbol] = {atom: fraction / sumsymbol for atom, fraction in atoms.items()}
1✔
1805
        ion_stat = self.statistics_dict["ion_coordination_environments_present"]
1✔
1806
        fraction_ion_stat = self.statistics_dict["fraction_ion_coordination_environments_present"]
1✔
1807
        ce_ion_stat = self.statistics_dict["coordination_environments_ion_present"]
1✔
1808
        fraction_ce_ion_stat = self.statistics_dict["fraction_coordination_environments_ion_present"]
1✔
1809
        count_ions = self.statistics_dict["count_ion_present"]
1✔
1810
        for elmt, oxi_states_envs in ion_stat.items():
1✔
1811
            fraction_ion_stat[elmt] = {}
1✔
1812
            for oxi_state, envs in oxi_states_envs.items():
1✔
1813
                sumspecie = count_ions[elmt][oxi_state]
1✔
1814
                fraction_ion_stat[elmt][oxi_state] = {env: fraction / sumspecie for env, fraction in envs.items()}
1✔
1815
        for ce_symbol, ions in ce_ion_stat.items():
1✔
1816
            fraction_ce_ion_stat[ce_symbol] = {}
1✔
1817
            sum_ce = np.sum([np.sum(list(oxistates.values())) for elmt, oxistates in ions.items()])
1✔
1818
            for elmt, oxistates in ions.items():
1✔
1819
                fraction_ce_ion_stat[ce_symbol][elmt] = {
1✔
1820
                    oxistate: fraction / sum_ce for oxistate, fraction in oxistates.items()
1821
                }
1822

1823
    def get_site_info_for_specie_ce(self, specie, ce_symbol):
1✔
1824
        """
1825
        Get list of indices that have the given specie with a given Coordination environment.
1826

1827
        Args:
1828
            specie: Species to get.
1829
            ce_symbol: Symbol of the coordination environment to get.
1830

1831
        Returns: Dictionary with the list of indices in the structure that have the given specie in the given
1832
            environment, their fraction and continuous symmetry measures.
1833
        """
1834
        element = specie.symbol
1✔
1835
        oxi_state = specie.oxi_state
1✔
1836
        isites = []
1✔
1837
        csms = []
1✔
1838
        fractions = []
1✔
1839
        for isite, site in enumerate(self.structure):
1✔
1840
            if element in [sp.symbol for sp in site.species]:
1✔
1841
                if self.valences == "undefined" or oxi_state == self.valences[isite]:
1✔
1842
                    for ce_dict in self.coordination_environments[isite]:
1✔
1843
                        if ce_symbol == ce_dict["ce_symbol"]:
1✔
1844
                            isites.append(isite)
1✔
1845
                            csms.append(ce_dict["csm"])
1✔
1846
                            fractions.append(ce_dict["ce_fraction"])
1✔
1847
        return {"isites": isites, "fractions": fractions, "csms": csms}
1✔
1848

1849
    def get_site_info_for_specie_allces(self, specie, min_fraction=0):
1✔
1850
        """
1851
        Get list of indices that have the given specie.
1852

1853
        Args:
1854
            specie: Species to get.
1855
            min_fraction: Minimum fraction of the coordination environment.
1856

1857
        Returns: Dictionary with the list of coordination environments for the given species, the indices of the sites
1858
            in which they appear, their fractions and continuous symmetry measures.
1859
        """
1860
        allces = {}
1✔
1861
        element = specie.symbol
1✔
1862
        oxi_state = specie.oxi_state
1✔
1863
        for isite, site in enumerate(self.structure):
1✔
1864
            if element in [sp.symbol for sp in site.species]:
1✔
1865
                if self.valences == "undefined" or oxi_state == self.valences[isite]:
1✔
1866
                    if self.coordination_environments[isite] is None:
1✔
1867
                        continue
×
1868
                    for ce_dict in self.coordination_environments[isite]:
1✔
1869
                        if ce_dict["ce_fraction"] < min_fraction:
1✔
1870
                            continue
×
1871
                        if ce_dict["ce_symbol"] not in allces:
1✔
1872
                            allces[ce_dict["ce_symbol"]] = {
1✔
1873
                                "isites": [],
1874
                                "fractions": [],
1875
                                "csms": [],
1876
                            }
1877
                        allces[ce_dict["ce_symbol"]]["isites"].append(isite)
1✔
1878
                        allces[ce_dict["ce_symbol"]]["fractions"].append(ce_dict["ce_fraction"])
1✔
1879
                        allces[ce_dict["ce_symbol"]]["csms"].append(ce_dict["csm"])
1✔
1880
        return allces
1✔
1881

1882
    def get_statistics(self, statistics_fields=DEFAULT_STATISTICS_FIELDS, bson_compatible=False):
1✔
1883
        """
1884
        Get the statistics of environments for this structure.
1885
        Args:
1886
            statistics_fields: Which statistics to get.
1887
            bson_compatible: Whether to make the dictionary BSON-compatible.
1888

1889
        Returns:
1890
            A dictionary with the requested statistics.
1891
        """
1892
        if self.statistics_dict is None:
1✔
1893
            self.setup_statistic_lists()
1✔
1894
        if statistics_fields == "ALL":
1✔
1895
            statistics_fields = list(self.statistics_dict)
×
1896
        if bson_compatible:
1✔
1897
            dd = jsanitize({field: self.statistics_dict[field] for field in statistics_fields})
×
1898
        else:
1899
            dd = {field: self.statistics_dict[field] for field in statistics_fields}
1✔
1900
        return dd
1✔
1901

1902
    def contains_only_one_anion_atom(self, anion_atom):
1✔
1903
        """
1904
        Whether this LightStructureEnvironments concerns a structure with only one given anion atom type.
1905

1906
        Args:
1907
            anion_atom: Anion (e.g. O, ...). The structure could contain O2- and O- though.
1908

1909
        Returns: True if this LightStructureEnvironments concerns a structure with only one given anion_atom.
1910
        """
1911
        return (
1✔
1912
            len(self.statistics_dict["anion_atom_list"]) == 1 and anion_atom in self.statistics_dict["anion_atom_list"]
1913
        )
1914

1915
    def contains_only_one_anion(self, anion):
1✔
1916
        """
1917
        Whether this LightStructureEnvironments concerns a structure with only one given anion type.
1918

1919
        Args:
1920
            anion: Anion (e.g. O2-, ...).
1921

1922
        Returns: True if this LightStructureEnvironments concerns a structure with only one given anion.
1923
        """
1924
        return len(self.statistics_dict["anion_list"]) == 1 and anion in self.statistics_dict["anion_list"]
1✔
1925

1926
    def site_contains_environment(self, isite, ce_symbol):
1✔
1927
        """
1928
        Whether a given site contains a given coordination environment.
1929

1930
        Args:
1931
            isite: Index of the site.
1932
            ce_symbol: Symbol of the coordination environment.
1933

1934
        Returns: True if the site contains the given coordination environment.
1935
        """
1936
        if self.coordination_environments[isite] is None:
1✔
1937
            return False
1✔
1938
        return ce_symbol in [ce_dict["ce_symbol"] for ce_dict in self.coordination_environments[isite]]
1✔
1939

1940
    def site_has_clear_environment(self, isite, conditions=None):
1✔
1941
        """
1942
        Whether a given site has a "clear" environments.
1943

1944
        A "clear" environment is somewhat arbitrary. You can pass (multiple) conditions, e.g. the environment should
1945
        have a continuous symmetry measure lower than this, a fraction higher than that, ...
1946

1947
        Args:
1948
            isite: Index of the site.
1949
            conditions: Conditions to be checked for an environment to be "clear".
1950

1951
        Returns: True if the site has a clear environment.
1952
        """
1953
        if self.coordination_environments[isite] is None:
×
1954
            raise ValueError(f"Coordination environments have not been determined for site {isite:d}")
×
1955
        if conditions is None:
×
1956
            return len(self.coordination_environments[isite]) == 1
×
1957
        ce = max(self.coordination_environments[isite], key=lambda x: x["ce_fraction"])
×
1958
        for condition in conditions:
×
1959
            target = condition["target"]
×
1960
            if target == "ce_fraction":
×
1961
                if ce[target] < condition["minvalue"]:
×
1962
                    return False
×
1963
            elif target == "csm":
×
1964
                if ce[target] > condition["maxvalue"]:
×
1965
                    return False
×
1966
            elif target == "number_of_ces":
×
1967
                if ce[target] > condition["maxnumber"]:
×
1968
                    return False
×
1969
            else:
1970
                raise ValueError(f"Target {target!r} for condition of clear environment is not allowed")
×
1971
        return True
×
1972

1973
    def structure_has_clear_environments(self, conditions=None, skip_none=True, skip_empty=False):
1✔
1974
        """
1975
        Whether all sites in a structure have "clear" environments.
1976

1977
        Args:
1978
            conditions: Conditions to be checked for an environment to be "clear".
1979
            skip_none: Whether to skip sites for which no environments have been computed.
1980
            skip_empty: Whether to skip sites for which no environments could be found.
1981

1982
        Returns:
1983
            bool: True if all the sites in the structure have clear environments.
1984
        """
1985
        for isite in range(len(self.structure)):
×
1986
            if self.coordination_environments[isite] is None:
×
1987
                if skip_none:
×
1988
                    continue
×
1989
                return False
×
1990
            if len(self.coordination_environments[isite]) == 0:
×
1991
                if skip_empty:
×
1992
                    continue
×
1993
                return False
×
1994
            if not self.site_has_clear_environment(isite=isite, conditions=conditions):
×
1995
                return False
×
1996
        return True
×
1997

1998
    def clear_environments(self, conditions=None):
1✔
1999
        """
2000
        Get the clear environments in the structure.
2001

2002
        Args:
2003
            conditions: Conditions to be checked for an environment to be "clear".
2004

2005
        Returns: Set of clear environments in this structure.
2006
        """
2007
        clear_envs_list = set()
×
2008
        for isite in range(len(self.structure)):
×
2009
            if self.coordination_environments[isite] is None:
×
2010
                continue
×
2011
            if len(self.coordination_environments[isite]) == 0:
×
2012
                continue
×
2013
            if self.site_has_clear_environment(isite=isite, conditions=conditions):
×
2014
                ce = max(
×
2015
                    self.coordination_environments[isite],
2016
                    key=lambda x: x["ce_fraction"],
2017
                )
2018
                clear_envs_list.add(ce["ce_symbol"])
×
2019
        return list(clear_envs_list)
×
2020

2021
    def structure_contains_atom_environment(self, atom_symbol, ce_symbol):
1✔
2022
        """
2023
        Checks whether the structure contains a given atom in a given environment.
2024

2025
        Args:
2026
            atom_symbol: Symbol of the atom.
2027
            ce_symbol: Symbol of the coordination environment.
2028

2029
        Returns:
2030
            True if the coordination environment is found, False otherwise
2031
        """
2032
        for isite, site in enumerate(self.structure):
1✔
2033
            if Element(atom_symbol) in site.species.element_composition and self.site_contains_environment(
1✔
2034
                isite, ce_symbol
2035
            ):
2036
                return True
1✔
2037
        return False
1✔
2038

2039
    def environments_identified(self):
1✔
2040
        """
2041
        Return the set of environments identified in this structure.
2042

2043
        Returns: Set of environments identified in this structure.
2044
        """
2045
        return {ce["ce_symbol"] for celist in self.coordination_environments if celist is not None for ce in celist}
1✔
2046

2047
    @property
1✔
2048
    def uniquely_determines_coordination_environments(self):
1✔
2049
        """
2050
        True if the coordination environments are uniquely determined.
2051
        """
2052
        return self.strategy.uniquely_determines_coordination_environments
1✔
2053

2054
    def __eq__(self, other: object) -> bool:
1✔
2055
        """
2056
        Equality method that checks if the LightStructureEnvironments object is equal to another
2057
        LightStructureEnvironments object. Two LightStructureEnvironments objects are equal if the strategy used
2058
        is the same, if the structure is the same, if the valences used in the strategies are the same, if the
2059
        coordination environments and the neighbors determined by the strategy are the same.
2060

2061
        Args:
2062
            other: LightStructureEnvironments object to compare with.
2063

2064
        Returns:
2065
            True if both objects are equal, False otherwise.
2066
        """
2067
        if not isinstance(other, LightStructureEnvironments):
1✔
2068
            return NotImplemented
×
2069

2070
        is_equal = (
1✔
2071
            self.strategy == other.strategy
2072
            and self.structure == other.structure
2073
            and self.coordination_environments == other.coordination_environments
2074
            and self.valences == other.valences
2075
            and self.neighbors_sets == other.neighbors_sets
2076
        )
2077
        this_sites = [ss["site"] for ss in self._all_nbs_sites]
1✔
2078
        other_sites = [ss["site"] for ss in other._all_nbs_sites]
1✔
2079
        this_indices = [ss["index"] for ss in self._all_nbs_sites]
1✔
2080
        other_indices = [ss["index"] for ss in other._all_nbs_sites]
1✔
2081
        return is_equal and this_sites == other_sites and this_indices == other_indices
1✔
2082

2083
    def as_dict(self):
1✔
2084
        """
2085
        Returns:
2086
            dict: Bson-serializable representation of the LightStructureEnvironments object.
2087
        """
2088
        return {
1✔
2089
            "@module": type(self).__module__,
2090
            "@class": type(self).__name__,
2091
            "strategy": self.strategy.as_dict(),
2092
            "structure": self.structure.as_dict(),
2093
            "coordination_environments": self.coordination_environments,
2094
            "all_nbs_sites": [
2095
                {
2096
                    "site": PeriodicSite(
2097
                        species=nb_site["site"].species,
2098
                        coords=nb_site["site"].frac_coords,
2099
                        lattice=nb_site["site"].lattice,
2100
                        to_unit_cell=False,
2101
                        coords_are_cartesian=False,
2102
                        properties=nb_site["site"].properties,
2103
                    ).as_dict(),
2104
                    "index": nb_site["index"],
2105
                    "image_cell": [int(ii) for ii in nb_site["image_cell"]],
2106
                }
2107
                for nb_site in self._all_nbs_sites
2108
            ],
2109
            "neighbors_sets": [
2110
                [nb_set.as_dict() for nb_set in site_nb_sets] if site_nb_sets is not None else None
2111
                for site_nb_sets in self.neighbors_sets
2112
            ],
2113
            "valences": self.valences,
2114
        }
2115

2116
    @classmethod
1✔
2117
    def from_dict(cls, d):
1✔
2118
        """
2119
        Reconstructs the LightStructureEnvironments object from a dict representation of the
2120
        LightStructureEnvironments created using the as_dict method.
2121

2122
        Args:
2123
            d: dict representation of the LightStructureEnvironments object.
2124

2125
        Returns:
2126
            LightStructureEnvironments object.
2127
        """
2128
        dec = MontyDecoder()
1✔
2129
        structure = dec.process_decoded(d["structure"])
1✔
2130
        all_nbs_sites = []
1✔
2131
        for nb_site in d["all_nbs_sites"]:
1✔
2132
            periodic_site = dec.process_decoded(nb_site["site"])
1✔
2133
            site = PeriodicNeighbor(
1✔
2134
                species=periodic_site.species,
2135
                coords=periodic_site.frac_coords,
2136
                lattice=periodic_site.lattice,
2137
                properties=periodic_site.properties,
2138
            )
2139
            if "image_cell" in nb_site:
1✔
2140
                image_cell = np.array(nb_site["image_cell"], int)
1✔
2141
            else:
2142
                diff = site.frac_coords - structure[nb_site["index"]].frac_coords
×
2143
                rounddiff = np.round(diff)
×
2144
                if not np.allclose(diff, rounddiff):
×
2145
                    raise ValueError("Weird, differences between one site in a periodic image cell is not integer ...")
×
2146
                image_cell = np.array(rounddiff, int)
×
2147
            all_nbs_sites.append({"site": site, "index": nb_site["index"], "image_cell": image_cell})
1✔
2148
        neighbors_sets = [
1✔
2149
            [
2150
                cls.NeighborsSet.from_dict(dd=nb_set, structure=structure, all_nbs_sites=all_nbs_sites)
2151
                for nb_set in site_nb_sets
2152
            ]
2153
            if site_nb_sets is not None
2154
            else None
2155
            for site_nb_sets in d["neighbors_sets"]
2156
        ]
2157
        return cls(
1✔
2158
            strategy=dec.process_decoded(d["strategy"]),
2159
            coordination_environments=d["coordination_environments"],
2160
            all_nbs_sites=all_nbs_sites,
2161
            neighbors_sets=neighbors_sets,
2162
            structure=structure,
2163
            valences=d["valences"],
2164
        )
2165

2166

2167
class ChemicalEnvironments(MSONable):
1✔
2168
    """
2169
    Class used to store all the information about the chemical environment of a given site for a given list of
2170
    coordinated neighbors (internally called "cn_map").
2171
    """
2172

2173
    def __init__(self, coord_geoms=None):
1✔
2174
        """
2175
        Initializes the ChemicalEnvironments object containing all the information about the chemical
2176
        environment of a given site.
2177

2178
        Args:
2179
            coord_geoms: coordination geometries to be added to the chemical environment.
2180
        """
2181
        if coord_geoms is None:
1✔
2182
            self.coord_geoms = {}
1✔
2183
        else:
2184
            raise NotImplementedError(
×
2185
                "Constructor for ChemicalEnvironments with the coord_geoms argument is not yet implemented"
2186
            )
2187

2188
    def __getitem__(self, mp_symbol):
1✔
2189
        return self.coord_geoms[mp_symbol]
1✔
2190

2191
    def __len__(self):
1✔
2192
        """
2193
        Returns the number of coordination geometries in this ChemicalEnvironments object.
2194

2195
        Returns:
2196
            Number of coordination geometries in this ChemicalEnvironments object.
2197
        """
2198
        return len(self.coord_geoms)
1✔
2199

2200
    def __iter__(self):
1✔
2201
        yield from self.coord_geoms.items()
1✔
2202

2203
    def minimum_geometry(self, symmetry_measure_type=None, max_csm=None):
1✔
2204
        """
2205
        Returns the geometry with the minimum continuous symmetry measure of this ChemicalEnvironments.
2206

2207
        Returns:
2208
            tuple (symbol, csm) with symbol being the geometry with the minimum continuous symmetry measure and
2209
            csm being the continuous symmetry measure associated to it.
2210

2211
        Raises:
2212
            ValueError if no coordination geometry is found in this ChemicalEnvironments object.
2213
        """
2214
        if len(self.coord_geoms) == 0:
1✔
2215
            return None
×
2216
        cglist = list(self.coord_geoms)
1✔
2217
        if symmetry_measure_type is None:
1✔
2218
            csms = np.array([self.coord_geoms[cg]["other_symmetry_measures"]["csm_wcs_ctwcc"] for cg in cglist])
1✔
2219
        else:
2220
            csms = np.array([self.coord_geoms[cg]["other_symmetry_measures"][symmetry_measure_type] for cg in cglist])
1✔
2221
        csmlist = [self.coord_geoms[cg] for cg in cglist]
1✔
2222
        imin = np.argmin(csms)
1✔
2223
        if max_csm is not None:
1✔
2224
            if csmlist[imin] > max_csm:
×
2225
                return None
×
2226
        return cglist[imin], csmlist[imin]
1✔
2227

2228
    def minimum_geometries(self, n=None, symmetry_measure_type=None, max_csm=None):
1✔
2229
        """
2230
        Returns a list of geometries with increasing continuous symmetry measure in this ChemicalEnvironments object.
2231

2232
        Args:
2233
            n: Number of geometries to be included in the list.
2234

2235
        Returns:
2236
            List of geometries with increasing continuous symmetry measure in this ChemicalEnvironments object.
2237

2238
        Raises:
2239
            ValueError if no coordination geometry is found in this ChemicalEnvironments object.
2240
        """
2241
        cglist = list(self.coord_geoms)
1✔
2242
        if symmetry_measure_type is None:
1✔
2243
            csms = np.array([self.coord_geoms[cg]["other_symmetry_measures"]["csm_wcs_ctwcc"] for cg in cglist])
1✔
2244
        else:
2245
            csms = np.array([self.coord_geoms[cg]["other_symmetry_measures"][symmetry_measure_type] for cg in cglist])
1✔
2246
        csmlist = [self.coord_geoms[cg] for cg in cglist]
1✔
2247
        isorted = np.argsort(csms)
1✔
2248
        if max_csm is not None:
1✔
2249
            if n is None:
1✔
2250
                return [(cglist[ii], csmlist[ii]) for ii in isorted if csms[ii] <= max_csm]
1✔
2251

2252
            return [(cglist[ii], csmlist[ii]) for ii in isorted[:n] if csms[ii] <= max_csm]
×
2253

2254
        if n is None:
1✔
2255
            return [(cglist[ii], csmlist[ii]) for ii in isorted]
1✔
2256
        return [(cglist[ii], csmlist[ii]) for ii in isorted[:n]]
1✔
2257

2258
    def add_coord_geom(
1✔
2259
        self,
2260
        mp_symbol,
2261
        symmetry_measure,
2262
        algo="UNKNOWN",
2263
        permutation=None,
2264
        override=False,
2265
        local2perfect_map=None,
2266
        perfect2local_map=None,
2267
        detailed_voronoi_index=None,
2268
        other_symmetry_measures=None,
2269
        rotation_matrix=None,
2270
        scaling_factor=None,
2271
    ):
2272
        """
2273
        Adds a coordination geometry to the ChemicalEnvironments object.
2274

2275
        Args:
2276
            mp_symbol: Symbol of the coordination geometry added.
2277
            symmetry_measure: Symmetry measure of the coordination geometry added.
2278
            algo: Algorithm used for the search of the coordination geometry added.
2279
            permutation: Permutation of the neighbors that leads to the csm stored.
2280
            override: If set to True, the coordination geometry will override the existent one if present.
2281
            local2perfect_map: Mapping of the local indices to the perfect indices.
2282
            perfect2local_map: Mapping of the perfect indices to the local indices.
2283
            detailed_voronoi_index: Index in the voronoi containing the neighbors set.
2284
            other_symmetry_measures: Other symmetry measure of the coordination geometry added (with/without the
2285
                central atom, centered on the central atom or on the centroid with/without the central atom).
2286
            rotation_matrix: Rotation matrix mapping the local geometry to the perfect geometry.
2287
            scaling_factor: Scaling factor mapping the local geometry to the perfect geometry.
2288

2289
        Raises:
2290
            ChemenvError if the coordination geometry is already added and override is set to False
2291
        """
2292
        if not allcg.is_a_valid_coordination_geometry(mp_symbol=mp_symbol):
1✔
2293
            raise ChemenvError(
×
2294
                self.__class__,
2295
                "add_coord_geom",
2296
                f"Coordination geometry with mp_symbol {mp_symbol!r} is not valid",
2297
            )
2298
        if mp_symbol in list(self.coord_geoms) and not override:
1✔
2299
            raise ChemenvError(
×
2300
                self.__class__,
2301
                "add_coord_geom",
2302
                "This coordination geometry is already present and override is set to False",
2303
            )
2304

2305
        self.coord_geoms[mp_symbol] = {
1✔
2306
            "symmetry_measure": float(symmetry_measure),
2307
            "algo": algo,
2308
            "permutation": [int(i) for i in permutation],
2309
            "local2perfect_map": local2perfect_map,
2310
            "perfect2local_map": perfect2local_map,
2311
            "detailed_voronoi_index": detailed_voronoi_index,
2312
            "other_symmetry_measures": other_symmetry_measures,
2313
            "rotation_matrix": rotation_matrix,
2314
            "scaling_factor": scaling_factor,
2315
        }
2316

2317
    def __str__(self):
1✔
2318
        """
2319
        Returns a string representation of the ChemicalEnvironments object.
2320

2321
        Returns:
2322
            String representation of the ChemicalEnvironments object.
2323
        """
2324
        out = "Chemical environments object :\n"
1✔
2325
        if len(self.coord_geoms) == 0:
1✔
2326
            out += " => No coordination in it <=\n"
×
2327
            return out
×
2328
        for key in self.coord_geoms:
1✔
2329
            mp_symbol = key
1✔
2330
            break
1✔
2331
        cn = symbol_cn_mapping[mp_symbol]
1✔
2332
        out += f" => Coordination {cn} <=\n"
1✔
2333
        mp_symbols = list(self.coord_geoms)
1✔
2334
        csms_wcs = [self.coord_geoms[mp_symbol]["other_symmetry_measures"]["csm_wcs_ctwcc"] for mp_symbol in mp_symbols]
1✔
2335
        icsms_sorted = np.argsort(csms_wcs)
1✔
2336
        mp_symbols = [mp_symbols[ii] for ii in icsms_sorted]
1✔
2337
        for mp_symbol in mp_symbols:
1✔
2338
            csm_wcs = self.coord_geoms[mp_symbol]["other_symmetry_measures"]["csm_wcs_ctwcc"]
1✔
2339
            csm_wocs = self.coord_geoms[mp_symbol]["other_symmetry_measures"]["csm_wocs_ctwocc"]
1✔
2340
            out += f"   - {mp_symbol}\n"
1✔
2341
            out += f"      csm1 (with central site) : {csm_wcs}"
1✔
2342
            out += f"      csm2 (without central site) : {csm_wocs}"
1✔
2343
            out += f"     algo : {self.coord_geoms[mp_symbol]['algo']}"
1✔
2344
            out += f"     perm : {self.coord_geoms[mp_symbol]['permutation']}\n"
1✔
2345
            out += f"       local2perfect : {str(self.coord_geoms[mp_symbol]['local2perfect_map'])}\n"
1✔
2346
            out += f"       perfect2local : {str(self.coord_geoms[mp_symbol]['perfect2local_map'])}\n"
1✔
2347
        return out
1✔
2348

2349
    def is_close_to(self, other, rtol=0.0, atol=1e-8) -> bool:
1✔
2350
        """
2351
        Whether this ChemicalEnvironments object is close to another one.
2352

2353
        Args:
2354
            other: Another ChemicalEnvironments object.
2355
            rtol: Relative tolerance for the comparison of Continuous Symmetry Measures.
2356
            atol: Absolute tolerance for the comparison of Continuous Symmetry Measures.
2357

2358
        Returns:
2359
            True if the two ChemicalEnvironments objects are close to each other.
2360
        """
2361
        if set(self.coord_geoms) != set(other.coord_geoms):
1✔
2362
            return False
×
2363
        for mp_symbol, cg_dict_self in self.coord_geoms.items():
1✔
2364
            cg_dict_other = other[mp_symbol]
1✔
2365
            other_csms_self = cg_dict_self["other_symmetry_measures"]
1✔
2366
            other_csms_other = cg_dict_other["other_symmetry_measures"]
1✔
2367
            for csmtype in [
1✔
2368
                "csm_wcs_ctwcc",
2369
                "csm_wcs_ctwocc",
2370
                "csm_wcs_csc",
2371
                "csm_wocs_ctwcc",
2372
                "csm_wocs_ctwocc",
2373
                "csm_wocs_csc",
2374
            ]:
2375
                if not np.isclose(
1✔
2376
                    other_csms_self[csmtype],
2377
                    other_csms_other[csmtype],
2378
                    rtol=rtol,
2379
                    atol=atol,
2380
                ):
2381
                    return False
1✔
2382
        return True
1✔
2383

2384
    def __eq__(self, other: object) -> bool:
1✔
2385
        """
2386
        Equality method that checks if the ChemicalEnvironments object is equal to another ChemicalEnvironments.
2387
        object.
2388

2389
        Args:
2390
            other: ChemicalEnvironments object to compare with.
2391

2392
        Returns:
2393
            True if both objects are equal, False otherwise.
2394
        """
2395
        if not isinstance(other, ChemicalEnvironments):
1✔
2396
            return NotImplemented
×
2397

2398
        if set(self.coord_geoms) != set(other.coord_geoms):
1✔
2399
            return False
×
2400
        for mp_symbol, cg_dict_self in self.coord_geoms.items():
1✔
2401
            cg_dict_other = other.coord_geoms[mp_symbol]
1✔
2402
            if cg_dict_self["symmetry_measure"] != cg_dict_other["symmetry_measure"]:
1✔
2403
                return False
1✔
2404
            if cg_dict_self["algo"] != cg_dict_other["algo"]:
1✔
2405
                return False
×
2406
            if cg_dict_self["permutation"] != cg_dict_other["permutation"]:
1✔
2407
                return False
×
2408
            if cg_dict_self["detailed_voronoi_index"] != cg_dict_other["detailed_voronoi_index"]:
1✔
2409
                return False
×
2410
            other_csms_self = cg_dict_self["other_symmetry_measures"]
1✔
2411
            other_csms_other = cg_dict_other["other_symmetry_measures"]
1✔
2412
            for csmtype in [
1✔
2413
                "csm_wcs_ctwcc",
2414
                "csm_wcs_ctwocc",
2415
                "csm_wcs_csc",
2416
                "csm_wocs_ctwcc",
2417
                "csm_wocs_ctwocc",
2418
                "csm_wocs_csc",
2419
            ]:
2420
                if other_csms_self[csmtype] != other_csms_other[csmtype]:
1✔
2421
                    return False
×
2422
        return True
1✔
2423

2424
    def as_dict(self):
1✔
2425
        """
2426
        Returns a dictionary representation of the ChemicalEnvironments object.
2427

2428
        Returns:
2429
            A dictionary representation of the ChemicalEnvironments object.
2430
        """
2431
        return {
1✔
2432
            "@module": type(self).__module__,
2433
            "@class": type(self).__name__,
2434
            "coord_geoms": jsanitize(self.coord_geoms),
2435
        }
2436

2437
    @classmethod
1✔
2438
    def from_dict(cls, d):
1✔
2439
        """
2440
        Reconstructs the ChemicalEnvironments object from a dict representation of the ChemicalEnvironments created
2441
        using the as_dict method.
2442

2443
        Args:
2444
            d: dict representation of the ChemicalEnvironments object.
2445

2446
        Returns:
2447
            ChemicalEnvironments object.
2448
        """
2449
        ce = cls()
1✔
2450
        for cg in d["coord_geoms"]:
1✔
2451
            if d["coord_geoms"][cg]["local2perfect_map"] is None:
1✔
2452
                l2p_map = None
×
2453
            else:
2454
                l2p_map = {int(key): int(val) for key, val in d["coord_geoms"][cg]["local2perfect_map"].items()}
1✔
2455
            if d["coord_geoms"][cg]["perfect2local_map"] is None:
1✔
2456
                p2l_map = None
×
2457
            else:
2458
                p2l_map = {int(key): int(val) for key, val in d["coord_geoms"][cg]["perfect2local_map"].items()}
1✔
2459
            if (
1✔
2460
                "other_symmetry_measures" in d["coord_geoms"][cg]
2461
                and d["coord_geoms"][cg]["other_symmetry_measures"] is not None
2462
            ):
2463
                other_csms = d["coord_geoms"][cg]["other_symmetry_measures"]
1✔
2464
            else:
2465
                other_csms = None
×
2466
            ce.add_coord_geom(
1✔
2467
                cg,
2468
                d["coord_geoms"][cg]["symmetry_measure"],
2469
                d["coord_geoms"][cg]["algo"],
2470
                permutation=d["coord_geoms"][cg]["permutation"],
2471
                local2perfect_map=l2p_map,
2472
                perfect2local_map=p2l_map,
2473
                detailed_voronoi_index=d["coord_geoms"][cg]["detailed_voronoi_index"],
2474
                other_symmetry_measures=other_csms,
2475
                rotation_matrix=d["coord_geoms"][cg]["rotation_matrix"],
2476
                scaling_factor=d["coord_geoms"][cg]["scaling_factor"],
2477
            )
2478
        return ce
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc