• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / MDSuite / 3999396905

pending completion
3999396905

push

github-actions

GitHub
[merge before other PRs] ruff updates (#580)

960 of 1311 branches covered (73.23%)

Branch coverage included in aggregate %.

15 of 15 new or added lines in 11 files covered. (100.0%)

4034 of 4930 relevant lines covered (81.83%)

3.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.57
/mdsuite/calculators/trajectory_calculator.py
1
"""
2
MDSuite: A Zincwarecode package.
3

4
License
5
-------
6
This program and the accompanying materials are made available under the terms
7
of the Eclipse Public License v2.0 which accompanies this distribution, and is
8
available at https://www.eclipse.org/legal/epl-v20.html
9

10
SPDX-License-Identifier: EPL-2.0
11

12
Copyright Contributors to the Zincwarecode Project.
13

14
Contact Information
15
-------------------
16
email: zincwarecode@gmail.com
17
github: https://github.com/zincware
18
web: https://zincwarecode.com/
19

20
Citation
21
--------
22
If you use this module please cite us with:
23

24
Summary
25
-------
26
A parent class for calculators that operate on the trajectory.
27
"""
28
from __future__ import annotations
4✔
29

30
from abc import ABC
4✔
31
from typing import TYPE_CHECKING, List, Union
4✔
32

33
import numpy as np
4✔
34
import tensorflow as tf
4✔
35

36
import mdsuite.database.simulation_database
4✔
37
from mdsuite.calculators.transformations_reference import switcher_transformations
4✔
38
from mdsuite.database.data_manager import DataManager
4✔
39
from mdsuite.database.simulation_database import Database
4✔
40
from mdsuite.memory_management import MemoryManager
4✔
41
from mdsuite.utils.meta_functions import join_path
4✔
42

43
from .calculator import Calculator
4✔
44

45
if TYPE_CHECKING:
4!
46
    from mdsuite import Experiment
×
47

48

49
class TrajectoryCalculator(Calculator, ABC):
4✔
50
    """
51
    Parent class for calculators operating on the trajectory.
52

53
    Attributes
54
    ----------
55
    data_resolution : int
56
            Resolution of the data to be plotted. This is necessary because if someone
57
            wants a data_range of 500 they may not want
58
    loaded_property : tuple
59
            The property being loaded from the simulation database.
60
    dependency : tuple
61
            A dependency required for the analysis to run.
62
    scale_function : dict
63
            The scaling behaviour of the computer. e.g.
64
            {"linear": {"scale_factor": 150}}.  See mdsuite.utils.scale_functions.py for
65
            the list of possible functions.
66
    batch_size : int
67
            Batch size to use. This is the number of configurations that can be loaded
68
            given the complexity and data requirements of the operation.
69
    n_batches : int
70
            Number of batches that can be looped over given the batch size.
71
    remainder : int
72
            The remainder of configurations after the batch process.
73
    minibatch : bool
74
            If true, atom-wise mini-batching will be used.
75
    memory_manager : MemoryManager
76
            Memory manager object to handle computation of batch sizes.
77
    data_manager : DataManager
78
            Data manager parent to handle preparation of data generators.
79
    _database : Database
80
            Simulation database from which data should be loaded.
81
    """
82

83
    def __init__(self, experiment: Experiment = None, experiments: List = None):
4✔
84
        """
85
        Constructor for the TrajectoryCalculator class.
86

87
        Parameters
88
        ----------
89
        experiment : Experiment
90
                Experiment for which the calculator will be run.
91
        experiments : List[Experiment]
92
                List of experiments on which to run the calculator.
93
        """
94
        super(TrajectoryCalculator, self).__init__(
4✔
95
            experiment=experiment, experiments=experiments
96
        )
97

98
        self.data_resolution = None
4✔
99
        self.loaded_property: mdsuite.database.simulation_database.PropertyInfo = None
4✔
100
        self.dependency: mdsuite.database.simulation_database.PropertyInfo = None
4✔
101
        self.scale_function = None
4✔
102
        self.batch_size: int = None
4✔
103
        self.n_batches: int = None
4✔
104
        self.remainder: int = None
4✔
105
        self.minibatch: bool = None
4✔
106
        self.memory_manager = None
4✔
107
        self.data_manager = None
4✔
108
        self._database = None
4✔
109

110
    @property
4✔
111
    def database(self):
3✔
112
        """Get the database based on the experiment database path."""
113
        if self._database is None:
4✔
114
            self._database = Database(self.experiment.database_path / "database.hdf5")
4✔
115
        return self._database
4✔
116

117
    def _run_dependency_check(self):
4✔
118
        """
119
        Check to see if the necessary property exists and build it if required.
120

121
        Returns
122
        -------
123
        Will call transformations if required.
124
        """
125
        if self.loaded_property is None:
4!
126
            return
×
127

128
        if self.dependency is not None:
4!
129
            dependency_exists = self.database.check_existence(self.dependency.name)
×
130
            if not dependency_exists:
×
131
                self._resolve_dependencies(self.dependency)
×
132

133
        loaded_property = self.database.check_existence(self.loaded_property.name)
4✔
134
        if not loaded_property:
4✔
135
            self._resolve_dependencies(self.loaded_property)
4✔
136

137
    def _resolve_dependencies(
4✔
138
        self, dependency: mdsuite.database.simulation_database.PropertyInfo
139
    ):
140
        """
141
        Resolve any calculation dependencies if possible.
142

143
        Parameters
144
        ----------
145
        dependency : str
146
                Name of the dependency to resolve.
147

148
        Returns
149
        -------
150

151
        """
152

153
        def _string_to_function(argument):
4✔
154
            """
155
            Select a transformation based on an input.
156

157
            Parameters
158
            ----------
159
            argument : str
160
                    Name of the transformation required
161

162
            Returns
163
            -------
164
            transformation call.
165
            """
166
            switcher_unwrapping = {"Unwrapped_Positions": self._unwrap_choice()}
4✔
167

168
            # add the other transformations and merge the dictionaries
169
            switcher = {**switcher_unwrapping, **switcher_transformations}
4✔
170

171
            try:
4✔
172
                return switcher[argument]
4✔
173
            except KeyError:
×
174
                raise KeyError("Data not in database and cannot be generated.")
×
175

176
        transformation = getattr(
4✔
177
            self.experiment.run, _string_to_function(dependency.name)
178
        )
179
        transformation()
4✔
180

181
    def _unwrap_choice(self):
4✔
182
        """
183
        Unwrap either with indices or with box arrays.
184

185
        Returns
186
        -------
187
        -------.
188

189
        """
190
        indices = self.database.check_existence("Box_Images")
4✔
191
        if indices:
4✔
192
            return "UnwrapViaIndices"
4✔
193
        else:
194
            return "CoordinateUnwrapper"
4✔
195

196
    def _handle_tau_values(self) -> np.array:
4✔
197
        """
198
        Handle the parsing of custom tau values.
199

200

201
        Returns
202
        -------
203
        times : np.array
204
            The time values corresponding to the selected tau values
205
        """
206
        if isinstance(self.args.tau_values, int):
4!
207
            self.data_resolution = self.args.tau_values
×
208
            self.args.tau_values = np.linspace(
×
209
                0, self.args.data_range - 1, self.args.tau_values, dtype=int
210
            )
211
        if isinstance(self.args.tau_values, list) or isinstance(
4✔
212
            self.args.tau_values, np.ndarray
213
        ):
214
            self.data_resolution = len(self.args.tau_values)
4✔
215
            self.args.data_range = self.args.tau_values[-1] + 1
4✔
216
        if isinstance(self.args.tau_values, slice):
4✔
217
            self.args.tau_values = np.linspace(
4✔
218
                0, self.args.data_range - 1, self.args.data_range, dtype=int
219
            )[self.args.tau_values]
220
            self.data_resolution = len(self.args.tau_values)
4✔
221

222
        times = (
4✔
223
            np.asarray(self.args.tau_values)
224
            * self.experiment.time_step
225
            * self.experiment.sample_rate
226
        )
227

228
        return times
4✔
229

230
    def _check_remainder(self):
4✔
231
        """
232
        Check that the remainder is compatible with the calculator.
233

234
        It may come to pass that the remainder computed by the memory manager is not
235
        divisible by your data range. In this case, it must be clipped such that it is.
236

237
        Returns
238
        -------
239
        Updates the remainder attribute if required.
240
        """
241
        return self.remainder - (self.remainder % self.args.data_range)
4✔
242

243
    def _prepare_managers(self, data_path: list, correct: bool = False):
4✔
244
        """
245
        Prepare the memory and tensor_values monitors for calculation.
246

247
        Parameters
248
        ----------
249
        data_path : list
250
                List of tensor_values paths to load from the hdf5
251
                database_path.
252
        correct : bool
253

254

255
        Returns
256
        -------
257
        Updates the calculator class
258
        """
259
        self.memory_manager = MemoryManager(
4✔
260
            data_path=data_path,
261
            database=self.database,
262
            memory_fraction=0.8,
263
            scale_function=self.scale_function,
264
        )
265
        (
4✔
266
            self.batch_size,
267
            self.n_batches,
268
            self.remainder,
269
        ) = self.memory_manager.get_batch_size()
270
        self.ensemble_loop, self.minibatch = self.memory_manager.get_ensemble_loop(
4✔
271
            self.args.data_range, self.args.correlation_time
272
        )
273

274
        if self.minibatch:
4✔
275
            self.batch_size = self.memory_manager.batch_size
4✔
276
            self.n_batches = self.memory_manager.n_batches
4✔
277
            self.remainder = self.memory_manager.remainder
4✔
278

279
        self._check_remainder()
4✔
280

281
        if correct:
4✔
282
            self._correct_batch_properties()
4✔
283
        self.data_manager = DataManager(
4✔
284
            data_path=data_path,
285
            database=self.database,
286
            data_range=self.args.data_range,
287
            batch_size=self.batch_size,
288
            n_batches=self.n_batches,
289
            ensemble_loop=self.ensemble_loop,
290
            correlation_time=self.args.correlation_time,
291
            remainder=self.remainder,
292
            atom_selection=self.args.atom_selection,
293
            minibatch=self.minibatch,
294
            atom_batch_size=self.memory_manager.atom_batch_size,
295
            n_atom_batches=self.memory_manager.n_atom_batches,
296
            atom_remainder=self.memory_manager.atom_remainder,
297
        )
298

299
    def _correct_batch_properties(self):
4✔
300
        """
301
        Fix batch properties.
302

303
        Notes
304
        -----
305
        This method is called by some calculator
306
        """
307
        raise NotImplementedError
×
308

309
    def get_batch_dataset(
4✔
310
        self,
311
        subject_list: list = None,
312
        loop_array: np.ndarray = None,
313
        correct: bool = False,
314
    ) -> tf.data.Dataset:
315
        """
316
        Collect the batch loop dataset.
317

318
        Parameters
319
        ----------
320
        correct : bool
321
                If true, a calculator specific method is called to correct some
322
                of the batching properties. For example, the RDF code will over-ride
323
                the data range in favour of number of configurations as it does not
324
                require dynamic properties.
325
        subject_list : list (default = None)
326
                A str of subjects to collect data for in case this is necessary.
327
                e.g. subject = ['Na']
328
                     subject = ['Na', 'Cl', 'K']
329
                     subject = ['Ionic_Current']
330
        loop_array : np.ndarray (default = None)
331
                If this is not None, elements of this array will be looped over in
332
                in the batches which load data at their indices. For example,
333
                    loop_array = [[1, 4, 7], [10, 13, 16], [19, 21, 24]]
334
                In this case, in the fist batch, configurations 1, 4, and 7 will be
335
                loaded for the analysis. This is particularly important in the
336
                structural properties.
337

338
        Returns
339
        -------
340
        dataset : tf.data.Dataset
341
                A TensorFlow dataset for the batch loop to be iterated over.
342

343
        """
344
        path_list = [join_path(item, self.loaded_property.name) for item in subject_list]
4✔
345
        self._prepare_managers(path_list, correct=correct)
4✔
346
        type_spec = {}
4✔
347
        for item in subject_list:
4✔
348
            dict_ref = "/".join([item, self.loaded_property.name])
4✔
349
            type_spec[str.encode(dict_ref)] = tf.TensorSpec(
4✔
350
                shape=(None, None, self.loaded_property.n_dims), dtype=self.dtype
351
            )
352
        type_spec[str.encode("data_size")] = tf.TensorSpec(shape=(), dtype=tf.int32)
4✔
353

354
        batch_generator, batch_generator_args = self.data_manager.batch_generator(
4✔
355
            system=self.system_property, loop_array=loop_array
356
        )
357
        ds = tf.data.Dataset.from_generator(
4✔
358
            generator=batch_generator,
359
            args=batch_generator_args,
360
            output_signature=type_spec,
361
        )
362

363
        return ds.prefetch(tf.data.AUTOTUNE)
4✔
364

365
    def get_ensemble_dataset(self, batch: dict, subject: Union[str, list]):
4✔
366
        """
367
        Collect the ensemble loop dataset.
368

369
        Parameters
370
        ----------
371
        subject : str
372
                What object to loop over.
373
        batch : tf.Tensor
374
                A batch of data to be looped over in ensembles.
375

376
        Returns
377
        -------
378
        dataset : tf.data.Dataset
379
                A TensorFlow dataset object for the ensemble loop to be iterated over.
380

381
        """
382
        (
4✔
383
            ensemble_generator,
384
            ensemble_generators_args,
385
        ) = self.data_manager.ensemble_generator(
386
            glob_data=batch, system=self.system_property
387
        )
388

389
        type_spec = {}
4✔
390
        if isinstance(subject, str):
4✔
391
            loop_list = [subject]
4✔
392
        else:
393
            loop_list = subject
4✔
394
        for item in loop_list:
4✔
395
            dict_ref = "/".join([item, self.loaded_property.name])
4✔
396
            type_spec[str.encode(dict_ref)] = tf.TensorSpec(
4✔
397
                shape=(None, None, self.loaded_property.n_dims), dtype=self.dtype
398
            )
399

400
        ds = tf.data.Dataset.from_generator(
4✔
401
            generator=ensemble_generator,
402
            args=ensemble_generators_args,
403
            output_signature=type_spec,
404
        )
405

406
        return ds.prefetch(tf.data.AUTOTUNE)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc