• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bjmorgan / site-analysis / 22255310399

21 Feb 2026 10:35AM UTC coverage: 97.253% (+0.5%) from 96.762%
22255310399

Pull #34

github

web-flow
Merge 836f05801 into 7d30a56e1
Pull Request #34: Code quality improvements from review

504 of 525 new or added lines in 19 files covered. (96.0%)

3 existing lines in 1 file now uncovered.

1310 of 1347 relevant lines covered (97.25%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.38
/site_analysis/reference_workflow/structure_aligner.py
1
"""Structure alignment tools for comparing and superimposing crystal structures.
2

3
This module provides the StructureAligner class, which finds the optimal
4
translation vector to superimpose one crystal structure onto another. This
5
alignment is important for:
6

7
1. Comparing structures from different sources with different coordinate origins
8
2. Analyzing structural changes while accounting for rigid translations
9
3. Preparing structures for site mapping in reference-based workflows
10

11
The alignment algorithm optimizes a translation vector to minimise distances
12
between corresponding atoms in the two structures, considering periodic
13
boundary conditions. It supports different optimisation metrics (RMSD or
14
maximum atom distance) and can align based on specific atom species.
15

16
This module is a key component of the reference-based workflow for defining
17
sites in one structure based on a template from another structure.
18
"""
19

20
import numpy as np
1✔
21
from pymatgen.core import Structure
1✔
22
from scipy.optimize import minimize
1✔
23
from typing import Any, Callable
1✔
24
from site_analysis.tools import calculate_species_distances
1✔
25

26
class StructureAligner:
1✔
27
    """Aligns crystal structures via translation optimization.
28
    
29
    This class provides methods to align a reference structure to a target structure
30
    by finding the optimal translation vector that minimizes distances between
31
    corresponding atoms, considering periodic boundary conditions.
32
    """
33
    
34
    def align(self, 
1✔
35
        reference: Structure, 
36
        target: Structure, 
37
        species: list[str] | None = None, 
38
        metric: str = 'rmsd', 
39
        tolerance: float = 1e-4,
40
        algorithm: str = 'Nelder-Mead',
41
        minimizer_options: dict[str, Any] | None = None) -> tuple[Structure, np.ndarray, dict[str, float]]:
42
        """Align reference structure to target structure via translation."""
43
        # Validate structures and get species to use
44
        valid_species = self._validate_structures(reference, target, species)
1✔
45
        
46
        # Create objective function
47
        objective_function = self._create_objective_function(reference, target, valid_species, metric)
1✔
48
        
49
        # Run the appropriate optimizer using the dispatcher
50
        translation_vector = self._run_minimizer(algorithm, objective_function, tolerance, minimizer_options)
1✔
51
        
52
        # Apply the translation to get the aligned structure
53
        aligned_structure = self._apply_translation(reference, translation_vector)
1✔
54
        
55
        # Calculate final metrics
56
        species_distances, all_distances = calculate_species_distances(
1✔
57
            aligned_structure, target, species=valid_species)
58
        
59
        metrics = {
1✔
60
            'rmsd': np.sqrt(np.mean(np.array(all_distances)**2)) if all_distances else float('inf'),
61
            'max_dist': np.max(all_distances) if all_distances else float('inf'),
62
            'mean_dist': np.mean(all_distances) if all_distances else float('inf')
63
        }
64
        
65
        return aligned_structure, translation_vector, metrics
1✔
66
        
67
    def _create_objective_function(self, 
1✔
68
        reference: Structure,
69
        target: Structure,
70
        valid_species: list[str],
71
        metric: str) -> Callable[[np.ndarray], float]:
72
        """Create the objective function for optimization.
73
        
74
        Args:
75
            reference: Reference structure
76
            target: Target structure
77
            valid_species: List of species to include in alignment
78
            metric: Metric to optimize ('rmsd' or 'max_dist')
79
            
80
        Returns:
81
            function: The objective function that takes a translation vector and
82
                    returns the distance metric value
83
        """
84
        def objective_function(
1✔
85
            translation_vector: np.ndarray) -> float:
86
            # Ensure translation is in [0,1) range
87
            translation_vector = translation_vector % 1.0
1✔
88
            
89
            # Apply translation to reference coordinates
90
            translated_coords = reference.frac_coords + translation_vector
1✔
91
            translated_coords = translated_coords % 1.0  # Apply PBC
1✔
92
            
93
            # Create a temporary translated structure for distance calculation
94
            temp_structure = reference.copy()
1✔
95
            for i in range(len(temp_structure)):
1✔
96
                temp_structure[i] = temp_structure[i].species, translated_coords[i]
1✔
97
            
98
            # Calculate distances using our helper function
99
            _, all_distances = calculate_species_distances(temp_structure, target, species=valid_species)
1✔
100
            
101
            # Calculate the desired metric
102
            if not all_distances:  # Handle empty distance list
1✔
NEW
103
                return float('inf')
×
104
                
105
            if metric == 'rmsd':
1✔
106
                return float(np.sqrt(np.mean(np.array(all_distances)**2)))
1✔
107
            elif metric == 'max_dist':
1✔
108
                return float(np.max(all_distances))
1✔
109
            else:
NEW
110
                raise ValueError(f"Unknown metric: {metric}")
×
111
        
112
        return objective_function
1✔
113
    
114
    def _validate_structures(self, 
1✔
115
                            reference: Structure, 
116
                            target: Structure, 
117
                            species: list[str] | None = None) -> list[str]:
118
        """Validate that structures can be aligned and determine species to use.
119
        
120
        Args:
121
            reference: Reference structure
122
            target: Target structure
123
            species: list of species to use for alignment
124
            
125
        Returns:
126
            list of species to use for alignment
127
            
128
        Raises:
129
            ValueError: If structures cannot be aligned
130
        """
131
        # Check if species is provided
132
        if species is None:
1✔
133
            # No specific species provided - get all species from reference
134
            ref_species_counts = reference.composition.as_dict()
1✔
135
            target_species_counts = target.composition.as_dict()
1✔
136
            
137
            # Verify compositions match exactly
138
            if ref_species_counts != target_species_counts:
1✔
139
                raise ValueError(
1✔
140
                    f"Structures have different compositions: "
141
                    f"{reference.composition.formula} vs {target.composition.formula}"
142
                )
143
            
144
            # Use all species from reference
145
            species_to_use = list(ref_species_counts.keys())
1✔
146
        else:
147
            species_to_use = species
1✔
148
        
149
        # Validate each species has matching counts
150
        for sp in species_to_use:
1✔
151
            ref_sp_indices = reference.indices_from_symbol(sp)
1✔
152
            target_sp_indices = target.indices_from_symbol(sp)
1✔
153
            
154
            if not ref_sp_indices:
1✔
155
                raise ValueError(f"Species {sp} not found in reference structure")
1✔
156
            if not target_sp_indices:
1✔
157
                raise ValueError(f"Species {sp} not found in target structure")
1✔
158
            
159
            # Check if we have the same number of atoms for this species
160
            if len(ref_sp_indices) != len(target_sp_indices):
1✔
161
                raise ValueError(
1✔
162
                    f"Different number of {sp} atoms: "
163
                    f"{len(ref_sp_indices)} in reference vs "
164
                    f"{len(target_sp_indices)} in target"
165
                )
166
        
167
        return species_to_use
1✔
168
    
169
    def _translate_coords(self, 
1✔
170
                         coords: np.ndarray, 
171
                         translation_vector: np.ndarray) -> np.ndarray:
172
        """Apply translation to coordinates.
173
        
174
        Args:
175
            coords: Fractional coordinates to translate
176
            translation_vector: Translation vector to apply
177
            
178
        Returns:
179
            Translated coordinates
180
        """
NEW
181
        translated = coords + translation_vector
×
182
        # Ensure coordinates are within [0, 1)
NEW
183
        return np.array(translated % 1.0)
×
184
    
185
    def _apply_translation(self, 
1✔
186
                          structure: Structure, 
187
                          translation_vector: np.ndarray) -> Structure:
188
        """Apply translation to entire structure.
189
        
190
        Args:
191
            structure: Structure to translate
192
            translation_vector: Translation vector to apply
193
            
194
        Returns:
195
            Translated structure
196
        """
197
        # Create a copy of the structure
198
        new_structure = structure.copy()
1✔
199
        
200
        # Apply translation to all sites
201
        for i, site in enumerate(new_structure):
1✔
202
            frac_coords = site.frac_coords + translation_vector
1✔
203
            # Ensure coordinates are within [0, 1)
204
            frac_coords = frac_coords % 1.0
1✔
205
            new_structure[i] = site.species, frac_coords
1✔
206
        
207
        return new_structure
1✔
208
    
209
    def _run_minimizer(self,
1✔
210
        algorithm: str, 
211
        objective_function: Callable[[np.ndarray], float],
212
        tolerance: float,
213
        minimizer_options: dict[str, Any] | None = None) -> np.ndarray:
214
        """Run the selected minimization algorithm.
215
        
216
        Args:
217
            algorithm: Name of the algorithm to run
218
            objective_function: Function to minimize
219
            tolerance: Convergence tolerance
220
            minimizer_options: Additional options for the minimizer
221
            
222
        Returns:
223
            np.ndarray: Optimal translation vector
224
            
225
        Raises:
226
            ValueError: If the algorithm is not supported
227
        """
228
        # Get the algorithm registry
229
        algorithm_registry = self._get_algorithm_registry()
1✔
230
        
231
        # Check if algorithm is supported
232
        if algorithm not in algorithm_registry:
1✔
233
            raise ValueError(f"Unsupported algorithm: {algorithm}. "
1✔
234
                            f"Supported algorithms: {', '.join(algorithm_registry.keys())}")
235
        
236
        # Get the appropriate implementation method
237
        run_algorithm = algorithm_registry[algorithm]
1✔
238
        
239
        # Call the selected algorithm implementation
240
        return run_algorithm(objective_function, tolerance, minimizer_options)
1✔
241
        
242
    def _get_algorithm_registry(self) -> dict[str, 
1✔
243
                                                Callable[
244
                                                    [Callable[[np.ndarray], float],
245
                                                    float,
246
                                                    dict[str, Any] | None
247
                                                ], np.ndarray]]:
248
        """Get the registry of supported optimization algorithms.
249
        
250
        Returns:
251
            dict: Dictionary mapping algorithm names to implementation methods
252
        """
253
        return {
1✔
254
            'Nelder-Mead': self._run_nelder_mead,
255
            'differential_evolution': self._run_differential_evolution
256
        }
257
        
258
    def _run_nelder_mead(self, 
1✔
259
                    objective_function: Callable[[np.ndarray], float], 
260
                    tolerance: float, 
261
                    minimizer_options: dict[str, Any] | None = None) -> np.ndarray:
262
        """Run Nelder-Mead optimization.
263
        
264
        Args:
265
            objective_function: Function to minimize
266
            tolerance: Convergence tolerance
267
            minimizer_options: Additional options for the minimizer
268
            
269
        Returns:
270
            np.ndarray: Optimised translation vector
271
            
272
        Raises:
273
            ValueError: If optimization fails
274
        """
275
        from scipy.optimize import minimize
1✔
276
        
277
        # Ensure minimizer_options is a dictionary
278
        minimizer_options = minimizer_options or {}
1✔
279
        
280
        # Default options - ensure they exactly match the original implementation
281
        options: dict[str, Any] = {
1✔
282
            'xatol': tolerance,
283
            'fatol': tolerance
284
        }
285
        
286
        # Update with user-provided options
287
        options.update(minimizer_options)
1✔
288
        
289
        # Run optimisation
290
        result = minimize(
1✔
291
            objective_function,
292
            x0=np.array([0, 0, 0]),  # Start with zero translation
293
            method='Nelder-Mead',
294
            options=options
295
        )
296
        
297
        if not result.success:
1✔
NEW
298
            raise ValueError(f"Optimization failed: {result.message}")
×
299
        
300
        # Ensure in [0,1) range
301
        return np.array(result.x) % 1.0
1✔
302
        
303
    def _run_differential_evolution(self,
1✔
304
            objective_function: Callable[[np.ndarray], float],
305
            tolerance: float,
306
            minimizer_options: dict[str, Any] | None = None) -> np.ndarray:
307
        """Run differential evolution optimization.
308
        
309
        Args:
310
            objective_function: Function to minimize
311
            tolerance: Convergence tolerance
312
            minimizer_options: Additional options for the minimizer
313
            
314
        Returns:
315
            np.ndarray: Optimal translation vector
316
        """
317
        from scipy.optimize import differential_evolution
1✔
318
        
319
        # Default options for differential evolution
320
        options = {
1✔
321
            'tol': tolerance,
322
            'popsize': 15,
323
            'maxiter': 1000,
324
            'strategy': 'best1bin',
325
            'updating': 'immediate',
326
            'workers': 1  # Default to single process for compatibility
327
        }
328
        
329
        # Bounds for translation vector (all components in [0,1))
330
        bounds = [(0, 1), (0, 1), (0, 1)]
1✔
331
        
332
        # Update with user-provided options
333
        if minimizer_options:
1✔
334
            options.update(minimizer_options)
1✔
335
        
336
        # Extract bounds if provided in options
337
        if minimizer_options and 'bounds' in minimizer_options:
1✔
338
            bounds = minimizer_options['bounds']
1✔
339
            options.pop('bounds')
1✔
340
            
341
        # Run optimization
342
        result = differential_evolution(
1✔
343
            objective_function,
344
            bounds=bounds,
345
            **options
346
        )
347
        
348
        if not result.success:
1✔
349
            raise ValueError(f"Differential evolution optimization failed: {result.message}")
1✔
350
        
351
        return np.array(result.x) % 1.0  # Ensure in [0,1) range
1✔
352
        
353
    
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc