• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.98
/pymatgen/analysis/chemenv/connectivity/connected_components.py
1
"""
2
Connected components.
3
"""
4

5
from __future__ import annotations
1✔
6

7
import itertools
1✔
8
import logging
1✔
9

10
import networkx as nx
1✔
11
import numpy as np
1✔
12
from matplotlib.patches import Circle, FancyArrowPatch
1✔
13
from monty.json import MSONable, jsanitize
1✔
14
from networkx.algorithms.components import is_connected
1✔
15
from networkx.algorithms.traversal import bfs_tree
1✔
16

17
from pymatgen.analysis.chemenv.connectivity.environment_nodes import EnvironmentNode
1✔
18
from pymatgen.analysis.chemenv.utils.chemenv_errors import ChemenvError
1✔
19
from pymatgen.analysis.chemenv.utils.graph_utils import get_delta
1✔
20
from pymatgen.analysis.chemenv.utils.math_utils import get_linearly_independent_vectors
1✔
21

22

23
def draw_network(env_graph, pos, ax, sg=None, periodicity_vectors=None):
1✔
24
    """Draw network of environments in a matplotlib figure axes.
25

26
    Args:
27
        env_graph: Graph of environments.
28
        pos: Positions of the nodes of the environments in the 2D figure.
29
        ax: Axes object in which the network should be drawn.
30
        sg: Not used currently (drawing of supergraphs).
31
        periodicity_vectors: List of periodicity vectors that should be drawn.
32

33
    Returns: None
34

35
    """
36
    for n in env_graph:
×
37
        c = Circle(pos[n], radius=0.02, alpha=0.5)
×
38
        ax.add_patch(c)
×
39
        env_graph.node[n]["patch"] = c
×
40
        x, y = pos[n]
×
41
        ax.annotate(str(n), pos[n], ha="center", va="center", xycoords="data")
×
42
    seen = {}
×
43
    e = None
×
44
    for u, v, d in env_graph.edges(data=True):
×
45
        n1 = env_graph.node[u]["patch"]
×
46
        n2 = env_graph.node[v]["patch"]
×
47
        rad = 0.1
×
48
        if (u, v) in seen:
×
49
            rad = seen.get((u, v))
×
50
            rad = (rad + np.sign(rad) * 0.1) * -1
×
51
        alpha = 0.5
×
52
        color = "k"
×
53
        periodic_color = "r"
×
54

55
        delta = get_delta(u, v, d)
×
56

57
        # center = get_center_of_arc(n1.center, n2.center, rad)
58
        n1center = np.array(n1.center)
×
59
        n2center = np.array(n2.center)
×
60
        midpoint = (n1center + n2center) / 2
×
61
        dist = np.sqrt(np.power(n2.center[0] - n1.center[0], 2) + np.power(n2.center[1] - n1.center[1], 2))
×
62
        n1c_to_n2c = n2center - n1center
×
63
        vv = np.cross(
×
64
            np.array([n1c_to_n2c[0], n1c_to_n2c[1], 0], np.float_),
65
            np.array([0, 0, 1], np.float_),
66
        )
67
        vv /= np.linalg.norm(vv)
×
68
        midarc = midpoint + rad * dist * np.array([vv[0], vv[1]], np.float_)
×
69
        xytext_offset = 0.1 * dist * np.array([vv[0], vv[1]], np.float_)
×
70

71
        if periodicity_vectors is not None and len(periodicity_vectors) == 1:
×
72
            if np.all(np.array(delta) == np.array(periodicity_vectors[0])) or np.all(
×
73
                np.array(delta) == -np.array(periodicity_vectors[0])
74
            ):
75
                e = FancyArrowPatch(
×
76
                    n1center,
77
                    n2center,
78
                    patchA=n1,
79
                    patchB=n2,
80
                    arrowstyle="-|>",
81
                    connectionstyle=f"arc3,{rad=}",
82
                    mutation_scale=15.0,
83
                    lw=2,
84
                    alpha=alpha,
85
                    color="r",
86
                    linestyle="dashed",
87
                )
88
            else:
89
                e = FancyArrowPatch(
×
90
                    n1center,
91
                    n2center,
92
                    patchA=n1,
93
                    patchB=n2,
94
                    arrowstyle="-|>",
95
                    connectionstyle=f"arc3,{rad=}",
96
                    mutation_scale=10.0,
97
                    lw=2,
98
                    alpha=alpha,
99
                    color=color,
100
                )
101
        else:
102
            ecolor = color if np.allclose(np.array(delta), np.zeros(3)) else periodic_color
×
103
            e = FancyArrowPatch(
×
104
                n1center,
105
                n2center,
106
                patchA=n1,
107
                patchB=n2,
108
                arrowstyle="-|>",
109
                connectionstyle=f"arc3,{rad=}",
110
                mutation_scale=10.0,
111
                lw=2,
112
                alpha=alpha,
113
                color=ecolor,
114
            )
115
        ax.annotate(
×
116
            delta,
117
            midarc,
118
            ha="center",
119
            va="center",
120
            xycoords="data",
121
            xytext=xytext_offset,
122
            textcoords="offset points",
123
        )
124
        seen[(u, v)] = rad
×
125
        ax.add_patch(e)
×
126

127

128
def make_supergraph(graph, multiplicity, periodicity_vectors):
1✔
129
    """Make supergraph from a graph of environments.
130

131
    Args:
132
        graph: Graph of environments.
133
        multiplicity: Multiplicity of the supergraph.
134
        periodicity_vectors: Periodicity vectors needed to make the supergraph.
135

136
    Returns: Super graph of the environments.
137

138
    """
139
    supergraph = nx.MultiGraph()
×
140
    print("peridoicity vectors :")
×
141
    print(periodicity_vectors)
×
142
    if isinstance(multiplicity, int) or len(multiplicity) == 1:
×
143
        mult = multiplicity if isinstance(multiplicity, int) else multiplicity[0]
×
144
        nodes = graph.nodes(data=True)
×
145
        inodes = [isite for isite, data in nodes]
×
146
        indices_nodes = {isite: inodes.index(isite) for isite in inodes}
×
147
        edges = graph.edges(data=True, keys=True)
×
148
        connecting_edges = []
×
149
        other_edges = []
×
150
        for n1, n2, key, data in edges:
×
151
            print(n1, n2, key, data)
×
152
            if np.all(np.array(data["delta"]) == np.array(periodicity_vectors[0])):
×
153
                connecting_edges.append((n1, n2, key, data))
×
154
            elif np.all(np.array(data["delta"]) == -np.array(periodicity_vectors[0])):
×
155
                new_data = dict(data)
×
156
                new_data["delta"] = tuple(-np.array(data["delta"]))
×
157
                new_data["start"] = data["end"]
×
158
                new_data["end"] = data["start"]
×
159
                connecting_edges.append((n1, n2, key, new_data))
×
160
            else:
161
                if not np.all(np.array(data["delta"]) == 0):
×
162
                    print(
×
163
                        "delta not equal to periodicity nor 0 ... : ",
164
                        n1,
165
                        n2,
166
                        key,
167
                        data["delta"],
168
                        data,
169
                    )
170
                    input("Are we ok with this ?")
×
171
                other_edges.append((n1, n2, key, data))
×
172

173
        for imult in range(mult - 1):
×
174
            for n1, n2, key, data in other_edges:
×
175
                new_data = dict(data)
×
176
                new_data["start"] = (imult * len(nodes)) + indices_nodes[n1]
×
177
                new_data["end"] = (imult * len(nodes)) + indices_nodes[n2]
×
178
                supergraph.add_edge(new_data["start"], new_data["end"], key=key, attr_dict=new_data)
×
179
            for n1, n2, key, data in connecting_edges:
×
180
                new_data = dict(data)
×
181
                new_data["start"] = (imult * len(nodes)) + indices_nodes[n1]
×
182
                new_data["end"] = np.mod(((imult + 1) * len(nodes)) + indices_nodes[n2], len(nodes) * mult)
×
183
                new_data["delta"] = (0, 0, 0)
×
184
                supergraph.add_edge(new_data["start"], new_data["end"], key=key, attr_dict=new_data)
×
185
        imult = mult - 1
×
186
        for n1, n2, key, data in other_edges:
×
187
            new_data = dict(data)
×
188
            new_data["start"] = (imult * len(nodes)) + indices_nodes[n1]
×
189
            new_data["end"] = (imult * len(nodes)) + indices_nodes[n2]
×
190
            supergraph.add_edge(new_data["start"], new_data["end"], key=key, attr_dict=new_data)
×
191
        for n1, n2, key, data in connecting_edges:
×
192
            new_data = dict(data)
×
193
            new_data["start"] = (imult * len(nodes)) + indices_nodes[n1]
×
194
            new_data["end"] = indices_nodes[n2]
×
195
            supergraph.add_edge(new_data["start"], new_data["end"], key=key, attr_dict=new_data)
×
196
        return supergraph
×
197

198
    raise NotImplementedError("make_supergraph not yet implemented for 2- and 3-periodic graphs")
×
199

200

201
class ConnectedComponent(MSONable):
1✔
202
    """
203
    Class used to describe the connected components in a structure in terms of coordination environments.
204
    """
205

206
    def __init__(
1✔
207
        self,
208
        environments=None,
209
        links=None,
210
        environments_data=None,
211
        links_data=None,
212
        graph=None,
213
    ) -> None:
214
        """
215
        Constructor for the ConnectedComponent object.
216

217
        Args:
218
            environments: Environments in the connected component.
219
            links: Links between environments in the connected component.
220
            environments_data: Data of environment nodes.
221
            links_data: Data of links between environment nodes.
222
            graph: Graph of the connected component.
223

224
        Returns:
225
            ConnectedComponent: Instance of this class
226
        """
227
        self._periodicity_vectors: list[list] | None = None
1✔
228
        self._primitive_reduced_connected_subgraph = None
1✔
229
        self._projected = False
1✔
230
        if graph is None:
1✔
231
            self._connected_subgraph = nx.MultiGraph()
1✔
232
            if environments_data is None:
1✔
233
                self._connected_subgraph.add_nodes_from(environments)
×
234
            else:
235
                for env in environments:
1✔
236
                    if env in environments_data:
1✔
237
                        self._connected_subgraph.add_node(env, **environments_data[env])
1✔
238
                    else:
239
                        self._connected_subgraph.add_node(env)
1✔
240
            for edge in links:
1✔
241
                env_node1 = edge[0]
1✔
242
                env_node2 = edge[1]
1✔
243
                if len(edge) == 2:
1✔
244
                    key = None
1✔
245
                else:
246
                    key = edge[2]
1✔
247
                if (not self._connected_subgraph.has_node(env_node1)) or (
1✔
248
                    not self._connected_subgraph.has_node(env_node2)
249
                ):
250
                    raise ChemenvError(
×
251
                        self.__class__,
252
                        "__init__",
253
                        "Trying to add edge with some unexisting node ...",
254
                    )
255
                if links_data is not None:
1✔
256
                    if (env_node1, env_node2, key) in links_data:
1✔
257
                        edge_data = links_data[(env_node1, env_node2, key)]
1✔
258
                    elif (env_node2, env_node1, key) in links_data:
1✔
259
                        edge_data = links_data[(env_node2, env_node1, key)]
×
260
                    elif (env_node1, env_node2) in links_data:
1✔
261
                        edge_data = links_data[(env_node1, env_node2)]
×
262
                    elif (env_node2, env_node1) in links_data:
1✔
263
                        edge_data = links_data[(env_node2, env_node1)]
1✔
264
                    else:
265
                        edge_data = None
1✔
266
                else:
267
                    edge_data = None
×
268
                if edge_data:
1✔
269
                    self._connected_subgraph.add_edge(env_node1, env_node2, key, **edge_data)
1✔
270
                else:
271
                    self._connected_subgraph.add_edge(env_node1, env_node2, key)
1✔
272
        else:
273
            # TODO: should check a few requirements here ?
274
            self._connected_subgraph = graph
1✔
275

276
    def coordination_sequence(self, source_node, path_size=5, coordination="number", include_source=False):
1✔
277
        """Get the coordination sequence for a given node.
278

279
        Args:
280
            source_node: Node for which the coordination sequence is computed.
281
            path_size: Maximum length of the path for the coordination sequence.
282
            coordination: Type of coordination sequence. The default ("number") corresponds to the number
283
                of environment nodes that are reachable by following paths of sizes between 1 and path_size.
284
                For coordination "env:number", this resulting coordination sequence is a sequence of dictionaries
285
                mapping the type of environment to the number of such environment reachable by following paths of
286
                sizes between 1 and path_size.
287
            include_source: Whether to include the source_node in the coordination sequence.
288

289
        Returns:
290
            dict: Mapping between the nth "layer" of the connected component with the corresponding coordination.
291

292
        Examples:
293
            The corner-sharing octahedral framework (as in perovskites) have the following coordination sequence (up to
294
            a path of size 6) :
295
            {1: 6, 2: 18, 3: 38, 4: 66, 5: 102, 6: 146}
296
            Considering both the octahedrons and the cuboctahedrons of the typical BaTiO3 perovskite, the "env:number"
297
            coordination sequence (up to a path of size 6) starting on the Ti octahedron and Ba cuboctahedron
298
            are the following :
299
            Starting on the Ti octahedron : {1: {'O:6': 6, 'C:12': 8}, 2: {'O:6': 26, 'C:12': 48},
300
                                             3: {'O:6': 90, 'C:12': 128}, 4: {'O:6': 194, 'C:12': 248},
301
                                             5: {'O:6': 338, 'C:12': 408}, 6: {'O:6': 522, 'C:12': 608}}
302
            Starting on the Ba cuboctahedron : {1: {'O:6': 8, 'C:12': 18}, 2: {'O:6': 48, 'C:12': 74},
303
                                                3: {'O:6': 128, 'C:12': 170}, 4: {'O:6': 248, 'C:12': 306},
304
                                                5: {'O:6': 408, 'C:12': 482}, 6: {'O:6': 608, 'C:12': 698}}
305
            If include_source is set to True, the source node is included in the sequence, e.g. for the corner-sharing
306
            octahedral framework : {0: 1, 1: 6, 2: 18, 3: 38, 4: 66, 5: 102, 6: 146}. For the "env:number" coordination
307
            starting on a Ba cuboctahedron (as shown above), the coordination sequence is then :
308
            {0: {'C:12': 1}, 1: {'O:6': 8, 'C:12': 18}, 2: {'O:6': 48, 'C:12': 74}, 3: {'O:6': 128, 'C:12': 170},
309
             4: {'O:6': 248, 'C:12': 306}, 5: {'O:6': 408, 'C:12': 482}, 6: {'O:6': 608, 'C:12': 698}}
310
        """
311
        if source_node not in self._connected_subgraph:
1✔
312
            raise ValueError("Node not in Connected Component. Cannot find coordination sequence.")
×
313
        # Example of an infinite periodic net in two dimensions consisting of a stacking of
314
        # A and B lines :
315
        #
316
        #     *     *     *     *     *
317
        #     *     *     *     *     *
318
        # * * A * * B * * A * * B * * A * *
319
        #     *     *     *     *     *
320
        #     *     *     *     *     *
321
        # * * A * * B * * A * * B * * A * *
322
        #     *     *     *     *     *
323
        #     *     *     *     *     *
324
        # * * A * * B * * A * * B * * A * *
325
        #     *     *     *     *     *
326
        #     *     *     *     *     *
327
        # * * A * * B * * A * * B * * A * *
328
        #     *     *     *     *     *
329
        #     *     *     *     *     *
330
        # * * A * * B * * A * * B * * A * *
331
        #     *     *     *     *     *
332
        #     *     *     *     *     *
333
        #
334
        # One possible quotient graph of this periodic net :
335
        #          __           __
336
        # (0,1,0) /  \         /  \ (0,1,0)
337
        #         `<--A--->---B--<´
338
        #            / (0,0,0) \
339
        #            \         /
340
        #             `--->---´
341
        #              (1,0,0)
342
        #
343
        # The "number" coordination sequence starting from any environment is : 4-8-12-16-...
344
        # The "env:number" coordination sequence starting from any environment is :
345
        # {A:2, B:2}-{A:4, B:4}-{A:6, B:6}-...
346
        current_delta = (0, 0, 0)
1✔
347
        current_ends = [(source_node, current_delta)]
1✔
348
        visited = {(source_node.isite, *current_delta)}
1✔
349
        path_len = 0
1✔
350
        cseq = {}
1✔
351
        if include_source:
1✔
352
            if coordination == "number":
1✔
353
                cseq[0] = 1
1✔
354
            elif coordination == "env:number":
1✔
355
                cseq[0] = {source_node.coordination_environment: 1}
1✔
356
            else:
357
                raise ValueError(f"Coordination type {coordination!r} is not valid for coordination_sequence.")
×
358
        while path_len < path_size:
1✔
359
            new_ends = []
1✔
360
            for current_node_end, current_delta_end in current_ends:
1✔
361
                for nb in self._connected_subgraph.neighbors(current_node_end):
1✔
362
                    for edata in self._connected_subgraph[current_node_end][nb].values():
1✔
363
                        new_delta = current_delta_end + get_delta(current_node_end, nb, edata)
1✔
364
                        if (nb.isite, *new_delta) not in visited:
1✔
365
                            new_ends.append((nb, new_delta))
1✔
366
                            visited.add((nb.isite, *new_delta))
1✔
367
                        if nb.isite == current_node_end.isite:  # Handle self loops
1✔
368
                            new_delta = current_delta_end - get_delta(current_node_end, nb, edata)
1✔
369
                            if (nb.isite, *new_delta) not in visited:
1✔
370
                                new_ends.append((nb, new_delta))
1✔
371
                                visited.add((nb.isite, *new_delta))
1✔
372
            current_ends = new_ends
1✔
373
            path_len += 1
1✔
374
            if coordination == "number":
1✔
375
                cseq[path_len] = len(current_ends)
1✔
376
            elif coordination == "env:number":
1✔
377
                myenvs = [myend.coordination_environment for myend, _ in current_ends]
1✔
378
                cseq[path_len] = {myenv: myenvs.count(myenv) for myenv in set(myenvs)}
1✔
379
            else:
380
                raise ValueError(f"Coordination type {coordination!r} is not valid for coordination_sequence.")
×
381
        return cseq
1✔
382

383
    def __len__(self):
1✔
384
        return len(self.graph)
1✔
385

386
    def compute_periodicity(self, algorithm="all_simple_paths"):
1✔
387
        """
388
        Args:
389
            algorithm ():
390

391
        Returns:
392
        """
393
        if algorithm == "all_simple_paths":
1✔
394
            self.compute_periodicity_all_simple_paths_algorithm()
1✔
395
        elif algorithm == "cycle_basis":
×
396
            self.compute_periodicity_cycle_basis()
×
397
        else:
398
            raise ValueError(f"Algorithm {algorithm!r} is not allowed to compute periodicity")
×
399
        self._order_periodicity_vectors()
1✔
400

401
    def compute_periodicity_all_simple_paths_algorithm(self):
1✔
402
        """
403
        Returns:
404
        """
405
        self_loop_nodes = list(nx.nodes_with_selfloops(self._connected_subgraph))
1✔
406
        all_nodes_independent_cell_image_vectors = []
1✔
407
        my_simple_graph = nx.Graph(self._connected_subgraph)
1✔
408
        for test_node in self._connected_subgraph.nodes():
1✔
409
            # TODO: do we need to go through all test nodes ?
410
            this_node_cell_img_vectors = []
1✔
411
            if test_node in self_loop_nodes:
1✔
412
                for edge_data in self._connected_subgraph[test_node][test_node].values():
1✔
413
                    if edge_data["delta"] == (0, 0, 0):
1✔
414
                        raise ValueError("There should not be self loops with delta image = (0, 0, 0).")
1✔
415
                    this_node_cell_img_vectors.append(edge_data["delta"])
1✔
416
            for d1, d2 in itertools.combinations(this_node_cell_img_vectors, 2):
1✔
417
                if d1 == d2 or d1 == tuple(-ii for ii in d2):
1✔
418
                    raise ValueError("There should not be self loops with the same (or opposite) delta image.")
1✔
419
            this_node_cell_img_vectors = get_linearly_independent_vectors(this_node_cell_img_vectors)
1✔
420
            # Here, we adopt a cutoff equal to the size of the graph, contrary to the default of networkX (size - 1),
421
            # because otherwise, the all_simple_paths algorithm fail when the source node is equal to the target node.
422
            paths = []
1✔
423
            # TODO: its probably possible to do just a dfs or bfs traversal instead of taking all simple paths!
424
            test_node_neighbors = my_simple_graph.neighbors(test_node)
1✔
425
            break_node_loop = False
1✔
426
            for test_node_neighbor in test_node_neighbors:
1✔
427
                # Special case for two nodes
428
                if len(self._connected_subgraph[test_node][test_node_neighbor]) > 1:
1✔
429
                    this_path_deltas = []
1✔
430
                    node_node_neighbor_edges_data = list(
1✔
431
                        self._connected_subgraph[test_node][test_node_neighbor].values()
432
                    )
433
                    for edge1_data, edge2_data in itertools.combinations(node_node_neighbor_edges_data, 2):
1✔
434
                        delta1 = get_delta(test_node, test_node_neighbor, edge1_data)
1✔
435
                        delta2 = get_delta(test_node_neighbor, test_node, edge2_data)
1✔
436
                        this_path_deltas.append(delta1 + delta2)
1✔
437
                    this_node_cell_img_vectors.extend(this_path_deltas)
1✔
438
                    this_node_cell_img_vectors = get_linearly_independent_vectors(this_node_cell_img_vectors)
1✔
439
                    if len(this_node_cell_img_vectors) == 3:
1✔
440
                        break
×
441
                for path in nx.all_simple_paths(
1✔
442
                    my_simple_graph,
443
                    test_node,
444
                    test_node_neighbor,
445
                    cutoff=len(self._connected_subgraph),
446
                ):
447
                    path_indices = [node_path.isite for node_path in path]
1✔
448
                    if path_indices == [test_node.isite, test_node_neighbor.isite]:
1✔
449
                        continue
1✔
450
                    path_indices.append(test_node.isite)
1✔
451
                    path_indices = tuple(path_indices)
1✔
452
                    if path_indices not in paths:
1✔
453
                        paths.append(path_indices)
1✔
454
                    else:
455
                        continue
×
456
                    path.append(test_node)
1✔
457
                    # TODO: there are some paths that appears twice for cycles, and there are some paths that should
458
                    # probably not be considered
459
                    this_path_deltas = [np.zeros(3, int)]
1✔
460
                    for node1, node2 in [(node1, path[inode1 + 1]) for inode1, node1 in enumerate(path[:-1])]:
1✔
461
                        this_path_deltas_new = []
1✔
462
                        for edge_data in self._connected_subgraph[node1][node2].values():
1✔
463
                            delta = get_delta(node1, node2, edge_data)
1✔
464
                            for current_delta in this_path_deltas:
1✔
465
                                this_path_deltas_new.append(current_delta + delta)
1✔
466
                        this_path_deltas = this_path_deltas_new
1✔
467
                    this_node_cell_img_vectors.extend(this_path_deltas)
1✔
468
                    this_node_cell_img_vectors = get_linearly_independent_vectors(this_node_cell_img_vectors)
1✔
469
                    if len(this_node_cell_img_vectors) == 3:
1✔
470
                        break_node_loop = True
1✔
471
                        break
1✔
472
                if break_node_loop:
1✔
473
                    break
1✔
474
            this_node_cell_img_vectors = get_linearly_independent_vectors(this_node_cell_img_vectors)
1✔
475
            independent_cell_img_vectors = this_node_cell_img_vectors
1✔
476
            all_nodes_independent_cell_image_vectors.append(independent_cell_img_vectors)
1✔
477
            # If we have found that the sub structure network is 3D-connected, we can stop ...
478
            if len(independent_cell_img_vectors) == 3:
1✔
479
                break
1✔
480
        self._periodicity_vectors = []
1✔
481
        if len(all_nodes_independent_cell_image_vectors) != 0:
1✔
482
            for independent_cell_img_vectors in all_nodes_independent_cell_image_vectors:
1✔
483
                if len(independent_cell_img_vectors) > len(self._periodicity_vectors):
1✔
484
                    self._periodicity_vectors = independent_cell_img_vectors
1✔
485
                if len(self._periodicity_vectors) == 3:
1✔
486
                    break
1✔
487

488
    def compute_periodicity_cycle_basis(self):
1✔
489
        """
490
        Returns:
491
        """
492
        my_simple_graph = nx.Graph(self._connected_subgraph)
×
493
        cycles = nx.cycle_basis(my_simple_graph)
×
494
        all_deltas = []
×
495
        for cyc in cycles:
×
496
            mycyc = list(cyc)
×
497
            mycyc.append(cyc[0])
×
498
            this_cycle_deltas = [np.zeros(3, int)]
×
499
            for node1, node2 in [(node1, mycyc[inode1 + 1]) for inode1, node1 in enumerate(mycyc[:-1])]:
×
500
                this_cycle_deltas_new = []
×
501
                for edge_data in self._connected_subgraph[node1][node2].values():
×
502
                    delta = get_delta(node1, node2, edge_data)
×
503
                    for current_delta in this_cycle_deltas:
×
504
                        this_cycle_deltas_new.append(current_delta + delta)
×
505
                this_cycle_deltas = this_cycle_deltas_new
×
506
            all_deltas.extend(this_cycle_deltas)
×
507
            all_deltas = get_linearly_independent_vectors(all_deltas)
×
508
            if len(all_deltas) == 3:
×
509
                return
×
510
        # One has to consider pairs of nodes with parallel edges (these are not considered in the simple graph cycles)
511
        edges = my_simple_graph.edges()
×
512
        for n1, n2 in edges:
×
513
            if n1 == n2:
×
514
                continue
×
515
            if len(self._connected_subgraph[n1][n2]) == 1:
×
516
                continue
×
517
            if len(self._connected_subgraph[n1][n2]) > 1:
×
518
                for iedge1, iedge2 in itertools.combinations(self._connected_subgraph[n1][n2], 2):
×
519
                    e1data = self._connected_subgraph[n1][n2][iedge1]
×
520
                    e2data = self._connected_subgraph[n1][n2][iedge2]
×
521
                    current_delta = get_delta(n1, n2, e1data)
×
522
                    delta = get_delta(n2, n1, e2data)
×
523
                    current_delta += delta
×
524
                    all_deltas.append(current_delta)
×
525
            else:
526
                raise ValueError("Should not be here ...")
×
527
            all_deltas = get_linearly_independent_vectors(all_deltas)
×
528
            if len(all_deltas) == 3:
×
529
                self._periodicity_vectors = all_deltas
×
530
                return
×
531
        self._periodicity_vectors = all_deltas
×
532

533
    def make_supergraph(self, multiplicity):
1✔
534
        """
535
        Args:
536
            multiplicity ():
537

538
        Returns:
539
        """
540
        supergraph = make_supergraph(self._connected_subgraph, multiplicity, self._periodicity_vectors)
×
541
        return supergraph
×
542

543
    def show_graph(self, graph=None, save_file=None, drawing_type="internal", pltshow=True) -> None:
1✔
544
        """
545
        Args:
546
            graph ():
547
            save_file ():
548
            drawing_type ():
549
            pltshow ():
550
        """
551
        import matplotlib.pyplot as plt
×
552

553
        if graph is None:
×
554
            shown_graph = self._connected_subgraph
×
555
        else:
556
            shown_graph = graph
×
557

558
        plt.figure()
×
559
        # pos = nx.spring_layout(shown_graph)
560
        if drawing_type == "internal":
×
561
            pos = nx.shell_layout(shown_graph)
×
562
            ax = plt.gca()
×
563
            draw_network(shown_graph, pos, ax, periodicity_vectors=self._periodicity_vectors)
×
564
            ax.autoscale()
×
565
            plt.axis("equal")
×
566
            plt.axis("off")
×
567
            if save_file is not None:
×
568
                plt.savefig(save_file)
×
569
            # nx.draw(self._connected_subgraph)
570
        elif drawing_type == "draw_graphviz":
×
571
            import networkx
×
572

573
            networkx.nx_pydot.graphviz_layout(shown_graph)
×
574
        elif drawing_type == "draw_random":
×
575
            import networkx
×
576

577
            networkx.draw_random(shown_graph)
×
578
        if pltshow:
×
579
            plt.show()
×
580

581
    @property
1✔
582
    def graph(self):
1✔
583
        """Return the graph of this connected component.
584

585
        Returns:
586
            MultiGraph: Networkx MultiGraph object with environment as nodes and links between these nodes as edges
587
                        with information about the image cell difference if any.
588
        """
589
        return self._connected_subgraph
1✔
590

591
    @property
1✔
592
    def is_periodic(self) -> bool:
1✔
593
        """
594
        Returns:
595
        """
596
        return not self.is_0d
1✔
597

598
    @property
1✔
599
    def is_0d(self) -> bool:
1✔
600
        """
601
        Returns:
602
        """
603
        if self._periodicity_vectors is None:
1✔
604
            self.compute_periodicity()
1✔
605
        assert self._periodicity_vectors is not None  # fix mypy arg 1 to len has incompatible type Optional
1✔
606
        return len(self._periodicity_vectors) == 0
1✔
607

608
    @property
1✔
609
    def is_1d(self) -> bool:
1✔
610
        """
611
        Returns:
612
        """
613
        if self._periodicity_vectors is None:
1✔
614
            self.compute_periodicity()
×
615
        assert self._periodicity_vectors is not None  # fix mypy arg 1 to len has incompatible type Optional
1✔
616
        return len(self._periodicity_vectors) == 1
1✔
617

618
    @property
1✔
619
    def is_2d(self) -> bool:
1✔
620
        """
621
        Returns:
622
        """
623
        if self._periodicity_vectors is None:
1✔
624
            self.compute_periodicity()
×
625
        assert self._periodicity_vectors is not None  # fix mypy arg 1 to len has incompatible type Optional
1✔
626
        return len(self._periodicity_vectors) == 2
1✔
627

628
    @property
1✔
629
    def is_3d(self) -> bool:
1✔
630
        """
631
        Returns:
632
        """
633
        if self._periodicity_vectors is None:
1✔
634
            self.compute_periodicity()
×
635
        assert self._periodicity_vectors is not None  # fix mypy arg 1 to len has incompatible type Optional
1✔
636
        return len(self._periodicity_vectors) == 3
1✔
637

638
    @staticmethod
1✔
639
    def _order_vectors(vectors):
1✔
640
        """Orders vectors.
641

642
        First, each vector is made such that the first non-zero dimension is positive.
643
        Example: a periodicity vector [0, -1, 1] is transformed to [0, 1, -1].
644
        Then vectors are ordered based on their first element, then (if the first element
645
        is identical) based on their second element, then (if the first and second element
646
        are identical) based on their third element and so on ...
647
        Example: [[1, 1, 0], [0, 1, -1], [0, 1, 1]] is ordered as [[0, 1, -1], [0, 1, 1], [1, 1, 0]]
648
        """
649
        for ipv, pv in enumerate(vectors):
1✔
650
            nonzeros = np.nonzero(pv)[0]
1✔
651
            if pv[nonzeros[0]] < 0 < len(nonzeros):
1✔
652
                vectors[ipv] = -pv
1✔
653
        return sorted(vectors, key=lambda x: x.tolist())
1✔
654

655
    def _order_periodicity_vectors(self):
1✔
656
        """Orders the periodicity vectors."""
657
        if len(self._periodicity_vectors) > 3:
1✔
658
            raise ValueError("Number of periodicity vectors is larger than 3.")
×
659
        self._periodicity_vectors = self._order_vectors(self._periodicity_vectors)
1✔
660
        # for ipv, pv in enumerate(self._periodicity_vectors):
661
        #     nonzeros = np.nonzero(pv)[0]
662
        #     if (len(nonzeros) > 0) and (pv[nonzeros[0]] < 0):
663
        #         self._periodicity_vectors[ipv] = -pv
664
        # self._periodicity_vectors = sorted(self._periodicity_vectors, key=lambda x: x.tolist())
665

666
    @property
1✔
667
    def periodicity_vectors(self):
1✔
668
        """
669
        Returns:
670
        """
671
        if self._periodicity_vectors is None:
1✔
672
            self.compute_periodicity()
×
673
        return [np.array(pp) for pp in self._periodicity_vectors]
1✔
674

675
    @property
1✔
676
    def periodicity(self):
1✔
677
        """
678
        Returns:
679
        """
680
        if self._periodicity_vectors is None:
1✔
681
            self.compute_periodicity()
1✔
682
        return f"{len(self._periodicity_vectors):d}D"
1✔
683

684
    def elastic_centered_graph(self, start_node=None):
1✔
685
        """
686
        Args:
687
            start_node ():
688

689
        Returns:
690
            nx.MultiGraph: Elastic centered subgraph.
691
        """
692
        logging.info("In elastic centering")
×
693
        # Loop on start_nodes, sometimes some nodes cannot be elastically taken
694
        # inside the cell if you start from a specific node
695
        ntest_nodes = 0
×
696
        start_node = list(self.graph.nodes())[0]
×
697

698
        ntest_nodes += 1
×
699
        centered_connected_subgraph = nx.MultiGraph()
×
700
        centered_connected_subgraph.add_nodes_from(self.graph.nodes())
×
701
        centered_connected_subgraph.add_edges_from(self.graph.edges(data=True))
×
702
        tree = bfs_tree(G=self.graph, source=start_node)
×
703

704
        current_nodes = [start_node]
×
705
        nodes_traversed = [start_node]
×
706

707
        inode = 0
×
708
        # Loop on "levels" in the tree
709
        tree_level = 0
×
710
        while True:
711
            tree_level += 1
×
712
            logging.debug(f"In tree level {tree_level:d} ({len(current_nodes):d} nodes)")
×
713
            new_current_nodes = []
×
714
            # Loop on nodes in this level of the tree
715
            for node in current_nodes:
×
716
                inode += 1
×
717
                logging.debug(f"  In node #{inode:d}/{len(current_nodes):d} in level {tree_level:d} ({node})")
×
718
                node_neighbors = list(tree.neighbors(n=node))
×
719
                node_edges = centered_connected_subgraph.edges(nbunch=[node], data=True, keys=True)
×
720
                # Loop on neighbors of a node (from the tree used)
721
                for inode_neighbor, node_neighbor in enumerate(node_neighbors):
×
722
                    logging.debug(
×
723
                        f"    Testing neighbor #{inode_neighbor:d}/{len(node_neighbors):d} ({node_neighbor}) of "
724
                        f"node #{inode:d} ({node})"
725
                    )
726
                    already_inside = False
×
727
                    ddeltas = []
×
728
                    for n1, n2, _key, edata in node_edges:
×
729
                        if (n1 == node and n2 == node_neighbor) or (n2 == node and n1 == node_neighbor):
×
730
                            if edata["delta"] == (0, 0, 0):
×
731
                                already_inside = True
×
732
                                thisdelta = edata["delta"]
×
733
                            else:
734
                                if edata["start"] == node.isite and edata["end"] != node.isite:
×
735
                                    thisdelta = edata["delta"]
×
736
                                elif edata["end"] == node.isite:
×
737
                                    thisdelta = tuple(-dd for dd in edata["delta"])
×
738
                                else:
739
                                    raise ValueError("Should not be here ...")
×
740
                            ddeltas.append(thisdelta)
×
741
                    logging.debug(
×
742
                        "        ddeltas : " + ", ".join(f"({', '.join(str(ddd) for ddd in dd)})" for dd in ddeltas)
743
                    )
744
                    if ddeltas.count((0, 0, 0)) > 1:
×
745
                        raise ValueError("Should not have more than one 000 delta ...")
×
746
                    if already_inside:
×
747
                        logging.debug("          Edge inside the cell ... continuing to next neighbor")
×
748
                        continue
×
749
                    logging.debug("          Edge outside the cell ... getting neighbor back inside")
×
750
                    if (0, 0, 0) in ddeltas:
×
751
                        ddeltas.remove((0, 0, 0))
×
752
                    myddelta = np.array(ddeltas[0], int)
×
753
                    node_neighbor_edges = centered_connected_subgraph.edges(
×
754
                        nbunch=[node_neighbor], data=True, keys=True
755
                    )
756
                    logging.debug(
×
757
                        f"            Delta image from node {str(node)} to neighbor {str(node_neighbor)} : "
758
                        f"({', '.join(map(str, myddelta))})"
759
                    )
760
                    # Loop on the edges of this neighbor
761
                    for n1, n2, key, edata in node_neighbor_edges:
×
762
                        if (n1 == node_neighbor and n2 != node_neighbor) or (
×
763
                            n2 == node_neighbor and n1 != node_neighbor
764
                        ):
765
                            if edata["start"] == node_neighbor.isite and edata["end"] != node_neighbor.isite:
×
766
                                centered_connected_subgraph[n1][n2][key]["delta"] = tuple(
×
767
                                    np.array(edata["delta"], int) + myddelta
768
                                )
769
                            elif edata["end"] == node_neighbor.isite:
×
770
                                centered_connected_subgraph[n1][n2][key]["delta"] = tuple(
×
771
                                    np.array(edata["delta"], int) - myddelta
772
                                )
773
                            else:
774
                                raise ValueError("DUHH")
×
775
                            logging.debug(
×
776
                                f"                  {n1} to node {n2} now has delta "
777
                                f"{centered_connected_subgraph[n1][n2][key]['delta']}"
778
                            )
779
                new_current_nodes.extend(node_neighbors)
×
780
                nodes_traversed.extend(node_neighbors)
×
781
            current_nodes = new_current_nodes
×
782
            if not current_nodes:
×
783
                break
×
784

785
        # Check if the graph is indeed connected if "periodic" edges (i.e. whose "delta" is not 0, 0, 0) are removed
786
        check_centered_connected_subgraph = nx.MultiGraph()
×
787
        check_centered_connected_subgraph.add_nodes_from(centered_connected_subgraph.nodes())
×
788
        check_centered_connected_subgraph.add_edges_from(
×
789
            [e for e in centered_connected_subgraph.edges(data=True) if np.allclose(e[2]["delta"], np.zeros(3))]
790
        )
791
        if not is_connected(check_centered_connected_subgraph):
×
792
            raise RuntimeError("Could not find a centered graph.")
×
793
        return centered_connected_subgraph
×
794

795
    @staticmethod
1✔
796
    def _edgekey_to_edgedictkey(key):
1✔
797
        if isinstance(key, int):
1✔
798
            return str(key)
1✔
799
        if isinstance(key, str):
1✔
800
            try:
1✔
801
                int(key)
1✔
802
                raise RuntimeError("Cannot pass an edge key which is a str representation of an int.")
1✔
803
            except ValueError:
1✔
804
                return key
1✔
805
        raise ValueError("Edge key should be either a str or an int.")
1✔
806

807
    @staticmethod
1✔
808
    def _edgedictkey_to_edgekey(key):
1✔
809
        if isinstance(key, int):
1✔
810
            return key
×
811
        if isinstance(key, str):
1✔
812
            try:
1✔
813
                return int(key)
1✔
814
            except ValueError:
×
815
                return key
×
816
        else:
817
            raise ValueError("Edge key in a dict of dicts representation of a graph should be either a str or an int.")
×
818

819
    @staticmethod
1✔
820
    def _retuplify_edgedata(edata):
1✔
821
        """
822
        Private method used to cast back lists to tuples where applicable in an edge data.
823

824
        The format of the edge data is :
825
        {'start': STARTINDEX, 'end': ENDINDEX, 'delta': TUPLE(DELTAX, DELTAY, DELTAZ),
826
         'ligands': [TUPLE(LIGAND_1_INDEX, TUPLE(DELTAX_START_LIG_1, DELTAY_START_LIG_1, DELTAZ_START_LIG_1),
827
                                           TUPLE(DELTAX_END_LIG_1, DELTAY_END_LIG_1, DELTAZ_END_LIG_1)),
828
                     TUPLE(LIGAND_2_INDEX, ...),
829
                     ... ]}
830
        When serializing to json/bson, these tuples are transformed into lists. This method transforms these lists
831
        back to tuples.
832

833
        Args:
834
            edata (dict): Edge data dictionary with possibly the above tuples as lists.
835

836
        Returns:
837
            dict: Edge data dictionary with the lists transformed back into tuples when applicable.
838
        """
839
        edata["delta"] = tuple(edata["delta"])
1✔
840
        edata["ligands"] = [tuple([lig[0], tuple(lig[1]), tuple(lig[2])]) for lig in edata["ligands"]]
1✔
841
        return edata
1✔
842

843
    def as_dict(self):
1✔
844
        """
845
        Bson-serializable dict representation of the ConnectedComponent object.
846

847
        Returns:
848
            dict: Bson-serializable dict representation of the ConnectedComponent object.
849
        """
850
        nodes = {f"{node.isite:d}": (node, data) for node, data in self._connected_subgraph.nodes(data=True)}
1✔
851
        node2stringindex = {node: strindex for strindex, (node, data) in nodes.items()}
1✔
852
        dict_of_dicts = nx.to_dict_of_dicts(self._connected_subgraph)
1✔
853
        new_dict_of_dicts = {}
1✔
854
        for n1, n2dict in dict_of_dicts.items():
1✔
855
            in1 = node2stringindex[n1]
1✔
856
            new_dict_of_dicts[in1] = {}
1✔
857
            for n2, edges_dict in n2dict.items():
1✔
858
                in2 = node2stringindex[n2]
1✔
859
                new_dict_of_dicts[in1][in2] = {}
1✔
860
                for ie, edge_data in edges_dict.items():
1✔
861
                    ied = self._edgekey_to_edgedictkey(ie)
1✔
862
                    new_dict_of_dicts[in1][in2][ied] = jsanitize(edge_data)
1✔
863
        return {
1✔
864
            "@module": type(self).__module__,
865
            "@class": type(self).__name__,
866
            "nodes": {strindex: (node.as_dict(), data) for strindex, (node, data) in nodes.items()},
867
            "graph": new_dict_of_dicts,
868
        }
869

870
    @classmethod
1✔
871
    def from_dict(cls, d):
1✔
872
        """
873
        Reconstructs the ConnectedComponent object from a dict representation of the
874
        ConnectedComponent object created using the as_dict method.
875

876
        Args:
877
            d (dict): dict representation of the ConnectedComponent object
878
        Returns:
879
            ConnectedComponent: The connected component representing the links of a given set of environments.
880
        """
881
        nodes_map = {
1✔
882
            inode_str: EnvironmentNode.from_dict(nodedict) for inode_str, (nodedict, nodedata) in d["nodes"].items()
883
        }
884
        nodes_data = {inode_str: nodedata for inode_str, (nodedict, nodedata) in d["nodes"].items()}
1✔
885
        dod = {}
1✔
886
        for e1, e1dict in d["graph"].items():
1✔
887
            dod[e1] = {}
1✔
888
            for e2, e2dict in e1dict.items():
1✔
889
                dod[e1][e2] = {
1✔
890
                    cls._edgedictkey_to_edgekey(ied): cls._retuplify_edgedata(edata) for ied, edata in e2dict.items()
891
                }
892
        graph = nx.from_dict_of_dicts(dod, create_using=nx.MultiGraph, multigraph_input=True)
1✔
893
        nx.set_node_attributes(graph, nodes_data)
1✔
894
        nx.relabel_nodes(graph, nodes_map, copy=False)
1✔
895
        return cls(graph=graph)
1✔
896

897
    @classmethod
1✔
898
    def from_graph(cls, g):
1✔
899
        """
900
        Constructor for the ConnectedComponent object from a graph of the connected component
901

902
        Args:
903
            g (MultiGraph): Graph of the connected component.
904
        Returns:
905
            ConnectedComponent: The connected component representing the links of a given set of environments.
906
        """
907
        return cls(graph=g)
1✔
908

909
    def description(self, full=False):
1✔
910
        """
911
        Args:
912
            full (bool): Whether to return a short or full description.
913

914
        Returns:
915
            str: A description of the connected component.
916
        """
917
        out = ["Connected component with environment nodes :"]
1✔
918
        if not full:
1✔
919
            out.extend(map(str, sorted(self.graph.nodes())))
1✔
920
            return "\n".join(out)
1✔
921
        for en in sorted(self.graph.nodes()):
1✔
922
            out.append(f"{en}, connected to :")
1✔
923
            en_neighbs = nx.neighbors(self.graph, en)
1✔
924
            for en_neighb in sorted(en_neighbs):
1✔
925
                out.append(f"  - {en_neighb} with delta image cells")
1✔
926
                all_deltas = sorted(
1✔
927
                    get_delta(node1=en, node2=en_neighb, edge_data=edge_data).tolist()
928
                    for iedge, edge_data in self.graph[en][en_neighb].items()
929
                )
930
                out.extend([f"     ({delta[0]:d} {delta[1]:d} {delta[2]:d})" for delta in all_deltas])
1✔
931
        return "\n".join(out)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc