• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.37
/pymatgen/analysis/graphs.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
Module for graph representations of crystals and molecules.
6
"""
7
from __future__ import annotations
1✔
8

9
import copy
1✔
10
import logging
1✔
11
import os.path
1✔
12
import subprocess
1✔
13
import warnings
1✔
14
from collections import defaultdict, namedtuple
1✔
15
from itertools import combinations
1✔
16
from operator import itemgetter
1✔
17
from shutil import which
1✔
18
from typing import Callable
1✔
19

20
import networkx as nx
1✔
21
import networkx.algorithms.isomorphism as iso
1✔
22
import numpy as np
1✔
23
from monty.json import MSONable
1✔
24
from networkx.drawing.nx_agraph import write_dot
1✔
25
from networkx.readwrite import json_graph
1✔
26
from scipy.spatial import KDTree
1✔
27
from scipy.stats import describe
1✔
28

29
from pymatgen.core import Lattice, Molecule, PeriodicSite, Structure
1✔
30
from pymatgen.core.structure import FunctionalGroups
1✔
31
from pymatgen.util.coord import lattice_points_in_supercell
1✔
32
from pymatgen.vis.structure_vtk import EL_COLORS
1✔
33

34
try:
1✔
35
    import igraph
1✔
36
except ImportError:
1✔
37
    igraph = None
1✔
38

39

40
logger = logging.getLogger(__name__)
1✔
41
logger.setLevel(logging.INFO)
1✔
42

43
__author__ = "Matthew Horton, Evan Spotte-Smith, Samuel Blau"
1✔
44
__version__ = "0.1"
1✔
45
__maintainer__ = "Matthew Horton"
1✔
46
__email__ = "mkhorton@lbl.gov"
1✔
47
__status__ = "Production"
1✔
48
__date__ = "August 2017"
1✔
49

50
ConnectedSite = namedtuple("ConnectedSite", "site, jimage, index, weight, dist")
1✔
51

52

53
def _compare(g1, g2, i1, i2):
1✔
54
    """
55
    Helper function called by isomorphic to ensure comparison of node identities.
56
    """
57
    return g1.vs[i1]["species"] == g2.vs[i2]["species"]
×
58

59

60
def _igraph_from_nxgraph(graph):
1✔
61
    """
62
    Helper function that converts a networkx graph object into an igraph graph object.
63
    """
64
    nodes = graph.nodes(data=True)
×
65
    new_igraph = igraph.Graph()
×
66
    for node in nodes:
×
67
        new_igraph.add_vertex(name=str(node[0]), species=node[1]["specie"], coords=node[1]["coords"])
×
68
    new_igraph.add_edges([(str(edge[0]), str(edge[1])) for edge in graph.edges()])
×
69
    return new_igraph
×
70

71

72
def _isomorphic(frag1, frag2):
1✔
73
    """
74
    Internal function to check if two graph objects are isomorphic, using igraph if
75
    if is available and networkx if it is not.
76
    """
77
    f1_nodes = frag1.nodes(data=True)
1✔
78
    f2_nodes = frag2.nodes(data=True)
1✔
79
    if len(f1_nodes) != len(f2_nodes):
1✔
80
        return False
×
81
    f2_edges = frag2.edges()
1✔
82
    if len(f2_edges) != len(f2_edges):
1✔
83
        return False
×
84
    f1_comp_dict = {}
1✔
85
    f2_comp_dict = {}
1✔
86
    for node in f1_nodes:
1✔
87
        if node[1]["specie"] not in f1_comp_dict:
1✔
88
            f1_comp_dict[node[1]["specie"]] = 1
1✔
89
        else:
90
            f1_comp_dict[node[1]["specie"]] += 1
1✔
91
    for node in f2_nodes:
1✔
92
        if node[1]["specie"] not in f2_comp_dict:
1✔
93
            f2_comp_dict[node[1]["specie"]] = 1
1✔
94
        else:
95
            f2_comp_dict[node[1]["specie"]] += 1
1✔
96
    if f1_comp_dict != f2_comp_dict:
1✔
97
        return False
×
98
    if igraph is not None:
1✔
99
        ifrag1 = _igraph_from_nxgraph(frag1)
×
100
        ifrag2 = _igraph_from_nxgraph(frag2)
×
101
        return ifrag1.isomorphic_vf2(ifrag2, node_compat_fn=_compare)
×
102
    nm = iso.categorical_node_match("specie", "ERROR")
1✔
103
    return nx.is_isomorphic(frag1.to_undirected(), frag2.to_undirected(), node_match=nm)
1✔
104

105

106
class StructureGraph(MSONable):
1✔
107
    """
108
    This is a class for annotating a Structure with
109
    bond information, stored in the form of a graph. A "bond" does
110
    not necessarily have to be a chemical bond, but can store any
111
    kind of information that connects two Sites.
112
    """
113

114
    def __init__(self, structure: Structure, graph_data=None):
1✔
115
        """
116
        If constructing this class manually, use the `with_empty_graph`
117
        method or `with_local_env_strategy` method (using an algorithm
118
        provided by the `local_env` module, such as O'Keeffe).
119

120
        This class that contains connection information:
121
        relationships between sites represented by a Graph structure,
122
        and an associated structure object.
123

124
        This class uses the NetworkX package to store and operate
125
        on the graph itself, but contains a lot of helper methods
126
        to make associating a graph with a given crystallographic
127
        structure easier.
128

129
        Use cases for this include storing bonding information,
130
        NMR J-couplings, Heisenberg exchange parameters, etc.
131

132
        For periodic graphs, class stores information on the graph
133
        edges of what lattice image the edge belongs to.
134

135
        :param structure: a Structure object
136

137
        :param graph_data: dict containing graph information in
138
            dict format (not intended to be constructed manually,
139
        see as_dict method for format)
140
        """
141
        if isinstance(structure, StructureGraph):
1✔
142
            # just make a copy from input
143
            graph_data = structure.as_dict()["graphs"]
×
144

145
        self.structure = structure
1✔
146
        self.graph = nx.readwrite.json_graph.adjacency_graph(graph_data)
1✔
147

148
        # tidy up edge attr dicts, reading to/from json duplicates
149
        # information
150
        for _, _, _, d in self.graph.edges(keys=True, data=True):
1✔
151
            if "id" in d:
1✔
152
                del d["id"]
1✔
153
            if "key" in d:
1✔
154
                del d["key"]
1✔
155
            # ensure images are tuples (conversion to lists happens
156
            # when serializing back from json), it's important images
157
            # are hashable/immutable
158
            if "to_jimage" in d:
1✔
159
                d["to_jimage"] = tuple(d["to_jimage"])
1✔
160
            if "from_jimage" in d:
1✔
161
                d["from_jimage"] = tuple(d["from_jimage"])
×
162

163
    @classmethod
1✔
164
    def with_empty_graph(
1✔
165
        cls,
166
        structure: Structure,
167
        name="bonds",
168
        edge_weight_name=None,
169
        edge_weight_units=None,
170
    ):
171
        """
172
        Constructor for StructureGraph, returns a StructureGraph
173
        object with an empty graph (no edges, only nodes defined
174
        that correspond to Sites in Structure).
175

176
        :param structure (Structure):
177
        :param name (str): name of graph, e.g. "bonds"
178
        :param edge_weight_name (str): name of edge weights,
179
            e.g. "bond_length" or "exchange_constant"
180
        :param edge_weight_units (str): name of edge weight units
181
            e.g. "Ã…" or "eV"
182
        :return (StructureGraph):
183
        """
184
        if edge_weight_name and (edge_weight_units is None):
1✔
185
            raise ValueError(
×
186
                "Please specify units associated "
187
                "with your edge weights. Can be "
188
                "empty string if arbitrary or "
189
                "dimensionless."
190
            )
191

192
        # construct graph with one node per site
193
        # graph attributes don't change behavior of graph,
194
        # they're just for book-keeping
195
        graph = nx.MultiDiGraph(
1✔
196
            edge_weight_name=edge_weight_name,
197
            edge_weight_units=edge_weight_units,
198
            name=name,
199
        )
200
        graph.add_nodes_from(range(len(structure)))
1✔
201

202
        graph_data = json_graph.adjacency_data(graph)
1✔
203

204
        return cls(structure, graph_data=graph_data)
1✔
205

206
    @staticmethod
1✔
207
    def with_edges(structure, edges):
1✔
208
        """
209
        Constructor for MoleculeGraph, using pre-existing or pre-defined edges
210
        with optional edge parameters.
211

212
        :param molecule: Molecule object
213
        :param edges: dict representing the bonds of the functional
214
            group (format: {(from_index, to_index, from_image, to_image): props},
215
            where props is a dictionary of properties, including weight.
216
            Props should be None if no additional properties are to be
217
            specified.
218
        :return: sg, a StructureGraph
219
        """
220
        sg = StructureGraph.with_empty_graph(structure, name="bonds", edge_weight_name="weight", edge_weight_units="")
1✔
221

222
        for edge, props in edges.items():
1✔
223
            try:
1✔
224
                from_index = edge[0]
1✔
225
                to_index = edge[1]
1✔
226
                from_image = edge[2]
1✔
227
                to_image = edge[3]
1✔
228
            except TypeError:
×
229
                raise ValueError("Edges must be given as (from_index, to_index, from_image, to_image) tuples")
×
230

231
            if props is not None:
1✔
232
                if "weight" in props:
×
233
                    weight = props["weight"]
×
234
                    del props["weight"]
×
235
                else:
236
                    weight = None
×
237

238
                if len(props.items()) == 0:
×
239
                    props = None
×
240
            else:
241
                weight = None
1✔
242

243
            nodes = sg.graph.nodes
1✔
244
            if not (from_index in nodes and to_index in nodes):
1✔
245
                raise ValueError(
×
246
                    "Edges cannot be added if nodes are not present in the graph. Please check your indices."
247
                )
248

249
            sg.add_edge(
1✔
250
                from_index,
251
                to_index,
252
                from_jimage=from_image,
253
                to_jimage=to_image,
254
                weight=weight,
255
                edge_properties=props,
256
            )
257

258
        sg.set_node_attributes()
1✔
259
        return sg
1✔
260

261
    @staticmethod
1✔
262
    def with_local_env_strategy(structure, strategy, weights=False, edge_properties=False):
1✔
263
        """
264
        Constructor for StructureGraph, using a strategy
265
        from :class:`pymatgen.analysis.local_env`.
266

267
        :param structure: Structure object
268
        :param strategy: an instance of a
269
            :class:`pymatgen.analysis.local_env.NearNeighbors` object
270
        :param weights: if True, use weights from local_env class
271
            (consult relevant class for their meaning)
272
        :param edge_properties: if True, edge_properties from neighbors will be used
273
        :return:
274
        """
275
        if not strategy.structures_allowed:
1✔
276
            raise ValueError("Chosen strategy is not designed for use with structures! Please choose another strategy.")
1✔
277

278
        sg = StructureGraph.with_empty_graph(structure, name="bonds")
1✔
279

280
        for n, neighbors in enumerate(strategy.get_all_nn_info(structure)):
1✔
281
            for neighbor in neighbors:
1✔
282
                # local_env will always try to add two edges
283
                # for any one bond, one from site u to site v
284
                # and another form site v to site u: this is
285
                # harmless, so warn_duplicates=False
286
                if edge_properties:
1✔
287
                    sg.add_edge(
1✔
288
                        from_index=n,
289
                        from_jimage=(0, 0, 0),
290
                        to_index=neighbor["site_index"],
291
                        to_jimage=neighbor["image"],
292
                        weight=neighbor["weight"] if weights else None,
293
                        edge_properties=neighbor["edge_properties"],
294
                        warn_duplicates=False,
295
                    )
296
                else:
297
                    sg.add_edge(
1✔
298
                        from_index=n,
299
                        from_jimage=(0, 0, 0),
300
                        to_index=neighbor["site_index"],
301
                        to_jimage=neighbor["image"],
302
                        weight=neighbor["weight"] if weights else None,
303
                        edge_properties=None,
304
                        warn_duplicates=False,
305
                    )
306

307
        return sg
1✔
308

309
    @property
1✔
310
    def name(self):
1✔
311
        """
312
        :return: Name of graph
313
        """
314
        return self.graph.graph["name"]
1✔
315

316
    @property
1✔
317
    def edge_weight_name(self):
1✔
318
        """
319
        :return: Name of the edge weight property of graph
320
        """
321
        return self.graph.graph["edge_weight_name"]
1✔
322

323
    @property
1✔
324
    def edge_weight_unit(self):
1✔
325
        """
326
        :return: Units of the edge weight property of graph
327
        """
328
        return self.graph.graph["edge_weight_units"]
1✔
329

330
    def add_edge(
1✔
331
        self,
332
        from_index,
333
        to_index,
334
        from_jimage=(0, 0, 0),
335
        to_jimage=None,
336
        weight=None,
337
        warn_duplicates=True,
338
        edge_properties=None,
339
    ):
340
        """
341
        Add edge to graph.
342

343
        Since physically a 'bond' (or other connection
344
        between sites) doesn't have a direction, from_index,
345
        from_jimage can be swapped with to_index, to_jimage.
346

347
        However, images will always be shifted so that
348
        from_index < to_index and from_jimage becomes (0, 0, 0).
349

350
        :param from_index: index of site connecting from
351
        :param to_index: index of site connecting to
352
        :param from_jimage (tuple of ints): lattice vector of periodic
353
            image, e.g. (1, 0, 0) for periodic image in +x direction
354
        :param to_jimage (tuple of ints): lattice vector of image
355
        :param weight (float): e.g. bond length
356
        :param warn_duplicates (bool): if True, will warn if
357
            trying to add duplicate edges (duplicate edges will not
358
            be added in either case)
359
        :param edge_properties (dict): any other information to
360
            store on graph edges, similar to Structure's site_properties
361
        :return:
362
        """
363
        # this is not necessary for the class to work, but
364
        # just makes it neater
365
        if to_index < from_index:
1✔
366
            to_index, from_index = from_index, to_index
1✔
367
            to_jimage, from_jimage = from_jimage, to_jimage
1✔
368

369
        # constrain all from_jimages to be (0, 0, 0),
370
        # initial version of this class worked even if
371
        # from_jimage != (0, 0, 0), but making this
372
        # assumption simplifies logic later
373
        if not np.array_equal(from_jimage, (0, 0, 0)):
1✔
374
            shift = from_jimage
1✔
375
            from_jimage = np.subtract(from_jimage, shift)
1✔
376
            to_jimage = np.subtract(to_jimage, shift)
1✔
377

378
        # automatic detection of to_jimage if user doesn't specify
379
        # will try and detect all equivalent images and add multiple
380
        # edges if appropriate
381
        if to_jimage is None:
1✔
382
            # assume we want the closest site
383
            warnings.warn("Please specify to_jimage to be unambiguous, trying to automatically detect.")
1✔
384
            dist, to_jimage = self.structure[from_index].distance_and_image(self.structure[to_index])
1✔
385
            if dist == 0:
1✔
386
                # this will happen when from_index == to_index,
387
                # typically in primitive single-atom lattices
388
                images = [1, 0, 0], [0, 1, 0], [0, 0, 1]
1✔
389
                dists = []
1✔
390
                for image in images:
1✔
391
                    dists.append(
1✔
392
                        self.structure[from_index].distance_and_image(self.structure[from_index], jimage=image)[0]
393
                    )
394
                dist = min(dists)
1✔
395
            equiv_sites = self.structure.get_neighbors_in_shell(
1✔
396
                self.structure[from_index].coords, dist, dist * 0.01, include_index=True
397
            )
398
            for nnsite in equiv_sites:
1✔
399
                to_jimage = np.subtract(nnsite.frac_coords, self.structure[from_index].frac_coords)
1✔
400
                to_jimage = np.round(to_jimage).astype(int)
1✔
401
                self.add_edge(
1✔
402
                    from_index=from_index,
403
                    from_jimage=(0, 0, 0),
404
                    to_jimage=to_jimage,
405
                    to_index=nnsite.index,
406
                )
407
            return
1✔
408

409
        # sanitize types
410
        from_jimage, to_jimage = (
1✔
411
            tuple(map(int, from_jimage)),
412
            tuple(map(int, to_jimage)),
413
        )
414
        from_index, to_index = int(from_index), int(to_index)
1✔
415

416
        # if edge is from site i to site i, constrain direction of edge
417
        # this is a convention to avoid duplicate hops
418
        if to_index == from_index:
1✔
419
            if to_jimage == (0, 0, 0):
1✔
420
                warnings.warn("Tried to create a bond to itself, this doesn't make sense so was ignored.")
×
421
                return
×
422

423
            # ensure that the first non-zero jimage index is positive
424
            # assumes that at least one non-zero index is present
425
            is_positive = [idx for idx in to_jimage if idx != 0][0] > 0
1✔
426

427
            if not is_positive:
1✔
428
                # let's flip the jimage,
429
                # e.g. (0, 1, 0) is equivalent to (0, -1, 0) in this case
430
                to_jimage = tuple(-idx for idx in to_jimage)
1✔
431

432
        # check we're not trying to add a duplicate edge
433
        # there should only ever be at most one edge
434
        # between a given (site, jimage) pair and another
435
        # (site, jimage) pair
436
        existing_edge_data = self.graph.get_edge_data(from_index, to_index)
1✔
437
        if existing_edge_data:
1✔
438
            for d in existing_edge_data.values():
1✔
439
                if d["to_jimage"] == to_jimage:
1✔
440
                    if warn_duplicates:
1✔
441
                        warnings.warn(
1✔
442
                            "Trying to add an edge that already exists from "
443
                            f"site {from_index} to site {to_index} in {to_jimage}."
444
                        )
445
                    return
1✔
446

447
        # generic container for additional edge properties,
448
        # similar to site properties
449
        edge_properties = edge_properties or {}
1✔
450

451
        if weight:
1✔
452
            self.graph.add_edge(
1✔
453
                from_index,
454
                to_index,
455
                to_jimage=to_jimage,
456
                weight=weight,
457
                **edge_properties,
458
            )
459
        else:
460
            self.graph.add_edge(from_index, to_index, to_jimage=to_jimage, **edge_properties)
1✔
461

462
    def insert_node(
1✔
463
        self,
464
        i,
465
        species,
466
        coords,
467
        coords_are_cartesian=False,
468
        validate_proximity=False,
469
        site_properties=None,
470
        edges=None,
471
    ):
472
        """
473
        A wrapper around Molecule.insert(), which also incorporates the new
474
        site into the MoleculeGraph.
475

476
        :param i: Index at which to insert the new site
477
        :param species: Species for the new site
478
        :param coords: 3x1 array representing coordinates of the new site
479
        :param coords_are_cartesian: Whether coordinates are cartesian.
480
            Defaults to False.
481
        :param validate_proximity: For Molecule.insert(); if True (default
482
            False), distance will be checked to ensure that site can be safely
483
            added.
484
        :param site_properties: Site properties for Molecule
485
        :param edges: List of dicts representing edges to be added to the
486
            MoleculeGraph. These edges must include the index of the new site i,
487
            and all indices used for these edges should reflect the
488
            MoleculeGraph AFTER the insertion, NOT before. Each dict should at
489
            least have a "to_index" and "from_index" key, and can also have a
490
            "weight" and a "properties" key.
491
        :return:
492
        """
493
        self.structure.insert(
1✔
494
            i,
495
            species,
496
            coords,
497
            coords_are_cartesian=coords_are_cartesian,
498
            validate_proximity=validate_proximity,
499
            properties=site_properties,
500
        )
501

502
        mapping = {}
1✔
503
        for j in range(len(self.structure) - 1):
1✔
504
            if j < i:
1✔
505
                mapping[j] = j
1✔
506
            else:
507
                mapping[j] = j + 1
1✔
508
        nx.relabel_nodes(self.graph, mapping, copy=False)
1✔
509

510
        self.graph.add_node(i)
1✔
511
        self.set_node_attributes()
1✔
512

513
        if edges is not None:
1✔
514
            for edge in edges:
1✔
515
                try:
1✔
516
                    self.add_edge(
1✔
517
                        edge["from_index"],
518
                        edge["to_index"],
519
                        from_jimage=(0, 0, 0),
520
                        to_jimage=edge["to_jimage"],
521
                        weight=edge.get("weight", None),
522
                        edge_properties=edge.get("properties", None),
523
                    )
524
                except KeyError:
×
525
                    raise RuntimeError("Some edges are invalid.")
×
526

527
    def set_node_attributes(self):
1✔
528
        """
529
        Gives each node a "specie" and a "coords" attribute, updated with the
530
        current species and coordinates.
531

532
        :return:
533
        """
534
        species = {}
1✔
535
        coords = {}
1✔
536
        properties = {}
1✔
537
        for node in self.graph.nodes():
1✔
538
            species[node] = self.structure[node].specie.symbol
1✔
539
            coords[node] = self.structure[node].coords
1✔
540
            properties[node] = self.structure[node].properties
1✔
541

542
        nx.set_node_attributes(self.graph, species, "specie")
1✔
543
        nx.set_node_attributes(self.graph, coords, "coords")
1✔
544
        nx.set_node_attributes(self.graph, properties, "properties")
1✔
545

546
    def alter_edge(
1✔
547
        self,
548
        from_index,
549
        to_index,
550
        to_jimage=None,
551
        new_weight=None,
552
        new_edge_properties=None,
553
    ):
554
        """
555
        Alters either the weight or the edge_properties of
556
        an edge in the StructureGraph.
557

558
        :param from_index: int
559
        :param to_index: int
560
        :param to_jimage: tuple
561
        :param new_weight: alter_edge does not require
562
            that weight be altered. As such, by default, this
563
            is None. If weight is to be changed, it should be a
564
            float.
565
        :param new_edge_properties: alter_edge does not require
566
            that edge_properties be altered. As such, by default,
567
            this is None. If any edge properties are to be changed,
568
            it should be a dictionary of edge properties to be changed.
569
        :return:
570
        """
571
        existing_edges = self.graph.get_edge_data(from_index, to_index)
1✔
572

573
        # ensure that edge exists before attempting to change it
574
        if not existing_edges:
1✔
575
            raise ValueError(
×
576
                f"Edge between {from_index} and {to_index} cannot be altered; no edge exists between those sites."
577
            )
578

579
        if to_jimage is None:
1✔
580
            edge_index = 0
×
581
        else:
582
            for i, properties in existing_edges.items():
1✔
583
                if properties["to_jimage"] == to_jimage:
1✔
584
                    edge_index = i
1✔
585

586
        if new_weight is not None:
1✔
587
            self.graph[from_index][to_index][edge_index]["weight"] = new_weight
1✔
588

589
        if new_edge_properties is not None:
1✔
590
            for prop in list(new_edge_properties):
1✔
591
                self.graph[from_index][to_index][edge_index][prop] = new_edge_properties[prop]
1✔
592

593
    def break_edge(self, from_index, to_index, to_jimage=None, allow_reverse=False):
1✔
594
        """
595
        Remove an edge from the StructureGraph. If no image is given, this method will fail.
596

597
        :param from_index: int
598
        :param to_index: int
599
        :param to_jimage: tuple
600
        :param allow_reverse: If allow_reverse is True, then break_edge will
601
            attempt to break both (from_index, to_index) and, failing that,
602
            will attempt to break (to_index, from_index).
603
        :return:
604
        """
605
        # ensure that edge exists before attempting to remove it
606
        existing_edges = self.graph.get_edge_data(from_index, to_index)
1✔
607
        existing_reverse = None
1✔
608

609
        if to_jimage is None:
1✔
610
            raise ValueError("Image must be supplied, to avoid ambiguity.")
×
611

612
        if existing_edges:
1✔
613
            for i, properties in existing_edges.items():
1✔
614
                if properties["to_jimage"] == to_jimage:
1✔
615
                    edge_index = i
1✔
616

617
            self.graph.remove_edge(from_index, to_index, edge_index)
1✔
618

619
        else:
620
            if allow_reverse:
×
621
                existing_reverse = self.graph.get_edge_data(to_index, from_index)
×
622

623
            if existing_reverse:
×
624
                for i, properties in existing_reverse.items():
×
625
                    if properties["to_jimage"] == to_jimage:
×
626
                        edge_index = i
×
627

628
                self.graph.remove_edge(to_index, from_index, edge_index)
×
629
            else:
630
                raise ValueError(
×
631
                    f"Edge cannot be broken between {from_index} and {to_index}; "
632
                    f"no edge exists between those sites."
633
                )
634

635
    def remove_nodes(self, indices):
1✔
636
        """
637
        A wrapper for Molecule.remove_sites().
638

639
        :param indices: list of indices in the current Molecule (and graph) to
640
            be removed.
641
        :return:
642
        """
643
        self.structure.remove_sites(indices)
1✔
644
        self.graph.remove_nodes_from(indices)
1✔
645

646
        mapping = {}
1✔
647
        for correct, current in enumerate(sorted(self.graph.nodes)):
1✔
648
            mapping[current] = correct
1✔
649

650
        nx.relabel_nodes(self.graph, mapping, copy=False)
1✔
651
        self.set_node_attributes()
1✔
652

653
    def substitute_group(
1✔
654
        self,
655
        index,
656
        func_grp,
657
        strategy,
658
        bond_order=1,
659
        graph_dict=None,
660
        strategy_params=None,
661
    ):
662
        """
663
        Builds off of Structure.substitute to replace an atom in self.structure
664
        with a functional group. This method also amends self.graph to
665
        incorporate the new functional group.
666

667
        NOTE: Care must be taken to ensure that the functional group that is
668
        substituted will not place atoms to close to each other, or violate the
669
        dimensions of the Lattice.
670

671
        :param index: Index of atom to substitute.
672
        :param func_grp: Substituent molecule. There are two options:
673

674
            1. Providing an actual Molecule as the input. The first atom
675
                must be a DummySpecies X, indicating the position of
676
                nearest neighbor. The second atom must be the next
677
                nearest atom. For example, for a methyl group
678
                substitution, func_grp should be X-CH3, where X is the
679
                first site and C is the second site. What the code will
680
                do is to remove the index site, and connect the nearest
681
                neighbor to the C atom in CH3. The X-C bond indicates the
682
                directionality to connect the atoms.
683
            2. A string name. The molecule will be obtained from the
684
                relevant template in func_groups.json.
685
        :param strategy: Class from pymatgen.analysis.local_env.
686
        :param bond_order: A specified bond order to calculate the bond
687
            length between the attached functional group and the nearest
688
            neighbor site. Defaults to 1.
689
        :param graph_dict: Dictionary representing the bonds of the functional
690
            group (format: {(u, v): props}, where props is a dictionary of
691
            properties, including weight. If None, then the algorithm
692
            will attempt to automatically determine bonds using one of
693
            a list of strategies defined in pymatgen.analysis.local_env.
694
        :param strategy_params: dictionary of keyword arguments for strategy.
695
            If None, default parameters will be used.
696
        :return:
697
        """
698

699
        def map_indices(grp):
1✔
700
            grp_map = {}
1✔
701

702
            # Get indices now occupied by functional group
703
            # Subtracting 1 because the dummy atom X should not count
704
            atoms = len(grp) - 1
1✔
705
            offset = len(self.structure) - atoms
1✔
706

707
            for i in range(atoms):
1✔
708
                grp_map[i] = i + offset
1✔
709

710
            return grp_map
1✔
711

712
        if isinstance(func_grp, Molecule):
1✔
713
            func_grp = copy.deepcopy(func_grp)
1✔
714
        else:
715
            try:
1✔
716
                func_grp = copy.deepcopy(FunctionalGroups[func_grp])
1✔
717
            except Exception:
×
718
                raise RuntimeError("Can't find functional group in list. Provide explicit coordinate instead")
×
719

720
        self.structure.substitute(index, func_grp, bond_order=bond_order)
1✔
721

722
        mapping = map_indices(func_grp)
1✔
723

724
        # Remove dummy atom "X"
725
        func_grp.remove_species("X")
1✔
726

727
        if graph_dict is not None:
1✔
728
            for u, v in graph_dict:
1✔
729
                edge_props = graph_dict[(u, v)]
1✔
730
                if "to_jimage" in edge_props:
1✔
731
                    to_jimage = edge_props["to_jimage"]
×
732
                    del edge_props["to_jimage"]
×
733
                else:
734
                    # By default, assume that all edges should stay remain
735
                    # inside the initial image
736
                    to_jimage = (0, 0, 0)
1✔
737
                if "weight" in edge_props:
1✔
738
                    weight = edge_props["weight"]
1✔
739
                    del edge_props["weight"]
1✔
740
                self.add_edge(
1✔
741
                    mapping[u],
742
                    mapping[v],
743
                    to_jimage=to_jimage,
744
                    weight=weight,
745
                    edge_properties=edge_props,
746
                )
747

748
        else:
749
            if strategy_params is None:
1✔
750
                strategy_params = {}
1✔
751
            strat = strategy(**strategy_params)
1✔
752

753
            for site in mapping.values():
1✔
754
                neighbors = strat.get_nn_info(self.structure, site)
1✔
755

756
                for neighbor in neighbors:
1✔
757
                    self.add_edge(
1✔
758
                        from_index=site,
759
                        from_jimage=(0, 0, 0),
760
                        to_index=neighbor["site_index"],
761
                        to_jimage=neighbor["image"],
762
                        weight=neighbor["weight"],
763
                        warn_duplicates=False,
764
                    )
765

766
    def get_connected_sites(self, n, jimage=(0, 0, 0)):
1✔
767
        """
768
        Returns a named tuple of neighbors of site n:
769
        periodic_site, jimage, index, weight.
770
        Index is the index of the corresponding site
771
        in the original structure, weight can be
772
        None if not defined.
773
        :param n: index of Site in Structure
774
        :param jimage: lattice vector of site
775
        :return: list of ConnectedSite tuples,
776
            sorted by closest first
777
        """
778
        connected_sites = set()
1✔
779
        connected_site_images = set()
1✔
780

781
        out_edges = [(u, v, d, "out") for u, v, d in self.graph.out_edges(n, data=True)]
1✔
782
        in_edges = [(u, v, d, "in") for u, v, d in self.graph.in_edges(n, data=True)]
1✔
783

784
        for u, v, d, dir in out_edges + in_edges:
1✔
785
            to_jimage = d["to_jimage"]
1✔
786

787
            if dir == "in":
1✔
788
                u, v = v, u
1✔
789
                to_jimage = np.multiply(-1, to_jimage)
1✔
790

791
            to_jimage = tuple(map(int, np.add(to_jimage, jimage)))
1✔
792
            site_d = self.structure[v].as_dict()
1✔
793
            site_d["abc"] = np.add(site_d["abc"], to_jimage).tolist()
1✔
794
            site = PeriodicSite.from_dict(site_d)
1✔
795

796
            # from_site if jimage arg != (0, 0, 0)
797
            relative_jimage = np.subtract(to_jimage, jimage)
1✔
798
            dist = self.structure[u].distance(self.structure[v], jimage=relative_jimage)
1✔
799

800
            weight = d.get("weight", None)
1✔
801

802
            if (v, to_jimage) not in connected_site_images:
1✔
803
                connected_site = ConnectedSite(site=site, jimage=to_jimage, index=v, weight=weight, dist=dist)
1✔
804

805
                connected_sites.add(connected_site)
1✔
806
                connected_site_images.add((v, to_jimage))
1✔
807

808
        # return list sorted by closest sites first
809
        connected_sites = list(connected_sites)
1✔
810
        connected_sites.sort(key=lambda x: x.dist)
1✔
811

812
        return connected_sites
1✔
813

814
    def get_coordination_of_site(self, n):
1✔
815
        """
816
        Returns the number of neighbors of site n.
817
        In graph terms, simply returns degree
818
        of node corresponding to site n.
819
        :param n: index of site
820
        :return (int):
821
        """
822
        number_of_self_loops = sum(1 for n, v in self.graph.edges(n) if n == v)
1✔
823
        return self.graph.degree(n) - number_of_self_loops
1✔
824

825
    def draw_graph_to_file(
1✔
826
        self,
827
        filename="graph",
828
        diff=None,
829
        hide_unconnected_nodes=False,
830
        hide_image_edges=True,
831
        edge_colors=False,
832
        node_labels=False,
833
        weight_labels=False,
834
        image_labels=False,
835
        color_scheme="VESTA",
836
        keep_dot=False,
837
        algo="fdp",
838
    ):
839
        """
840
        Draws graph using GraphViz.
841

842
        The networkx graph object itself can also be drawn
843
        with networkx's in-built graph drawing methods, but
844
        note that this might give misleading results for
845
        multigraphs (edges are super-imposed on each other).
846

847
        If visualization is difficult to interpret,
848
        `hide_image_edges` can help, especially in larger
849
        graphs.
850

851
        :param filename: filename to output, will detect filetype
852
            from extension (any graphviz filetype supported, such as
853
            pdf or png)
854
        :param diff (StructureGraph): an additional graph to
855
            compare with, will color edges red that do not exist in diff
856
            and edges green that are in diff graph but not in the
857
            reference graph
858
        :param hide_unconnected_nodes: if True, hide unconnected
859
            nodes
860
        :param hide_image_edges: if True, do not draw edges that
861
            go through periodic boundaries
862
        :param edge_colors (bool): if True, use node colors to
863
            color edges
864
        :param node_labels (bool): if True, label nodes with
865
            species and site index
866
        :param weight_labels (bool): if True, label edges with
867
            weights
868
        :param image_labels (bool): if True, label edges with
869
            their periodic images (usually only used for debugging,
870
            edges to periodic images always appear as dashed lines)
871
        :param color_scheme (str): "VESTA" or "JMOL"
872
        :param keep_dot (bool): keep GraphViz .dot file for later
873
            visualization
874
        :param algo: any graphviz algo, "neato" (for simple graphs)
875
            or "fdp" (for more crowded graphs) usually give good outputs
876
        :return:
877
        """
878
        if not which(algo):
×
879
            raise RuntimeError("StructureGraph graph drawing requires GraphViz binaries to be in the path.")
×
880

881
        # Developer note: NetworkX also has methods for drawing
882
        # graphs using matplotlib, these also work here. However,
883
        # a dedicated tool like GraphViz allows for much easier
884
        # control over graph appearance and also correctly displays
885
        # multi-graphs (matplotlib can superimpose multiple edges).
886

887
        g = self.graph.copy()
×
888

889
        g.graph = {"nodesep": 10.0, "dpi": 300, "overlap": "false"}
×
890

891
        # add display options for nodes
892
        for n in g.nodes():
×
893
            # get label by species name
894
            label = f"{self.structure[n].specie}({n})" if node_labels else ""
×
895

896
            # use standard color scheme for nodes
897
            c = EL_COLORS[color_scheme].get(str(self.structure[n].specie.symbol), [0, 0, 0])
×
898

899
            # get contrasting font color
900
            # magic numbers account for perceived luminescence
901
            # https://stackoverflow.com/questions/1855884/determine-font-color-based-on-background-color
902
            fontcolor = "#000000" if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5 else "#ffffff"
×
903

904
            # convert color to hex string
905
            color = f"#{c[0]:02x}{c[1]:02x}{c[2]:02x}"
×
906

907
            g.add_node(
×
908
                n,
909
                fillcolor=color,
910
                fontcolor=fontcolor,
911
                label=label,
912
                fontname="Helvetica-bold",
913
                style="filled",
914
                shape="circle",
915
            )
916

917
        edges_to_delete = []
×
918

919
        # add display options for edges
920
        for u, v, k, d in g.edges(keys=True, data=True):
×
921
            # retrieve from/to images, set as origin if not defined
922
            to_image = d["to_jimage"]
×
923

924
            # set edge style
925
            d["style"] = "solid"
×
926
            if to_image != (0, 0, 0):
×
927
                d["style"] = "dashed"
×
928
                if hide_image_edges:
×
929
                    edges_to_delete.append((u, v, k))
×
930

931
            # don't show edge directions
932
            d["arrowhead"] = "none"
×
933

934
            # only add labels for images that are not the origin
935
            if image_labels:
×
936
                d["headlabel"] = "" if to_image == (0, 0, 0) else f"to {to_image}"
×
937
                d["arrowhead"] = "normal" if d["headlabel"] else "none"
×
938

939
            # optionally color edges using node colors
940
            color_u = g.nodes[u]["fillcolor"]
×
941
            color_v = g.nodes[v]["fillcolor"]
×
942
            d["color_uv"] = f"{color_u};0.5:{color_v};0.5" if edge_colors else "#000000"
×
943

944
            # optionally add weights to graph
945
            if weight_labels:
×
946
                units = g.graph.get("edge_weight_units", "")
×
947
                if d.get("weight"):
×
948
                    d["label"] = f"{d['weight']:.2f} {units}"
×
949

950
            # update edge with our new style attributes
951
            g.edges[u, v, k].update(d)
×
952

953
        # optionally remove periodic image edges,
954
        # these can be confusing due to periodic boundaries
955
        if hide_image_edges:
×
956
            for edge_to_delete in edges_to_delete:
×
957
                g.remove_edge(*edge_to_delete)
×
958

959
        # optionally hide unconnected nodes,
960
        # these can appear when removing periodic edges
961
        if hide_unconnected_nodes:
×
962
            g = g.subgraph([n for n in g.degree() if g.degree()[n] != 0])
×
963

964
        # optionally highlight differences with another graph
965
        if diff:
×
966
            diff = self.diff(diff, strict=True)
×
967
            green_edges = []
×
968
            red_edges = []
×
969
            for u, v, k, d in g.edges(keys=True, data=True):
×
970
                if (u, v, d["to_jimage"]) in diff["self"]:
×
971
                    # edge has been deleted
972
                    red_edges.append((u, v, k))
×
973
                elif (u, v, d["to_jimage"]) in diff["other"]:
×
974
                    # edge has been added
975
                    green_edges.append((u, v, k))
×
976
            for u, v, k in green_edges:
×
977
                g.edges[u, v, k].update({"color_uv": "#00ff00"})
×
978
            for u, v, k in red_edges:
×
979
                g.edges[u, v, k].update({"color_uv": "#ff0000"})
×
980

981
        basename, extension = os.path.splitext(filename)
×
982
        extension = extension[1:]
×
983

984
        write_dot(g, basename + ".dot")
×
985

986
        with open(filename, "w") as f:
×
987
            args = [algo, "-T", extension, basename + ".dot"]
×
988
            with subprocess.Popen(args, stdout=f, stdin=subprocess.PIPE, close_fds=True) as rs:
×
989
                rs.communicate()
×
990
                if rs.returncode != 0:
×
991
                    raise RuntimeError(f"{algo} exited with return code {rs.returncode}.")
×
992

993
        if not keep_dot:
×
994
            os.remove(basename + ".dot")
×
995

996
    @property
1✔
997
    def types_and_weights_of_connections(self):
1✔
998
        """
999
        Extract a dictionary summarizing the types and weights
1000
        of edges in the graph.
1001

1002
        :return: A dictionary with keys specifying the
1003
            species involved in a connection in alphabetical order
1004
            (e.g. string 'Fe-O') and values which are a list of
1005
            weights for those connections (e.g. bond lengths).
1006
        """
1007

1008
        def get_label(u, v):
1✔
1009
            u_label = self.structure[u].species_string
1✔
1010
            v_label = self.structure[v].species_string
1✔
1011
            return "-".join(sorted((u_label, v_label)))
1✔
1012

1013
        types = defaultdict(list)
1✔
1014
        for u, v, d in self.graph.edges(data=True):
1✔
1015
            label = get_label(u, v)
1✔
1016
            types[label].append(d["weight"])
1✔
1017

1018
        return dict(types)
1✔
1019

1020
    @property
1✔
1021
    def weight_statistics(self):
1✔
1022
        """
1023
        Extract a statistical summary of edge weights present in
1024
        the graph.
1025

1026
        :return: A dict with an 'all_weights' list, 'minimum',
1027
            'maximum', 'median', 'mean', 'std_dev'
1028
        """
1029
        all_weights = [d.get("weight", None) for u, v, d in self.graph.edges(data=True)]
1✔
1030
        stats = describe(all_weights, nan_policy="omit")
1✔
1031

1032
        return {
1✔
1033
            "all_weights": all_weights,
1034
            "min": stats.minmax[0],
1035
            "max": stats.minmax[1],
1036
            "mean": stats.mean,
1037
            "variance": stats.variance,
1038
        }
1039

1040
    def types_of_coordination_environments(self, anonymous=False):
1✔
1041
        """
1042
        Extract information on the different co-ordination environments
1043
        present in the graph.
1044

1045
        :param anonymous: if anonymous, will replace specie names
1046
            with A, B, C, etc.
1047
        :return: a list of co-ordination environments,
1048
            e.g. ['Mo-S(6)', 'S-Mo(3)']
1049
        """
1050
        motifs = set()
1✔
1051
        for idx, site in enumerate(self.structure):
1✔
1052
            centre_sp = site.species_string
1✔
1053

1054
            connected_sites = self.get_connected_sites(idx)
1✔
1055
            connected_species = [connected_site.site.species_string for connected_site in connected_sites]
1✔
1056

1057
            labels = []
1✔
1058
            for sp in set(connected_species):
1✔
1059
                count = connected_species.count(sp)
1✔
1060
                labels.append((count, sp))
1✔
1061

1062
            labels = sorted(labels, reverse=True)
1✔
1063

1064
            if anonymous:
1✔
1065
                mapping = {centre_sp: "A"}
1✔
1066
                available_letters = [chr(66 + i) for i in range(25)]
1✔
1067
                for label in labels:
1✔
1068
                    sp = label[1]
1✔
1069
                    if sp not in mapping:
1✔
1070
                        mapping[sp] = available_letters.pop(0)
1✔
1071
                centre_sp = "A"
1✔
1072
                labels = [(label[0], mapping[label[1]]) for label in labels]
1✔
1073

1074
            labels = [f"{label[1]}({label[0]})" for label in labels]
1✔
1075
            motif = f"{centre_sp}-{','.join(labels)}"
1✔
1076
            motifs.add(motif)
1✔
1077

1078
        return sorted(list(motifs))
1✔
1079

1080
    def as_dict(self):
1✔
1081
        """
1082
        As in :class:`pymatgen.core.Structure` except
1083
        with using `to_dict_of_dicts` from NetworkX
1084
        to store graph information.
1085
        """
1086
        d = {
1✔
1087
            "@module": type(self).__module__,
1088
            "@class": type(self).__name__,
1089
            "structure": self.structure.as_dict(),
1090
            "graphs": json_graph.adjacency_data(self.graph),
1091
        }
1092

1093
        return d
1✔
1094

1095
    @classmethod
1✔
1096
    def from_dict(cls, d):
1✔
1097
        """
1098
        As in :class:`pymatgen.core.Structure` except
1099
        restoring graphs using `from_dict_of_dicts`
1100
        from NetworkX to restore graph information.
1101
        """
1102
        s = Structure.from_dict(d["structure"])
1✔
1103
        return cls(s, d["graphs"])
1✔
1104

1105
    def __mul__(self, scaling_matrix):
1✔
1106
        """
1107
        Replicates the graph, creating a supercell,
1108
        intelligently joining together
1109
        edges that lie on periodic boundaries.
1110
        In principle, any operations on the expanded
1111
        graph could also be done on the original
1112
        graph, but a larger graph can be easier to
1113
        visualize and reason about.
1114
        :param scaling_matrix: same as Structure.__mul__
1115
        :return:
1116
        """
1117
        # Developer note: a different approach was also trialed, using
1118
        # a simple Graph (instead of MultiDiGraph), with node indices
1119
        # representing both site index and periodic image. Here, the
1120
        # number of nodes != number of sites in the Structure. This
1121
        # approach has many benefits, but made it more difficult to
1122
        # keep the graph in sync with its corresponding Structure.
1123

1124
        # Broadly, it would be easier to multiply the Structure
1125
        # *before* generating the StructureGraph, but this isn't
1126
        # possible when generating the graph using critic2 from
1127
        # charge density.
1128

1129
        # Multiplication works by looking for the expected position
1130
        # of an image node, and seeing if that node exists in the
1131
        # supercell. If it does, the edge is updated. This is more
1132
        # computationally expensive than just keeping track of the
1133
        # which new lattice images present, but should hopefully be
1134
        # easier to extend to a general 3x3 scaling matrix.
1135

1136
        # code adapted from Structure.__mul__
1137
        scale_matrix = np.array(scaling_matrix, int)
1✔
1138
        if scale_matrix.shape != (3, 3):
1✔
1139
            scale_matrix = np.array(scale_matrix * np.eye(3), int)
1✔
1140
        else:
1141
            # TODO: test __mul__ with full 3x3 scaling matrices
1142
            raise NotImplementedError("Not tested with 3x3 scaling matrices yet.")
×
1143
        new_lattice = Lattice(np.dot(scale_matrix, self.structure.lattice.matrix))
1✔
1144

1145
        f_lat = lattice_points_in_supercell(scale_matrix)
1✔
1146
        c_lat = new_lattice.get_cartesian_coords(f_lat)
1✔
1147

1148
        new_sites = []
1✔
1149
        new_graphs = []
1✔
1150

1151
        for v in c_lat:
1✔
1152
            # create a map of nodes from original graph to its image
1153
            mapping = {n: n + len(new_sites) for n in range(len(self.structure))}
1✔
1154

1155
            for site in self.structure:
1✔
1156
                s = PeriodicSite(
1✔
1157
                    site.species,
1158
                    site.coords + v,
1159
                    new_lattice,
1160
                    properties=site.properties,
1161
                    coords_are_cartesian=True,
1162
                    to_unit_cell=False,
1163
                )
1164

1165
                new_sites.append(s)
1✔
1166

1167
            new_graphs.append(nx.relabel_nodes(self.graph, mapping, copy=True))
1✔
1168

1169
        new_structure = Structure.from_sites(new_sites)
1✔
1170

1171
        # merge all graphs into one big graph
1172
        new_g = nx.MultiDiGraph()
1✔
1173
        for new_graph in new_graphs:
1✔
1174
            new_g = nx.union(new_g, new_graph)
1✔
1175

1176
        edges_to_remove = []  # tuple of (u, v, k)
1✔
1177
        edges_to_add = []  # tuple of (u, v, attr_dict)
1✔
1178

1179
        # list of new edges inside supercell
1180
        # for duplicate checking
1181
        edges_inside_supercell = [{u, v} for u, v, d in new_g.edges(data=True) if d["to_jimage"] == (0, 0, 0)]
1✔
1182
        new_periodic_images = []
1✔
1183

1184
        orig_lattice = self.structure.lattice
1✔
1185

1186
        # use k-d tree to match given position to an
1187
        # existing Site in Structure
1188
        kd_tree = KDTree(new_structure.cart_coords)
1✔
1189

1190
        # tolerance in Ã… for sites to be considered equal
1191
        # this could probably be a lot smaller
1192
        tol = 0.05
1✔
1193

1194
        for u, v, k, d in new_g.edges(keys=True, data=True):
1✔
1195
            to_jimage = d["to_jimage"]  # for node v
1✔
1196

1197
            # reduce unnecessary checking
1198
            if to_jimage != (0, 0, 0):
1✔
1199
                # get index in original site
1200
                n_u = u % len(self.structure)
1✔
1201
                n_v = v % len(self.structure)
1✔
1202

1203
                # get fractional coordinates of where atoms defined
1204
                # by edge are expected to be, relative to original
1205
                # lattice (keeping original lattice has
1206
                # significant benefits)
1207
                v_image_frac = np.add(self.structure[n_v].frac_coords, to_jimage)
1✔
1208
                u_frac = self.structure[n_u].frac_coords
1✔
1209

1210
                # using the position of node u as a reference,
1211
                # get relative Cartesian coordinates of where
1212
                # atoms defined by edge are expected to be
1213
                v_image_cart = orig_lattice.get_cartesian_coords(v_image_frac)
1✔
1214
                u_cart = orig_lattice.get_cartesian_coords(u_frac)
1✔
1215
                v_rel = np.subtract(v_image_cart, u_cart)
1✔
1216

1217
                # now retrieve position of node v in
1218
                # new supercell, and get asgolute Cartesian
1219
                # coordinates of where atoms defined by edge
1220
                # are expected to be
1221
                v_expec = new_structure[u].coords + v_rel
1✔
1222

1223
                # now search in new structure for these atoms
1224
                # query returns (distance, index)
1225
                v_present = kd_tree.query(v_expec)
1✔
1226
                v_present = v_present[1] if v_present[0] <= tol else None
1✔
1227

1228
                # check if image sites now present in supercell
1229
                # and if so, delete old edge that went through
1230
                # periodic boundary
1231
                if v_present is not None:
1✔
1232
                    new_u = u
1✔
1233
                    new_v = v_present
1✔
1234
                    new_d = d.copy()
1✔
1235

1236
                    # node now inside supercell
1237
                    new_d["to_jimage"] = (0, 0, 0)
1✔
1238

1239
                    edges_to_remove.append((u, v, k))
1✔
1240

1241
                    # make sure we don't try to add duplicate edges
1242
                    # will remove two edges for everyone one we add
1243
                    if {new_u, new_v} not in edges_inside_supercell:
1✔
1244
                        # normalize direction
1245
                        if new_v < new_u:
1✔
1246
                            new_u, new_v = new_v, new_u
1✔
1247

1248
                        edges_inside_supercell.append({new_u, new_v})
1✔
1249
                        edges_to_add.append((new_u, new_v, new_d))
1✔
1250

1251
                else:
1252
                    # want to find new_v such that we have
1253
                    # full periodic boundary conditions
1254
                    # so that nodes on one side of supercell
1255
                    # are connected to nodes on opposite side
1256

1257
                    v_expec_frac = new_structure.lattice.get_fractional_coords(v_expec)
1✔
1258

1259
                    # find new to_jimage
1260
                    # use np.around to fix issues with finite precision leading to incorrect image
1261
                    v_expec_image = np.around(v_expec_frac, decimals=3)
1✔
1262
                    v_expec_image = v_expec_image - v_expec_image % 1
1✔
1263

1264
                    v_expec_frac = np.subtract(v_expec_frac, v_expec_image)
1✔
1265
                    v_expec = new_structure.lattice.get_cartesian_coords(v_expec_frac)
1✔
1266
                    v_present = kd_tree.query(v_expec)
1✔
1267
                    v_present = v_present[1] if v_present[0] <= tol else None
1✔
1268

1269
                    if v_present is not None:
1✔
1270
                        new_u = u
1✔
1271
                        new_v = v_present
1✔
1272
                        new_d = d.copy()
1✔
1273
                        new_to_jimage = tuple(map(int, v_expec_image))
1✔
1274

1275
                        # normalize direction
1276
                        if new_v < new_u:
1✔
1277
                            new_u, new_v = new_v, new_u
1✔
1278
                            new_to_jimage = tuple(np.multiply(-1, d["to_jimage"]).astype(int))
1✔
1279

1280
                        new_d["to_jimage"] = new_to_jimage
1✔
1281

1282
                        edges_to_remove.append((u, v, k))
1✔
1283

1284
                        if (new_u, new_v, new_to_jimage) not in new_periodic_images:
1✔
1285
                            edges_to_add.append((new_u, new_v, new_d))
1✔
1286
                            new_periodic_images.append((new_u, new_v, new_to_jimage))
1✔
1287

1288
        logger.debug(f"Removing {len(edges_to_remove)} edges, adding {len(edges_to_add)} new edges.")
1✔
1289

1290
        # add/delete marked edges
1291
        for edge in edges_to_remove:
1✔
1292
            new_g.remove_edge(*edge)
1✔
1293
        for u, v, d in edges_to_add:
1✔
1294
            new_g.add_edge(u, v, **d)
1✔
1295

1296
        # return new instance of StructureGraph with supercell
1297
        d = {
1✔
1298
            "@module": type(self).__module__,
1299
            "@class": type(self).__name__,
1300
            "structure": new_structure.as_dict(),
1301
            "graphs": json_graph.adjacency_data(new_g),
1302
        }
1303

1304
        sg = StructureGraph.from_dict(d)
1✔
1305

1306
        return sg
1✔
1307

1308
    def __rmul__(self, other):
1✔
1309
        return self.__mul__(other)
×
1310

1311
    @classmethod
1✔
1312
    def _edges_to_string(cls, g):
1✔
1313
        header = "from    to  to_image    "
1✔
1314
        header_line = "----  ----  ------------"
1✔
1315
        edge_weight_name = g.graph["edge_weight_name"]
1✔
1316
        if edge_weight_name:
1✔
1317
            print_weights = ["weight"]
1✔
1318
            edge_label = g.graph["edge_weight_name"]
1✔
1319
            edge_weight_units = g.graph["edge_weight_units"]
1✔
1320
            if edge_weight_units:
1✔
1321
                edge_label += f" ({edge_weight_units})"
1✔
1322
            header += f"  {edge_label}"
1✔
1323
            header_line += f"  {'-' * max([18, len(edge_label)])}"
1✔
1324
        else:
1325
            print_weights = False
1✔
1326

1327
        s = header + "\n" + header_line + "\n"
1✔
1328

1329
        edges = list(g.edges(data=True))
1✔
1330

1331
        # sort edges for consistent ordering
1332
        edges.sort(key=itemgetter(0, 1))
1✔
1333

1334
        if print_weights:
1✔
1335
            for u, v, data in edges:
1✔
1336
                s += f"{u:4}  {v:4}  {str(data.get('to_jimage', (0, 0, 0))):12}  {data.get('weight', 0):.3e}\n"
1✔
1337
        else:
1338
            for u, v, data in edges:
1✔
1339
                s += f"{u:4}  {v:4}  {str(data.get('to_jimage', (0, 0, 0))):12}\n"
1✔
1340

1341
        return s
1✔
1342

1343
    def __str__(self):
1✔
1344
        s = "Structure Graph"
1✔
1345
        s += f"\nStructure: \n{self.structure}"
1✔
1346
        s += f"\nGraph: {self.name}\n"
1✔
1347
        s += self._edges_to_string(self.graph)
1✔
1348
        return s
1✔
1349

1350
    def __repr__(self):
1✔
1351
        s = "Structure Graph"
×
1352
        s += f"\nStructure: \n{self.structure.__repr__()}"
×
1353
        s += f"\nGraph: {self.name}\n"
×
1354
        s += self._edges_to_string(self.graph)
×
1355
        return s
×
1356

1357
    def __len__(self):
1✔
1358
        """
1359
        :return: length of Structure / number of nodes in graph
1360
        """
1361
        return len(self.structure)
1✔
1362

1363
    def sort(self, key=None, reverse=False):
1✔
1364
        """Same as Structure.sort(). Also remaps nodes in graph.
1365

1366
        Args:
1367
            key: key to sort by
1368
            reverse: reverse sort order
1369
        """
1370
        old_structure = self.structure.copy()
1✔
1371

1372
        # sort Structure
1373
        self.structure._sites = sorted(self.structure._sites, key=key, reverse=reverse)
1✔
1374

1375
        # apply Structure ordering to graph
1376
        mapping = {idx: self.structure.index(site) for idx, site in enumerate(old_structure)}
1✔
1377
        self.graph = nx.relabel_nodes(self.graph, mapping, copy=True)
1✔
1378

1379
        # normalize directions of edges
1380
        edges_to_remove = []
1✔
1381
        edges_to_add = []
1✔
1382
        for u, v, keys, data in self.graph.edges(keys=True, data=True):
1✔
1383
            if v < u:
1✔
1384
                new_v, new_u, new_d = u, v, data.copy()
1✔
1385
                new_d["to_jimage"] = tuple(np.multiply(-1, data["to_jimage"]).astype(int))
1✔
1386
                edges_to_remove.append((u, v, keys))
1✔
1387
                edges_to_add.append((new_u, new_v, new_d))
1✔
1388

1389
        # add/delete marked edges
1390
        for edge in edges_to_remove:
1✔
1391
            self.graph.remove_edge(*edge)
1✔
1392
        for u, v, d in edges_to_add:
1✔
1393
            self.graph.add_edge(u, v, **d)
1✔
1394

1395
    def __copy__(self):
1✔
1396
        return StructureGraph.from_dict(self.as_dict())
1✔
1397

1398
    def __eq__(self, other: object) -> bool:
1✔
1399
        """
1400
        Two StructureGraphs are equal if they have equal Structures,
1401
        and have the same edges between Sites. Edge weights can be
1402
        different and StructureGraphs can still be considered equal.
1403

1404
        :param other: StructureGraph
1405
        :return (bool):
1406
        """
1407
        if not isinstance(other, StructureGraph):
1✔
1408
            return NotImplemented
×
1409
        # sort for consistent node indices
1410
        # PeriodicSite should have a proper __hash__() value,
1411
        # using its frac_coords as a convenient key
1412
        mapping = {tuple(site.frac_coords): self.structure.index(site) for site in other.structure}
1✔
1413
        other_sorted = other.__copy__()
1✔
1414
        other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)])
1✔
1415

1416
        edges = {(u, v, d["to_jimage"]) for u, v, d in self.graph.edges(keys=False, data=True)}
1✔
1417

1418
        edges_other = {(u, v, d["to_jimage"]) for u, v, d in other_sorted.graph.edges(keys=False, data=True)}
1✔
1419

1420
        return (edges == edges_other) and (self.structure == other_sorted.structure)
1✔
1421

1422
    def diff(self, other, strict=True):
1✔
1423
        """
1424
        Compares two StructureGraphs. Returns dict with
1425
        keys 'self', 'other', 'both' with edges that are
1426
        present in only one StructureGraph ('self' and
1427
        'other'), and edges that are present in both.
1428

1429
        The Jaccard distance is a simple measure of the
1430
        dissimilarity between two StructureGraphs (ignoring
1431
        edge weights), and is defined by 1 - (size of the
1432
        intersection / size of the union) of the sets of
1433
        edges. This is returned with key 'dist'.
1434

1435
        Important note: all node indices are in terms
1436
        of the StructureGraph this method is called
1437
        from, not the 'other' StructureGraph: there
1438
        is no guarantee the node indices will be the
1439
        same if the underlying Structures are ordered
1440
        differently.
1441

1442
        :param other: StructureGraph
1443
        :param strict: if False, will compare bonds
1444
            from different Structures, with node indices
1445
            replaced by Species strings, will not count
1446
            number of occurrences of bonds
1447
        :return:
1448
        """
1449
        if self.structure != other.structure and strict:
1✔
1450
            return ValueError("Meaningless to compare StructureGraphs if corresponding Structures are different.")
×
1451

1452
        if strict:
1✔
1453
            # sort for consistent node indices
1454
            # PeriodicSite should have a proper __hash__() value,
1455
            # using its frac_coords as a convenient key
1456
            mapping = {tuple(site.frac_coords): self.structure.index(site) for site in other.structure}
1✔
1457
            other_sorted = copy.copy(other)
1✔
1458
            other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)])
1✔
1459

1460
            edges = {(u, v, d["to_jimage"]) for u, v, d in self.graph.edges(keys=False, data=True)}
1✔
1461

1462
            edges_other = {(u, v, d["to_jimage"]) for u, v, d in other_sorted.graph.edges(keys=False, data=True)}
1✔
1463

1464
        else:
1465
            edges = {
×
1466
                (str(self.structure[u].specie), str(self.structure[v].specie))
1467
                for u, v, d in self.graph.edges(keys=False, data=True)
1468
            }
1469

1470
            edges_other = {
×
1471
                (str(other.structure[u].specie), str(other.structure[v].specie))
1472
                for u, v, d in other.graph.edges(keys=False, data=True)
1473
            }
1474

1475
        if len(edges) == 0 and len(edges_other) == 0:
1✔
1476
            jaccard_dist = 0  # by definition
×
1477
        else:
1478
            jaccard_dist = 1 - len(edges & edges_other) / len(edges | edges_other)
1✔
1479

1480
        return {
1✔
1481
            "self": edges - edges_other,
1482
            "other": edges_other - edges,
1483
            "both": edges.intersection(edges_other),
1484
            "dist": jaccard_dist,
1485
        }
1486

1487
    def get_subgraphs_as_molecules(self, use_weights=False):
1✔
1488
        """
1489
        Retrieve subgraphs as molecules, useful for extracting
1490
        molecules from periodic crystals.
1491

1492
        Will only return unique molecules, not any duplicates
1493
        present in the crystal (a duplicate defined as an
1494
        isomorphic subgraph).
1495

1496
        :param use_weights (bool): If True, only treat subgraphs
1497
            as isomorphic if edges have the same weights. Typically,
1498
            this means molecules will need to have the same bond
1499
            lengths to be defined as duplicates, otherwise bond
1500
            lengths can differ. This is a fairly robust approach,
1501
            but will treat e.g. enantiomers as being duplicates.
1502

1503
        :return: list of unique Molecules in Structure
1504
        """
1505
        # creating a supercell is an easy way to extract
1506
        # molecules (and not, e.g., layers of a 2D crystal)
1507
        # without adding extra logic
1508
        if getattr(self, "_supercell_sg", None) is None:
1✔
1509
            self._supercell_sg = supercell_sg = self * (3, 3, 3)
1✔
1510

1511
        # make undirected to find connected subgraphs
1512
        supercell_sg.graph = nx.Graph(supercell_sg.graph)
1✔
1513

1514
        # find subgraphs
1515
        all_subgraphs = [supercell_sg.graph.subgraph(c) for c in nx.connected_components(supercell_sg.graph)]
1✔
1516

1517
        # discount subgraphs that lie across *supercell* boundaries
1518
        # these will subgraphs representing crystals
1519
        molecule_subgraphs = []
1✔
1520
        for subgraph in all_subgraphs:
1✔
1521
            intersects_boundary = any(d["to_jimage"] != (0, 0, 0) for u, v, d in subgraph.edges(data=True))
1✔
1522
            if not intersects_boundary:
1✔
1523
                molecule_subgraphs.append(nx.MultiDiGraph(subgraph))
1✔
1524

1525
        # add specie names to graph to be able to test for isomorphism
1526
        for subgraph in molecule_subgraphs:
1✔
1527
            for n in subgraph:
1✔
1528
                subgraph.add_node(n, specie=str(supercell_sg.structure[n].specie))
1✔
1529

1530
        # now define how we test for isomorphism
1531
        def node_match(n1, n2):
1✔
1532
            return n1["specie"] == n2["specie"]
1✔
1533

1534
        def edge_match(e1, e2):
1✔
1535
            if use_weights:
1✔
1536
                return e1["weight"] == e2["weight"]
×
1537
            return True
1✔
1538

1539
        # prune duplicate subgraphs
1540
        unique_subgraphs = []
1✔
1541
        for subgraph in molecule_subgraphs:
1✔
1542
            already_present = [
1✔
1543
                nx.is_isomorphic(subgraph, g, node_match=node_match, edge_match=edge_match) for g in unique_subgraphs
1544
            ]
1545

1546
            if not any(already_present):
1✔
1547
                unique_subgraphs.append(subgraph)
1✔
1548

1549
        # get Molecule objects for each subgraph
1550
        molecules = []
1✔
1551
        for subgraph in unique_subgraphs:
1✔
1552
            coords = [supercell_sg.structure[n].coords for n in subgraph.nodes()]
1✔
1553
            species = [supercell_sg.structure[n].specie for n in subgraph.nodes()]
1✔
1554

1555
            molecule = Molecule(species, coords)
1✔
1556

1557
            # shift so origin is at center of mass
1558
            molecule = molecule.get_centered_molecule()
1✔
1559

1560
            molecules.append(molecule)
1✔
1561

1562
        return molecules
1✔
1563

1564

1565
class MolGraphSplitError(Exception):
1✔
1566
    """
1567
    Raised when a molecule graph is failed to split into two disconnected
1568
    subgraphs
1569
    """
1570

1571

1572
class MoleculeGraph(MSONable):
1✔
1573
    """
1574
    This is a class for annotating a Molecule with
1575
    bond information, stored in the form of a graph. A "bond" does
1576
    not necessarily have to be a chemical bond, but can store any
1577
    kind of information that connects two Sites.
1578
    """
1579

1580
    def __init__(self, molecule, graph_data=None):
1✔
1581
        """
1582
        If constructing this class manually, use the `with_empty_graph`
1583
        method or `with_local_env_strategy` method (using an algorithm
1584
        provided by the `local_env` module, such as O'Keeffe).
1585

1586
        This class that contains connection information:
1587
        relationships between sites represented by a Graph structure,
1588
        and an associated structure object.
1589

1590
        This class uses the NetworkX package to store and operate
1591
        on the graph itself, but contains a lot of helper methods
1592
        to make associating a graph with a given molecule easier.
1593

1594
        Use cases for this include storing bonding information,
1595
        NMR J-couplings, Heisenberg exchange parameters, etc.
1596

1597
        :param molecule: Molecule object
1598

1599
        :param graph_data: dict containing graph information in
1600
            dict format (not intended to be constructed manually,
1601
            see as_dict method for format)
1602
        """
1603
        if isinstance(molecule, MoleculeGraph):
1✔
1604
            # just make a copy from input
1605
            graph_data = molecule.as_dict()["graphs"]
×
1606

1607
        self.molecule = molecule
1✔
1608
        self.graph = nx.readwrite.json_graph.adjacency_graph(graph_data)
1✔
1609

1610
        # tidy up edge attr dicts, reading to/from json duplicates
1611
        # information
1612
        for _, _, _, d in self.graph.edges(keys=True, data=True):
1✔
1613
            if "id" in d:
1✔
1614
                del d["id"]
1✔
1615
            if "key" in d:
1✔
1616
                del d["key"]
1✔
1617
            # ensure images are tuples (conversion to lists happens
1618
            # when serializing back from json), it's important images
1619
            # are hashable/immutable
1620
            if "to_jimage" in d:
1✔
1621
                d["to_jimage"] = tuple(d["to_jimage"])
×
1622
            if "from_jimage" in d:
1✔
1623
                d["from_jimage"] = tuple(d["from_jimage"])
×
1624

1625
        self.set_node_attributes()
1✔
1626

1627
    @classmethod
1✔
1628
    def with_empty_graph(cls, molecule, name="bonds", edge_weight_name=None, edge_weight_units=None):
1✔
1629
        """
1630
        Constructor for MoleculeGraph, returns a MoleculeGraph
1631
        object with an empty graph (no edges, only nodes defined
1632
        that correspond to Sites in Molecule).
1633

1634
        :param molecule (Molecule):
1635
        :param name (str): name of graph, e.g. "bonds"
1636
        :param edge_weight_name (str): name of edge weights,
1637
            e.g. "bond_length" or "exchange_constant"
1638
        :param edge_weight_units (str): name of edge weight units
1639
            e.g. "Ã…" or "eV"
1640
        :return (MoleculeGraph):
1641
        """
1642
        if edge_weight_name and (edge_weight_units is None):
1✔
1643
            raise ValueError(
×
1644
                "Please specify units associated "
1645
                "with your edge weights. Can be "
1646
                "empty string if arbitrary or "
1647
                "dimensionless."
1648
            )
1649

1650
        # construct graph with one node per site
1651
        # graph attributes don't change behavior of graph,
1652
        # they're just for book-keeping
1653
        graph = nx.MultiDiGraph(
1✔
1654
            edge_weight_name=edge_weight_name,
1655
            edge_weight_units=edge_weight_units,
1656
            name=name,
1657
        )
1658
        graph.add_nodes_from(range(len(molecule)))
1✔
1659

1660
        graph_data = json_graph.adjacency_data(graph)
1✔
1661

1662
        return cls(molecule, graph_data=graph_data)
1✔
1663

1664
    @staticmethod
1✔
1665
    def with_edges(molecule, edges):
1✔
1666
        """
1667
        Constructor for MoleculeGraph, using pre-existing or pre-defined edges
1668
        with optional edge parameters.
1669

1670
        :param molecule: Molecule object
1671
        :param edges: dict representing the bonds of the functional
1672
            group (format: {(u, v): props}, where props is a dictionary of
1673
            properties, including weight. Props should be None if no
1674
            additional properties are to be specified.
1675
        :return: mg, a MoleculeGraph
1676
        """
1677
        mg = MoleculeGraph.with_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="")
1✔
1678

1679
        for edge, props in edges.items():
1✔
1680
            try:
1✔
1681
                from_index = edge[0]
1✔
1682
                to_index = edge[1]
1✔
1683
            except TypeError:
×
1684
                raise ValueError("Edges must be given as (from_index, to_index) tuples")
×
1685

1686
            if props is not None:
1✔
1687
                if "weight" in props:
1✔
1688
                    weight = props["weight"]
1✔
1689
                    del props["weight"]
1✔
1690
                else:
1691
                    weight = None
1✔
1692

1693
                if len(props.items()) == 0:
1✔
1694
                    props = None
1✔
1695
            else:
1696
                weight = None
1✔
1697

1698
            nodes = mg.graph.nodes
1✔
1699
            if not (from_index in nodes and to_index in nodes):
1✔
1700
                raise ValueError(
×
1701
                    "Edges cannot be added if nodes are not present in the graph. Please check your indices."
1702
                )
1703

1704
            mg.add_edge(from_index, to_index, weight=weight, edge_properties=props)
1✔
1705

1706
        mg.set_node_attributes()
1✔
1707
        return mg
1✔
1708

1709
    @staticmethod
1✔
1710
    def with_local_env_strategy(molecule, strategy):
1✔
1711
        """
1712
        Constructor for MoleculeGraph, using a strategy
1713
        from :class:`pymatgen.analysis.local_env`.
1714

1715
        :param molecule: Molecule object
1716
        :param strategy: an instance of a
1717
            :class:`pymatgen.analysis.local_env.NearNeighbors` object
1718
        :return: mg, a MoleculeGraph
1719
        """
1720
        if not strategy.molecules_allowed:
1✔
1721
            raise ValueError("Chosen strategy is not designed for use with molecules! Please choose another strategy.")
×
1722
        extend_structure = strategy.extend_structure_molecules
1✔
1723

1724
        mg = MoleculeGraph.with_empty_graph(molecule, name="bonds", edge_weight_name="weight", edge_weight_units="")
1✔
1725

1726
        # NearNeighbor classes only (generally) work with structures
1727
        # molecules have to be boxed first
1728
        coords = molecule.cart_coords
1✔
1729

1730
        if extend_structure:
1✔
1731
            a = max(coords[:, 0]) - min(coords[:, 0]) + 100
1✔
1732
            b = max(coords[:, 1]) - min(coords[:, 1]) + 100
1✔
1733
            c = max(coords[:, 2]) - min(coords[:, 2]) + 100
1✔
1734

1735
            structure = molecule.get_boxed_structure(a, b, c, no_cross=True, reorder=False)
1✔
1736
        else:
1737
            structure = None
1✔
1738

1739
        for n in range(len(molecule)):
1✔
1740
            if structure is None:
1✔
1741
                neighbors = strategy.get_nn_info(molecule, n)
1✔
1742
            else:
1743
                neighbors = strategy.get_nn_info(structure, n)
1✔
1744
            for neighbor in neighbors:
1✔
1745
                # all bonds in molecules should not cross
1746
                # (artificial) periodic boundaries
1747
                if not np.array_equal(neighbor["image"], [0, 0, 0]):
1✔
1748
                    continue
×
1749

1750
                if n > neighbor["site_index"]:
1✔
1751
                    from_index = neighbor["site_index"]
1✔
1752
                    to_index = n
1✔
1753
                else:
1754
                    from_index = n
1✔
1755
                    to_index = neighbor["site_index"]
1✔
1756

1757
                mg.add_edge(
1✔
1758
                    from_index=from_index,
1759
                    to_index=to_index,
1760
                    weight=neighbor["weight"],
1761
                    warn_duplicates=False,
1762
                )
1763

1764
        duplicates = []
1✔
1765
        for edge in mg.graph.edges:
1✔
1766
            if edge[2] != 0:
1✔
1767
                duplicates.append(edge)
1✔
1768

1769
        for duplicate in duplicates:
1✔
1770
            mg.graph.remove_edge(duplicate[0], duplicate[1], key=duplicate[2])
1✔
1771

1772
        mg.set_node_attributes()
1✔
1773
        return mg
1✔
1774

1775
    @property
1✔
1776
    def name(self):
1✔
1777
        """
1778
        :return: Name of graph
1779
        """
1780
        return self.graph.graph["name"]
1✔
1781

1782
    @property
1✔
1783
    def edge_weight_name(self):
1✔
1784
        """
1785
        :return: Name of the edge weight property of graph
1786
        """
1787
        return self.graph.graph["edge_weight_name"]
1✔
1788

1789
    @property
1✔
1790
    def edge_weight_unit(self):
1✔
1791
        """
1792
        :return: Units of the edge weight property of graph
1793
        """
1794
        return self.graph.graph["edge_weight_units"]
1✔
1795

1796
    def add_edge(
1✔
1797
        self,
1798
        from_index,
1799
        to_index,
1800
        weight=None,
1801
        warn_duplicates=True,
1802
        edge_properties=None,
1803
    ):
1804
        """
1805
        Add edge to graph.
1806

1807
        Since physically a 'bond' (or other connection
1808
        between sites) doesn't have a direction, from_index,
1809
        from_jimage can be swapped with to_index, to_jimage.
1810

1811
        However, images will always be shifted so that
1812
        from_index < to_index and from_jimage becomes (0, 0, 0).
1813

1814
        :param from_index: index of site connecting from
1815
        :param to_index: index of site connecting to
1816
        :param weight (float): e.g. bond length
1817
        :param warn_duplicates (bool): if True, will warn if
1818
            trying to add duplicate edges (duplicate edges will not
1819
            be added in either case)
1820
        :param edge_properties (dict): any other information to
1821
            store on graph edges, similar to Structure's site_properties
1822
        :return:
1823
        """
1824
        # this is not necessary for the class to work, but
1825
        # just makes it neater
1826
        if to_index < from_index:
1✔
1827
            to_index, from_index = from_index, to_index
1✔
1828

1829
        # sanitize types
1830
        from_index, to_index = int(from_index), int(to_index)
1✔
1831

1832
        # check we're not trying to add a duplicate edge
1833
        # there should only ever be at most one edge
1834
        # between two sites
1835
        existing_edge_data = self.graph.get_edge_data(from_index, to_index)
1✔
1836
        if existing_edge_data and warn_duplicates:
1✔
1837
            warnings.warn(f"Trying to add an edge that already exists from site {from_index} to site {to_index}.")
×
1838
            return
×
1839

1840
        # generic container for additional edge properties,
1841
        # similar to site properties
1842
        edge_properties = edge_properties or {}
1✔
1843

1844
        if weight:
1✔
1845
            self.graph.add_edge(from_index, to_index, weight=weight, **edge_properties)
1✔
1846
        else:
1847
            self.graph.add_edge(from_index, to_index, **edge_properties)
1✔
1848

1849
    def insert_node(
1✔
1850
        self,
1851
        i,
1852
        species,
1853
        coords,
1854
        validate_proximity=False,
1855
        site_properties=None,
1856
        edges=None,
1857
    ):
1858
        """
1859
        A wrapper around Molecule.insert(), which also incorporates the new
1860
        site into the MoleculeGraph.
1861

1862
        :param i: Index at which to insert the new site
1863
        :param species: Species for the new site
1864
        :param coords: 3x1 array representing coordinates of the new site
1865
        :param validate_proximity: For Molecule.insert(); if True (default
1866
            False), distance will be checked to ensure that site can be safely
1867
            added.
1868
        :param site_properties: Site properties for Molecule
1869
        :param edges: List of dicts representing edges to be added to the
1870
            MoleculeGraph. These edges must include the index of the new site i,
1871
            and all indices used for these edges should reflect the
1872
            MoleculeGraph AFTER the insertion, NOT before. Each dict should at
1873
            least have a "to_index" and "from_index" key, and can also have a
1874
            "weight" and a "properties" key.
1875
        :return:
1876
        """
1877
        self.molecule.insert(
1✔
1878
            i,
1879
            species,
1880
            coords,
1881
            validate_proximity=validate_proximity,
1882
            properties=site_properties,
1883
        )
1884

1885
        mapping = {}
1✔
1886
        for j in range(len(self.molecule) - 1):
1✔
1887
            if j < i:
1✔
1888
                mapping[j] = j
1✔
1889
            else:
1890
                mapping[j] = j + 1
1✔
1891
        nx.relabel_nodes(self.graph, mapping, copy=False)
1✔
1892

1893
        self.graph.add_node(i)
1✔
1894
        self.set_node_attributes()
1✔
1895

1896
        if edges is not None:
1✔
1897
            for edge in edges:
1✔
1898
                try:
1✔
1899
                    self.add_edge(
1✔
1900
                        edge["from_index"],
1901
                        edge["to_index"],
1902
                        weight=edge.get("weight", None),
1903
                        edge_properties=edge.get("properties", None),
1904
                    )
1905
                except KeyError:
×
1906
                    raise RuntimeError("Some edges are invalid.")
×
1907

1908
    def set_node_attributes(self):
1✔
1909
        """
1910
        Replicates molecule site properties (specie, coords, etc.) in the
1911
        MoleculeGraph.
1912

1913
        :return:
1914
        """
1915
        species = {}
1✔
1916
        coords = {}
1✔
1917
        properties = {}
1✔
1918
        for node in self.graph.nodes():
1✔
1919
            species[node] = self.molecule[node].specie.symbol
1✔
1920
            coords[node] = self.molecule[node].coords
1✔
1921
            properties[node] = self.molecule[node].properties
1✔
1922

1923
        nx.set_node_attributes(self.graph, species, "specie")
1✔
1924
        nx.set_node_attributes(self.graph, coords, "coords")
1✔
1925
        nx.set_node_attributes(self.graph, properties, "properties")
1✔
1926

1927
    def alter_edge(self, from_index, to_index, new_weight=None, new_edge_properties=None):
1✔
1928
        """
1929
        Alters either the weight or the edge_properties of
1930
        an edge in the MoleculeGraph.
1931

1932
        :param from_index: int
1933
        :param to_index: int
1934
        :param new_weight: alter_edge does not require
1935
            that weight be altered. As such, by default, this
1936
            is None. If weight is to be changed, it should be a
1937
            float.
1938
        :param new_edge_properties: alter_edge does not require
1939
            that edge_properties be altered. As such, by default,
1940
            this is None. If any edge properties are to be changed,
1941
            it should be a dictionary of edge properties to be changed.
1942
        :return:
1943
        """
1944
        existing_edge = self.graph.get_edge_data(from_index, to_index)
1✔
1945

1946
        # ensure that edge exists before attempting to change it
1947
        if not existing_edge:
1✔
1948
            raise ValueError(
×
1949
                f"Edge between {from_index} and {to_index} cannot be altered; " f"no edge exists between those sites."
1950
            )
1951

1952
        # Third index should always be 0 because there should only be one edge between any two nodes
1953
        if new_weight is not None:
1✔
1954
            self.graph[from_index][to_index][0]["weight"] = new_weight
1✔
1955

1956
        if new_edge_properties is not None:
1✔
1957
            for prop in new_edge_properties:
1✔
1958
                self.graph[from_index][to_index][0][prop] = new_edge_properties[prop]
1✔
1959

1960
    def break_edge(self, from_index, to_index, allow_reverse=False):
1✔
1961
        """
1962
        Remove an edge from the MoleculeGraph
1963

1964
        :param from_index: int
1965
        :param to_index: int
1966
        :param allow_reverse: If allow_reverse is True, then break_edge will
1967
            attempt to break both (from_index, to_index) and, failing that,
1968
            will attempt to break (to_index, from_index).
1969
        :return:
1970
        """
1971
        # ensure that edge exists before attempting to remove it
1972
        existing_edge = self.graph.get_edge_data(from_index, to_index)
1✔
1973
        existing_reverse = None
1✔
1974

1975
        if existing_edge:
1✔
1976
            self.graph.remove_edge(from_index, to_index)
1✔
1977

1978
        else:
1979
            if allow_reverse:
×
1980
                existing_reverse = self.graph.get_edge_data(to_index, from_index)
×
1981

1982
            if existing_reverse:
×
1983
                self.graph.remove_edge(to_index, from_index)
×
1984
            else:
1985
                raise ValueError(
×
1986
                    f"Edge cannot be broken between {from_index} and {to_index}; "
1987
                    f"no edge exists between those sites."
1988
                )
1989

1990
    def remove_nodes(self, indices):
1✔
1991
        """
1992
        A wrapper for Molecule.remove_sites().
1993

1994
        :param indices: list of indices in the current Molecule (and graph) to
1995
            be removed.
1996
        :return:
1997
        """
1998
        self.molecule.remove_sites(indices)
1✔
1999
        self.graph.remove_nodes_from(indices)
1✔
2000

2001
        mapping = {}
1✔
2002
        for correct, current in enumerate(sorted(self.graph.nodes)):
1✔
2003
            mapping[current] = correct
1✔
2004

2005
        nx.relabel_nodes(self.graph, mapping, copy=False)
1✔
2006
        self.set_node_attributes()
1✔
2007

2008
    def get_disconnected_fragments(self):
1✔
2009
        """
2010
        Determine if the MoleculeGraph is connected. If it is not, separate the
2011
        MoleculeGraph into different MoleculeGraphs, where each resulting
2012
        MoleculeGraph is a disconnected subgraph of the original.
2013
        Currently, this function naively assigns the charge
2014
        of the total molecule to a single submolecule. A
2015
        later effort will be to actually accurately assign
2016
        charge.
2017
        NOTE: This function does not modify the original
2018
        MoleculeGraph. It creates a copy, modifies that, and
2019
        returns two or more new MoleculeGraph objects.
2020
        :return: list of MoleculeGraphs
2021
        """
2022
        if nx.is_weakly_connected(self.graph):
1✔
2023
            return [copy.deepcopy(self)]
1✔
2024

2025
        original = copy.deepcopy(self)
1✔
2026
        sub_mols = []
1✔
2027

2028
        # Had to use nx.weakly_connected_components because of deprecation
2029
        # of nx.weakly_connected_component_subgraphs
2030
        subgraphs = [original.graph.subgraph(c) for c in nx.weakly_connected_components(original.graph)]
1✔
2031

2032
        for subg in subgraphs:
1✔
2033
            nodes = sorted(list(subg.nodes))
1✔
2034

2035
            # Molecule indices are essentially list-based, so node indices
2036
            # must be remapped, incrementing from 0
2037
            mapping = {}
1✔
2038
            for i, n in enumerate(nodes):
1✔
2039
                mapping[n] = i
1✔
2040

2041
            # just give charge to whatever subgraph has node with index 0
2042
            # TODO: actually figure out how to distribute charge
2043
            if 0 in nodes:
1✔
2044
                charge = self.molecule.charge
1✔
2045
            else:
2046
                charge = 0
1✔
2047

2048
            # relabel nodes in graph to match mapping
2049
            new_graph = nx.relabel_nodes(subg, mapping)
1✔
2050

2051
            species = nx.get_node_attributes(new_graph, "specie")
1✔
2052
            coords = nx.get_node_attributes(new_graph, "coords")
1✔
2053
            raw_props = nx.get_node_attributes(new_graph, "properties")
1✔
2054

2055
            properties = {}
1✔
2056
            for prop_set in raw_props.values():
1✔
2057
                for prop in prop_set:
1✔
2058
                    if prop in properties:
×
2059
                        properties[prop].append(prop_set[prop])
×
2060
                    else:
2061
                        properties[prop] = [prop_set[prop]]
×
2062

2063
            # Site properties must be present for all atoms in the molecule
2064
            # in order to be used for Molecule instantiation
2065
            for k, v in properties.items():
1✔
2066
                if len(v) != len(species):
×
2067
                    del properties[k]  # pylint: disable=R1733
×
2068

2069
            new_mol = Molecule(species, coords, charge=charge, site_properties=properties)
1✔
2070
            graph_data = json_graph.adjacency_data(new_graph)
1✔
2071

2072
            # create new MoleculeGraph
2073
            sub_mols.append(MoleculeGraph(new_mol, graph_data=graph_data))
1✔
2074

2075
        return sub_mols
1✔
2076

2077
    def split_molecule_subgraphs(self, bonds, allow_reverse=False, alterations=None):
1✔
2078
        """
2079
        Split MoleculeGraph into two or more MoleculeGraphs by
2080
        breaking a set of bonds. This function uses
2081
        MoleculeGraph.break_edge repeatedly to create
2082
        disjoint graphs (two or more separate molecules).
2083
        This function does not only alter the graph
2084
        information, but also changes the underlying
2085
        Molecules.
2086
        If the bonds parameter does not include sufficient
2087
        bonds to separate two molecule fragments, then this
2088
        function will fail.
2089
        Currently, this function naively assigns the charge
2090
        of the total molecule to a single submolecule. A
2091
        later effort will be to actually accurately assign
2092
        charge.
2093
        NOTE: This function does not modify the original
2094
        MoleculeGraph. It creates a copy, modifies that, and
2095
        returns two or more new MoleculeGraph objects.
2096
        :param bonds: list of tuples (from_index, to_index)
2097
            representing bonds to be broken to split the MoleculeGraph.
2098
        :param alterations: a dict {(from_index, to_index): alt},
2099
            where alt is a dictionary including weight and/or edge
2100
            properties to be changed following the split.
2101
        :param allow_reverse: If allow_reverse is True, then break_edge will
2102
            attempt to break both (from_index, to_index) and, failing that,
2103
            will attempt to break (to_index, from_index).
2104
        :return: list of MoleculeGraphs
2105
        """
2106
        self.set_node_attributes()
1✔
2107
        original = copy.deepcopy(self)
1✔
2108

2109
        for bond in bonds:
1✔
2110
            original.break_edge(bond[0], bond[1], allow_reverse=allow_reverse)
1✔
2111

2112
        if nx.is_weakly_connected(original.graph):
1✔
2113
            raise MolGraphSplitError(
1✔
2114
                "Cannot split molecule; \
2115
                                MoleculeGraph is still connected."
2116
            )
2117

2118
        # alter any bonds before partition, to avoid remapping
2119
        if alterations is not None:
1✔
2120
            for u, v in alterations:
1✔
2121
                if "weight" in alterations[(u, v)]:
1✔
2122
                    weight = alterations[(u, v)]["weight"]
1✔
2123
                    del alterations[(u, v)]["weight"]
1✔
2124
                    edge_properties = alterations[(u, v)] if len(alterations[(u, v)]) != 0 else None
1✔
2125
                    original.alter_edge(u, v, new_weight=weight, new_edge_properties=edge_properties)
1✔
2126
                else:
2127
                    original.alter_edge(u, v, new_edge_properties=alterations[(u, v)])
×
2128

2129
        return original.get_disconnected_fragments()
1✔
2130

2131
    def build_unique_fragments(self):
1✔
2132
        """
2133
        Find all possible fragment combinations of the MoleculeGraphs (in other
2134
        words, all connected induced subgraphs)
2135

2136
        :return:
2137
        """
2138
        self.set_node_attributes()
1✔
2139

2140
        graph = self.graph.to_undirected()
1✔
2141

2142
        # find all possible fragments, aka connected induced subgraphs
2143
        frag_dict = {}
1✔
2144
        for ii in range(1, len(self.molecule)):
1✔
2145
            for combination in combinations(graph.nodes, ii):
1✔
2146
                mycomp = []
1✔
2147
                for idx in combination:
1✔
2148
                    mycomp.append(str(self.molecule[idx].specie))
1✔
2149
                mycomp = "".join(sorted(mycomp))
1✔
2150
                subgraph = nx.subgraph(graph, combination)
1✔
2151
                if nx.is_connected(subgraph):
1✔
2152
                    mykey = mycomp + str(len(subgraph.edges()))
1✔
2153
                    if mykey not in frag_dict:
1✔
2154
                        frag_dict[mykey] = [copy.deepcopy(subgraph)]
1✔
2155
                    else:
2156
                        frag_dict[mykey].append(copy.deepcopy(subgraph))
1✔
2157

2158
        # narrow to all unique fragments using graph isomorphism
2159
        unique_frag_dict = {}
1✔
2160
        for key, fragments in frag_dict.items():
1✔
2161
            unique_frags = []
1✔
2162
            for frag in fragments:
1✔
2163
                found = False
1✔
2164
                for f in unique_frags:
1✔
2165
                    if _isomorphic(frag, f):
1✔
2166
                        found = True
1✔
2167
                        break
1✔
2168
                if not found:
1✔
2169
                    unique_frags.append(frag)
1✔
2170
            unique_frag_dict[key] = copy.deepcopy(unique_frags)
1✔
2171

2172
        # convert back to molecule graphs
2173
        unique_mol_graph_dict = {}
1✔
2174
        for key, fragments in unique_frag_dict.items():
1✔
2175
            unique_mol_graph_list = []
1✔
2176
            for fragment in fragments:
1✔
2177
                mapping = {e: i for i, e in enumerate(sorted(fragment.nodes))}
1✔
2178
                remapped = nx.relabel_nodes(fragment, mapping)
1✔
2179

2180
                species = nx.get_node_attributes(remapped, "specie")
1✔
2181
                coords = nx.get_node_attributes(remapped, "coords")
1✔
2182

2183
                edges = {}
1✔
2184

2185
                for from_index, to_index, key in remapped.edges:
1✔
2186
                    edge_props = fragment.get_edge_data(from_index, to_index, key=key)
1✔
2187

2188
                    edges[(from_index, to_index)] = edge_props
1✔
2189

2190
                unique_mol_graph_list.append(
1✔
2191
                    self.with_edges(
2192
                        Molecule(species=species, coords=coords, charge=self.molecule.charge),
2193
                        edges,
2194
                    )
2195
                )
2196

2197
            frag_key = (
1✔
2198
                str(unique_mol_graph_list[0].molecule.composition.alphabetical_formula)
2199
                + " E"
2200
                + str(len(unique_mol_graph_list[0].graph.edges()))
2201
            )
2202
            unique_mol_graph_dict[frag_key] = copy.deepcopy(unique_mol_graph_list)
1✔
2203
        return unique_mol_graph_dict
1✔
2204

2205
    def substitute_group(
1✔
2206
        self,
2207
        index,
2208
        func_grp,
2209
        strategy,
2210
        bond_order=1,
2211
        graph_dict=None,
2212
        strategy_params=None,
2213
    ):
2214
        """
2215
        Builds off of Molecule.substitute to replace an atom in self.molecule
2216
        with a functional group. This method also amends self.graph to
2217
        incorporate the new functional group.
2218

2219
        NOTE: using a MoleculeGraph will generally produce a different graph
2220
        compared with using a Molecule or str (when not using graph_dict).
2221

2222
        :param index: Index of atom to substitute.
2223
        :param func_grp: Substituent molecule. There are three options:
2224

2225
            1. Providing an actual molecule as the input. The first atom
2226
                must be a DummySpecies X, indicating the position of
2227
                nearest neighbor. The second atom must be the next
2228
                nearest atom. For example, for a methyl group
2229
                substitution, func_grp should be X-CH3, where X is the
2230
                first site and C is the second site. What the code will
2231
                do is to remove the index site, and connect the nearest
2232
                neighbor to the C atom in CH3. The X-C bond indicates the
2233
                directionality to connect the atoms.
2234
            2. A string name. The molecule will be obtained from the
2235
                relevant template in func_groups.json.
2236
            3. A MoleculeGraph object.
2237
        :param strategy: Class from pymatgen.analysis.local_env.
2238
        :param bond_order: A specified bond order to calculate the bond
2239
                length between the attached functional group and the nearest
2240
                neighbor site. Defaults to 1.
2241
        :param graph_dict: Dictionary representing the bonds of the functional
2242
                group (format: {(u, v): props}, where props is a dictionary of
2243
                properties, including weight. If None, then the algorithm
2244
                will attempt to automatically determine bonds using one of
2245
                a list of strategies defined in pymatgen.analysis.local_env.
2246
        :param strategy_params: dictionary of keyword arguments for strategy.
2247
                If None, default parameters will be used.
2248
        :return:
2249
        """
2250

2251
        def map_indices(grp):
1✔
2252
            grp_map = {}
1✔
2253

2254
            # Get indices now occupied by functional group
2255
            # Subtracting 1 because the dummy atom X should not count
2256
            atoms = len(grp) - 1
1✔
2257
            offset = len(self.molecule) - atoms
1✔
2258

2259
            for i in range(atoms):
1✔
2260
                grp_map[i] = i + offset
1✔
2261

2262
            return grp_map
1✔
2263

2264
        # Work is simplified if a graph is already in place
2265
        if isinstance(func_grp, MoleculeGraph):
1✔
2266
            self.molecule.substitute(index, func_grp.molecule, bond_order=bond_order)
1✔
2267

2268
            mapping = map_indices(func_grp.molecule)
1✔
2269

2270
            for u, v in list(func_grp.graph.edges()):
1✔
2271
                edge_props = func_grp.graph.get_edge_data(u, v)[0]
1✔
2272
                weight = None
1✔
2273
                if "weight" in edge_props:
1✔
2274
                    weight = edge_props["weight"]
1✔
2275
                    del edge_props["weight"]
1✔
2276
                self.add_edge(mapping[u], mapping[v], weight=weight, edge_properties=edge_props)
1✔
2277

2278
        else:
2279
            if isinstance(func_grp, Molecule):
1✔
2280
                func_grp = copy.deepcopy(func_grp)
1✔
2281
            else:
2282
                try:
1✔
2283
                    func_grp = copy.deepcopy(FunctionalGroups[func_grp])
1✔
2284
                except Exception:
×
2285
                    raise RuntimeError("Can't find functional group in list. Provide explicit coordinate instead")
×
2286

2287
            self.molecule.substitute(index, func_grp, bond_order=bond_order)
1✔
2288

2289
            mapping = map_indices(func_grp)
1✔
2290

2291
            # Remove dummy atom "X"
2292
            func_grp.remove_species("X")
1✔
2293

2294
            if graph_dict is not None:
1✔
2295
                for u, v in graph_dict:
1✔
2296
                    edge_props = graph_dict[(u, v)]
1✔
2297
                    if "weight" in edge_props:
1✔
2298
                        weight = edge_props["weight"]
1✔
2299
                        del edge_props["weight"]
1✔
2300
                    self.add_edge(
1✔
2301
                        mapping[u],
2302
                        mapping[v],
2303
                        weight=weight,
2304
                        edge_properties=edge_props,
2305
                    )
2306

2307
            else:
2308
                if strategy_params is None:
1✔
2309
                    strategy_params = {}
1✔
2310
                strat = strategy(**strategy_params)
1✔
2311
                graph = self.with_local_env_strategy(func_grp, strat)
1✔
2312

2313
                for u, v in list(graph.graph.edges()):
1✔
2314
                    edge_props = graph.graph.get_edge_data(u, v)[0]
1✔
2315
                    weight = None
1✔
2316
                    if "weight" in edge_props:
1✔
2317
                        weight = edge_props["weight"]
1✔
2318
                        del edge_props["weight"]
1✔
2319

2320
                    if 0 not in list(graph.graph.nodes()):
1✔
2321
                        # If graph indices have different indexing
2322
                        u, v = (u - 1), (v - 1)
×
2323

2324
                    self.add_edge(
1✔
2325
                        mapping[u],
2326
                        mapping[v],
2327
                        weight=weight,
2328
                        edge_properties=edge_props,
2329
                    )
2330

2331
    def replace_group(
1✔
2332
        self,
2333
        index,
2334
        func_grp,
2335
        strategy,
2336
        bond_order=1,
2337
        graph_dict=None,
2338
        strategy_params=None,
2339
    ):
2340
        """
2341
        Builds off of Molecule.substitute and MoleculeGraph.substitute_group
2342
        to replace a functional group in self.molecule with a functional group.
2343
        This method also amends self.graph to incorporate the new functional
2344
        group.
2345

2346
        TODO: Figure out how to replace into a ring structure.
2347

2348
        :param index: Index of atom to substitute.
2349
        :param func_grp: Substituent molecule. There are three options:
2350

2351
            1. Providing an actual molecule as the input. The first atom
2352
               must be a DummySpecies X, indicating the position of
2353
               nearest neighbor. The second atom must be the next
2354
               nearest atom. For example, for a methyl group
2355
               substitution, func_grp should be X-CH3, where X is the
2356
               first site and C is the second site. What the code will
2357
               do is to remove the index site, and connect the nearest
2358
               neighbor to the C atom in CH3. The X-C bond indicates the
2359
               directionality to connect the atoms.
2360
            2. A string name. The molecule will be obtained from the
2361
               relevant template in func_groups.json.
2362
            3. A MoleculeGraph object.
2363
        :param strategy: Class from pymatgen.analysis.local_env.
2364
        :param bond_order: A specified bond order to calculate the bond
2365
            length between the attached functional group and the nearest
2366
            neighbor site. Defaults to 1.
2367
        :param graph_dict: Dictionary representing the bonds of the functional
2368
            group (format: {(u, v): props}, where props is a dictionary of
2369
            properties, including weight. If None, then the algorithm
2370
            will attempt to automatically determine bonds using one of
2371
            a list of strategies defined in pymatgen.analysis.local_env.
2372
        :param strategy_params: dictionary of keyword arguments for strategy.
2373
            If None, default parameters will be used.
2374
        :return:
2375
        """
2376
        self.set_node_attributes()
1✔
2377
        neighbors = self.get_connected_sites(index)
1✔
2378

2379
        # If the atom at index is terminal
2380
        if len(neighbors) == 1:
1✔
2381
            self.substitute_group(
1✔
2382
                index,
2383
                func_grp,
2384
                strategy,
2385
                bond_order=bond_order,
2386
                graph_dict=graph_dict,
2387
                strategy_params=strategy_params,
2388
            )
2389

2390
        else:
2391
            rings = self.find_rings(including=[index])
1✔
2392
            if len(rings) != 0:
1✔
2393
                raise RuntimeError(
×
2394
                    "Currently functional group replacement cannot occur at an atom within a ring structure."
2395
                )
2396

2397
            to_remove = set()
1✔
2398
            sizes = {}
1✔
2399
            disconnected = self.graph.to_undirected()
1✔
2400
            disconnected.remove_node(index)
1✔
2401
            for neighbor in neighbors:
1✔
2402
                sizes[neighbor[2]] = len(nx.descendants(disconnected, neighbor[2]))
1✔
2403

2404
            keep = max(sizes, key=lambda x: sizes[x])
1✔
2405
            for i in sizes:
1✔
2406
                if i != keep:
1✔
2407
                    to_remove.add(i)
1✔
2408

2409
            self.remove_nodes(list(to_remove))
1✔
2410
            self.substitute_group(
1✔
2411
                index,
2412
                func_grp,
2413
                strategy,
2414
                bond_order=bond_order,
2415
                graph_dict=graph_dict,
2416
                strategy_params=strategy_params,
2417
            )
2418

2419
    def find_rings(self, including=None):
1✔
2420
        """
2421
        Find ring structures in the MoleculeGraph.
2422

2423
        :param including: list of site indices. If
2424
            including is not None, then find_rings will
2425
            only return those rings including the specified
2426
            sites. By default, this parameter is None, and
2427
            all rings will be returned.
2428
        :return: dict {index:cycle}. Each
2429
            entry will be a ring (cycle, in graph theory terms) including the index
2430
            found in the Molecule. If there is no cycle including an index, the
2431
            value will be an empty list.
2432
        """
2433
        # Copies self.graph such that all edges (u, v) matched by edges (v, u)
2434
        undirected = self.graph.to_undirected()
1✔
2435
        directed = undirected.to_directed()
1✔
2436

2437
        cycles_nodes = []
1✔
2438
        cycles_edges = []
1✔
2439

2440
        # Remove all two-edge cycles
2441
        all_cycles = [c for c in nx.simple_cycles(directed) if len(c) > 2]
1✔
2442

2443
        # Using to_directed() will mean that each cycle always appears twice
2444
        # So, we must also remove duplicates
2445
        unique_sorted = []
1✔
2446
        unique_cycles = []
1✔
2447
        for cycle in all_cycles:
1✔
2448
            if sorted(cycle) not in unique_sorted:
1✔
2449
                unique_sorted.append(sorted(cycle))
1✔
2450
                unique_cycles.append(cycle)
1✔
2451

2452
        if including is None:
1✔
2453
            cycles_nodes = unique_cycles
1✔
2454
        else:
2455
            for i in including:
1✔
2456
                for cycle in unique_cycles:
1✔
2457
                    if i in cycle and cycle not in cycles_nodes:
1✔
2458
                        cycles_nodes.append(cycle)
1✔
2459

2460
        for cycle in cycles_nodes:
1✔
2461
            edges = []
1✔
2462
            for i, e in enumerate(cycle):
1✔
2463
                edges.append((cycle[i - 1], e))
1✔
2464
            cycles_edges.append(edges)
1✔
2465

2466
        return cycles_edges
1✔
2467

2468
    def get_connected_sites(self, n):
1✔
2469
        """
2470
        Returns a named tuple of neighbors of site n:
2471
        periodic_site, jimage, index, weight.
2472
        Index is the index of the corresponding site
2473
        in the original structure, weight can be
2474
        None if not defined.
2475
        :param n: index of Site in Molecule
2476
        :param jimage: lattice vector of site
2477
        :return: list of ConnectedSite tuples,
2478
            sorted by closest first
2479
        """
2480
        connected_sites = set()
1✔
2481

2482
        out_edges = list(self.graph.out_edges(n, data=True))
1✔
2483
        in_edges = list(self.graph.in_edges(n, data=True))
1✔
2484

2485
        for u, v, d in out_edges + in_edges:
1✔
2486
            weight = d.get("weight", None)
1✔
2487

2488
            if v == n:
1✔
2489
                site = self.molecule[u]
1✔
2490
                dist = self.molecule[v].distance(self.molecule[u])
1✔
2491

2492
                connected_site = ConnectedSite(site=site, jimage=(0, 0, 0), index=u, weight=weight, dist=dist)
1✔
2493
            else:
2494
                site = self.molecule[v]
1✔
2495
                dist = self.molecule[u].distance(self.molecule[v])
1✔
2496

2497
                connected_site = ConnectedSite(site=site, jimage=(0, 0, 0), index=v, weight=weight, dist=dist)
1✔
2498

2499
            connected_sites.add(connected_site)
1✔
2500

2501
        # return list sorted by closest sites first
2502
        connected_sites = list(connected_sites)
1✔
2503
        connected_sites.sort(key=lambda x: x.dist)
1✔
2504

2505
        return connected_sites
1✔
2506

2507
    def get_coordination_of_site(self, n):
1✔
2508
        """
2509
        Returns the number of neighbors of site n.
2510
        In graph terms, simply returns degree
2511
        of node corresponding to site n.
2512
        :param n: index of site
2513
        :return (int):
2514
        """
2515
        number_of_self_loops = sum(1 for n, v in self.graph.edges(n) if n == v)
1✔
2516
        return self.graph.degree(n) - number_of_self_loops
1✔
2517

2518
    def draw_graph_to_file(
1✔
2519
        self,
2520
        filename="graph",
2521
        diff=None,
2522
        hide_unconnected_nodes=False,
2523
        hide_image_edges=True,
2524
        edge_colors=False,
2525
        node_labels=False,
2526
        weight_labels=False,
2527
        image_labels=False,
2528
        color_scheme="VESTA",
2529
        keep_dot=False,
2530
        algo="fdp",
2531
    ):
2532
        """
2533
        Draws graph using GraphViz.
2534

2535
        The networkx graph object itself can also be drawn
2536
        with networkx's in-built graph drawing methods, but
2537
        note that this might give misleading results for
2538
        multigraphs (edges are super-imposed on each other).
2539

2540
        If visualization is difficult to interpret,
2541
        `hide_image_edges` can help, especially in larger
2542
        graphs.
2543

2544
        :param filename: filename to output, will detect filetype
2545
            from extension (any graphviz filetype supported, such as
2546
            pdf or png)
2547
        :param diff (StructureGraph): an additional graph to
2548
            compare with, will color edges red that do not exist in diff
2549
            and edges green that are in diff graph but not in the
2550
            reference graph
2551
        :param hide_unconnected_nodes: if True, hide unconnected
2552
            nodes
2553
        :param hide_image_edges: if True, do not draw edges that
2554
            go through periodic boundaries
2555
        :param edge_colors (bool): if True, use node colors to
2556
            color edges
2557
        :param node_labels (bool): if True, label nodes with
2558
            species and site index
2559
        :param weight_labels (bool): if True, label edges with
2560
            weights
2561
        :param image_labels (bool): if True, label edges with
2562
            their periodic images (usually only used for debugging,
2563
            edges to periodic images always appear as dashed lines)
2564
        :param color_scheme (str): "VESTA" or "JMOL"
2565
        :param keep_dot (bool): keep GraphViz .dot file for later
2566
            visualization
2567
        :param algo: any graphviz algo, "neato" (for simple graphs)
2568
            or "fdp" (for more crowded graphs) usually give good outputs
2569
        :return:
2570
        """
2571
        if not which(algo):
×
2572
            raise RuntimeError("StructureGraph graph drawing requires GraphViz binaries to be in the path.")
×
2573

2574
        # Developer note: NetworkX also has methods for drawing
2575
        # graphs using matplotlib, these also work here. However,
2576
        # a dedicated tool like GraphViz allows for much easier
2577
        # control over graph appearance and also correctly displays
2578
        # multi-graphs (matplotlib can superimpose multiple edges).
2579

2580
        g = self.graph.copy()
×
2581

2582
        g.graph = {"nodesep": 10.0, "dpi": 300, "overlap": "false"}
×
2583

2584
        # add display options for nodes
2585
        for n in g.nodes():
×
2586
            # get label by species name
2587
            label = f"{self.molecule[n].specie}({n})" if node_labels else ""
×
2588

2589
            # use standard color scheme for nodes
2590
            c = EL_COLORS[color_scheme].get(str(self.molecule[n].specie.symbol), [0, 0, 0])
×
2591

2592
            # get contrasting font color
2593
            # magic numbers account for perceived luminescence
2594
            # https://stackoverflow.com/questions/1855884/determine-font-color-based-on-background-color
2595
            fontcolor = "#000000" if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5 else "#ffffff"
×
2596

2597
            # convert color to hex string
2598
            color = f"#{c[0]:02x}{c[1]:02x}{c[2]:02x}"
×
2599

2600
            g.add_node(
×
2601
                n,
2602
                fillcolor=color,
2603
                fontcolor=fontcolor,
2604
                label=label,
2605
                fontname="Helvetica-bold",
2606
                style="filled",
2607
                shape="circle",
2608
            )
2609

2610
        edges_to_delete = []
×
2611

2612
        # add display options for edges
2613
        for u, v, k, d in g.edges(keys=True, data=True):
×
2614
            # retrieve from/to images, set as origin if not defined
2615
            if "to_image" in d:
×
2616
                to_image = d["to_jimage"]
×
2617
            else:
2618
                to_image = (0, 0, 0)
×
2619

2620
            # set edge style
2621
            d["style"] = "solid"
×
2622
            if to_image != (0, 0, 0):
×
2623
                d["style"] = "dashed"
×
2624
                if hide_image_edges:
×
2625
                    edges_to_delete.append((u, v, k))
×
2626

2627
            # don't show edge directions
2628
            d["arrowhead"] = "none"
×
2629

2630
            # only add labels for images that are not the origin
2631
            if image_labels:
×
2632
                d["headlabel"] = "" if to_image == (0, 0, 0) else f"to {to_image}"
×
2633
                d["arrowhead"] = "normal" if d["headlabel"] else "none"
×
2634

2635
            # optionally color edges using node colors
2636
            color_u = g.node[u]["fillcolor"]
×
2637
            color_v = g.node[v]["fillcolor"]
×
2638
            d["color_uv"] = f"{color_u};0.5:{color_v};0.5" if edge_colors else "#000000"
×
2639

2640
            # optionally add weights to graph
2641
            if weight_labels:
×
2642
                units = g.graph.get("edge_weight_units", "")
×
2643
                if d.get("weight"):
×
2644
                    d["label"] = f"{d['weight']:.2f} {units}"
×
2645

2646
            # update edge with our new style attributes
2647
            g.edges[u, v, k].update(d)
×
2648

2649
        # optionally remove periodic image edges,
2650
        # these can be confusing due to periodic boundaries
2651
        if hide_image_edges:
×
2652
            for edge_to_delete in edges_to_delete:
×
2653
                g.remove_edge(*edge_to_delete)
×
2654

2655
        # optionally hide unconnected nodes,
2656
        # these can appear when removing periodic edges
2657
        if hide_unconnected_nodes:
×
2658
            g = g.subgraph([n for n in g.degree() if g.degree()[n] != 0])
×
2659

2660
        # optionally highlight differences with another graph
2661
        if diff:
×
2662
            diff = self.diff(diff, strict=True)
×
2663
            green_edges = []
×
2664
            red_edges = []
×
2665
            for u, v, k, d in g.edges(keys=True, data=True):
×
2666
                if (u, v, d["to_jimage"]) in diff["self"]:
×
2667
                    # edge has been deleted
2668
                    red_edges.append((u, v, k))
×
2669
                elif (u, v, d["to_jimage"]) in diff["other"]:
×
2670
                    # edge has been added
2671
                    green_edges.append((u, v, k))
×
2672
            for u, v, k in green_edges:
×
2673
                g.edges[u, v, k].update({"color_uv": "#00ff00"})
×
2674
            for u, v, k in red_edges:
×
2675
                g.edges[u, v, k].update({"color_uv": "#ff0000"})
×
2676

2677
        basename, extension = os.path.splitext(filename)
×
2678
        extension = extension[1:]
×
2679

2680
        write_dot(g, basename + ".dot")
×
2681

2682
        with open(filename, "w") as f:
×
2683
            args = [algo, "-T", extension, basename + ".dot"]
×
2684
            with subprocess.Popen(args, stdout=f, stdin=subprocess.PIPE, close_fds=True) as rs:
×
2685
                rs.communicate()
×
2686
                if rs.returncode != 0:
×
2687
                    raise RuntimeError(f"{algo} exited with return code {rs.returncode}.")
×
2688

2689
        if not keep_dot:
×
2690
            os.remove(basename + ".dot")
×
2691

2692
    def as_dict(self):
1✔
2693
        """
2694
        As in :class:`pymatgen.core.Molecule` except
2695
        with using `to_dict_of_dicts` from NetworkX
2696
        to store graph information.
2697
        """
2698
        d = {
1✔
2699
            "@module": type(self).__module__,
2700
            "@class": type(self).__name__,
2701
            "molecule": self.molecule.as_dict(),
2702
            "graphs": json_graph.adjacency_data(self.graph),
2703
        }
2704

2705
        return d
1✔
2706

2707
    @classmethod
1✔
2708
    def from_dict(cls, d):
1✔
2709
        """
2710
        As in :class:`pymatgen.core.Molecule` except
2711
        restoring graphs using `from_dict_of_dicts`
2712
        from NetworkX to restore graph information.
2713
        """
2714
        m = Molecule.from_dict(d["molecule"])
1✔
2715
        return cls(m, d["graphs"])
1✔
2716

2717
    @classmethod
1✔
2718
    def _edges_to_string(cls, g):
1✔
2719
        header = "from    to  to_image    "
×
2720
        header_line = "----  ----  ------------"
×
2721
        edge_weight_name = g.graph["edge_weight_name"]
×
2722
        if edge_weight_name:
×
2723
            print_weights = ["weight"]
×
2724
            edge_label = g.graph["edge_weight_name"]
×
2725
            edge_weight_units = g.graph["edge_weight_units"]
×
2726
            if edge_weight_units:
×
2727
                edge_label += f" ({edge_weight_units})"
×
2728
            header += f"  {edge_label}"
×
2729
            header_line += f"  {'-' * max([18, len(edge_label)])}"
×
2730
        else:
2731
            print_weights = False
×
2732

2733
        s = f"{header}\n{header_line}\n"
×
2734

2735
        edges = list(g.edges(data=True))
×
2736

2737
        # sort edges for consistent ordering
2738
        edges.sort(key=itemgetter(0, 1))
×
2739

2740
        if print_weights:
×
2741
            for u, v, data in edges:
×
2742
                s += f"{u:4}  {v:4}  {str(data.get('to_jimage', (0, 0, 0))):12}  {data.get('weight', 0):.3e}\n"
×
2743
        else:
2744
            for u, v, data in edges:
×
2745
                s += f"{u:4}  {v:4}  {str(data.get('to_jimage', (0, 0, 0))):12}\n"
×
2746

2747
        return s
×
2748

2749
    def __str__(self) -> str:
1✔
2750
        s = "Molecule Graph"
×
2751
        s += f"\nMolecule: \n{self.molecule}"
×
2752
        s += f"\nGraph: {self.name}\n"
×
2753
        s += self._edges_to_string(self.graph)
×
2754
        return s
×
2755

2756
    def __repr__(self) -> str:
1✔
2757
        s = "Molecule Graph"
×
2758
        s += f"\nMolecule: \n{self.molecule.__repr__()}"
×
2759
        s += f"\nGraph: {self.name}\n"
×
2760
        s += self._edges_to_string(self.graph)
×
2761
        return s
×
2762

2763
    def __len__(self) -> int:
1✔
2764
        """
2765
        :return: length of Molecule / number of nodes in graph
2766
        """
2767
        return len(self.molecule)
1✔
2768

2769
    def sort(self, key: Callable[[Molecule], float] | None = None, reverse: bool = False) -> None:
1✔
2770
        """Same as Molecule.sort(). Also remaps nodes in graph.
2771

2772
        Args:
2773
            key (callable, optional): Sort key. Defaults to None.
2774
            reverse (bool, optional): Reverse sort order. Defaults to False.
2775
        """
2776
        old_molecule = self.molecule.copy()
1✔
2777

2778
        # sort Molecule
2779
        self.molecule._sites = sorted(self.molecule._sites, key=key, reverse=reverse)
1✔
2780

2781
        # apply Molecule ordering to graph
2782
        mapping = {idx: self.molecule.index(site) for idx, site in enumerate(old_molecule)}
1✔
2783
        self.graph = nx.relabel_nodes(self.graph, mapping, copy=True)
1✔
2784

2785
        # normalize directions of edges
2786
        edges_to_remove = []
1✔
2787
        edges_to_add = []
1✔
2788
        for u, v, keys, data in self.graph.edges(keys=True, data=True):
1✔
2789
            if v < u:
1✔
2790
                new_v, new_u, new_d = u, v, data.copy()
1✔
2791
                new_d["to_jimage"] = (0, 0, 0)
1✔
2792
                edges_to_remove.append((u, v, keys))
1✔
2793
                edges_to_add.append((new_u, new_v, new_d))
1✔
2794

2795
        # add/delete marked edges
2796
        for edge in edges_to_remove:
1✔
2797
            self.graph.remove_edge(*edge)
1✔
2798
        for u, v, data in edges_to_add:
1✔
2799
            self.graph.add_edge(u, v, **data)
1✔
2800

2801
    def __copy__(self):
1✔
2802
        return MoleculeGraph.from_dict(self.as_dict())
1✔
2803

2804
    def __eq__(self, other: object) -> bool:
1✔
2805
        """
2806
        Two MoleculeGraphs are equal if they have equal Molecules,
2807
        and have the same edges between Sites. Edge weights can be
2808
        different and MoleculeGraphs can still be considered equal.
2809

2810
        :param other: MoleculeGraph
2811
        :return (bool):
2812
        """
2813
        if not isinstance(other, type(self)):
1✔
2814
            return NotImplemented
×
2815

2816
        # sort for consistent node indices
2817
        # PeriodicSite should have a proper __hash__() value,
2818
        # using its frac_coords as a convenient key
2819
        try:
1✔
2820
            mapping = {tuple(site.coords): self.molecule.index(site) for site in other.molecule}
1✔
2821
        except ValueError:
×
2822
            return False
×
2823
        other_sorted = other.__copy__()
1✔
2824
        other_sorted.sort(key=lambda site: mapping[tuple(site.coords)])
1✔
2825

2826
        edges = {(u, v) for u, v, d in self.graph.edges(keys=False, data=True)}
1✔
2827

2828
        edges_other = {(u, v) for u, v, d in other_sorted.graph.edges(keys=False, data=True)}
1✔
2829

2830
        return (edges == edges_other) and (self.molecule == other_sorted.molecule)
1✔
2831

2832
    def isomorphic_to(self, other):
1✔
2833
        """
2834
        Checks if the graphs of two MoleculeGraphs are isomorphic to one
2835
        another. In order to prevent problems with misdirected edges, both
2836
        graphs are converted into undirected nx.Graph objects.
2837

2838
        :param other: MoleculeGraph object to be compared.
2839
        :return: bool
2840
        """
2841
        if len(self.molecule) != len(other.molecule):
1✔
2842
            return False
1✔
2843
        if self.molecule.composition.alphabetical_formula != other.molecule.composition.alphabetical_formula:
1✔
2844
            return False
×
2845
        if len(self.graph.edges()) != len(other.graph.edges()):
1✔
2846
            return False
×
2847
        return _isomorphic(self.graph, other.graph)
1✔
2848

2849
    def diff(self, other, strict=True):
1✔
2850
        """
2851
        Compares two MoleculeGraphs. Returns dict with
2852
        keys 'self', 'other', 'both' with edges that are
2853
        present in only one MoleculeGraph ('self' and
2854
        'other'), and edges that are present in both.
2855

2856
        The Jaccard distance is a simple measure of the
2857
        dissimilarity between two MoleculeGraphs (ignoring
2858
        edge weights), and is defined by 1 - (size of the
2859
        intersection / size of the union) of the sets of
2860
        edges. This is returned with key 'dist'.
2861

2862
        Important note: all node indices are in terms
2863
        of the MoleculeGraph this method is called
2864
        from, not the 'other' MoleculeGraph: there
2865
        is no guarantee the node indices will be the
2866
        same if the underlying Molecules are ordered
2867
        differently.
2868

2869
        :param other: MoleculeGraph
2870
        :param strict: if False, will compare bonds
2871
            from different Molecules, with node indices
2872
            replaced by Species strings, will not count
2873
            number of occurrences of bonds
2874
        :return:
2875
        """
2876
        if self.molecule != other.molecule and strict:
×
2877
            return ValueError("Meaningless to compare MoleculeGraphs if corresponding Molecules are different.")
×
2878

2879
        if strict:
×
2880
            # sort for consistent node indices
2881
            # PeriodicSite should have a proper __hash__() value,
2882
            # using its frac_coords as a convenient key
2883
            mapping = {tuple(site.frac_coords): self.molecule.index(site) for site in other.molecule}
×
2884
            other_sorted = copy.copy(other)
×
2885
            other_sorted.sort(key=lambda site: mapping[tuple(site.frac_coords)])
×
2886

2887
            edges = {(u, v, d.get("to_jimage", (0, 0, 0))) for u, v, d in self.graph.edges(keys=False, data=True)}
×
2888

2889
            edges_other = {
×
2890
                (u, v, d.get("to_jimage", (0, 0, 0))) for u, v, d in other_sorted.graph.edges(keys=False, data=True)
2891
            }
2892

2893
        else:
2894
            edges = {
×
2895
                (str(self.molecule[u].specie), str(self.molecule[v].specie))
2896
                for u, v, d in self.graph.edges(keys=False, data=True)
2897
            }
2898

2899
            edges_other = {
×
2900
                (str(other.structure[u].specie), str(other.structure[v].specie))
2901
                for u, v, d in other.graph.edges(keys=False, data=True)
2902
            }
2903

2904
        if len(edges) == 0 and len(edges_other) == 0:
×
2905
            jaccard_dist = 0  # by definition
×
2906
        else:
2907
            jaccard_dist = 1 - len(edges ^ edges_other) / len(edges | edges_other)
×
2908

2909
        return {
×
2910
            "self": edges - edges_other,
2911
            "other": edges_other - edges,
2912
            "both": edges ^ edges_other,
2913
            "dist": jaccard_dist,
2914
        }
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc