• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.53
/pymatgen/analysis/chemenv/utils/graph_utils.py
1
"""
2
This module contains some graph utils that are used in the chemenv package.
3
"""
4

5
from __future__ import annotations
1✔
6

7
import itertools
1✔
8
import operator
1✔
9

10
import networkx as nx
1✔
11
import numpy as np
1✔
12
from monty.json import MSONable
1✔
13

14
__author__ = "waroquiers"
1✔
15

16

17
def get_delta(node1, node2, edge_data):
1✔
18
    """
19
    Get the delta.
20
    :param node1:
21
    :param node2:
22
    :param edge_data:
23
    :return:
24
    """
25
    if node1.isite == edge_data["start"] and node2.isite == edge_data["end"]:
1✔
26
        return np.array(edge_data["delta"], dtype=int)
1✔
27
    if node2.isite == edge_data["start"] and node1.isite == edge_data["end"]:
1✔
28
        return -np.array(edge_data["delta"], dtype=int)
1✔
29
    raise ValueError("Trying to find a delta between two nodes with an edge that seems not to link these nodes.")
1✔
30

31

32
def get_all_simple_paths_edges(graph, source, target, cutoff=None, data=True):
1✔
33
    """
34
    Get all the simple path and edges.
35
    :param graph:
36
    :param source:
37
    :param target:
38
    :param cutoff:
39
    :param data:
40
    :return:
41
    """
42
    edge_paths = []
×
43
    if not graph.is_multigraph():
×
44
        for path in nx.all_simple_paths(graph, source, target, cutoff=cutoff):
×
45
            edge_paths.append([(path[ii], path[ii + 1]) for ii in range(len(path) - 1)])
×
46
        return edge_paths
×
47

48
    node_paths = []
×
49
    for path in nx.all_simple_paths(graph, source, target, cutoff=cutoff):
×
50
        exists = False
×
51
        for path2 in node_paths:
×
52
            if len(path2) == len(path) and np.all(path == path2):
×
53
                exists = True
×
54
                break
×
55
        if exists:
×
56
            continue
×
57
        node_paths.append(path)
×
58
        current_edge_paths = [[]]
×
59
        for node1, node2 in [(node1, path[inode1 + 1]) for inode1, node1 in enumerate(path[:-1])]:
×
60
            new_edge_paths = []
×
61
            for key, edge_data in graph[node1][node2].items():
×
62
                for tmp_edge_path in current_edge_paths:
×
63
                    if data:
×
64
                        new_path = list(tmp_edge_path)
×
65
                        new_path.append((node1, node2, key, edge_data))
×
66
                        new_edge_paths.append(new_path)
×
67
                    else:
68
                        new_path = list(tmp_edge_path)
×
69
                        new_path.append((node1, node2, key))
×
70
                        new_edge_paths.append(new_path)
×
71
            current_edge_paths = new_edge_paths
×
72
        edge_paths.extend(current_edge_paths)
×
73
    return edge_paths
×
74

75

76
def _c2index_isreverse(c1, c2):
1✔
77
    """
78
    Private helper function to get the index c2_0_index of the first node of cycle c1
79
    in cycle c2 and whether the cycle c2 should be reversed or not.
80

81
    Returns None if the first node of cycle c1 is not found in cycle c2.
82
    The reverse value depends on the index c2_1_index of the second node of cycle c1 in
83
    cycle c2 : if it is *just after* the c2_0_index, reverse is False, if it is
84
    *just before* the c2_0_index, reverse is True, otherwise the function returns None).
85
    """
86

87
    c1_0 = c1.nodes[0]
×
88
    c1_1 = c1.nodes[1]
×
89
    if c1_0 not in c2.nodes:
×
90
        return None, None, "First node of cycle c1 not found in cycle c2."
×
91
    if c1_1 not in c2.nodes:
×
92
        return None, None, "Second node of cycle c1 not found in cycle c2."
×
93
    c2_0_index = c2.nodes.index(c1_0)
×
94
    c2_1_index = c2.nodes.index(c1_1)
×
95
    if c2_0_index == 0:
×
96
        if c2_1_index == 1:
×
97
            reverse = False
×
98
        elif c2_1_index == len(c2.nodes) - 1:
×
99
            reverse = True
×
100
        else:
101
            return (
×
102
                None,
103
                None,
104
                "Second node of cycle c1 is not second or last in cycle c2 "
105
                "(first node of cycle c1 is first in cycle c2).",
106
            )
107
    elif c2_0_index == len(c2.nodes) - 1:
×
108
        if c2_1_index == 0:
×
109
            reverse = False
×
110
        elif c2_1_index == c2_0_index - 1:
×
111
            reverse = True
×
112
        else:
113
            return (
×
114
                None,
115
                None,
116
                "Second node of cycle c1 is not first or before last in cycle c2 "
117
                "(first node of cycle c1 is last in cycle c2).",
118
            )
119
    else:
120
        if c2_1_index == c2_0_index + 1:
×
121
            reverse = False
×
122
        elif c2_1_index == c2_0_index - 1:
×
123
            reverse = True
×
124
        else:
125
            return (
×
126
                None,
127
                None,
128
                "Second node of cycle c1 in cycle c2 is not just after or "
129
                "just before first node of cycle c1 in cycle c2.",
130
            )
131
    return c2_0_index, reverse, ""
×
132

133

134
class SimpleGraphCycle(MSONable):
1✔
135
    """
136
    Class used to describe a cycle in a simple graph (graph without multiple edges).
137

138
    Note that the convention used here is the networkx convention for which simple graphs allow
139
    to have self-loops in a simple graph.
140
    No simple graph cycle with two nodes is possible in a simple graph. The graph will not
141
    be validated if validate is set to False.
142
    By default, the "ordered" parameter is None, in which case the SimpleGraphCycle will be ordered.
143
    If the user explicitly sets ordered to False, the SimpleGraphCycle will not be ordered.
144
    """
145

146
    def __init__(self, nodes, validate=True, ordered=None):
1✔
147
        """
148
        :param nodes:
149
        :param validate:
150
        :param ordered:
151
        """
152
        self.nodes = tuple(nodes)
1✔
153
        if validate:
1✔
154
            self.validate()
1✔
155
        if ordered is not None:
1✔
156
            self.ordered = ordered
1✔
157
        else:
158
            self.order()
1✔
159

160
    def _is_valid(self, check_strict_ordering=False):
1✔
161
        """Check if a SimpleGraphCycle is valid.
162

163
        This method checks :
164
        - that there are no duplicate nodes,
165
        - that there are either 1 or more than 2 nodes
166
        :return: True if the SimpleGraphCycle is valid, False otherwise.
167
        """
168
        if len(self.nodes) == 1:
1✔
169
            return True, ""
1✔
170
        if len(self.nodes) == 2:
1✔
171
            return False, "Simple graph cycle with 2 nodes is not valid."
1✔
172
        if len(self.nodes) == 0:
1✔
173
            return False, "Empty cycle is not valid."
1✔
174
        if len(self.nodes) != len(set(self.nodes)):  # Should not have duplicate nodes
1✔
175
            return False, "Duplicate nodes."
1✔
176
        if check_strict_ordering:
1✔
177
            try:
1✔
178
                sorted_nodes = sorted(self.nodes)
1✔
179
            except TypeError as te:
1✔
180
                msg = te.args[0]
1✔
181
                if "'<' not supported between instances of" in msg:
1✔
182
                    return False, "The nodes are not sortable."
1✔
183
                raise
×
184
            res = all(i < j for i, j in zip(sorted_nodes, sorted_nodes[1:]))
1✔
185
            if not res:
1✔
186
                return (
1✔
187
                    False,
188
                    "The list of nodes in the cycle cannot be strictly ordered.",
189
                )
190
        return True, ""
1✔
191

192
    def validate(self, check_strict_ordering=False):
1✔
193
        """
194
        :param check_strict_ordering:
195
        :return:
196
        """
197
        is_valid, msg = self._is_valid(check_strict_ordering=check_strict_ordering)
1✔
198
        if not is_valid:
1✔
199
            raise ValueError(f"SimpleGraphCycle is not valid : {msg}")
1✔
200

201
    def order(self, raise_on_fail=True):
1✔
202
        """Orders the SimpleGraphCycle.
203

204
        The ordering is performed such that the first node is the "lowest" one
205
        and the second node is the lowest one of the two neighbor nodes of the
206
        first node. If raise_on_fail is set to True a RuntimeError will be
207
        raised if the ordering fails.
208

209
        :param raise_on_fail: If set to True, will raise a RuntimeError if the
210
                              ordering fails.
211
        :return: None
212
        """
213
        # always validate the cycle if it needs to be ordered
214
        # also validates that the nodes can be strictly ordered
215
        try:
1✔
216
            self.validate(check_strict_ordering=True)
1✔
217
        except ValueError as ve:
1✔
218
            msg = ve.args[0]
1✔
219
            if "SimpleGraphCycle is not valid :" in msg and not raise_on_fail:
1✔
220
                self.ordered = False
1✔
221
                return
1✔
222
            raise
1✔
223

224
        if len(self.nodes) == 1:
1✔
225
            self.ordered = True
1✔
226
            return
1✔
227

228
        # Not sure whether the following should be checked here... if strict ordering was guaranteed by
229
        # the validate method, why would it be needed to have a unique class. One could have 2 subclasses of
230
        # the same parent class and things could be ok. To be checked what to do. (see also MultiGraphCycle)
231
        node_classes = {n.__class__ for n in self.nodes}
1✔
232
        if len(node_classes) > 1:
1✔
233
            if raise_on_fail:
1✔
234
                raise ValueError("Could not order simple graph cycle as the nodes are of different classes.")
1✔
235
            self.ordered = False
1✔
236
            return
1✔
237

238
        min_index, min_node = min(enumerate(self.nodes), key=operator.itemgetter(1))
1✔
239
        reverse = self.nodes[(min_index - 1) % len(self.nodes)] < self.nodes[(min_index + 1) % len(self.nodes)]
1✔
240
        if reverse:
1✔
241
            self.nodes = self.nodes[min_index::-1] + self.nodes[:min_index:-1]
1✔
242
        else:
243
            self.nodes = self.nodes[min_index:] + self.nodes[:min_index]
1✔
244
        self.ordered = True
1✔
245

246
    def __hash__(self):
1✔
247
        return len(self.nodes)
1✔
248

249
    def __len__(self):
1✔
250
        return len(self.nodes)
1✔
251

252
    def __str__(self):
1✔
253
        out = ["Simple cycle with nodes :"]
1✔
254
        out.extend([str(node) for node in self.nodes])
1✔
255
        return "\n".join(out)
1✔
256

257
    def __eq__(self, other: object) -> bool:
1✔
258
        if not isinstance(other, SimpleGraphCycle):
1✔
259
            return NotImplemented
×
260
        if not self.ordered or not other.ordered:
1✔
261
            raise RuntimeError("Simple cycles should be ordered in order to be compared.")
×
262
        return self.nodes == other.nodes
1✔
263

264
    @classmethod
1✔
265
    def from_edges(cls, edges, edges_are_ordered=True):
1✔
266
        """Constructs SimpleGraphCycle from a list edges.
267

268
        By default, the edges list is supposed to be ordered as it will be
269
        much faster to construct the cycle. If edges_are_ordered is set to
270
        False, the code will automatically try to find the corresponding edge
271
        order in the list.
272
        """
273
        if edges_are_ordered:
1✔
274
            nodes = [e[0] for e in edges]
1✔
275
            if not all(e1e2[0][1] == e1e2[1][0] for e1e2 in zip(edges, edges[1:])) or edges[-1][1] != edges[0][0]:
1✔
276
                raise ValueError("Could not construct a cycle from edges.")
×
277
        else:
278
            remaining_edges = list(edges)
1✔
279
            nodes = list(remaining_edges.pop())
1✔
280
            while remaining_edges:
1✔
281
                prev_node = nodes[-1]
1✔
282
                for ie, e in enumerate(remaining_edges):
1✔
283
                    if prev_node == e[0]:
1✔
284
                        remaining_edges.pop(ie)
1✔
285
                        nodes.append(e[1])
1✔
286
                        break
1✔
287
                    if prev_node == e[1]:
1✔
288
                        remaining_edges.pop(ie)
1✔
289
                        nodes.append(e[0])
1✔
290
                        break
1✔
291
                else:  # did not find the next edge
292
                    raise ValueError("Could not construct a cycle from edges.")
1✔
293
            if nodes[0] != nodes[-1]:
1✔
294
                raise ValueError("Could not construct a cycle from edges.")
1✔
295
            nodes.pop()
1✔
296
        return cls(nodes)
1✔
297

298
    def as_dict(self):
1✔
299
        """
300
        :return: MSONAble dict
301
        """
302
        d = MSONable.as_dict(self)
1✔
303
        # Transforming tuple object to a list to allow BSON and MongoDB
304
        d["nodes"] = list(d["nodes"])
1✔
305
        return d
1✔
306

307
    @classmethod
1✔
308
    def from_dict(cls, d, validate=False):
1✔
309
        """
310
        Serialize from dict.
311
        :param d:
312
        :param validate:
313
        :return:
314
        """
315
        return cls(nodes=d["nodes"], validate=validate, ordered=d["ordered"])
1✔
316

317

318
class MultiGraphCycle(MSONable):
1✔
319
    """Class used to describe a cycle in a multigraph.
320

321
    nodes are the nodes of the cycle and edge_indices are the indices of the edges in the cycle.
322
    The nth index in edge_indices corresponds to the edge index between the nth node in nodes and
323
    the (n+1)th node in nodes with the exception of the last one being the edge index between
324
    the last node in nodes and the first node in nodes
325

326
    Example: A cycle
327
        nodes:          1 - 3 - 4 - 0 - 2 - (1)
328
        edge_indices:     0 . 1 . 0 . 2 . 0 . (0)
329
    """
330

331
    def __init__(self, nodes, edge_indices, validate=True, ordered=None):
1✔
332
        """
333
        :param nodes:
334
        :param edge_indices:
335
        :param validate:
336
        :param ordered:
337
        """
338
        self.nodes = tuple(nodes)
1✔
339
        self.edge_indices = tuple(edge_indices)
1✔
340
        if validate:
1✔
341
            self.validate()
1✔
342
        if ordered is not None:
1✔
343
            self.ordered = ordered
1✔
344
        else:
345
            self.order()
1✔
346
        self.edge_deltas = None
1✔
347
        self.per = None
1✔
348

349
    def _is_valid(self, check_strict_ordering=False):
1✔
350
        """Check if a MultiGraphCycle is valid.
351

352
        This method checks :
353
        - that there are no duplicate nodes,
354
        - that there are either 1 or more than 2 nodes
355
        :return: True if the SimpleGraphCycle is valid, False otherwise.
356
        """
357
        if len(self.nodes) != len(self.edge_indices):  # Should have the same number of nodes and edges
1✔
358
            return False, "Number of nodes different from number of edge indices."
1✔
359
        if len(self.nodes) == 0:
1✔
360
            return False, "Empty cycle is not valid."
1✔
361
        if len(self.nodes) != len(set(self.nodes)):  # Should not have duplicate nodes
1✔
362
            return False, "Duplicate nodes."
1✔
363
        if len(self.nodes) == 2:  # Cycles with two nodes cannot use the same edge for the cycle
1✔
364
            if self.edge_indices[0] == self.edge_indices[1]:
1✔
365
                return (
1✔
366
                    False,
367
                    "Cycles with two nodes cannot use the same edge for the cycle.",
368
                )
369
        if check_strict_ordering:
1✔
370
            try:
1✔
371
                sorted_nodes = sorted(self.nodes)
1✔
372
            except TypeError as te:
1✔
373
                msg = te.args[0]
1✔
374
                if "'<' not supported between instances of" in msg:
1✔
375
                    return False, "The nodes are not sortable."
1✔
376
                raise
×
377
            res = all(i < j for i, j in zip(sorted_nodes, sorted_nodes[1:]))
1✔
378
            if not res:
1✔
379
                return (
1✔
380
                    False,
381
                    "The list of nodes in the cycle cannot be strictly ordered.",
382
                )
383
        return True, ""
1✔
384

385
    def validate(self, check_strict_ordering=False):
1✔
386
        """
387
        :param check_strict_ordering:
388
        :return:
389
        """
390
        is_valid, msg = self._is_valid(check_strict_ordering=check_strict_ordering)
1✔
391
        if not is_valid:
1✔
392
            raise ValueError(f"MultiGraphCycle is not valid : {msg}")
1✔
393

394
    def order(self, raise_on_fail=True):
1✔
395
        """Orders the SimpleGraphCycle.
396

397
        The ordering is performed such that the first node is the "lowest" one
398
        and the second node is the lowest one of the two neighbor nodes of the
399
        first node. If raise_on_fail is set to True a RuntimeError will be
400
        raised if the ordering fails.
401

402
        :param raise_on_fail: If set to True, will raise a RuntimeError if the
403
                              ordering fails.
404
        :return: None
405
        """
406
        # always validate the cycle if it needs to be ordered
407
        # also validates that the nodes can be strictly ordered
408
        try:
1✔
409
            self.validate(check_strict_ordering=True)
1✔
410
        except ValueError as ve:
1✔
411
            msg = ve.args[0]
1✔
412
            if "MultiGraphCycle is not valid :" in msg and not raise_on_fail:
1✔
413
                self.ordered = False
1✔
414
                return
1✔
415
            raise
1✔
416

417
        if len(self.nodes) == 1:
1✔
418
            self.ordered = True
1✔
419
            return
1✔
420

421
        # Not sure whether the following should be checked here... if strict ordering was guaranteed by
422
        # the validate method, why would it be needed to have a unique class. One could have 2 subclasses of
423
        # the same parent class and things could be ok. To be checked what to do. (see also SimpleGraphCycle)
424
        node_classes = {n.__class__ for n in self.nodes}
1✔
425
        if len(node_classes) > 1:
1✔
426
            if raise_on_fail:
1✔
427
                raise ValueError("Could not order simple graph cycle as the nodes are of different classes.")
1✔
428
            self.ordered = False
1✔
429
            return
1✔
430

431
        min_index, min_node = min(enumerate(self.nodes), key=operator.itemgetter(1))
1✔
432

433
        # Special case when number of nodes is 2 because the two
434
        # edge_indices refer to the same pair of nodes
435
        if len(self.nodes) == 2:
1✔
436
            self.nodes = tuple(sorted(self.nodes))
1✔
437
            self.edge_indices = tuple(sorted(self.edge_indices))
1✔
438
            self.ordered = True
1✔
439
            return
1✔
440

441
        reverse = self.nodes[(min_index - 1) % len(self.nodes)] < self.nodes[(min_index + 1) % len(self.nodes)]
1✔
442
        if reverse:
1✔
443
            self.nodes = self.nodes[min_index::-1] + self.nodes[:min_index:-1]
1✔
444
            min_edge_index = (min_index - 1) % len(self.nodes)
1✔
445
            self.edge_indices = self.edge_indices[min_edge_index::-1] + self.edge_indices[:min_edge_index:-1]
1✔
446
        else:
447
            self.nodes = self.nodes[min_index:] + self.nodes[:min_index]
1✔
448
            self.edge_indices = self.edge_indices[min_index:] + self.edge_indices[:min_index]
1✔
449
        self.ordered = True
1✔
450

451
    def __hash__(self):
1✔
452
        return len(self.nodes)
1✔
453

454
    def __len__(self):
1✔
455
        return len(self.nodes)
1✔
456

457
    def __str__(self):
1✔
458
        out = ["Multigraph cycle with nodes :"]
×
459
        cycle = []
×
460
        for inode, node1, node2 in zip(itertools.count(), self.nodes[:-1], self.nodes[1:]):
×
461
            cycle.append(f"{node1} -*{self.edge_indices[inode]:d}*- {node2}")
×
462
        cycle.append(f"{self.nodes[-1]} -*{self.edge_indices[-1]:d}*- {self.nodes[0]}")
×
463
        # out.extend([str(node) for node in self.nodes])
464
        out.extend(cycle)
×
465
        return "\n".join(out)
×
466

467
    def __eq__(self, other: object) -> bool:
1✔
468
        if not isinstance(other, MultiGraphCycle):
1✔
469
            return NotImplemented
×
470
        if not self.ordered or not other.ordered:
1✔
471
            raise RuntimeError("Multigraph cycles should be ordered in order to be compared.")
×
472
        return self.nodes == other.nodes and self.edge_indices == other.edge_indices
1✔
473

474

475
def get_all_elementary_cycles(graph):
1✔
476
    """
477
    :param graph:
478
    :return:
479
    """
480
    if not isinstance(graph, nx.Graph):
×
481
        raise TypeError("graph should be a networkx Graph object.")
×
482

483
    cycle_basis = nx.cycle_basis(graph)
×
484

485
    # print('CYCLE BASIS')
486
    # print(cycle_basis)
487

488
    if len(cycle_basis) < 2:
×
489
        return {SimpleGraphCycle(c) for c in cycle_basis}
×
490

491
    all_edges_dict = {}
×
492
    index2edge = []
×
493
    nedges = 0
×
494
    for n1, n2 in graph.edges:
×
495
        all_edges_dict[(n1, n2)] = nedges
×
496
        all_edges_dict[(n2, n1)] = nedges
×
497
        index2edge.append((n1, n2))
×
498
        nedges += 1
×
499
    cycles_matrix = np.zeros(shape=(len(cycle_basis), nedges), dtype=bool)
×
500
    for icycle, cycle in enumerate(cycle_basis):
×
501
        for in1, n1 in enumerate(cycle):
×
502
            n2 = cycle[(in1 + 1) % len(cycle)]
×
503
            iedge = all_edges_dict[(n1, n2)]
×
504
            cycles_matrix[icycle, iedge] = True
×
505

506
    # print(cycles_matrix)
507
    elementary_cycles_list = []
×
508

509
    for ncycles in range(1, len(cycle_basis) + 1):
×
510
        for cycles_combination in itertools.combinations(cycles_matrix, ncycles):
×
511
            edges_counts = np.array(np.mod(np.sum(cycles_combination, axis=0), 2), dtype=bool)
×
512
            myedges = [edge for iedge, edge in enumerate(index2edge) if edges_counts[iedge]]
×
513
            # print(myedges)
514
            try:
×
515
                sgc = SimpleGraphCycle.from_edges(myedges, edges_are_ordered=False)
×
516
                # print(sgc)
517
            except ValueError as ve:
×
518
                msg = ve.args[0]
×
519
                if msg == "SimpleGraphCycle is not valid : Duplicate nodes.":
×
520
                    continue
×
521
                if msg == "Could not construct a cycle from edges.":
×
522
                    continue
×
523
                raise
×
524
            elementary_cycles_list.append(sgc)
×
525

526
    return elementary_cycles_list
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc