• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

materialsproject / pymatgen / 4075885785

pending completion
4075885785

push

github

Shyue Ping Ong
Merge branch 'master' of github.com:materialsproject/pymatgen

96 of 96 new or added lines in 27 files covered. (100.0%)

81013 of 102710 relevant lines covered (78.88%)

0.79 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

47.94
/pymatgen/analysis/molecule_matcher.py
1
# Copyright (c) Pymatgen Development Team.
2
# Distributed under the terms of the MIT License.
3

4
"""
1✔
5
This module provides classes to perform fitting of molecule with arbitrary
6
atom orders.
7
This module is supposed to perform exact comparisons without the atom order
8
correspondence prerequisite, while molecule_structure_comparator is supposed
9
to do rough comparisons with the atom order correspondence prerequisite.
10

11
The implementation is based on an excellent python package called `rmsd` that
12
you can find at https://github.com/charnley/rmsd.
13
"""
14

15
from __future__ import annotations
1✔
16

17
import abc
1✔
18
import copy
1✔
19
import itertools
1✔
20
import logging
1✔
21
import math
1✔
22
import re
1✔
23

24
import numpy as np
1✔
25
from monty.dev import requires
1✔
26
from monty.json import MSONable
1✔
27

28
try:
1✔
29
    from openbabel import openbabel
1✔
30

31
    from pymatgen.io.babel import BabelMolAdaptor
×
32
except ImportError:
1✔
33
    openbabel = None
1✔
34

35
from scipy.optimize import linear_sum_assignment
1✔
36
from scipy.spatial.distance import cdist
1✔
37

38
from pymatgen.core.structure import Molecule  # pylint: disable=ungrouped-imports
1✔
39

40
__author__ = "Xiaohui Qu, Adam Fekete"
1✔
41
__version__ = "1.0"
1✔
42
__email__ = "xhqu1981@gmail.com"
1✔
43

44
logger = logging.getLogger(__name__)
1✔
45

46

47
class AbstractMolAtomMapper(MSONable, metaclass=abc.ABCMeta):
1✔
48
    """
49
    Abstract molecular atom order mapping class. A mapping will be able to
50
    find the uniform atom order of two molecules that can pair the
51
    geometrically equivalent atoms.
52
    """
53

54
    @abc.abstractmethod
1✔
55
    def uniform_labels(self, mol1, mol2):
1✔
56
        """
57
        Pair the geometrically equivalent atoms of the molecules.
58

59
        Args:
60
            mol1: First molecule. OpenBabel OBMol or pymatgen Molecule object.
61
            mol2: Second molecule. OpenBabel OBMol or pymatgen Molecule object.
62

63
        Returns:
64
            (list1, list2) if uniform atom order is found. list1 and list2
65
            are for mol1 and mol2, respectively. Their length equal
66
            to the number of atoms. They represents the uniform atom order
67
            of the two molecules. The value of each element is the original
68
            atom index in mol1 or mol2 of the current atom in uniform atom
69
            order.
70
            (None, None) if unform atom is not available.
71
        """
72

73
    @abc.abstractmethod
1✔
74
    def get_molecule_hash(self, mol):
1✔
75
        """
76
        Defines a hash for molecules. This allows molecules to be grouped
77
        efficiently for comparison.
78

79
        Args:
80
            mol: The molecule. OpenBabel OBMol or pymatgen Molecule object
81

82
        Returns:
83
            A hashable object. Examples can be string formulas, etc.
84
        """
85

86
    @classmethod
1✔
87
    def from_dict(cls, d):
1✔
88
        """
89
        Args:
90
            d (): Dict
91

92
        Returns:
93
            AbstractMolAtomMapper
94
        """
95
        for trans_modules in ["molecule_matcher"]:
×
96
            level = 0  # Python 3.x
×
97
            mod = __import__(
×
98
                "pymatgen.analysis." + trans_modules,
99
                globals(),
100
                locals(),
101
                [d["@class"]],
102
                level,
103
            )
104
            if hasattr(mod, d["@class"]):
×
105
                class_proxy = getattr(mod, d["@class"])
×
106
                return class_proxy.from_dict(d)
×
107
        raise ValueError("Invalid Comparator dict")
×
108

109

110
class IsomorphismMolAtomMapper(AbstractMolAtomMapper):
1✔
111
    """
112
    Pair atoms by isomorphism permutations in the OpenBabel::OBAlign class
113
    """
114

115
    def uniform_labels(self, mol1, mol2):
1✔
116
        """
117
        Pair the geometrically equivalent atoms of the molecules.
118
        Calculate RMSD on all possible isomorphism mappings and return mapping
119
        with the least RMSD
120

121
        Args:
122
            mol1: First molecule. OpenBabel OBMol or pymatgen Molecule object.
123
            mol2: Second molecule. OpenBabel OBMol or pymatgen Molecule object.
124

125
        Returns:
126
            (list1, list2) if uniform atom order is found. list1 and list2
127
            are for mol1 and mol2, respectively. Their length equal
128
            to the number of atoms. They represents the uniform atom order
129
            of the two molecules. The value of each element is the original
130
            atom index in mol1 or mol2 of the current atom in uniform atom
131
            order.
132
            (None, None) if unform atom is not available.
133
        """
134
        obmol1 = BabelMolAdaptor(mol1).openbabel_mol
×
135
        obmol2 = BabelMolAdaptor(mol2).openbabel_mol
×
136

137
        h1 = self.get_molecule_hash(obmol1)
×
138
        h2 = self.get_molecule_hash(obmol2)
×
139
        if h1 != h2:
×
140
            return None, None
×
141

142
        query = openbabel.CompileMoleculeQuery(obmol1)
×
143
        isomapper = openbabel.OBIsomorphismMapper.GetInstance(query)
×
144
        isomorph = openbabel.vvpairUIntUInt()
×
145
        isomapper.MapAll(obmol2, isomorph)
×
146

147
        sorted_isomorph = [sorted(x, key=lambda morp: morp[0]) for x in isomorph]
×
148
        label2_list = tuple(tuple(p[1] + 1 for p in x) for x in sorted_isomorph)
×
149

150
        vmol1 = obmol1
×
151
        aligner = openbabel.OBAlign(True, False)
×
152
        aligner.SetRefMol(vmol1)
×
153
        least_rmsd = float("Inf")
×
154
        best_label2 = None
×
155
        label1 = list(range(1, obmol1.NumAtoms() + 1))
×
156
        # noinspection PyProtectedMember
157
        elements1 = InchiMolAtomMapper._get_elements(vmol1, label1)
×
158
        for label2 in label2_list:
×
159
            # noinspection PyProtectedMember
160
            elements2 = InchiMolAtomMapper._get_elements(obmol2, label2)
×
161
            if elements1 != elements2:
×
162
                continue
×
163
            vmol2 = openbabel.OBMol()
×
164
            for i in label2:
×
165
                vmol2.AddAtom(obmol2.GetAtom(i))
×
166
            aligner.SetTargetMol(vmol2)
×
167
            aligner.Align()
×
168
            rmsd = aligner.GetRMSD()
×
169
            if rmsd < least_rmsd:
×
170
                least_rmsd = rmsd
×
171
                best_label2 = copy.copy(label2)
×
172
        return label1, best_label2
×
173

174
    def get_molecule_hash(self, mol):
1✔
175
        """
176
        Return inchi as molecular hash
177
        """
178
        obconv = openbabel.OBConversion()
×
179
        obconv.SetOutFormat("inchi")
×
180
        obconv.AddOption("X", openbabel.OBConversion.OUTOPTIONS, "DoNotAddH")
×
181
        inchi_text = obconv.WriteString(mol)
×
182
        match = re.search(r"InChI=(?P<inchi>.+)\n", inchi_text)
×
183
        return match.group("inchi")
×
184

185
    def as_dict(self):
1✔
186
        """
187
        Returns:
188
            Jsonable dict.
189
        """
190
        return {
×
191
            "version": __version__,
192
            "@module": type(self).__module__,
193
            "@class": type(self).__name__,
194
        }
195

196
    @classmethod
1✔
197
    def from_dict(cls, d):
1✔
198
        """
199
        Args:
200
            d (dict): Dict representation
201

202
        Returns:
203
            IsomorphismMolAtomMapper
204
        """
205
        return cls()
×
206

207

208
class InchiMolAtomMapper(AbstractMolAtomMapper):
1✔
209
    """
210
    Pair atoms by inchi labels.
211
    """
212

213
    def __init__(self, angle_tolerance=10.0):
1✔
214
        """
215
        Args:
216
            angle_tolerance (float): Angle threshold to assume linear molecule. In degrees.
217
        """
218
        self._angle_tolerance = angle_tolerance
×
219
        self._assistant_mapper = IsomorphismMolAtomMapper()
×
220

221
    def as_dict(self):
1✔
222
        """
223
        Returns:
224
            MSONAble dict.
225
        """
226
        return {
×
227
            "version": __version__,
228
            "@module": type(self).__module__,
229
            "@class": type(self).__name__,
230
            "angle_tolerance": self._angle_tolerance,
231
        }
232

233
    @classmethod
1✔
234
    def from_dict(cls, d):
1✔
235
        """
236
        Args:
237
            d (dict): Dict Representation
238

239
        Returns:
240
            InchiMolAtomMapper
241
        """
242
        return cls(angle_tolerance=d["angle_tolerance"])
×
243

244
    @staticmethod
1✔
245
    def _inchi_labels(mol):
1✔
246
        """
247
        Get the inchi canonical labels of the heavy atoms in the molecule
248

249
        Args:
250
            mol: The molecule. OpenBabel OBMol object
251

252
        Returns:
253
            The label mappings. List of tuple of canonical label,
254
            original label
255
            List of equivalent atoms.
256
        """
257
        ob_conv = openbabel.OBConversion()
×
258
        ob_conv.SetOutFormat("inchi")
×
259
        ob_conv.AddOption("a", openbabel.OBConversion.OUTOPTIONS)
×
260
        ob_conv.AddOption("X", openbabel.OBConversion.OUTOPTIONS, "DoNotAddH")
×
261
        inchi_text = ob_conv.WriteString(mol)
×
262
        match = re.search(
×
263
            r"InChI=(?P<inchi>.+)\nAuxInfo=.+" r"/N:(?P<labels>[0-9,;]+)/(E:(?P<eq_atoms>[0-9," r";\(\)]*)/)?",
264
            inchi_text,
265
        )
266
        inchi = match.group("inchi")
×
267
        label_text = match.group("labels")
×
268
        eq_atom_text = match.group("eq_atoms")
×
269
        heavy_atom_labels = tuple(int(i) for i in label_text.replace(";", ",").split(","))
×
270
        eq_atoms = []
×
271
        if eq_atom_text is not None:
×
272
            eq_tokens = re.findall(r"\(((?:[0-9]+,)+[0-9]+)\)", eq_atom_text.replace(";", ","))
×
273
            eq_atoms = tuple(tuple(int(i) for i in t.split(",")) for t in eq_tokens)
×
274
        return heavy_atom_labels, eq_atoms, inchi
×
275

276
    @staticmethod
1✔
277
    def _group_centroid(mol, ilabels, group_atoms):
1✔
278
        """
279
        Calculate the centroids of a group atoms indexed by the labels of inchi
280

281
        Args:
282
            mol: The molecule. OpenBabel OBMol object
283
            ilabel: inchi label map
284

285
        Returns:
286
            Centroid. Tuple (x, y, z)
287
        """
288
        c1x, c1y, c1z = 0.0, 0.0, 0.0
×
289
        for i in group_atoms:
×
290
            orig_idx = ilabels[i - 1]
×
291
            oa1 = mol.GetAtom(orig_idx)
×
292
            c1x += float(oa1.x())
×
293
            c1y += float(oa1.y())
×
294
            c1z += float(oa1.z())
×
295
        num_atoms = len(group_atoms)
×
296
        c1x /= num_atoms
×
297
        c1y /= num_atoms
×
298
        c1z /= num_atoms
×
299
        return c1x, c1y, c1z
×
300

301
    def _virtual_molecule(self, mol, ilabels, eq_atoms):
1✔
302
        """
303
        Create a virtual molecule by unique atoms, the centroids of the
304
        equivalent atoms
305

306
        Args:
307
            mol: The molecule. OpenBabel OBMol object
308
            ilabels: inchi label map
309
            eq_atoms: equivalent atom labels
310
            farthest_group_idx: The equivalent atom group index in which
311
                there is the farthest atom to the centroid
312

313
        Return:
314
            The virtual molecule
315
        """
316
        vmol = openbabel.OBMol()
×
317

318
        non_unique_atoms = {a for g in eq_atoms for a in g}
×
319
        all_atoms = set(range(1, len(ilabels) + 1))
×
320
        unique_atom_labels = sorted(all_atoms - non_unique_atoms)
×
321

322
        # try to align molecules using unique atoms
323
        for i in unique_atom_labels:
×
324
            orig_idx = ilabels[i - 1]
×
325
            oa1 = mol.GetAtom(orig_idx)
×
326
            a1 = vmol.NewAtom()
×
327
            a1.SetAtomicNum(oa1.GetAtomicNum())
×
328
            a1.SetVector(oa1.GetVector())
×
329

330
        # try to align using centroids of the equivalent atoms
331
        if vmol.NumAtoms() < 3:
×
332
            for symm in eq_atoms:
×
333
                c1x, c1y, c1z = self._group_centroid(mol, ilabels, symm)
×
334
                min_distance = float("inf")
×
335
                for i in range(1, vmol.NumAtoms() + 1):
×
336
                    va = vmol.GetAtom(i)
×
337
                    distance = math.sqrt((c1x - va.x()) ** 2 + (c1y - va.y()) ** 2 + (c1z - va.z()) ** 2)
×
338
                    if distance < min_distance:
×
339
                        min_distance = distance
×
340
                if min_distance > 0.2:
×
341
                    a1 = vmol.NewAtom()
×
342
                    a1.SetAtomicNum(9)
×
343
                    a1.SetVector(c1x, c1y, c1z)
×
344

345
        return vmol
×
346

347
    @staticmethod
1✔
348
    def _align_heavy_atoms(mol1, mol2, vmol1, vmol2, ilabel1, ilabel2, eq_atoms):
1✔
349
        """
350
        Align the label of topologically identical atoms of second molecule
351
        towards first molecule
352

353
        Args:
354
            mol1: First molecule. OpenBabel OBMol object
355
            mol2: Second molecule. OpenBabel OBMol object
356
            vmol1: First virtual molecule constructed by centroids. OpenBabel
357
                OBMol object
358
            vmol2: First virtual molecule constructed by centroids. OpenBabel
359
                OBMol object
360
            ilabel1: inchi label map of the first molecule
361
            ilabel2: inchi label map of the second molecule
362
            eq_atoms: equivalent atom labels
363

364
        Return:
365
            corrected inchi labels of heavy atoms of the second molecule
366
        """
367
        nvirtual = vmol1.NumAtoms()
×
368
        nheavy = len(ilabel1)
×
369

370
        for i in ilabel2:  # add all heavy atoms
×
371
            a1 = vmol1.NewAtom()
×
372
            a1.SetAtomicNum(1)
×
373
            a1.SetVector(0.0, 0.0, 0.0)  # useless, just to pair with vmol2
×
374
            oa2 = mol2.GetAtom(i)
×
375
            a2 = vmol2.NewAtom()
×
376
            a2.SetAtomicNum(1)
×
377
            # align using the virtual atoms, these atoms are not
378
            # used to align, but match by positions
379
            a2.SetVector(oa2.GetVector())
×
380

381
        aligner = openbabel.OBAlign(False, False)
×
382
        aligner.SetRefMol(vmol1)
×
383
        aligner.SetTargetMol(vmol2)
×
384
        aligner.Align()
×
385
        aligner.UpdateCoords(vmol2)
×
386

387
        canon_mol1 = openbabel.OBMol()
×
388
        for i in ilabel1:
×
389
            oa1 = mol1.GetAtom(i)
×
390
            a1 = canon_mol1.NewAtom()
×
391
            a1.SetAtomicNum(oa1.GetAtomicNum())
×
392
            a1.SetVector(oa1.GetVector())
×
393

394
        aligned_mol2 = openbabel.OBMol()
×
395
        for i in range(nvirtual + 1, nvirtual + nheavy + 1):
×
396
            oa2 = vmol2.GetAtom(i)
×
397
            a2 = aligned_mol2.NewAtom()
×
398
            a2.SetAtomicNum(oa2.GetAtomicNum())
×
399
            a2.SetVector(oa2.GetVector())
×
400

401
        canon_label2 = list(range(1, nheavy + 1))
×
402
        for symm in eq_atoms:
×
403
            for i in symm:
×
404
                canon_label2[i - 1] = -1
×
405
        for symm in eq_atoms:
×
406
            candidates1 = list(symm)
×
407
            candidates2 = list(symm)
×
408
            for c2 in candidates2:
×
409
                distance = 99999.0
×
410
                canon_idx = candidates1[0]
×
411
                a2 = aligned_mol2.GetAtom(c2)
×
412
                for c1 in candidates1:
×
413
                    a1 = canon_mol1.GetAtom(c1)
×
414
                    d = a1.GetDistance(a2)
×
415
                    if d < distance:
×
416
                        distance = d
×
417
                        canon_idx = c1
×
418
                canon_label2[c2 - 1] = canon_idx
×
419
                candidates1.remove(canon_idx)
×
420

421
        canon_inchi_orig_map2 = list(zip(canon_label2, list(range(1, nheavy + 1)), ilabel2))
×
422
        canon_inchi_orig_map2.sort(key=lambda m: m[0])
×
423
        heavy_atom_indices2 = tuple(x[2] for x in canon_inchi_orig_map2)
×
424
        return heavy_atom_indices2
×
425

426
    @staticmethod
1✔
427
    def _align_hydrogen_atoms(mol1, mol2, heavy_indices1, heavy_indices2):
1✔
428
        """
429
        Align the label of topologically identical atoms of second molecule
430
        towards first molecule
431

432
        Args:
433
            mol1: First molecule. OpenBabel OBMol object
434
            mol2: Second molecule. OpenBabel OBMol object
435
            heavy_indices1: inchi label map of the first molecule
436
            heavy_indices2: label map of the second molecule
437

438
        Return:
439
            corrected label map of all atoms of the second molecule
440
        """
441
        num_atoms = mol2.NumAtoms()
×
442
        all_atom = set(range(1, num_atoms + 1))
×
443
        hydrogen_atoms1 = all_atom - set(heavy_indices1)
×
444
        hydrogen_atoms2 = all_atom - set(heavy_indices2)
×
445
        label1 = heavy_indices1 + tuple(hydrogen_atoms1)
×
446
        label2 = heavy_indices2 + tuple(hydrogen_atoms2)
×
447

448
        cmol1 = openbabel.OBMol()
×
449
        for i in label1:
×
450
            oa1 = mol1.GetAtom(i)
×
451
            a1 = cmol1.NewAtom()
×
452
            a1.SetAtomicNum(oa1.GetAtomicNum())
×
453
            a1.SetVector(oa1.GetVector())
×
454
        cmol2 = openbabel.OBMol()
×
455
        for i in label2:
×
456
            oa2 = mol2.GetAtom(i)
×
457
            a2 = cmol2.NewAtom()
×
458
            a2.SetAtomicNum(oa2.GetAtomicNum())
×
459
            a2.SetVector(oa2.GetVector())
×
460

461
        aligner = openbabel.OBAlign(False, False)
×
462
        aligner.SetRefMol(cmol1)
×
463
        aligner.SetTargetMol(cmol2)
×
464
        aligner.Align()
×
465
        aligner.UpdateCoords(cmol2)
×
466

467
        hydrogen_label2 = []
×
468
        hydrogen_label1 = list(range(len(heavy_indices1) + 1, num_atoms + 1))
×
469
        for h2 in range(len(heavy_indices2) + 1, num_atoms + 1):
×
470
            distance = 99999.0
×
471
            idx = hydrogen_label1[0]
×
472
            a2 = cmol2.GetAtom(h2)
×
473
            for h1 in hydrogen_label1:
×
474
                a1 = cmol1.GetAtom(h1)
×
475
                d = a1.GetDistance(a2)
×
476
                if d < distance:
×
477
                    distance = d
×
478
                    idx = h1
×
479
            hydrogen_label2.append(idx)
×
480
            hydrogen_label1.remove(idx)
×
481

482
        hydrogen_orig_idx2 = label2[len(heavy_indices2) :]
×
483
        hydrogen_canon_orig_map2 = list(zip(hydrogen_label2, hydrogen_orig_idx2))
×
484
        hydrogen_canon_orig_map2.sort(key=lambda m: m[0])
×
485
        hydrogen_canon_indices2 = [x[1] for x in hydrogen_canon_orig_map2]
×
486

487
        canon_label1 = label1
×
488
        canon_label2 = heavy_indices2 + tuple(hydrogen_canon_indices2)
×
489

490
        return canon_label1, canon_label2
×
491

492
    @staticmethod
1✔
493
    def _get_elements(mol, label):
1✔
494
        """
495
        The elements of the atoms in the specified order
496

497
        Args:
498
            mol: The molecule. OpenBabel OBMol object.
499
            label: The atom indices. List of integers.
500

501
        Returns:
502
            Elements. List of integers.
503
        """
504
        elements = [int(mol.GetAtom(i).GetAtomicNum()) for i in label]
×
505
        return elements
×
506

507
    def _is_molecule_linear(self, mol):
1✔
508
        """
509
        Is the molecule a linear one
510

511
        Args:
512
            mol: The molecule. OpenBabel OBMol object.
513

514
        Returns:
515
            Boolean value.
516
        """
517
        if mol.NumAtoms() < 3:
×
518
            return True
×
519
        a1 = mol.GetAtom(1)
×
520
        a2 = mol.GetAtom(2)
×
521
        for i in range(3, mol.NumAtoms() + 1):
×
522
            angle = float(mol.GetAtom(i).GetAngle(a2, a1))
×
523
            if angle < 0.0:
×
524
                angle = -angle
×
525
            if angle > 90.0:
×
526
                angle = 180.0 - angle
×
527
            if angle > self._angle_tolerance:
×
528
                return False
×
529
        return True
×
530

531
    def uniform_labels(self, mol1, mol2):
1✔
532
        """
533
        Args:
534
            mol1 (Molecule): Molecule 1
535
            mol2 (Molecule): Molecule 2
536

537
        Returns:
538
            Labels
539
        """
540
        obmol1 = BabelMolAdaptor(mol1).openbabel_mol
×
541
        obmol2 = BabelMolAdaptor(mol2).openbabel_mol
×
542

543
        ilabel1, iequal_atom1, inchi1 = self._inchi_labels(obmol1)
×
544
        ilabel2, iequal_atom2, inchi2 = self._inchi_labels(obmol2)
×
545

546
        if inchi1 != inchi2:
×
547
            return None, None  # Topoligically different
×
548

549
        if iequal_atom1 != iequal_atom2:
×
550
            raise Exception("Design Error! Equavilent atoms are inconsistent")
×
551

552
        vmol1 = self._virtual_molecule(obmol1, ilabel1, iequal_atom1)
×
553
        vmol2 = self._virtual_molecule(obmol2, ilabel2, iequal_atom2)
×
554

555
        if vmol1.NumAtoms() != vmol2.NumAtoms():
×
556
            return None, None
×
557

558
        if vmol1.NumAtoms() < 3 or self._is_molecule_linear(vmol1) or self._is_molecule_linear(vmol2):
×
559
            # using isomorphism for difficult (actually simple) molecules
560
            clabel1, clabel2 = self._assistant_mapper.uniform_labels(mol1, mol2)
×
561
        else:
562
            heavy_atom_indices2 = self._align_heavy_atoms(obmol1, obmol2, vmol1, vmol2, ilabel1, ilabel2, iequal_atom1)
×
563
            clabel1, clabel2 = self._align_hydrogen_atoms(obmol1, obmol2, ilabel1, heavy_atom_indices2)
×
564
        if clabel1 and clabel2:
×
565
            elements1 = self._get_elements(obmol1, clabel1)
×
566
            elements2 = self._get_elements(obmol2, clabel2)
×
567

568
            if elements1 != elements2:
×
569
                return None, None
×
570

571
        return clabel1, clabel2
×
572

573
    def get_molecule_hash(self, mol):
1✔
574
        """
575
        Return inchi as molecular hash
576
        """
577
        obmol = BabelMolAdaptor(mol).openbabel_mol
×
578
        inchi = self._inchi_labels(obmol)[2]
×
579
        return inchi
×
580

581

582
class MoleculeMatcher(MSONable):
1✔
583
    """
584
    Class to match molecules and identify whether molecules are the same.
585
    """
586

587
    @requires(
1✔
588
        openbabel,
589
        "BabelMolAdaptor requires openbabel to be installed with "
590
        "Python bindings. Please get it at http://openbabel.org "
591
        "(version >=3.0.0).",
592
    )
593
    def __init__(self, tolerance: float = 0.01, mapper=None) -> None:
1✔
594
        """
595
        Args:
596
            tolerance (float): RMSD difference threshold whether two molecules are
597
                different
598
            mapper (AbstractMolAtomMapper): MolAtomMapper object that is able to map the atoms of two
599
                molecule to uniform order
600
        """
601
        self._tolerance = tolerance
×
602
        self._mapper = mapper or InchiMolAtomMapper()
×
603

604
    def fit(self, mol1, mol2):
1✔
605
        """
606
        Fit two molecules.
607

608
        Args:
609
            mol1: First molecule. OpenBabel OBMol or pymatgen Molecule object
610
            mol2: Second molecule. OpenBabel OBMol or pymatgen Molecule object
611

612
        Returns:
613
            A boolean value indicates whether two molecules are the same.
614
        """
615
        return self.get_rmsd(mol1, mol2) < self._tolerance
×
616

617
    def get_rmsd(self, mol1, mol2):
1✔
618
        """
619
        Get RMSD between two molecule with arbitrary atom order.
620

621
        Returns:
622
            RMSD if topology of the two molecules are the same
623
            Infinite if  the topology is different
624
        """
625
        label1, label2 = self._mapper.uniform_labels(mol1, mol2)
×
626
        if label1 is None or label2 is None:
×
627
            return float("Inf")
×
628
        return self._calc_rms(mol1, mol2, label1, label2)
×
629

630
    @staticmethod
1✔
631
    def _calc_rms(mol1, mol2, clabel1, clabel2):
1✔
632
        """
633
        Calculate the RMSD.
634

635
        Args:
636
            mol1: The first molecule. OpenBabel OBMol or pymatgen Molecule
637
                object
638
            mol2: The second molecule. OpenBabel OBMol or pymatgen Molecule
639
                object
640
            clabel1: The atom indices that can reorder the first molecule to
641
                uniform atom order
642
            clabel1: The atom indices that can reorder the second molecule to
643
                uniform atom order
644

645
        Returns:
646
            The RMSD.
647
        """
648
        obmol1 = BabelMolAdaptor(mol1).openbabel_mol
×
649
        obmol2 = BabelMolAdaptor(mol2).openbabel_mol
×
650

651
        cmol1 = openbabel.OBMol()
×
652
        for i in clabel1:
×
653
            oa1 = obmol1.GetAtom(i)
×
654
            a1 = cmol1.NewAtom()
×
655
            a1.SetAtomicNum(oa1.GetAtomicNum())
×
656
            a1.SetVector(oa1.GetVector())
×
657
        cmol2 = openbabel.OBMol()
×
658
        for i in clabel2:
×
659
            oa2 = obmol2.GetAtom(i)
×
660
            a2 = cmol2.NewAtom()
×
661
            a2.SetAtomicNum(oa2.GetAtomicNum())
×
662
            a2.SetVector(oa2.GetVector())
×
663

664
        aligner = openbabel.OBAlign(True, False)
×
665
        aligner.SetRefMol(cmol1)
×
666
        aligner.SetTargetMol(cmol2)
×
667
        aligner.Align()
×
668
        return aligner.GetRMSD()
×
669

670
    def group_molecules(self, mol_list):
1✔
671
        """
672
        Group molecules by structural equality.
673

674
        Args:
675
            mol_list: List of OpenBabel OBMol or pymatgen objects
676

677
        Returns:
678
            A list of lists of matched molecules
679
            Assumption: if s1=s2 and s2=s3, then s1=s3
680
            This may not be true for small tolerances.
681
        """
682
        mol_hash = [(i, self._mapper.get_molecule_hash(m)) for i, m in enumerate(mol_list)]
×
683
        mol_hash.sort(key=lambda x: x[1])
×
684

685
        # Use molecular hash to pre-group molecules.
686
        raw_groups = tuple(tuple(m[0] for m in g) for k, g in itertools.groupby(mol_hash, key=lambda x: x[1]))
×
687

688
        group_indices = []
×
689
        for rg in raw_groups:
×
690
            mol_eq_test = [
×
691
                (p[0], p[1], self.fit(mol_list[p[0]], mol_list[p[1]])) for p in itertools.combinations(sorted(rg), 2)
692
            ]
693
            mol_eq = {(p[0], p[1]) for p in mol_eq_test if p[2]}
×
694
            not_alone_mols = set(itertools.chain.from_iterable(mol_eq))
×
695
            alone_mols = set(rg) - not_alone_mols
×
696
            group_indices.extend([[m] for m in alone_mols])
×
697
            while len(not_alone_mols) > 0:
×
698
                current_group = {not_alone_mols.pop()}
×
699
                while len(not_alone_mols) > 0:
×
700
                    candidate_pairs = {tuple(sorted(p)) for p in itertools.product(current_group, not_alone_mols)}
×
701
                    mutual_pairs = candidate_pairs & mol_eq
×
702
                    if len(mutual_pairs) == 0:
×
703
                        break
×
704
                    mutual_mols = set(itertools.chain.from_iterable(mutual_pairs))
×
705
                    current_group |= mutual_mols
×
706
                    not_alone_mols -= mutual_mols
×
707
                group_indices.append(sorted(current_group))
×
708

709
        group_indices.sort(key=lambda x: (len(x), -x[0]), reverse=True)
×
710
        all_groups = [[mol_list[i] for i in g] for g in group_indices]
×
711
        return all_groups
×
712

713
    def as_dict(self):
1✔
714
        """
715
        Returns:
716
            MSONAble dict.
717
        """
718
        return {
×
719
            "version": __version__,
720
            "@module": type(self).__module__,
721
            "@class": type(self).__name__,
722
            "tolerance": self._tolerance,
723
            "mapper": self._mapper.as_dict(),
724
        }
725

726
    @classmethod
1✔
727
    def from_dict(cls, d):
1✔
728
        """
729
        Args:
730
            d (dict): Dict representation
731

732
        Returns:
733
            MoleculeMatcher
734
        """
735
        return cls(
×
736
            tolerance=d["tolerance"],
737
            mapper=AbstractMolAtomMapper.from_dict(d["mapper"]),
738
        )
739

740

741
class KabschMatcher(MSONable):
1✔
742
    """Molecule matcher using Kabsch algorithm
743

744
    The Kabsch algorithm capable aligning two molecules by finding the parameters
745
    (translation, rotation) which minimize the root-mean-square-deviation (RMSD) of
746
    two molecules which are topologically (atom types, geometry) similar two each other.
747

748
    Notes:
749
        When aligning molecules, the atoms of the two molecules **must** be in the same
750
        order for the results to be sensible.
751
    """
752

753
    def __init__(self, target: Molecule):
1✔
754
        """Constructor of the matcher object.
755

756
        Args:
757
            target: a `Molecule` object used as a target during the alignment
758
        """
759
        self.target = target
1✔
760

761
    def match(self, p: Molecule):
1✔
762
        """Using the Kabsch algorithm the alignment of two molecules (P, Q)
763
        happens in three steps:
764
        - translate the P and Q into their centroid
765
        - compute of the optimal rotation matrix (U) using Kabsch algorithm
766
        - compute the translation (V) and rmsd
767

768
        The function returns the rotation matrix (U), translation vector (V),
769
        and RMSD between Q and P', where P' is:
770

771
            P' = P * U + V
772

773
        Args:
774
            p: a `Molecule` object what will be matched with the target one.
775

776
        Returns:
777
            U: Rotation matrix (D,D)
778
            V: Translation vector (D)
779
            RMSD : Root mean squared deviation between P and Q
780
        """
781
        if self.target.atomic_numbers != p.atomic_numbers:
1✔
782
            raise ValueError("The order of the species aren't matching! Please try using `PermInvMatcher`.")
1✔
783

784
        p_coord, q_coord = p.cart_coords, self.target.cart_coords
1✔
785

786
        # Both sets of coordinates must be translated first, so that their
787
        # centroid coincides with the origin of the coordinate system.
788
        p_trans, q_trans = p_coord.mean(axis=0), q_coord.mean(axis=0)
1✔
789
        p_centroid, q_centroid = p_coord - p_trans, q_coord - q_trans
1✔
790

791
        # The optimal rotation matrix U using Kabsch algorithm
792
        U = self.kabsch(p_centroid, q_centroid)
1✔
793

794
        p_prime_centroid = np.dot(p_centroid, U)
1✔
795
        rmsd = np.sqrt(np.mean(np.square(p_prime_centroid - q_centroid)))
1✔
796

797
        V = q_trans - np.dot(p_trans, U)
1✔
798

799
        return U, V, rmsd
1✔
800

801
    def fit(self, p: Molecule):
1✔
802
        """Rotate and transform `p` molecule according to the best match.
803

804
        Args:
805
            p: a `Molecule` object what will be matched with the target one.
806

807
        Returns:
808
            p_prime: Rotated and translated of the `p` `Molecule` object
809
            rmsd: Root-mean-square-deviation between `p_prime` and the `target`
810
        """
811
        U, V, rmsd = self.match(p)
1✔
812

813
        # Rotate and translate matrix `p` onto the target molecule.
814
        # P' = P * U + V
815
        p_prime = p.copy()
1✔
816
        for site in p_prime:
1✔
817
            site.coords = np.dot(site.coords, U) + V
1✔
818

819
        return p_prime, rmsd
1✔
820

821
    @staticmethod
1✔
822
    def kabsch(P: np.ndarray, Q: np.ndarray):
1✔
823
        """The Kabsch algorithm is a method for calculating the optimal rotation matrix
824
        that minimizes the root mean squared deviation (RMSD) between two paired sets of points
825
        P and Q, centered around the their centroid.
826

827
        For more info see:
828
        - http://en.wikipedia.org/wiki/Kabsch_algorithm and
829
        - https://cnx.org/contents/HV-RsdwL@23/Molecular-Distance-Measures
830

831
        Args:
832
            P: Nx3 matrix, where N is the number of points.
833
            Q: Nx3 matrix, where N is the number of points.
834

835
        Returns:
836
            U: 3x3 rotation matrix
837
        """
838
        # Computation of the cross-covariance matrix
839
        C = np.dot(P.T, Q)
1✔
840

841
        # Computation of the optimal rotation matrix
842
        # using singular value decomposition (SVD).
843
        V, S, WT = np.linalg.svd(C)
1✔
844

845
        # Getting the sign of the det(V*Wt) to decide whether
846
        d = np.linalg.det(np.dot(V, WT))
1✔
847

848
        # And finally calculating the optimal rotation matrix R
849
        # we need to correct our rotation matrix to ensure a right-handed coordinate system.
850
        U = np.dot(np.dot(V, np.diag([1, 1, d])), WT)
1✔
851

852
        return U
1✔
853

854

855
class BruteForceOrderMatcher(KabschMatcher):
1✔
856
    """Finding the best match between molecules by selecting molecule order
857
    with the smallest RMSD from all the possible order combinations.
858

859
    Notes:
860
        When aligning molecules, the atoms of the two molecules **must** have same number
861
        of atoms from the same species.
862
    """
863

864
    def match(self, p: Molecule, ignore_warning=False):
1✔
865
        """Similar as `KabschMatcher.match` but this method also finds the order of
866
        atoms which belongs to the best match.
867

868
        A `ValueError` will be raised when the total number of possible combinations
869
        become unfeasible (more than a million combination).
870

871
        Args:
872
            p: a `Molecule` object what will be matched with the target one.
873
            ignore_warning: ignoring error when the number of combination is too large
874

875
        Returns:
876
            inds: The indices of atoms
877
            U: 3x3 rotation matrix
878
            V: Translation vector
879
            rmsd: Root mean squared deviation between P and Q
880
        """
881
        q = self.target
1✔
882

883
        if sorted(p.atomic_numbers) != sorted(q.atomic_numbers):
1✔
884
            raise ValueError("The number of the same species aren't matching!")
1✔
885

886
        _, count = np.unique(p.atomic_numbers, return_counts=True)
1✔
887
        total_permutations = 1
1✔
888
        for c in count:
1✔
889
            total_permutations *= np.math.factorial(c)  # type: ignore
1✔
890

891
        if not ignore_warning and total_permutations > 1_000_000:
1✔
892
            raise ValueError(
1✔
893
                "The number of all possible permutations " f"({total_permutations}) is not feasible to run this method!"
894
            )
895

896
        p_coord, q_coord = p.cart_coords, q.cart_coords
1✔
897
        p_atoms, q_atoms = np.array(p.atomic_numbers), np.array(q.atomic_numbers)
1✔
898

899
        # Both sets of coordinates must be translated first, so that
900
        # their centroid coincides with the origin of the coordinate system.
901
        p_trans, q_trans = p_coord.mean(axis=0), q_coord.mean(axis=0)
1✔
902
        p_centroid, q_centroid = p_coord - p_trans, q_coord - q_trans
1✔
903

904
        # Sort the order of the target molecule by the elements
905
        q_inds = np.argsort(q_atoms)
1✔
906
        q_centroid = q_centroid[q_inds]
1✔
907

908
        # Initializing return values
909
        rmsd = np.inf
1✔
910

911
        # Generate all permutation grouped/sorted by the elements
912
        for p_inds_test in self.permutations(p_atoms):
1✔
913
            p_centroid_test = p_centroid[p_inds_test]
1✔
914
            U_test = self.kabsch(p_centroid_test, q_centroid)
1✔
915

916
            p_centroid_prime_test = np.dot(p_centroid_test, U_test)
1✔
917
            rmsd_test = np.sqrt(np.mean(np.square(p_centroid_prime_test - q_centroid)))
1✔
918

919
            if rmsd_test < rmsd:
1✔
920
                p_inds, U, rmsd = p_inds_test, U_test, rmsd_test
1✔
921

922
        # Rotate and translate matrix P unto matrix Q using Kabsch algorithm.
923
        # P' = P * U + V
924
        V = q_trans - np.dot(p_trans, U)
1✔
925

926
        # Using the original order of the indices
927
        inds = p_inds[np.argsort(q_inds)]
1✔
928

929
        return inds, U, V, rmsd
1✔
930

931
    def fit(self, p: Molecule, ignore_warning=False):
1✔
932
        """Order, rotate and transform `p` molecule according to the best match.
933

934
        A `ValueError` will be raised when the total number of possible combinations
935
        become unfeasible (more than a million combinations).
936

937
        Args:
938
            p: a `Molecule` object what will be matched with the target one.
939
            ignore_warning: ignoring error when the number of combination is too large
940

941
        Returns:
942
            p_prime: Rotated and translated of the `p` `Molecule` object
943
            rmsd: Root-mean-square-deviation between `p_prime` and the `target`
944
        """
945
        inds, U, V, rmsd = self.match(p, ignore_warning=ignore_warning)
1✔
946

947
        p_prime = Molecule.from_sites([p[i] for i in inds])
1✔
948
        for site in p_prime:
1✔
949
            site.coords = np.dot(site.coords, U) + V
1✔
950

951
        return p_prime, rmsd
1✔
952

953
    @staticmethod
1✔
954
    def permutations(atoms):
1✔
955
        """Generates all the possible permutations of atom order. To achieve better
956
        performance all the cases where the atoms are different has been ignored.
957
        """
958
        element_iterators = [itertools.permutations(np.where(atoms == element)[0]) for element in np.unique(atoms)]
1✔
959

960
        for inds in itertools.product(*element_iterators):
1✔
961
            yield np.array(list(itertools.chain(*inds)))
1✔
962

963

964
class HungarianOrderMatcher(KabschMatcher):
1✔
965
    """This method pre-aligns the molecules based on their principal inertia
966
    axis and then re-orders the input atom list using the Hungarian method.
967

968
    Notes:
969
        This method cannot guarantee the best match but is very fast.
970

971
        When aligning molecules, the atoms of the two molecules **must** have same number
972
        of atoms from the same species.
973
    """
974

975
    def match(self, p: Molecule):
1✔
976
        """Similar as `KabschMatcher.match` but this method also finds the order of
977
        atoms which belongs to the best match.
978

979
        Args:
980
            p: a `Molecule` object what will be matched with the target one.
981

982
        Returns:
983
            inds: The indices of atoms
984
            U: 3x3 rotation matrix
985
            V: Translation vector
986
            rmsd: Root mean squared deviation between P and Q
987
        """
988
        if sorted(p.atomic_numbers) != sorted(self.target.atomic_numbers):
1✔
989
            raise ValueError("The number of the same species aren't matching!")
1✔
990

991
        p_coord, q_coord = p.cart_coords, self.target.cart_coords
1✔
992
        p_atoms, q_atoms = (
1✔
993
            np.array(p.atomic_numbers),
994
            np.array(self.target.atomic_numbers),
995
        )
996

997
        p_weights = np.array([site.species.weight for site in p])
1✔
998
        q_weights = np.array([site.species.weight for site in self.target])
1✔
999

1000
        # Both sets of coordinates must be translated first, so that
1001
        # their center of mass with the origin of the coordinate system.
1002
        p_trans, q_trans = p.center_of_mass, self.target.center_of_mass
1✔
1003
        p_centroid, q_centroid = p_coord - p_trans, q_coord - q_trans
1✔
1004

1005
        # Initializing return values
1006
        rmsd = np.inf
1✔
1007

1008
        # Generate all permutation grouped/sorted by the elements
1009
        for p_inds_test in self.permutations(p_atoms, p_centroid, p_weights, q_atoms, q_centroid, q_weights):
1✔
1010
            p_centroid_test = p_centroid[p_inds_test]
1✔
1011
            U_test = self.kabsch(p_centroid_test, q_centroid)
1✔
1012

1013
            p_centroid_prime_test = np.dot(p_centroid_test, U_test)
1✔
1014
            rmsd_test = np.sqrt(np.mean(np.square(p_centroid_prime_test - q_centroid)))
1✔
1015

1016
            if rmsd_test < rmsd:
1✔
1017
                inds, U, rmsd = p_inds_test, U_test, rmsd_test
1✔
1018

1019
        # Rotate and translate matrix P unto matrix Q using Kabsch algorithm.
1020
        # P' = P * U + V
1021
        V = q_trans - np.dot(p_trans, U)
1✔
1022

1023
        return inds, U, V, rmsd
1✔
1024

1025
    def fit(self, p: Molecule):
1✔
1026
        """Order, rotate and transform `p` molecule according to the best match.
1027

1028
        Args:
1029
            p: a `Molecule` object what will be matched with the target one.
1030

1031
        Returns:
1032
            p_prime: Rotated and translated of the `p` `Molecule` object
1033
            rmsd: Root-mean-square-deviation between `p_prime` and the `target`
1034
        """
1035
        inds, U, V, rmsd = self.match(p)
1✔
1036

1037
        # Translate and rotate `mol1` unto `mol2` using Kabsch algorithm.
1038
        p_prime = Molecule.from_sites([p[i] for i in inds])
1✔
1039
        for site in p_prime:
1✔
1040
            site.coords = np.dot(site.coords, U) + V
1✔
1041

1042
        return p_prime, rmsd
1✔
1043

1044
    @staticmethod
1✔
1045
    def permutations(p_atoms, p_centroid, p_weights, q_atoms, q_centroid, q_weights):
1✔
1046
        """Generates two possible permutations of atom order. This method uses the principle component
1047
        of the inertia tensor to prealign the molecules and hungarian method to determine the order.
1048
        There are always two possible permutation depending on the way to pre-aligning the molecules.
1049

1050
        Args:
1051
            p_atoms: atom numbers
1052
            p_centroid: array of atom positions
1053
            p_weights: array of atom weights
1054
            q_atoms: atom numbers
1055
            q_centroid: array of atom positions
1056
            q_weights: array of atom weights
1057

1058
        Yield:
1059
            perm_inds: array of atoms' order
1060
        """
1061
        # get the principal axis of P and Q
1062
        p_axis = HungarianOrderMatcher.get_principal_axis(p_centroid, p_weights)
1✔
1063
        q_axis = HungarianOrderMatcher.get_principal_axis(q_centroid, q_weights)
1✔
1064

1065
        # rotate Q onto P considering that the axis are parallel and antiparallel
1066
        U = HungarianOrderMatcher.rotation_matrix_vectors(q_axis, p_axis)
1✔
1067
        p_centroid_test = np.dot(p_centroid, U)
1✔
1068

1069
        # generate full view from q shape to fill in atom view on the fly
1070
        perm_inds = np.zeros(len(p_atoms), dtype=int)
1✔
1071

1072
        # Find unique atoms
1073
        species = np.unique(p_atoms)
1✔
1074

1075
        for specie in species:
1✔
1076
            p_atom_inds = np.where(p_atoms == specie)[0]
1✔
1077
            q_atom_inds = np.where(q_atoms == specie)[0]
1✔
1078
            A = q_centroid[q_atom_inds]
1✔
1079
            B = p_centroid_test[p_atom_inds]
1✔
1080

1081
            # Perform Hungarian analysis on distance matrix between atoms of 1st
1082
            # structure and trial structure
1083
            distances = cdist(A, B, "euclidean")
1✔
1084
            a_inds, b_inds = linear_sum_assignment(distances)
1✔
1085

1086
            perm_inds[q_atom_inds] = p_atom_inds[b_inds]
1✔
1087

1088
        yield perm_inds
1✔
1089

1090
        # rotate Q onto P considering that the axis are parallel and antiparallel
1091
        U = HungarianOrderMatcher.rotation_matrix_vectors(q_axis, -p_axis)
1✔
1092
        p_centroid_test = np.dot(p_centroid, U)
1✔
1093

1094
        # generate full view from q shape to fill in atom view on the fly
1095
        perm_inds = np.zeros(len(p_atoms), dtype=int)
1✔
1096

1097
        # Find unique atoms
1098
        species = np.unique(p_atoms)
1✔
1099

1100
        for specie in species:
1✔
1101
            p_atom_inds = np.where(p_atoms == specie)[0]
1✔
1102
            q_atom_inds = np.where(q_atoms == specie)[0]
1✔
1103
            A = q_centroid[q_atom_inds]
1✔
1104
            B = p_centroid_test[p_atom_inds]
1✔
1105

1106
            # Perform Hungarian analysis on distance matrix between atoms of 1st
1107
            # structure and trial structure
1108
            distances = cdist(A, B, "euclidean")
1✔
1109
            a_inds, b_inds = linear_sum_assignment(distances)
1✔
1110

1111
            perm_inds[q_atom_inds] = p_atom_inds[b_inds]
1✔
1112

1113
        yield perm_inds
1✔
1114

1115
    @staticmethod
1✔
1116
    def get_principal_axis(coords, weights):
1✔
1117
        """Get the molecule's principal axis.
1118

1119
        Args:
1120
            coords: coordinates of atoms
1121
            weights: the weight use for calculating the inertia tensor
1122

1123
        Returns:
1124
            Array of dim 3 containing the principal axis
1125
        """
1126
        Ixx = Iyy = Izz = Ixy = Ixz = Iyz = 0.0
1✔
1127

1128
        for (x, y, z), wt in zip(coords, weights):
1✔
1129
            Ixx += wt * (y * y + z * z)
1✔
1130
            Iyy += wt * (x * x + z * z)
1✔
1131
            Izz += wt * (x * x + y * y)
1✔
1132

1133
            Ixy += -wt * x * y
1✔
1134
            Ixz += -wt * x * z
1✔
1135
            Iyz += -wt * y * z
1✔
1136

1137
        inertia_tensor = np.array([[Ixx, Ixy, Ixz], [Ixy, Iyy, Iyz], [Ixz, Iyz, Izz]])
1✔
1138

1139
        eigvals, eigvecs = np.linalg.eigh(inertia_tensor)
1✔
1140

1141
        principal_axis = eigvecs[:, 0]
1✔
1142
        return principal_axis
1✔
1143

1144
    @staticmethod
1✔
1145
    def rotation_matrix_vectors(v1, v2):
1✔
1146
        """Returns the rotation matrix that rotates v1 onto v2 using
1147
        Rodrigues' rotation formula.
1148

1149
        See more: https://math.stackexchange.com/a/476311
1150

1151
        Args:
1152
            v1: initial vector
1153
            v2: target vector
1154

1155
        Returns:
1156
            3x3 rotation matrix
1157
        """
1158
        if np.allclose(v1, v2):
1✔
1159
            # same direction
1160
            return np.eye(3)
1✔
1161

1162
        if np.allclose(v1, -v2):
1✔
1163
            # opposite direction: return a rotation of pi around the y-axis
1164
            return np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]])
1✔
1165

1166
        v = np.cross(v1, v2)
1✔
1167
        s = np.linalg.norm(v)
1✔
1168
        c = np.vdot(v1, v2)
1✔
1169

1170
        vx = np.array([[0.0, -v[2], v[1]], [v[2], 0.0, -v[0]], [-v[1], v[0], 0.0]])
1✔
1171

1172
        return np.eye(3) + vx + np.dot(vx, vx) * ((1.0 - c) / (s * s))
1✔
1173

1174

1175
class GeneticOrderMatcher(KabschMatcher):
1✔
1176
    """This method was inspired by genetic algorithms and tries to match molecules
1177
    based on their already matched fragments.
1178

1179
    It uses the fact that when two molecule is matching their sub-structures have to match as well.
1180
    The main idea here is that in each iteration (generation) we can check the match of all possible
1181
    fragments and ignore those which are not feasible.
1182

1183
    Although in the worst case this method has N! complexity (same as the brute force one),
1184
    in practice it performs much faster because many of the combination can be eliminated
1185
    during the fragment matching.
1186

1187
    Notes:
1188
        This method very robust and returns with all the possible orders.
1189

1190
        There is a well known weakness/corner case: The case when there is
1191
        a outlier with large deviation with a small index might be ignored.
1192
        This happens due to the nature of the average function
1193
        used to calculate the RMSD for the fragments.
1194

1195
        When aligning molecules, the atoms of the two molecules **must** have the
1196
        same number of atoms from the same species.
1197
    """
1198

1199
    def __init__(self, target: Molecule, threshold: float):
1✔
1200
        """Constructor of the matcher object.
1201

1202
        Args:
1203
            target: a `Molecule` object used as a target during the alignment
1204
            threshold: value used to match fragments and prune configuration
1205
        """
1206
        super().__init__(target)
1✔
1207
        self.threshold = threshold
1✔
1208
        self.N = len(target)
1✔
1209

1210
    def match(self, p: Molecule):
1✔
1211
        """Similar as `KabschMatcher.match` but this method also finds all of the
1212
        possible atomic orders according to the `threshold`.
1213

1214
        Args:
1215
            p: a `Molecule` object what will be matched with the target one.
1216

1217
        Returns:
1218
            Array of the possible matches where the elements are:
1219
                inds: The indices of atoms
1220
                U: 3x3 rotation matrix
1221
                V: Translation vector
1222
                rmsd: Root mean squared deviation between P and Q
1223
        """
1224
        out = []
1✔
1225
        for inds in self.permutations(p):
1✔
1226
            p_prime = p.copy()
1✔
1227
            p_prime._sites = [p_prime[i] for i in inds]
1✔
1228

1229
            U, V, rmsd = super().match(p_prime)
1✔
1230

1231
            out.append((inds, U, V, rmsd))
1✔
1232

1233
        return out
1✔
1234

1235
    def fit(self, p: Molecule):
1✔
1236
        """Order, rotate and transform all of the matched `p` molecule
1237
        according to the given `threshold`.
1238

1239
        Args:
1240
            p: a `Molecule` object what will be matched with the target one.
1241

1242
        Returns:
1243
            Array of the possible matches where the elements are:
1244
                p_prime: Rotated and translated of the `p` `Molecule` object
1245
                rmsd: Root-mean-square-deviation between `p_prime` and the `target`
1246
        """
1247
        out = []
1✔
1248
        for inds in self.permutations(p):
1✔
1249
            p_prime = p.copy()
1✔
1250
            p_prime._sites = [p_prime[i] for i in inds]
1✔
1251

1252
            U, V, rmsd = super().match(p_prime)
1✔
1253

1254
            # Rotate and translate matrix `p` onto the target molecule.
1255
            # P' = P * U + V
1256
            for site in p_prime:
1✔
1257
                site.coords = np.dot(site.coords, U) + V
1✔
1258

1259
            out.append((p_prime, rmsd))
1✔
1260

1261
        return out
1✔
1262

1263
    def permutations(self, p: Molecule):
1✔
1264
        """Generates all of possible permutations of atom order according the threshold.
1265

1266
        Args:
1267
            p: a `Molecule` object what will be matched with the target one.
1268

1269
        Returns:
1270
            Array of index arrays
1271
        """
1272
        # caching atomic numbers and coordinates
1273
        p_atoms, q_atoms = p.atomic_numbers, self.target.atomic_numbers
1✔
1274
        p_coords, q_coords = p.cart_coords, self.target.cart_coords
1✔
1275

1276
        if sorted(p_atoms) != sorted(q_atoms):
1✔
1277
            raise ValueError("The number of the same species aren't matching!")
1✔
1278

1279
        # starting matches (only based on element)
1280
        partial_matches = [[j] for j in range(self.N) if p_atoms[j] == q_atoms[0]]
1✔
1281

1282
        for idx in range(1, self.N):
1✔
1283
            # extending the target fragment with then next atom
1284
            f_coords = q_coords[: idx + 1]
1✔
1285
            f_atom = q_atoms[idx]
1✔
1286

1287
            f_trans = f_coords.mean(axis=0)
1✔
1288
            f_centroid = f_coords - f_trans
1✔
1289

1290
            matches = []
1✔
1291
            for indices in partial_matches:
1✔
1292
                for jdx in range(self.N):
1✔
1293
                    # skipping if the this index is already matched
1294
                    if jdx in indices:
1✔
1295
                        continue
1✔
1296

1297
                    # skipping if they are different species
1298
                    if p_atoms[jdx] != f_atom:
1✔
1299
                        continue
1✔
1300

1301
                    inds = indices + [jdx]
1✔
1302
                    P = p_coords[inds]
1✔
1303

1304
                    # Both sets of coordinates must be translated first, so that
1305
                    # their centroid coincides with the origin of the coordinate system.
1306
                    p_trans = P.mean(axis=0)
1✔
1307
                    p_centroid = P - p_trans
1✔
1308

1309
                    # The optimal rotation matrix U using Kabsch algorithm
1310
                    U = self.kabsch(p_centroid, f_centroid)
1✔
1311

1312
                    p_prime_centroid = np.dot(p_centroid, U)
1✔
1313
                    rmsd = np.sqrt(np.mean(np.square(p_prime_centroid - f_centroid)))
1✔
1314

1315
                    # rejecting if the deviation is too large
1316
                    if rmsd > self.threshold:
1✔
1317
                        continue
1✔
1318

1319
                    logger.debug(f"match - rmsd: {rmsd}, inds: {inds}")
1✔
1320
                    matches.append(inds)
1✔
1321

1322
            partial_matches = matches
1✔
1323

1324
            logger.info(f"number of atom in the fragment: {idx + 1}, number of possible matches: {len(matches)}")
1✔
1325

1326
        return matches
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc