• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.76
/pymatgen/analysis/tests/test_local_env.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3
from __future__ import annotations
1✔
4

5
import os
1✔
6
import unittest
1✔
7
import warnings
1✔
8
from math import pi
1✔
9
from shutil import which
1✔
10
from typing import get_args
1✔
11

12
import numpy as np
1✔
13
import pytest
1✔
14
from pytest import approx
1✔
15

16
from pymatgen.analysis.graphs import MoleculeGraph, StructureGraph
1✔
17
from pymatgen.analysis.local_env import (
1✔
18
    BrunnerNN_real,
19
    BrunnerNN_reciprocal,
20
    BrunnerNN_relative,
21
    CovalentBondNN,
22
    Critic2NN,
23
    CrystalNN,
24
    CutOffDictNN,
25
    EconNN,
26
    IsayevNN,
27
    JmolNN,
28
    LocalStructOrderParams,
29
    MinimumDistanceNN,
30
    MinimumOKeeffeNN,
31
    MinimumVIRENN,
32
    NearNeighbors,
33
    OpenBabelNN,
34
    ValenceIonicRadiusEvaluator,
35
    VoronoiNN,
36
    get_neighbors_of_site_with_index,
37
    metal_edge_extender,
38
    on_disorder_options,
39
    site_is_of_motif_type,
40
    solid_angle,
41
)
42
from pymatgen.core import Lattice, Molecule, Structure
1✔
43
from pymatgen.core.periodic_table import Element
1✔
44
from pymatgen.util.testing import PymatgenTest
1✔
45

46
test_dir = os.path.join(PymatgenTest.TEST_FILES_DIR, "fragmenter_files")
1✔
47

48

49
class ValenceIonicRadiusEvaluatorTest(PymatgenTest):
1✔
50
    def setUp(self):
1✔
51
        """
52
        Setup MgO rocksalt structure for testing Vacancy
53
        """
54
        mgo_latt = [[4.212, 0, 0], [0, 4.212, 0], [0, 0, 4.212]]
1✔
55
        mgo_specie = ["Mg"] * 4 + ["O"] * 4
1✔
56
        mgo_frac_cord = [
1✔
57
            [0, 0, 0],
58
            [0.5, 0.5, 0],
59
            [0.5, 0, 0.5],
60
            [0, 0.5, 0.5],
61
            [0.5, 0, 0],
62
            [0, 0.5, 0],
63
            [0, 0, 0.5],
64
            [0.5, 0.5, 0.5],
65
        ]
66
        self._mgo_uc = Structure(mgo_latt, mgo_specie, mgo_frac_cord, True, True)
1✔
67
        self._mgo_valrad_evaluator = ValenceIonicRadiusEvaluator(self._mgo_uc)
1✔
68

69
    def test_valences_ionic_structure(self):
1✔
70
        valence_dict = self._mgo_valrad_evaluator.valences
1✔
71
        for val in list(valence_dict.values()):
1✔
72
            assert val in {2, -2}
1✔
73

74
    def test_radii_ionic_structure(self):
1✔
75
        radii_dict = self._mgo_valrad_evaluator.radii
1✔
76
        for rad in list(radii_dict.values()):
1✔
77
            assert rad in {0.86, 1.26}
1✔
78

79
    def tearDown(self):
1✔
80
        del self._mgo_uc
1✔
81
        del self._mgo_valrad_evaluator
1✔
82

83

84
class VoronoiNNTest(PymatgenTest):
1✔
85
    def setUp(self):
1✔
86
        self.s = self.get_structure("LiFePO4")
1✔
87
        self.nn = VoronoiNN(targets=[Element("O")])
1✔
88
        self.s_sic = self.get_structure("Si")
1✔
89
        self.s_sic["Si"] = {"Si": 0.5, "C": 0.5}
1✔
90
        self.nn_sic = VoronoiNN()
1✔
91

92
    def test_get_voronoi_polyhedra(self):
1✔
93
        assert len(self.nn.get_voronoi_polyhedra(self.s, 0).items()) == 8
1✔
94

95
    def test_get_cn(self):
1✔
96
        site_0_coord_num = self.nn.get_cn(self.s, 0, use_weights=True, on_disorder="take_max_species")
1✔
97
        assert site_0_coord_num == approx(5.809265748999465, abs=1e-7)
1✔
98

99
        site_0_coord_num = self.nn_sic.get_cn(self.s_sic, 0, use_weights=True, on_disorder="take_max_species")
1✔
100
        assert site_0_coord_num == approx(4.5381161643940668, abs=1e-7)
1✔
101

102
    def test_get_coordinated_sites(self):
1✔
103
        assert len(self.nn.get_nn(self.s, 0)) == 8
1✔
104

105
    def test_volume(self):
1✔
106
        self.nn.targets = None
1✔
107
        volume = 0
1✔
108
        for n in range(len(self.s)):
1✔
109
            for nn in self.nn.get_voronoi_polyhedra(self.s, n).values():
1✔
110
                volume += nn["volume"]
1✔
111
        assert self.s.volume == approx(volume)
1✔
112

113
    def test_solid_angle(self):
1✔
114
        self.nn.targets = None
1✔
115
        for n in range(len(self.s)):
1✔
116
            angle = 0
1✔
117
            for nn in self.nn.get_voronoi_polyhedra(self.s, n).values():
1✔
118
                angle += nn["solid_angle"]
1✔
119
            assert 4 * np.pi == approx(angle)
1✔
120
        assert solid_angle([0, 0, 0], [[1, 0, 0], [-1, 0, 0], [0, 1, 0]]) == pi
1✔
121

122
    def test_nn_shell(self):
1✔
123
        # First, make a SC lattice. Make my math easier
124
        s = Structure([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ["Cu"], [[0, 0, 0]])
1✔
125

126
        # Get the 1NN shell
127
        self.nn.targets = None
1✔
128
        nns = self.nn.get_nn_shell_info(s, 0, 1)
1✔
129
        assert 6 == len(nns)
1✔
130

131
        # Test the 2nd NN shell
132
        nns = self.nn.get_nn_shell_info(s, 0, 2)
1✔
133
        assert 18 == len(nns)
1✔
134
        self.assertArrayAlmostEqual([1] * 6, [x["weight"] for x in nns if max(np.abs(x["image"])) == 2])
1✔
135
        self.assertArrayAlmostEqual([2] * 12, [x["weight"] for x in nns if max(np.abs(x["image"])) == 1])
1✔
136

137
        # Test the 3rd NN shell
138
        nns = self.nn.get_nn_shell_info(s, 0, 3)
1✔
139
        for nn in nns:
1✔
140
            #  Check that the coordinates were set correctly
141
            self.assertArrayAlmostEqual(nn["site"].frac_coords, nn["image"])
1✔
142

143
        # Test with a structure that has unequal faces
144
        cscl = Structure(
1✔
145
            Lattice([[4.209, 0, 0], [0, 4.209, 0], [0, 0, 4.209]]),
146
            ["Cl1-", "Cs1+"],
147
            [[2.1045, 2.1045, 2.1045], [0, 0, 0]],
148
            validate_proximity=False,
149
            to_unit_cell=False,
150
            coords_are_cartesian=True,
151
            site_properties=None,
152
        )
153
        self.nn.weight = "area"
1✔
154
        nns = self.nn.get_nn_shell_info(cscl, 0, 1)
1✔
155
        assert 14 == len(nns)
1✔
156
        assert 6 == np.isclose([x["weight"] for x in nns], 0.125 / 0.32476).sum()  # Square faces
1✔
157
        assert 8 == np.isclose([x["weight"] for x in nns], 1).sum()
1✔
158

159
        nns = self.nn.get_nn_shell_info(cscl, 0, 2)
1✔
160
        # Weight of getting back on to own site
161
        #  Square-square hop: 6*5 options times (0.125/0.32476)^2 weight each
162
        #  Hex-hex hop: 8*7 options times 1 weight each
163
        assert 60.4444 == approx(np.sum([x["weight"] for x in nns if x["site_index"] == 0]), abs=1e-3)
1✔
164

165
    def test_adj_neighbors(self):
1✔
166
        # Make a simple cubic structure
167
        s = Structure([[1, 0, 0], [0, 1, 0], [0, 0, 1]], ["Cu"], [[0, 0, 0]])
1✔
168

169
        # Compute the NNs with adjacency
170
        self.nn.targets = None
1✔
171
        neighbors = self.nn.get_voronoi_polyhedra(s, 0)
1✔
172

173
        # Each neighbor has 4 adjacent neighbors, all orthogonal
174
        for nn_info in neighbors.values():
1✔
175
            assert 4 == len(nn_info["adj_neighbors"])
1✔
176

177
            for adj_key in nn_info["adj_neighbors"]:
1✔
178
                assert 0 == np.dot(nn_info["normal"], neighbors[adj_key]["normal"])
1✔
179

180
    def test_all_at_once(self):
1✔
181
        # Get all of the sites for LiFePO4
182
        all_sites = self.nn.get_all_voronoi_polyhedra(self.s)
1✔
183

184
        # Make sure they are the same as the single-atom ones
185
        for i, site in enumerate(all_sites):
1✔
186
            # Compute the tessellation using only one site
187
            by_one = self.nn.get_voronoi_polyhedra(self.s, i)
1✔
188

189
            # Match the coordinates the of the neighbors, as site matching does not seem to work?
190
            all_coords = np.sort([x["site"].coords for x in site.values()], axis=0)
1✔
191
            by_one_coords = np.sort([x["site"].coords for x in by_one.values()], axis=0)
1✔
192

193
            self.assertArrayAlmostEqual(all_coords, by_one_coords)
1✔
194

195
        # Test the nn_info operation
196
        all_nn_info = self.nn.get_all_nn_info(self.s)
1✔
197
        for i, info in enumerate(all_nn_info):
1✔
198
            # Compute using the by-one method
199
            by_one = self.nn.get_nn_info(self.s, i)
1✔
200

201
            # Get the weights
202
            all_weights = sorted(x["weight"] for x in info)
1✔
203
            by_one_weights = sorted(x["weight"] for x in by_one)
1✔
204

205
            self.assertArrayAlmostEqual(all_weights, by_one_weights)
1✔
206

207
    def test_Cs2O(self):
1✔
208
        """A problematic structure in the Materials Project"""
209
        strc = Structure(
1✔
210
            [
211
                [4.358219, 0.192833, 6.406960],
212
                [2.114414, 3.815824, 6.406960],
213
                [0.311360, 0.192833, 7.742498],
214
            ],
215
            ["O", "Cs", "Cs"],
216
            [[0, 0, 0], [0.264318, 0.264318, 0.264318], [0.735682, 0.735682, 0.735682]],
217
            coords_are_cartesian=False,
218
        )
219

220
        # Compute the voronoi tessellation
221
        result = VoronoiNN().get_all_voronoi_polyhedra(strc)
1✔
222
        assert 3 == len(result)
1✔
223

224
    def test_filtered(self):
1✔
225
        nn = VoronoiNN(weight="area")
1✔
226

227
        # Make a bcc crystal
228
        bcc = Structure(
1✔
229
            [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
230
            ["Cu", "Cu"],
231
            [[0, 0, 0], [0.5, 0.5, 0.5]],
232
            coords_are_cartesian=False,
233
        )
234

235
        # Compute the weight of the little face
236
        big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4)
1✔
237
        small_face_area = 0.125
1✔
238
        little_weight = small_face_area / big_face_area
1✔
239

240
        # Run one test where you get the small neighbors
241
        nn.tol = little_weight * 0.99
1✔
242
        nns = nn.get_nn_info(bcc, 0)
1✔
243
        assert 14 == len(nns)
1✔
244

245
        # Run a second test where we screen out little faces
246
        nn.tol = little_weight * 1.01
1✔
247
        nns = nn.get_nn_info(bcc, 0)
1✔
248
        assert 8 == len(nns)
1✔
249

250
        # Make sure it works for the `get_all` operation
251
        all_nns = nn.get_all_nn_info(bcc * [2, 2, 2])
1✔
252
        assert [
1✔
253
            8,
254
        ] * 16 == [len(x) for x in all_nns]
255

256
    def tearDown(self):
1✔
257
        del self.s
1✔
258
        del self.nn
1✔
259

260

261
class JmolNNTest(PymatgenTest):
1✔
262
    def setUp(self):
1✔
263
        self.jmol = JmolNN()
1✔
264
        self.jmol_update = JmolNN(el_radius_updates={"Li": 1})
1✔
265

266
    def test_get_nn(self):
1✔
267
        s = self.get_structure("LiFePO4")
1✔
268

269
        # Test the default near-neighbor finder.
270
        nsites_checked = 0
1✔
271

272
        for site_idx, site in enumerate(s):
1✔
273
            if site.specie == Element("Li"):
1✔
274
                assert self.jmol.get_cn(s, site_idx) == 0
1✔
275
                nsites_checked += 1
1✔
276
            elif site.specie == Element("Fe"):
1✔
277
                assert self.jmol.get_cn(s, site_idx) == 6
1✔
278
                nsites_checked += 1
1✔
279
            elif site.specie == Element("P"):
1✔
280
                assert self.jmol.get_cn(s, site_idx) == 4
1✔
281
                nsites_checked += 1
1✔
282
        assert nsites_checked == 12
1✔
283

284
        # Test a user override that would cause Li to show up as 2-coordinated
285
        assert self.jmol_update.get_cn(s, 0) == 2
1✔
286

287
        # Verify get_nn function works
288
        assert len(self.jmol_update.get_nn(s, 0)) == 2
1✔
289

290
    def tearDown(self):
1✔
291
        del self.jmol
1✔
292
        del self.jmol_update
1✔
293

294

295
class TestIsayevNN(PymatgenTest):
1✔
296
    def test_get_nn(self):
1✔
297
        inn = IsayevNN()
1✔
298
        s = self.get_structure("LiFePO4")
1✔
299

300
        assert inn.get_cn(s, 0) == 2
1✔
301
        assert inn.get_cn(s, 5) == 6
1✔
302
        assert inn.get_cn(s, 10) == 4
1✔
303
        assert len(inn.get_nn(s, 0)) == 2
1✔
304

305

306
class OpenBabelNNTest(PymatgenTest):
1✔
307
    def setUp(self):
1✔
308
        pytest.importorskip("openbabel", reason="OpenBabel not installed")
1✔
309
        self.benzene = Molecule.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "benzene.xyz"))
×
310
        self.acetylene = Molecule.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "acetylene.xyz"))
×
311

312
    def test_nn_orders(self):
1✔
313
        strategy = OpenBabelNN()
×
314
        acetylene = strategy.get_nn_info(self.acetylene, 0)
×
315
        assert acetylene[0]["weight"] == 3
×
316
        assert acetylene[1]["weight"] == 1
×
317

318
        # Currently, benzene bonds register either as double or single,
319
        # not aromatic
320
        # Instead of searching for aromatic bonds, we check that bonds are
321
        # detected in the same way from both sides
322
        assert strategy.get_nn_info(self.benzene, 0)[0]["weight"] == strategy.get_nn_info(self.benzene, 1)[0]["weight"]
×
323

324
    def test_nn_length(self):
1✔
325
        strategy = OpenBabelNN(order=False)
×
326

327
        benzene_bonds = strategy.get_nn_info(self.benzene, 0)
×
328

329
        c_bonds = [b for b in benzene_bonds if str(b["site"].specie) == "C"]
×
330
        h_bonds = [b for b in benzene_bonds if str(b["site"].specie) == "H"]
×
331

332
        assert c_bonds[0]["weight"] == approx(1.41, abs=1e-2)
×
333
        assert h_bonds[0]["weight"] == approx(1.02, abs=1e-2)
×
334

335
        assert strategy.get_nn_info(self.acetylene, 0)[0]["weight"] == approx(1.19, abs=1e-2)
×
336

337
    def tearDown(self):
1✔
338
        del self.benzene
×
339
        del self.acetylene
×
340

341

342
class CovalentBondNNTest(PymatgenTest):
1✔
343
    def setUp(self):
1✔
344
        self.benzene = Molecule.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "benzene.xyz"))
1✔
345
        self.acetylene = Molecule.from_file(os.path.join(PymatgenTest.TEST_FILES_DIR, "acetylene.xyz"))
1✔
346

347
    def test_nn_orders(self):
1✔
348
        strategy = CovalentBondNN()
1✔
349

350
        acetylene = strategy.get_nn_info(self.acetylene, 0)
1✔
351
        assert acetylene[0]["weight"] == 3
1✔
352
        assert acetylene[1]["weight"] == 1
1✔
353

354
        benzene = strategy.get_nn_info(self.benzene, 0)
1✔
355
        assert benzene[0]["weight"] == approx(1.6596, abs=1e-4)
1✔
356

357
    def test_nn_length(self):
1✔
358
        strategy = CovalentBondNN(order=False)
1✔
359

360
        benzene_bonds = strategy.get_nn_info(self.benzene, 0)
1✔
361

362
        c_bonds = [b for b in benzene_bonds if str(b["site"].specie) == "C"]
1✔
363
        h_bonds = [b for b in benzene_bonds if str(b["site"].specie) == "H"]
1✔
364

365
        assert c_bonds[0]["weight"] == approx(1.41, abs=1e-2)
1✔
366
        assert h_bonds[0]["weight"] == approx(1.02, abs=1e-2)
1✔
367

368
        acetylene = strategy.get_nn_info(self.acetylene, 0)
1✔
369
        assert acetylene[0]["weight"] == approx(1.19, abs=1e-2)
1✔
370

371
    def test_bonded_structure(self):
1✔
372
        strategy = CovalentBondNN()
1✔
373

374
        benzene = strategy.get_bonded_structure(self.benzene)
1✔
375
        assert len(benzene.find_rings()) == 1
1✔
376

377
        acetylene = strategy.get_bonded_structure(self.acetylene)
1✔
378
        assert len(acetylene.graph.nodes) == 4
1✔
379

380
    def tearDown(self):
1✔
381
        del self.benzene
1✔
382
        del self.acetylene
1✔
383

384

385
class MiniDistNNTest(PymatgenTest):
1✔
386
    def setUp(self):
1✔
387
        self.diamond = Structure(
1✔
388
            Lattice([[2.189, 0, 1.264], [0.73, 2.064, 1.264], [0, 0, 2.528]]),
389
            ["C0+", "C0+"],
390
            [[2.554, 1.806, 4.423], [0.365, 0.258, 0.632]],
391
            validate_proximity=False,
392
            to_unit_cell=False,
393
            coords_are_cartesian=True,
394
            site_properties=None,
395
        )
396
        self.nacl = Structure(
1✔
397
            Lattice([[3.485, 0, 2.012], [1.162, 3.286, 2.012], [0, 0, 4.025]]),
398
            ["Na1+", "Cl1-"],
399
            [[0, 0, 0], [2.324, 1.643, 4.025]],
400
            validate_proximity=False,
401
            to_unit_cell=False,
402
            coords_are_cartesian=True,
403
            site_properties=None,
404
        )
405
        self.cscl = Structure(
1✔
406
            Lattice([[4.209, 0, 0], [0, 4.209, 0], [0, 0, 4.209]]),
407
            ["Cl1-", "Cs1+"],
408
            [[2.105, 2.105, 2.105], [0, 0, 0]],
409
            validate_proximity=False,
410
            to_unit_cell=False,
411
            coords_are_cartesian=True,
412
            site_properties=None,
413
        )
414
        self.mos2 = Structure(
1✔
415
            Lattice([[3.19, 0, 0], [-1.595, 2.763, 0], [0, 0, 17.44]]),
416
            ["Mo", "S", "S"],
417
            [[-1e-06, 1.842, 3.72], [1.595, 0.92, 5.29], [1.595, 0.92, 2.155]],
418
            coords_are_cartesian=True,
419
        )
420
        self.lifepo4 = self.get_structure("LiFePO4")
1✔
421
        self.lifepo4.add_oxidation_state_by_guess()
1✔
422

423
    def test_all_nn_classes(self):
1✔
424
        assert MinimumDistanceNN(cutoff=5, get_all_sites=True).get_cn(self.cscl, 0) == 14
1✔
425
        assert MinimumDistanceNN().get_cn(self.diamond, 0) == 4
1✔
426
        assert MinimumDistanceNN().get_cn(self.nacl, 0) == 6
1✔
427
        assert MinimumDistanceNN().get_cn(self.lifepo4, 0) == 6
1✔
428
        assert MinimumDistanceNN(tol=0.01).get_cn(self.cscl, 0) == 8
1✔
429
        assert MinimumDistanceNN(tol=0.1).get_cn(self.mos2, 0) == 6
1✔
430

431
        for image in MinimumDistanceNN(tol=0.1).get_nn_images(self.mos2, 0):
1✔
432
            assert image in [(0, 0, 0), (0, 1, 0), (-1, 0, 0), (0, 0, 0), (0, 1, 0), (-1, 0, 0)]
1✔
433

434
        okeeffe = MinimumOKeeffeNN(tol=0.01)
1✔
435
        assert okeeffe.get_cn(self.diamond, 0) == 4
1✔
436
        assert okeeffe.get_cn(self.nacl, 0) == 6
1✔
437
        assert okeeffe.get_cn(self.cscl, 0) == 8
1✔
438
        assert okeeffe.get_cn(self.lifepo4, 0) == 2
1✔
439

440
        virenn = MinimumVIRENN(tol=0.01)
1✔
441
        assert virenn.get_cn(self.diamond, 0) == 4
1✔
442
        assert virenn.get_cn(self.nacl, 0) == 6
1✔
443
        assert virenn.get_cn(self.cscl, 0) == 8
1✔
444
        assert virenn.get_cn(self.lifepo4, 0) == 2
1✔
445

446
        brunner_recip = BrunnerNN_reciprocal(tol=0.01)
1✔
447
        assert brunner_recip.get_cn(self.diamond, 0) == 4
1✔
448
        assert brunner_recip.get_cn(self.nacl, 0) == 6
1✔
449
        assert brunner_recip.get_cn(self.cscl, 0) == 14
1✔
450
        assert brunner_recip.get_cn(self.lifepo4, 0) == 6
1✔
451

452
        brunner_rel = BrunnerNN_relative(tol=0.01)
1✔
453
        assert brunner_rel.get_cn(self.diamond, 0) == 4
1✔
454
        assert brunner_rel.get_cn(self.nacl, 0) == 6
1✔
455
        assert brunner_rel.get_cn(self.cscl, 0) == 14
1✔
456
        assert brunner_rel.get_cn(self.lifepo4, 0) == 6
1✔
457

458
        brunner_real = BrunnerNN_real(tol=0.01)
1✔
459
        assert brunner_real.get_cn(self.diamond, 0) == 4
1✔
460
        assert brunner_real.get_cn(self.nacl, 0) == 6
1✔
461
        assert brunner_real.get_cn(self.cscl, 0) == 14
1✔
462
        assert brunner_real.get_cn(self.lifepo4, 0) == 30
1✔
463

464
        econn = EconNN()
1✔
465
        assert econn.get_cn(self.diamond, 0) == 4
1✔
466
        assert econn.get_cn(self.nacl, 0) == 6
1✔
467
        assert econn.get_cn(self.cscl, 0) == 14
1✔
468
        assert econn.get_cn(self.lifepo4, 0) == 6
1✔
469

470
        voroinn = VoronoiNN(tol=0.5)
1✔
471
        assert voroinn.get_cn(self.diamond, 0) == 4
1✔
472
        assert voroinn.get_cn(self.nacl, 0) == 6
1✔
473
        assert voroinn.get_cn(self.cscl, 0) == 8
1✔
474
        assert voroinn.get_cn(self.lifepo4, 0) == 6
1✔
475

476
        crystalnn = CrystalNN()
1✔
477
        assert crystalnn.get_cn(self.diamond, 0) == 4
1✔
478
        assert crystalnn.get_cn(self.nacl, 0) == 6
1✔
479
        assert crystalnn.get_cn(self.cscl, 0) == 8
1✔
480
        assert crystalnn.get_cn(self.lifepo4, 0) == 6
1✔
481

482
    def test_get_local_order_params(self):
1✔
483
        nn = MinimumDistanceNN()
1✔
484
        ops = nn.get_local_order_parameters(self.diamond, 0)
1✔
485
        assert ops["tetrahedral"] == approx(0.9999934389036574)
1✔
486

487
        ops = nn.get_local_order_parameters(self.nacl, 0)
1✔
488
        assert ops["octahedral"] == approx(0.9999995266669)
1✔
489

490

491
class MotifIdentificationTest(PymatgenTest):
1✔
492
    def setUp(self):
1✔
493
        self.silicon = Structure(
1✔
494
            Lattice.cubic(5.47),
495
            ["Si", "Si", "Si", "Si", "Si", "Si", "Si", "Si"],
496
            [
497
                [0.000000, 0.000000, 0.500000],
498
                [0.750000, 0.750000, 0.750000],
499
                [0.000000, 0.500000, 1.000000],
500
                [0.750000, 0.250000, 0.250000],
501
                [0.500000, 0.000000, 1.000000],
502
                [0.250000, 0.750000, 0.250000],
503
                [0.500000, 0.500000, 0.500000],
504
                [0.250000, 0.250000, 0.750000],
505
            ],
506
            validate_proximity=False,
507
            to_unit_cell=False,
508
            coords_are_cartesian=False,
509
            site_properties=None,
510
        )
511
        self.diamond = Structure(
1✔
512
            Lattice([[2.189, 0, 1.264], [0.73, 2.064, 1.264], [0, 0, 2.528]]),
513
            ["C0+", "C0+"],
514
            [[2.554, 1.806, 4.423], [0.365, 0.258, 0.632]],
515
            validate_proximity=False,
516
            to_unit_cell=False,
517
            coords_are_cartesian=True,
518
            site_properties=None,
519
        )
520
        self.nacl = Structure(
1✔
521
            Lattice([[3.485, 0, 2.012], [1.162, 3.286, 2.012], [0, 0, 4.025]]),
522
            ["Na1+", "Cl1-"],
523
            [[0, 0, 0], [2.324, 1.643, 4.025]],
524
            validate_proximity=False,
525
            to_unit_cell=False,
526
            coords_are_cartesian=True,
527
            site_properties=None,
528
        )
529
        self.cscl = Structure(
1✔
530
            Lattice([[4.209, 0, 0], [0, 4.209, 0], [0, 0, 4.209]]),
531
            ["Cl1-", "Cs1+"],
532
            [[2.105, 2.105, 2.105], [0, 0, 0]],
533
            validate_proximity=False,
534
            to_unit_cell=False,
535
            coords_are_cartesian=True,
536
            site_properties=None,
537
        )
538
        self.square_pyramid = Structure(
1✔
539
            Lattice([[100, 0, 0], [0, 100, 0], [0, 0, 100]]),
540
            ["C", "C", "C", "C", "C", "C"],
541
            [[0, 0, 0], [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1]],
542
            validate_proximity=False,
543
            to_unit_cell=False,
544
            coords_are_cartesian=True,
545
            site_properties=None,
546
        )
547
        self.trigonal_bipyramid = Structure(
1✔
548
            Lattice([[100, 0, 0], [0, 100, 0], [0, 0, 100]]),
549
            ["P", "Cl", "Cl", "Cl", "Cl", "Cl"],
550
            [
551
                [0, 0, 0],
552
                [0, 0, 2.14],
553
                [0, 2.02, 0],
554
                [1.74937, -1.01, 0],
555
                [-1.74937, -1.01, 0],
556
                [0, 0, -2.14],
557
            ],
558
            validate_proximity=False,
559
            to_unit_cell=False,
560
            coords_are_cartesian=True,
561
            site_properties=None,
562
        )
563

564
    def test_site_is_of_motif_type(self):
1✔
565
        for i in range(self.diamond.num_sites):
1✔
566
            assert site_is_of_motif_type(self.diamond, i) == "tetrahedral"
1✔
567
        for i in range(self.nacl.num_sites):
1✔
568
            assert site_is_of_motif_type(self.nacl, i) == "octahedral"
1✔
569
        for i in range(self.cscl.num_sites):
1✔
570
            assert site_is_of_motif_type(self.cscl, i) == "bcc"
1✔
571
        assert site_is_of_motif_type(self.square_pyramid, 0) == "square pyramidal"
1✔
572
        for i in range(1, self.square_pyramid.num_sites):
1✔
573
            assert site_is_of_motif_type(self.square_pyramid, i) == "unrecognized"
1✔
574
        assert site_is_of_motif_type(self.trigonal_bipyramid, 0) == "trigonal bipyramidal"
1✔
575
        for i in range(1, self.trigonal_bipyramid.num_sites):
1✔
576
            assert site_is_of_motif_type(self.trigonal_bipyramid, i) == "unrecognized"
1✔
577

578
    def test_get_neighbors_of_site_with_index(self):
1✔
579
        assert len(get_neighbors_of_site_with_index(self.diamond, 0)) == 4
1✔
580
        assert len(get_neighbors_of_site_with_index(self.nacl, 0)) == 6
1✔
581
        assert len(get_neighbors_of_site_with_index(self.cscl, 0)) == 8
1✔
582
        assert len(get_neighbors_of_site_with_index(self.diamond, 0, delta=0.01)) == 4
1✔
583
        assert len(get_neighbors_of_site_with_index(self.diamond, 0, cutoff=6)) == 4
1✔
584
        assert len(get_neighbors_of_site_with_index(self.diamond, 0, approach="voronoi")) == 4
1✔
585
        assert len(get_neighbors_of_site_with_index(self.diamond, 0, approach="min_OKeeffe")) == 4
1✔
586
        assert len(get_neighbors_of_site_with_index(self.diamond, 0, approach="min_VIRE")) == 4
1✔
587

588
    def tearDown(self):
1✔
589
        del self.silicon
1✔
590
        del self.diamond
1✔
591
        del self.nacl
1✔
592
        del self.cscl
1✔
593

594

595
class NearNeighborTest(PymatgenTest):
1✔
596
    def setUp(self):
1✔
597
        self.diamond = Structure(
1✔
598
            Lattice([[2.189, 0, 1.264], [0.73, 2.064, 1.264], [0, 0, 2.528]]),
599
            ["C0+", "C0+"],
600
            [[2.554, 1.806, 4.423], [0.365, 0.258, 0.632]],
601
            validate_proximity=False,
602
            to_unit_cell=False,
603
            coords_are_cartesian=True,
604
            site_properties=None,
605
        )
606

607
    def set_nn_info(self):
1✔
608
        # check conformance
609
        # implicitly assumes that all NearNeighbors subclasses
610
        # will correctly identify bonds in diamond, if it
611
        # can't there are probably bigger problems
612
        subclasses = NearNeighbors.__subclasses__()
×
613
        for subclass in subclasses:
×
614
            # Critic2NN has external dependency, is tested separately
615
            if "Critic2" not in str(subclass):
×
616
                nn_info = subclass().get_nn_info(self.diamond, 0)
×
617
                assert nn_info[0]["site_index"] == 1
×
618
                assert nn_info[0]["image"][0] == 1
×
619

620
    def test_on_disorder_options(self):
1✔
621
        assert get_args(on_disorder_options) == (
1✔
622
            "take_majority_strict",
623
            "take_majority_drop",
624
            "take_max_species",
625
            "error",
626
        )
627

628
    def tearDown(self):
1✔
629
        del self.diamond
1✔
630

631

632
class LocalStructOrderParamsTest(PymatgenTest):
1✔
633
    def setUp(self):
1✔
634
        self.single_bond = Structure(
1✔
635
            Lattice.cubic(10),
636
            ["H", "H", "H"],
637
            [[1, 0, 0], [0, 0, 0], [6, 0, 0]],
638
            validate_proximity=False,
639
            to_unit_cell=False,
640
            coords_are_cartesian=True,
641
            site_properties=None,
642
        )
643
        self.linear = Structure(
1✔
644
            Lattice.cubic(10),
645
            ["H", "H", "H"],
646
            [[1, 0, 0], [0, 0, 0], [2, 0, 0]],
647
            validate_proximity=False,
648
            to_unit_cell=False,
649
            coords_are_cartesian=True,
650
            site_properties=None,
651
        )
652
        self.bent45 = Structure(
1✔
653
            Lattice.cubic(10),
654
            ["H", "H", "H"],
655
            [[0, 0, 0], [0.707, 0.707, 0], [0.707, 0, 0]],
656
            validate_proximity=False,
657
            to_unit_cell=False,
658
            coords_are_cartesian=True,
659
            site_properties=None,
660
        )
661
        self.cubic = Structure(
1✔
662
            Lattice.cubic(1),
663
            ["H"],
664
            [[0, 0, 0]],
665
            validate_proximity=False,
666
            to_unit_cell=False,
667
            coords_are_cartesian=False,
668
            site_properties=None,
669
        )
670
        self.bcc = Structure(
1✔
671
            Lattice.cubic(1),
672
            ["H", "H"],
673
            [[0, 0, 0], [0.5, 0.5, 0.5]],
674
            validate_proximity=False,
675
            to_unit_cell=False,
676
            coords_are_cartesian=False,
677
            site_properties=None,
678
        )
679
        self.fcc = Structure(
1✔
680
            Lattice.cubic(1),
681
            ["H", "H", "H", "H"],
682
            [[0, 0, 0], [0, 0.5, 0.5], [0.5, 0, 0.5], [0.5, 0.5, 0]],
683
            validate_proximity=False,
684
            to_unit_cell=False,
685
            coords_are_cartesian=False,
686
            site_properties=None,
687
        )
688
        self.hcp = Structure(
1✔
689
            Lattice.hexagonal(1, 1.633),
690
            ["H", "H"],
691
            [[0.3333, 0.6667, 0.25], [0.6667, 0.3333, 0.75]],
692
            validate_proximity=False,
693
            to_unit_cell=False,
694
            coords_are_cartesian=False,
695
            site_properties=None,
696
        )
697
        self.diamond = Structure(
1✔
698
            Lattice.cubic(1),
699
            ["H", "H", "H", "H", "H", "H", "H", "H"],
700
            [
701
                [0, 0, 0.5],
702
                [0.75, 0.75, 0.75],
703
                [0, 0.5, 0],
704
                [0.75, 0.25, 0.25],
705
                [0.5, 0, 0],
706
                [0.25, 0.75, 0.25],
707
                [0.5, 0.5, 0.5],
708
                [0.25, 0.25, 0.75],
709
            ],
710
            validate_proximity=False,
711
            to_unit_cell=False,
712
            coords_are_cartesian=False,
713
            site_properties=None,
714
        )
715
        self.trigonal_off_plane = Structure(
1✔
716
            Lattice.cubic(100),
717
            ["H", "H", "H", "H"],
718
            [
719
                [0.50, 0.50, 0.50],
720
                [0.25, 0.75, 0.25],
721
                [0.25, 0.25, 0.75],
722
                [0.75, 0.25, 0.25],
723
            ],
724
            validate_proximity=False,
725
            to_unit_cell=False,
726
            coords_are_cartesian=True,
727
            site_properties=None,
728
        )
729
        self.regular_triangle = Structure(
1✔
730
            Lattice.cubic(30),
731
            ["H", "H", "H", "H"],
732
            [[15, 15.28867, 15.65], [14.5, 15, 15], [15.5, 15, 15], [15, 15.866, 15]],
733
            validate_proximity=False,
734
            to_unit_cell=False,
735
            coords_are_cartesian=True,
736
            site_properties=None,
737
        )
738
        self.trigonal_planar = Structure(
1✔
739
            Lattice.cubic(30),
740
            ["H", "H", "H", "H"],
741
            [[15, 15.28867, 15], [14.5, 15, 15], [15.5, 15, 15], [15, 15.866, 15]],
742
            validate_proximity=False,
743
            to_unit_cell=False,
744
            coords_are_cartesian=True,
745
            site_properties=None,
746
        )
747
        self.square_planar = Structure(
1✔
748
            Lattice.cubic(30),
749
            ["H", "H", "H", "H", "H"],
750
            [
751
                [15, 15, 15],
752
                [14.75, 14.75, 15],
753
                [14.75, 15.25, 15],
754
                [15.25, 14.75, 15],
755
                [15.25, 15.25, 15],
756
            ],
757
            validate_proximity=False,
758
            to_unit_cell=False,
759
            coords_are_cartesian=True,
760
            site_properties=None,
761
        )
762
        self.square = Structure(
1✔
763
            Lattice.cubic(30),
764
            ["H", "H", "H", "H", "H"],
765
            [
766
                [15, 15, 15.707],
767
                [14.75, 14.75, 15],
768
                [14.75, 15.25, 15],
769
                [15.25, 14.75, 15],
770
                [15.25, 15.25, 15],
771
            ],
772
            validate_proximity=False,
773
            to_unit_cell=False,
774
            coords_are_cartesian=True,
775
            site_properties=None,
776
        )
777
        self.T_shape = Structure(
1✔
778
            Lattice.cubic(30),
779
            ["H", "H", "H", "H"],
780
            [[15, 15, 15], [15, 15, 15.5], [15, 15.5, 15], [15, 14.5, 15]],
781
            validate_proximity=False,
782
            to_unit_cell=False,
783
            coords_are_cartesian=True,
784
            site_properties=None,
785
        )
786
        self.square_pyramid = Structure(
1✔
787
            Lattice.cubic(30),
788
            ["H", "H", "H", "H", "H", "H"],
789
            [
790
                [15, 15, 15],
791
                [15, 15, 15.3535],
792
                [14.75, 14.75, 15],
793
                [14.75, 15.25, 15],
794
                [15.25, 14.75, 15],
795
                [15.25, 15.25, 15],
796
            ],
797
            validate_proximity=False,
798
            to_unit_cell=False,
799
            coords_are_cartesian=True,
800
            site_properties=None,
801
        )
802
        self.pentagonal_planar = Structure(
1✔
803
            Lattice.cubic(30),
804
            ["Xe", "F", "F", "F", "F", "F"],
805
            [
806
                [0, -1.6237, 0],
807
                [1.17969, 0, 0],
808
                [-1.17969, 0, 0],
809
                [1.90877, -2.24389, 0],
810
                [-1.90877, -2.24389, 0],
811
                [0, -3.6307, 0],
812
            ],
813
            validate_proximity=False,
814
            to_unit_cell=False,
815
            coords_are_cartesian=True,
816
            site_properties=None,
817
        )
818
        self.pentagonal_pyramid = Structure(
1✔
819
            Lattice.cubic(30),
820
            ["Xe", "F", "F", "F", "F", "F", "F"],
821
            [
822
                [0, -1.6237, 0],
823
                [0, -1.6237, 1.17969],
824
                [1.17969, 0, 0],
825
                [-1.17969, 0, 0],
826
                [1.90877, -2.24389, 0],
827
                [-1.90877, -2.24389, 0],
828
                [0, -3.6307, 0],
829
            ],
830
            validate_proximity=False,
831
            to_unit_cell=False,
832
            coords_are_cartesian=True,
833
            site_properties=None,
834
        )
835
        self.pentagonal_bipyramid = Structure(
1✔
836
            Lattice.cubic(30),
837
            ["Xe", "F", "F", "F", "F", "F", "F", "F"],
838
            [
839
                [0, -1.6237, 0],
840
                [0, -1.6237, -1.17969],
841
                [0, -1.6237, 1.17969],
842
                [1.17969, 0, 0],
843
                [-1.17969, 0, 0],
844
                [1.90877, -2.24389, 0],
845
                [-1.90877, -2.24389, 0],
846
                [0, -3.6307, 0],
847
            ],
848
            validate_proximity=False,
849
            to_unit_cell=False,
850
            coords_are_cartesian=True,
851
            site_properties=None,
852
        )
853
        self.hexagonal_planar = Structure(
1✔
854
            Lattice.cubic(30),
855
            ["H", "C", "C", "C", "C", "C", "C"],
856
            [
857
                [0, 0, 0],
858
                [0.71, 1.2298, 0],
859
                [-0.71, 1.2298, 0],
860
                [0.71, -1.2298, 0],
861
                [-0.71, -1.2298, 0],
862
                [1.4199, 0, 0],
863
                [-1.4199, 0, 0],
864
            ],
865
            validate_proximity=False,
866
            to_unit_cell=False,
867
            coords_are_cartesian=True,
868
            site_properties=None,
869
        )
870
        self.hexagonal_pyramid = Structure(
1✔
871
            Lattice.cubic(30),
872
            ["H", "Li", "C", "C", "C", "C", "C", "C"],
873
            [
874
                [0, 0, 0],
875
                [0, 0, 1.675],
876
                [0.71, 1.2298, 0],
877
                [-0.71, 1.2298, 0],
878
                [0.71, -1.2298, 0],
879
                [-0.71, -1.2298, 0],
880
                [1.4199, 0, 0],
881
                [-1.4199, 0, 0],
882
            ],
883
            validate_proximity=False,
884
            to_unit_cell=False,
885
            coords_are_cartesian=True,
886
            site_properties=None,
887
        )
888
        self.hexagonal_bipyramid = Structure(
1✔
889
            Lattice.cubic(30),
890
            ["H", "Li", "Li", "C", "C", "C", "C", "C", "C"],
891
            [
892
                [0, 0, 0],
893
                [0, 0, 1.675],
894
                [0, 0, -1.675],
895
                [0.71, 1.2298, 0],
896
                [-0.71, 1.2298, 0],
897
                [0.71, -1.2298, 0],
898
                [-0.71, -1.2298, 0],
899
                [1.4199, 0, 0],
900
                [-1.4199, 0, 0],
901
            ],
902
            validate_proximity=False,
903
            to_unit_cell=False,
904
            coords_are_cartesian=True,
905
            site_properties=None,
906
        )
907
        self.trigonal_pyramid = Structure(
1✔
908
            Lattice.cubic(30),
909
            ["P", "Cl", "Cl", "Cl", "Cl"],
910
            [
911
                [0, 0, 0],
912
                [0, 0, 2.14],
913
                [0, 2.02, 0],
914
                [1.74937, -1.01, 0],
915
                [-1.74937, -1.01, 0],
916
            ],
917
            validate_proximity=False,
918
            to_unit_cell=False,
919
            coords_are_cartesian=True,
920
            site_properties=None,
921
        )
922
        self.trigonal_bipyramidal = Structure(
1✔
923
            Lattice.cubic(30),
924
            ["P", "Cl", "Cl", "Cl", "Cl", "Cl"],
925
            [
926
                [0, 0, 0],
927
                [0, 0, 2.14],
928
                [0, 2.02, 0],
929
                [1.74937, -1.01, 0],
930
                [-1.74937, -1.01, 0],
931
                [0, 0, -2.14],
932
            ],
933
            validate_proximity=False,
934
            to_unit_cell=False,
935
            coords_are_cartesian=True,
936
            site_properties=None,
937
        )
938
        self.cuboctahedron = Structure(
1✔
939
            Lattice.cubic(30),
940
            ["H", "H", "H", "H", "H", "H", "H", "H", "H", "H", "H", "H", "H"],
941
            [
942
                [15, 15, 15],
943
                [15, 14.5, 14.5],
944
                [15, 14.5, 15.5],
945
                [15, 15.5, 14.5],
946
                [15, 15.5, 15.5],
947
                [14.5, 15, 14.5],
948
                [14.5, 15, 15.5],
949
                [15.5, 15, 14.5],
950
                [15.5, 15, 15.5],
951
                [14.5, 14.5, 15],
952
                [14.5, 15.5, 15],
953
                [15.5, 14.5, 15],
954
                [15.5, 15.5, 15],
955
            ],
956
            validate_proximity=False,
957
            to_unit_cell=False,
958
            coords_are_cartesian=True,
959
            site_properties=None,
960
        )
961
        self.see_saw_rect = Structure(
1✔
962
            Lattice.cubic(30),
963
            ["H", "H", "H", "H", "H"],
964
            [
965
                [0.0, 0.0, 0.0],
966
                [1.0, 0.0, 0.0],
967
                [0.0, -1.0, 0.0],
968
                [0.0, 0.0, -1.0],
969
                [-1.0, 0.0, 0.0],
970
            ],
971
            validate_proximity=False,
972
            to_unit_cell=False,
973
            coords_are_cartesian=True,
974
            site_properties=None,
975
        )
976
        self.sq_face_capped_trig_pris = Structure(
1✔
977
            Lattice.cubic(30),
978
            ["H", "H", "H", "H", "H", "H", "H", "H"],
979
            [
980
                [0, 0, 0],
981
                [-0.6546536707079771, -0.37796447300922725, 0.6546536707079771],
982
                [0.6546536707079771, -0.37796447300922725, 0.6546536707079771],
983
                [0.0, 0.7559289460184545, 0.6546536707079771],
984
                [-0.6546536707079771, -0.37796447300922725, -0.6546536707079771],
985
                [0.6546536707079771, -0.37796447300922725, -0.6546536707079771],
986
                [0.0, 0.7559289460184545, -0.6546536707079771],
987
                [0.0, -1.0, 0.0],
988
            ],
989
            validate_proximity=False,
990
            to_unit_cell=False,
991
            coords_are_cartesian=True,
992
            site_properties=None,
993
        )
994

995
    def test_init(self):
1✔
996
        assert LocalStructOrderParams(["cn"], parameters=None, cutoff=0.99) is not None
1✔
997

998
        parameters = [{"norm": 2}]
1✔
999
        lostops = LocalStructOrderParams(["cn"], parameters=parameters)
1✔
1000
        tmp = lostops.get_parameters(0)
1✔
1001
        parameters[0]["norm"] = 3
1✔
1002
        assert tmp == lostops.get_parameters(0)
1✔
1003

1004
    def test_get_order_parameters(self):
1✔
1005
        # Set up everything.
1006
        op_types = [
1✔
1007
            "cn",
1008
            "bent",
1009
            "bent",
1010
            "tet",
1011
            "oct",
1012
            "bcc",
1013
            "q2",
1014
            "q4",
1015
            "q6",
1016
            "reg_tri",
1017
            "sq",
1018
            "sq_pyr_legacy",
1019
            "tri_bipyr",
1020
            "sgl_bd",
1021
            "tri_plan",
1022
            "sq_plan",
1023
            "pent_plan",
1024
            "sq_pyr",
1025
            "tri_pyr",
1026
            "pent_pyr",
1027
            "hex_pyr",
1028
            "pent_bipyr",
1029
            "hex_bipyr",
1030
            "T",
1031
            "cuboct",
1032
            "see_saw_rect",
1033
            "hex_plan_max",
1034
            "tet_max",
1035
            "oct_max",
1036
            "tri_plan_max",
1037
            "sq_plan_max",
1038
            "pent_plan_max",
1039
            "cuboct_max",
1040
            "tet_max",
1041
            "sq_face_cap_trig_pris",
1042
        ]
1043
        op_params = [None for i in range(len(op_types))]
1✔
1044
        op_params[1] = {"TA": 1, "IGW_TA": 1.0 / 0.0667}
1✔
1045
        op_params[2] = {"TA": 45.0 / 180, "IGW_TA": 1.0 / 0.0667}
1✔
1046
        op_params[33] = {
1✔
1047
            "TA": 0.6081734479693927,
1048
            "IGW_TA": 18.33,
1049
            "fac_AA": 1.5,
1050
            "exp_cos_AA": 2,
1051
        }
1052
        ops_044 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.44)
1✔
1053
        ops_071 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.71)
1✔
1054
        ops_087 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.87)
1✔
1055
        ops_099 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=0.99)
1✔
1056
        ops_101 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=1.01)
1✔
1057
        ops_501 = LocalStructOrderParams(op_types, parameters=op_params, cutoff=5.01)
1✔
1058
        _ = LocalStructOrderParams(op_types, parameters=op_params)
1✔
1059

1060
        # Single bond.
1061
        op_vals = ops_101.get_order_parameters(self.single_bond, 0)
1✔
1062
        assert int(op_vals[13] * 1000) == approx(1000)
1✔
1063
        op_vals = ops_501.get_order_parameters(self.single_bond, 0)
1✔
1064
        assert int(op_vals[13] * 1000) == approx(799)
1✔
1065
        op_vals = ops_101.get_order_parameters(self.linear, 0)
1✔
1066
        assert int(op_vals[13] * 1000) == approx(0)
1✔
1067

1068
        # Linear motif.
1069
        op_vals = ops_101.get_order_parameters(self.linear, 0)
1✔
1070
        assert int(op_vals[1] * 1000) == approx(1000)
1✔
1071

1072
        # 45 degrees-bent motif.
1073
        op_vals = ops_101.get_order_parameters(self.bent45, 0)
1✔
1074
        assert int(op_vals[2] * 1000) == approx(1000)
1✔
1075

1076
        # T-shape motif.
1077
        op_vals = ops_101.get_order_parameters(self.T_shape, 0, indices_neighs=[1, 2, 3])
1✔
1078
        assert int(op_vals[23] * 1000) == approx(1000)
1✔
1079

1080
        # Cubic structure.
1081
        op_vals = ops_099.get_order_parameters(self.cubic, 0)
1✔
1082
        assert op_vals[0] == approx(0.0)
1✔
1083
        assert op_vals[3] is None
1✔
1084
        assert op_vals[4] is None
1✔
1085
        assert op_vals[5] is None
1✔
1086
        assert op_vals[6] is None
1✔
1087
        assert op_vals[7] is None
1✔
1088
        assert op_vals[8] is None
1✔
1089
        op_vals = ops_101.get_order_parameters(self.cubic, 0)
1✔
1090
        assert op_vals[0] == approx(6.0)
1✔
1091
        assert int(op_vals[3] * 1000) == approx(23)
1✔
1092
        assert int(op_vals[4] * 1000) == approx(1000)
1✔
1093
        assert int(op_vals[5] * 1000) == approx(333)
1✔
1094
        assert int(op_vals[6] * 1000) == approx(0)
1✔
1095
        assert int(op_vals[7] * 1000) == approx(763)
1✔
1096
        assert int(op_vals[8] * 1000) == approx(353)
1✔
1097
        assert int(op_vals[28] * 1000) == approx(1000)
1✔
1098

1099
        # Bcc structure.
1100
        op_vals = ops_087.get_order_parameters(self.bcc, 0)
1✔
1101
        assert op_vals[0] == approx(8.0)
1✔
1102
        assert int(op_vals[3] * 1000) == approx(200)
1✔
1103
        assert int(op_vals[4] * 1000) == approx(145)
1✔
1104
        assert int(op_vals[5] * 1000 + 0.5) == approx(1000)
1✔
1105
        assert int(op_vals[6] * 1000) == approx(0)
1✔
1106
        assert int(op_vals[7] * 1000) == approx(509)
1✔
1107
        assert int(op_vals[8] * 1000) == approx(628)
1✔
1108

1109
        # Fcc structure.
1110
        op_vals = ops_071.get_order_parameters(self.fcc, 0)
1✔
1111
        assert op_vals[0] == approx(12.0)
1✔
1112
        assert int(op_vals[3] * 1000) == approx(36)
1✔
1113
        assert int(op_vals[4] * 1000) == approx(78)
1✔
1114
        assert int(op_vals[5] * 1000) == approx(-2)
1✔
1115
        assert int(op_vals[6] * 1000) == approx(0)
1✔
1116
        assert int(op_vals[7] * 1000) == approx(190)
1✔
1117
        assert int(op_vals[8] * 1000) == approx(574)
1✔
1118

1119
        # Hcp structure.
1120
        op_vals = ops_101.get_order_parameters(self.hcp, 0)
1✔
1121
        assert op_vals[0] == approx(12.0)
1✔
1122
        assert int(op_vals[3] * 1000) == approx(33)
1✔
1123
        assert int(op_vals[4] * 1000) == approx(82)
1✔
1124
        # self.assertAlmostEqual(int(op_vals[5] * 1000), -26)
1125
        assert int(op_vals[6] * 1000) == approx(0)
1✔
1126
        assert int(op_vals[7] * 1000) == approx(97)
1✔
1127
        assert int(op_vals[8] * 1000) == approx(484)
1✔
1128

1129
        # Diamond structure.
1130
        op_vals = ops_044.get_order_parameters(self.diamond, 0)
1✔
1131
        assert op_vals[0] == approx(4.0)
1✔
1132
        assert int(op_vals[3] * 1000) == approx(1000)
1✔
1133
        assert int(op_vals[4] * 1000) == approx(37)
1✔
1134
        assert op_vals[5] == approx(0.75)
1✔
1135
        assert int(op_vals[6] * 1000) == approx(0)
1✔
1136
        assert int(op_vals[7] * 1000) == approx(509)
1✔
1137
        assert int(op_vals[8] * 1000) == approx(628)
1✔
1138
        assert int(op_vals[27] * 1000) == approx(1000)
1✔
1139

1140
        # Trigonal off-plane molecule.
1141
        op_vals = ops_044.get_order_parameters(self.trigonal_off_plane, 0)
1✔
1142
        assert op_vals[0] == approx(3.0)
1✔
1143
        assert int(op_vals[3] * 1000) == approx(1000)
1✔
1144
        assert int(op_vals[33] * 1000) == approx(1000)
1✔
1145

1146
        # Trigonal-planar motif.
1147
        op_vals = ops_101.get_order_parameters(self.trigonal_planar, 0)
1✔
1148
        assert int(op_vals[0] + 0.5) == 3
1✔
1149
        assert int(op_vals[14] * 1000 + 0.5) == approx(1000)
1✔
1150
        assert int(op_vals[29] * 1000 + 0.5) == approx(1000)
1✔
1151

1152
        # Regular triangle motif.
1153
        op_vals = ops_101.get_order_parameters(self.regular_triangle, 0)
1✔
1154
        assert int(op_vals[9] * 1000) == approx(999)
1✔
1155

1156
        # Square-planar motif.
1157
        op_vals = ops_101.get_order_parameters(self.square_planar, 0)
1✔
1158
        assert int(op_vals[15] * 1000 + 0.5) == approx(1000)
1✔
1159
        assert int(op_vals[30] * 1000 + 0.5) == approx(1000)
1✔
1160

1161
        # Square motif.
1162
        op_vals = ops_101.get_order_parameters(self.square, 0)
1✔
1163
        assert int(op_vals[10] * 1000) == approx(1000)
1✔
1164

1165
        # Pentagonal planar.
1166
        op_vals = ops_101.get_order_parameters(self.pentagonal_planar.sites, 0, indices_neighs=[1, 2, 3, 4, 5])
1✔
1167
        assert int(op_vals[12] * 1000 + 0.5) == approx(126)
1✔
1168
        assert int(op_vals[16] * 1000 + 0.5) == approx(1000)
1✔
1169
        assert int(op_vals[31] * 1000 + 0.5) == approx(1000)
1✔
1170

1171
        # Trigonal pyramid motif.
1172
        op_vals = ops_101.get_order_parameters(self.trigonal_pyramid, 0, indices_neighs=[1, 2, 3, 4])
1✔
1173
        assert int(op_vals[18] * 1000 + 0.5) == approx(1000)
1✔
1174

1175
        # Square pyramid motif.
1176
        op_vals = ops_101.get_order_parameters(self.square_pyramid, 0)
1✔
1177
        assert int(op_vals[11] * 1000 + 0.5) == approx(1000)
1✔
1178
        assert int(op_vals[12] * 1000 + 0.5) == approx(667)
1✔
1179
        assert int(op_vals[17] * 1000 + 0.5) == approx(1000)
1✔
1180

1181
        # Pentagonal pyramid motif.
1182
        op_vals = ops_101.get_order_parameters(self.pentagonal_pyramid, 0, indices_neighs=[1, 2, 3, 4, 5, 6])
1✔
1183
        assert int(op_vals[19] * 1000 + 0.5) == approx(1000)
1✔
1184

1185
        # Hexagonal pyramid motif.
1186
        op_vals = ops_101.get_order_parameters(self.hexagonal_pyramid, 0, indices_neighs=[1, 2, 3, 4, 5, 6, 7])
1✔
1187
        assert int(op_vals[20] * 1000 + 0.5) == approx(1000)
1✔
1188

1189
        # Trigonal bipyramidal.
1190
        op_vals = ops_101.get_order_parameters(self.trigonal_bipyramidal.sites, 0, indices_neighs=[1, 2, 3, 4, 5])
1✔
1191
        assert int(op_vals[12] * 1000 + 0.5) == approx(1000)
1✔
1192

1193
        # Pentagonal bipyramidal.
1194
        op_vals = ops_101.get_order_parameters(self.pentagonal_bipyramid.sites, 0, indices_neighs=[1, 2, 3, 4, 5, 6, 7])
1✔
1195
        assert int(op_vals[21] * 1000 + 0.5) == approx(1000)
1✔
1196

1197
        # Hexagonal bipyramid motif.
1198
        op_vals = ops_101.get_order_parameters(self.hexagonal_bipyramid, 0, indices_neighs=[1, 2, 3, 4, 5, 6, 7, 8])
1✔
1199
        assert int(op_vals[22] * 1000 + 0.5) == approx(1000)
1✔
1200

1201
        # Cuboctahedral motif.
1202
        op_vals = ops_101.get_order_parameters(self.cuboctahedron, 0, indices_neighs=[i for i in range(1, 13)])
1✔
1203
        assert int(op_vals[24] * 1000 + 0.5) == approx(1000)
1✔
1204
        assert int(op_vals[32] * 1000 + 0.5) == approx(1000)
1✔
1205

1206
        # See-saw motif.
1207
        op_vals = ops_101.get_order_parameters(self.see_saw_rect, 0, indices_neighs=[i for i in range(1, 5)])
1✔
1208
        assert int(op_vals[25] * 1000 + 0.5) == approx(1000)
1✔
1209

1210
        # Hexagonal planar motif.
1211
        op_vals = ops_101.get_order_parameters(self.hexagonal_planar, 0, indices_neighs=[1, 2, 3, 4, 5, 6])
1✔
1212
        assert int(op_vals[26] * 1000 + 0.5) == approx(1000)
1✔
1213

1214
        # Square face capped trigonal prism.
1215
        op_vals = ops_101.get_order_parameters(
1✔
1216
            self.sq_face_capped_trig_pris, 0, indices_neighs=[i for i in range(1, 8)]
1217
        )
1218
        assert int(op_vals[34] * 1000 + 0.5) == approx(1000)
1✔
1219

1220
        # Test providing explicit neighbor lists.
1221
        op_vals = ops_101.get_order_parameters(self.bcc, 0, indices_neighs=[1])
1✔
1222
        assert op_vals[0] is not None
1✔
1223
        assert op_vals[3] is None
1✔
1224
        with pytest.raises(ValueError):
1✔
1225
            ops_101.get_order_parameters(self.bcc, 0, indices_neighs=[2])
1✔
1226

1227
    def tearDown(self):
1✔
1228
        del self.single_bond
1✔
1229
        del self.linear
1✔
1230
        del self.bent45
1✔
1231
        del self.cubic
1✔
1232
        del self.fcc
1✔
1233
        del self.bcc
1✔
1234
        del self.hcp
1✔
1235
        del self.diamond
1✔
1236
        del self.regular_triangle
1✔
1237
        del self.square
1✔
1238
        del self.square_pyramid
1✔
1239
        del self.trigonal_off_plane
1✔
1240
        del self.trigonal_pyramid
1✔
1241
        del self.trigonal_planar
1✔
1242
        del self.square_planar
1✔
1243
        del self.pentagonal_pyramid
1✔
1244
        del self.hexagonal_pyramid
1✔
1245
        del self.pentagonal_bipyramid
1✔
1246
        del self.T_shape
1✔
1247
        del self.cuboctahedron
1✔
1248
        del self.see_saw_rect
1✔
1249

1250

1251
class CrystalNNTest(PymatgenTest):
1✔
1252
    def setUp(self):
1✔
1253
        self.lifepo4 = self.get_structure("LiFePO4")
1✔
1254
        self.lifepo4.add_oxidation_state_by_guess()
1✔
1255
        self.he_bcc = self.get_structure("He_BCC")
1✔
1256
        self.he_bcc.add_oxidation_state_by_guess()
1✔
1257
        self.prev_warnings = warnings.filters
1✔
1258
        warnings.simplefilter("ignore")
1✔
1259

1260
        self.disordered_struct = Structure(
1✔
1261
            Lattice.cubic(3), [{"Fe": 0.4, "C": 0.3, "Mn": 0.3}, "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
1262
        )
1263
        self.disordered_struct_with_majority = Structure(
1✔
1264
            Lattice.cubic(3), [{"Fe": 0.6, "C": 0.4}, "O"], [[0, 0, 0], [0.5, 0.5, 0.5]]
1265
        )
1266

1267
    def tearDown(self):
1✔
1268
        warnings.filters = self.prev_warnings
1✔
1269

1270
    def test_sanity(self):
1✔
1271
        with pytest.raises(ValueError):
1✔
1272
            cnn = CrystalNN()
1✔
1273
            cnn.get_cn(self.lifepo4, 0, use_weights=True)
1✔
1274

1275
        with pytest.raises(ValueError):
1✔
1276
            cnn = CrystalNN(weighted_cn=True)
1✔
1277
            cnn.get_cn(self.lifepo4, 0, use_weights=False)
1✔
1278

1279
    def test_discrete_cn(self):
1✔
1280
        cnn = CrystalNN()
1✔
1281
        cn_array = []
1✔
1282
        expected_array = 8 * [6] + 20 * [4]
1✔
1283
        for idx, _ in enumerate(self.lifepo4):
1✔
1284
            cn_array.append(cnn.get_cn(self.lifepo4, idx))
1✔
1285

1286
        assert cn_array == expected_array
1✔
1287

1288
    def test_weighted_cn(self):
1✔
1289
        cnn = CrystalNN(weighted_cn=True)
1✔
1290
        cn_array = []
1✔
1291

1292
        # fmt: off
1293
        expected_array = [
1✔
1294
            5.863, 5.8716, 5.863, 5.8716, 5.7182, 5.7182, 5.719, 5.7181, 3.991, 3.991, 3.991,
1295
            3.9907, 3.5997, 3.525, 3.4133, 3.4714, 3.4727, 3.4133, 3.525, 3.5997, 3.5997, 3.525,
1296
            3.4122, 3.4738, 3.4728, 3.4109, 3.5259, 3.5997,
1297
        ]
1298
        # fmt: on
1299
        for idx, _ in enumerate(self.lifepo4):
1✔
1300
            cn_array.append(cnn.get_cn(self.lifepo4, idx, use_weights=True))
1✔
1301

1302
        self.assertArrayAlmostEqual(expected_array, cn_array, 2)
1✔
1303

1304
    def test_weighted_cn_no_oxid(self):
1✔
1305
        cnn = CrystalNN(weighted_cn=True)
1✔
1306
        cn_array = []
1✔
1307
        # fmt: off
1308
        expected_array = [
1✔
1309
            5.8962, 5.8996, 5.8962, 5.8996, 5.7195, 5.7195, 5.7202, 5.7194, 4.0012, 4.0012,
1310
            4.0012, 4.0009, 3.3897, 3.2589, 3.1218, 3.1914, 3.1914, 3.1218, 3.2589, 3.3897,
1311
            3.3897, 3.2589, 3.1207, 3.1924, 3.1915, 3.1207, 3.2598, 3.3897,
1312
        ]
1313
        # fmt: on
1314
        s = self.lifepo4.copy()
1✔
1315
        s.remove_oxidation_states()
1✔
1316
        for idx, _ in enumerate(s):
1✔
1317
            cn_array.append(cnn.get_cn(s, idx, use_weights=True))
1✔
1318

1319
        self.assertArrayAlmostEqual(expected_array, cn_array, 2)
1✔
1320

1321
    def test_fixed_length(self):
1✔
1322
        cnn = CrystalNN(fingerprint_length=30)
1✔
1323
        nndata = cnn.get_nn_data(self.lifepo4, 0)
1✔
1324
        assert len(nndata.cn_weights) == 30
1✔
1325
        assert len(nndata.cn_nninfo) == 30
1✔
1326

1327
    def test_cation_anion(self):
1✔
1328
        cnn = CrystalNN(weighted_cn=True, cation_anion=True)
1✔
1329
        assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.8630, abs=1e-2)
1✔
1330

1331
    def test_x_diff_weight(self):
1✔
1332
        cnn = CrystalNN(weighted_cn=True, x_diff_weight=0)
1✔
1333
        assert cnn.get_cn(self.lifepo4, 0, use_weights=True) == approx(5.8630, abs=1e-2)
1✔
1334

1335
    def test_noble_gas_material(self):
1✔
1336
        cnn = CrystalNN()
1✔
1337

1338
        assert cnn.get_cn(self.he_bcc, 0, use_weights=False) == 0
1✔
1339

1340
        cnn = CrystalNN(distance_cutoffs=(1.25, 5))
1✔
1341
        assert cnn.get_cn(self.he_bcc, 0, use_weights=False) == 8
1✔
1342

1343
    def test_shifted_sites(self):
1✔
1344
        cnn = CrystalNN()
1✔
1345

1346
        sites = [[0.0, 0.2, 0.2], [0, 0, 0]]
1✔
1347
        struct = Structure([7, 0, 0, 0, 7, 0, 0, 0, 7], ["I"] * len(sites), sites)
1✔
1348
        bonded_struct = cnn.get_bonded_structure(struct)
1✔
1349

1350
        sites_shifted = [[1.0, 0.2, 0.2], [0, 0, 0]]
1✔
1351
        struct_shifted = Structure([7, 0, 0, 0, 7, 0, 0, 0, 7], ["I"] * len(sites_shifted), sites_shifted)
1✔
1352
        bonded_struct_shifted = cnn.get_bonded_structure(struct_shifted)
1✔
1353

1354
        assert len(bonded_struct.get_connected_sites(0)) == len(bonded_struct_shifted.get_connected_sites(0))
1✔
1355

1356
    def test_get_cn(self):
1✔
1357
        cnn = CrystalNN()
1✔
1358

1359
        site_0_coord_num = cnn.get_cn(self.disordered_struct, 0, on_disorder="take_max_species")
1✔
1360
        site_0_coord_num_strict_majority = cnn.get_cn(
1✔
1361
            self.disordered_struct_with_majority, 0, on_disorder="take_majority_strict"
1362
        )
1363
        assert site_0_coord_num == 8
1✔
1364
        assert site_0_coord_num == site_0_coord_num_strict_majority
1✔
1365

1366
        with pytest.raises(ValueError):
1✔
1367
            cnn.get_cn(self.disordered_struct, 0, on_disorder="take_majority_strict")
1✔
1368
        with pytest.raises(ValueError):
1✔
1369
            cnn.get_cn(self.disordered_struct, 0, on_disorder="error")
1✔
1370

1371
    def test_get_bonded_structure(self):
1✔
1372
        cnn = CrystalNN()
1✔
1373

1374
        structure_graph = cnn.get_bonded_structure(self.disordered_struct, on_disorder="take_max_species")
1✔
1375
        structure_graph_strict_majority = cnn.get_bonded_structure(
1✔
1376
            self.disordered_struct_with_majority, on_disorder="take_majority_strict"
1377
        )
1378
        structure_graph_drop_majority = cnn.get_bonded_structure(
1✔
1379
            self.disordered_struct_with_majority, on_disorder="take_majority_drop"
1380
        )
1381

1382
        assert isinstance(structure_graph, StructureGraph)
1✔
1383
        assert len(structure_graph) == 2
1✔
1384
        assert structure_graph == structure_graph_strict_majority == structure_graph_drop_majority
1✔
1385

1386
        with pytest.raises(ValueError):
1✔
1387
            cnn.get_bonded_structure(self.disordered_struct, 0, on_disorder="take_majority_strict")
1✔
1388
        with pytest.raises(ValueError):
1✔
1389
            cnn.get_bonded_structure(self.disordered_struct, 0, on_disorder="error")
1✔
1390

1391
        with pytest.raises(ValueError):
1✔
1392
            cnn.get_bonded_structure(self.disordered_struct, 0, on_disorder="error")
1✔
1393

1394

1395
class CutOffDictNNTest(PymatgenTest):
1✔
1396
    def setUp(self):
1✔
1397
        self.diamond = Structure(
1✔
1398
            Lattice([[2.189, 0, 1.264], [0.73, 2.064, 1.264], [0, 0, 2.528]]),
1399
            ["C", "C"],
1400
            [[2.554, 1.806, 4.423], [0.365, 0.258, 0.632]],
1401
            coords_are_cartesian=True,
1402
        )
1403
        self.prev_warnings = warnings.filters
1✔
1404
        warnings.simplefilter("ignore")
1✔
1405

1406
    def tearDown(self):
1✔
1407
        warnings.filters = self.prev_warnings
1✔
1408

1409
    def test_cn(self):
1✔
1410
        nn = CutOffDictNN({("C", "C"): 2})
1✔
1411
        assert nn.get_cn(self.diamond, 0) == 4
1✔
1412

1413
        nn_null = CutOffDictNN()
1✔
1414
        assert nn_null.get_cn(self.diamond, 0) == 0
1✔
1415

1416
    def test_from_preset(self):
1✔
1417
        nn = CutOffDictNN.from_preset("vesta_2019")
1✔
1418
        assert nn.get_cn(self.diamond, 0) == 4
1✔
1419

1420
        # test error thrown on unknown preset
1421
        with pytest.raises(ValueError):
1✔
1422
            CutOffDictNN.from_preset("test")
1✔
1423

1424

1425
@unittest.skipIf(not which("critic2"), "critic2 executable not present")
1✔
1426
class Critic2NNTest(PymatgenTest):
1✔
1427
    def setUp(self):
1✔
1428
        self.diamond = Structure(
×
1429
            Lattice([[2.189, 0, 1.264], [0.73, 2.064, 1.264], [0, 0, 2.528]]),
1430
            ["C", "C"],
1431
            [[2.554, 1.806, 4.423], [0.365, 0.258, 0.632]],
1432
            coords_are_cartesian=True,
1433
        )
1434
        self.prev_warnings = warnings.filters
×
1435
        warnings.simplefilter("ignore")
×
1436

1437
    def tearDown(self):
1✔
1438
        warnings.filters = self.prev_warnings
×
1439

1440
    def test_cn(self):
1✔
1441
        Critic2NN()
×
1442
        # self.assertEqual(nn.get_cn(self.diamond, 0), 4)
1443

1444

1445
class MetalEdgeExtenderTest(PymatgenTest):
1✔
1446
    def setUp(self):
1✔
1447
        self.LiEC = Molecule.from_file(os.path.join(test_dir, "LiEC.xyz"))
1✔
1448
        self.LiEC_graph = MoleculeGraph.with_edges(
1✔
1449
            molecule=self.LiEC,
1450
            edges={
1451
                (0, 2): None,
1452
                (0, 1): None,
1453
                (1, 3): None,
1454
                (1, 4): None,
1455
                (2, 7): None,
1456
                (2, 5): None,
1457
                (2, 8): None,
1458
                (3, 6): None,
1459
                (4, 5): None,
1460
                (5, 9): None,
1461
                (5, 10): None,
1462
            },
1463
        )
1464

1465
        # potassium + 7 H2O. 4 at ~2.5 Ang and 3 more within 4.25 Ang
1466
        uncharged_K_cluster = Molecule.from_file(os.path.join(test_dir, "water_cluster_K.xyz"))
1✔
1467
        K_sites = [s.coords for s in uncharged_K_cluster.sites]
1✔
1468
        K_species = [s.species for s in uncharged_K_cluster.sites]
1✔
1469
        charged_K_cluster = Molecule(K_species, K_sites, charge=1)
1✔
1470
        self.water_cluster_K = MoleculeGraph.with_empty_graph(charged_K_cluster)
1✔
1471
        assert len(self.water_cluster_K.graph.edges) == 0
1✔
1472

1473
        # Mg + 6 H2O at 1.94 Ang from Mg
1474
        uncharged_Mg_cluster = Molecule.from_file(os.path.join(test_dir, "water_cluster_Mg.xyz"))
1✔
1475
        Mg_sites = [s.coords for s in uncharged_Mg_cluster.sites]
1✔
1476
        Mg_species = [s.species for s in uncharged_Mg_cluster.sites]
1✔
1477
        charged_Mg_cluster = Molecule(Mg_species, Mg_sites, charge=2)
1✔
1478
        self.water_cluster_Mg = MoleculeGraph.with_empty_graph(charged_Mg_cluster)
1✔
1479

1480
    def test_metal_edge_extender(self):
1✔
1481
        assert len(self.LiEC_graph.graph.edges) == 11
1✔
1482
        extended_mol_graph = metal_edge_extender(self.LiEC_graph)
1✔
1483
        assert len(extended_mol_graph.graph.edges) == 12
1✔
1484

1485
    def test_custom_metals(self):
1✔
1486
        extended_mol_graph = metal_edge_extender(self.LiEC_graph, metals={"K"})
1✔
1487
        assert len(extended_mol_graph.graph.edges) == 11
1✔
1488

1489
        # empty metals should exit cleanly with no change to graph
1490
        mol_graph = metal_edge_extender(self.water_cluster_K, metals={}, cutoff=2.5)
1✔
1491
        assert len(mol_graph.graph.edges) == 0
1✔
1492

1493
        mol_graph = metal_edge_extender(self.water_cluster_K, metals={"K"}, cutoff=2.5)
1✔
1494
        assert len(mol_graph.graph.edges) == 4
1✔
1495

1496
        extended_graph = metal_edge_extender(self.water_cluster_K, metals={"K"}, cutoff=4.5)
1✔
1497
        assert len(extended_graph.graph.edges) == 7
1✔
1498

1499
        # if None, should auto-detect Li
1500
        extended_mol_graph = metal_edge_extender(self.LiEC_graph, metals=None)
1✔
1501
        assert len(extended_mol_graph.graph.edges) == 12
1✔
1502

1503
    def test_custom_coordinators(self):
1✔
1504
        # leave out Oxygen, graph should not change
1505
        extended_mol_graph = metal_edge_extender(self.LiEC_graph, coordinators={"N", "F", "S", "Cl"})
1✔
1506
        assert len(extended_mol_graph.graph.edges) == 11
1✔
1507
        # empty coordinators should exit cleanly with no change
1508
        extended_mol_graph = metal_edge_extender(self.LiEC_graph, coordinators={})
1✔
1509
        assert len(extended_mol_graph.graph.edges) == 11
1✔
1510

1511
    def test_custom_cutoff(self):
1✔
1512
        short_mol_graph = metal_edge_extender(self.LiEC_graph, cutoff=0.5)
1✔
1513
        assert len(short_mol_graph.graph.edges) == 11
1✔
1514

1515
        # with a cutoff of 1.5, no edges should be found.
1516
        # test that the 2nd pass analysis (auto increasing cutoff to 2.5) picks
1517
        # up the six coordination bonds
1518
        short_mol_graph = metal_edge_extender(self.water_cluster_Mg, cutoff=1.5)
1✔
1519
        assert len(short_mol_graph.graph.edges) == 6
1✔
1520

1521

1522
if __name__ == "__main__":
1✔
1523
    unittest.main()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc