• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / MDSuite / 4992945232

pending completion
4992945232

push

github-actions

PythonFZ
Merge branch 'pre-commit' of https://github.com/zincware/MDSuite into pre-commit

974 of 1318 branches covered (73.9%)

Branch coverage included in aggregate %.

2 of 2 new or added lines in 2 files covered. (100.0%)

4123 of 4974 relevant lines covered (82.89%)

2.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

29.41
/mdsuite/database/data_manager.py
1
"""
2
MDSuite: A Zincwarecode package.
3

4
License
5
-------
6
This program and the accompanying materials are made available under the terms
7
of the Eclipse Public License v2.0 which accompanies this distribution, and is
8
available at https://www.eclipse.org/legal/epl-v20.html
9

10
SPDX-License-Identifier: EPL-2.0
11

12
Copyright Contributors to the Zincwarecode Project.
13

14
Contact Information
15
-------------------
16
email: zincwarecode@gmail.com
17
github: https://github.com/zincware
18
web: https://zincwarecode.com/
19

20
Citation
21
--------
22
If you use this module please cite us with:
23

24
Summary
25
-------
26
Module for the data manager. The data manager handles loading of data as TensorFlow
27
generators. These generators allow for the full use of the TF data pipelines but can
28
required special formatting rules.
29
"""
30
import logging
3✔
31

32
import numpy as np
3✔
33
import tensorflow as tf
3✔
34
from tqdm import tqdm
3✔
35

36
from mdsuite.database.simulation_database import Database
3✔
37

38
log = logging.getLogger(__name__)
3✔
39

40

41
class DataManager:
3✔
42
    """
43
    Class for the MDS tensor_values fetcher.
44

45
    Due to the amount of tensor_values that needs to be collected and the possibility
46
    to optimize repeated loading, a separate tensor_values fetching class is required.
47
    This class manages how tensor_values is loaded from the MDS database_path and
48
    optimizes processes such as pre-loading and parallel reading.
49
    """
50

51
    def __init__(
3✔
52
        self,
53
        database: Database = None,
54
        data_path: list = None,
55
        data_range: int = None,
56
        n_batches: int = None,
57
        batch_size: int = None,
58
        ensemble_loop: int = None,
59
        correlation_time: int = 1,
60
        remainder: int = None,
61
        atom_selection=np.s_[:],
62
        minibatch: bool = False,
63
        atom_batch_size: int = None,
64
        n_atom_batches: int = None,
65
        atom_remainder: int = None,
66
        offset: int = 0,
67
    ):
68
        """
69
        Constructor for the DataManager class.
70

71
        Parameters
72
        ----------
73
        database : Database
74
                Database object from which tensor_values should be loaded
75
        data_path : list
76
                Path in the HDF5 database to be loaded.
77
        data_range : int
78
                Data range used in the calculator.
79
        n_batches : int
80
                Number of batches required.
81
        batch_size : int
82
                Size of a batch.
83
        ensemble_loop : int
84
                Number of ensembles to be looped over.
85
        correlation_time : int
86
                Correlation time used in the calculator.
87
        remainder : int
88
                Remainder used in the batching.
89
        atom_remainder : int
90
                Atom-wise remainder used in the atom-wise batching.
91
        minibatch : bool
92
                If true, atom-wise batching is required.
93
        atom_batch_size : int
94
                Size of an atom-wise batch.
95
        n_atom_batches : int
96
                Number of atom-wise batches.
97
        atom_selection : int
98
                Selection of atoms in the calculation.
99
        offset : int
100
                Offset in the data loading if it should not be loaded from the start.
101
        """
102
        self.database = database
3✔
103
        self.data_path = data_path
3✔
104
        self.minibatch = minibatch
3✔
105
        self.atom_batch_size = atom_batch_size
3✔
106
        self.n_atom_batches = n_atom_batches
3✔
107
        self.atom_remainder = atom_remainder
3✔
108
        self.offset = offset
3✔
109

110
        self.data_range = data_range
3✔
111
        self.n_batches = n_batches
3✔
112
        self.batch_size = batch_size
3✔
113
        self.remainder = remainder
3✔
114
        self.ensemble_loop = ensemble_loop
3✔
115
        self.correlation_time = correlation_time
3✔
116
        self.atom_selection = atom_selection
3✔
117

118
    def batch_generator(  # noqa: C901
3✔
119
        self,
120
        dictionary: bool = False,
121
        system: bool = False,
122
        remainder: bool = False,
123
        loop_array: np.ndarray = None,
124
    ) -> tuple:
125
        """
126
        Build a generator object for the batch loop.
127

128
        Parameters
129
        ----------
130
        dictionary : bool
131
                If true return a dict. This is default now and could be removed.
132
        system : bool
133
                If true, a system parameter is being called for.
134
        remainder : bool
135
                If true, a remainder batch must be computed.
136
        loop_array : np.ndarray
137
                If this is not None, elements of this array will be looped over in
138
                in the batches which load data at their indices. For example,
139
                    loop_array = [[1, 4, 7], [10, 13, 16], [19, 21, 24]]
140
                In this case, in the fist batch, configurations 1, 4, and 7 will be
141
                loaded for the analysis. This is particularly important in the
142
                structural properties.
143

144
        Returns
145
        -------
146
        Returns a generator function and its arguments
147
        """
148
        args = (
3✔
149
            self.n_batches,
150
            self.batch_size,
151
            self.database.path,
152
            self.data_path,
153
            dictionary,
154
        )
155

156
        def generator(
3✔
157
            batch_number: int,
158
            batch_size: int,
159
            database: str,
160
            data_path: list,
161
            dictionary: bool,
162
        ):
163
            """
164
            Generator function for the batch loop.
165

166
            Parameters
167
            ----------
168
            batch_number : int
169
                    Number of batches to be looped over
170
            batch_size : int
171
                    size of each batch to load
172
            database : Database
173
                    database_path from which to load the tensor_values
174
            data_path : str
175
                    Path to the tensor_values in the database_path
176
            dictionary : bool
177
                    If true, tensor_values is returned in a dictionary
178
            Returns
179
            -------
180
            """
181
            database = Database(database)
×
182

183
            loop_over_remainder = self.remainder > 0
×
184

185
            for batch in range(batch_number + int(loop_over_remainder)):
×
186
                start = int(batch * batch_size) + self.offset
×
187
                stop = int(start + batch_size)
×
188
                data_size = tf.cast(batch_size, dtype=tf.int32)
×
189
                # Handle the remainder
190
                if batch == batch_number:
×
191
                    stop = int(start + self.remainder)
×
192
                    data_size = tf.cast(self.remainder, dtype=tf.int16)
×
193
                    # TODO make default
194

195
                if loop_array is not None:
×
196
                    if isinstance(self.atom_selection, dict):
×
197
                        select_slice = {}
×
198
                        for item in self.atom_selection:
×
199
                            select_slice[item] = np.s_[
×
200
                                self.atom_selection[item], loop_array[batch]
201
                            ]
202
                    else:
203
                        select_slice = np.s_[self.atom_selection, loop_array[batch]]
×
204
                elif system:
×
205
                    select_slice = np.s_[start:stop]
×
206
                else:
207
                    if type(self.atom_selection) is dict:
×
208
                        select_slice = {}
×
209
                        for item in self.atom_selection:
×
210
                            select_slice[item] = np.s_[
×
211
                                self.atom_selection[item], start:stop
212
                            ]
213
                    else:
214
                        select_slice = np.s_[self.atom_selection, start:stop]
×
215

216
                yield database.load_data(
×
217
                    data_path,
218
                    select_slice=select_slice,
219
                    dictionary=dictionary,
220
                    d_size=data_size,
221
                )
222

223
        def atom_generator(
3✔
224
            batch_number: int,
225
            batch_size: int,
226
            database: str,
227
            data_path: list,
228
            dictionary: bool,
229
        ):
230
            """
231
            Generator function for a mini-batched calculation.
232

233
            Parameters
234
            ----------
235
            batch_number : int
236
                    Number of batches to be looped over
237
            batch_size : int
238
                    size of each batch to load
239
            database : Database
240
                    database_path from which to load the tensor_values
241
            data_path : str
242
                    Path to the tensor_values in the database_path
243
            dictionary : bool
244
                    If true, tensor_values is returned in a dictionary
245
            Returns
246
            -------
247
            """
248
            # Atom selection not currently available for mini-batched calculations
249
            if type(self.atom_selection) is dict:
×
250
                raise ValueError(
×
251
                    "Atom selection is not currently available "
252
                    "for mini-batched calculations"
253
                )
254

255
            database = Database(database)
×
256
            _atom_remainder = [1 if self.atom_remainder else 0][0]
×
257
            start = 0
×
258
            for atom_batch in tqdm(
×
259
                range(self.n_atom_batches + _atom_remainder),
260
                total=self.n_atom_batches + _atom_remainder,
261
                ncols=70,
262
                desc="batch loop",
263
            ):
264
                atom_start = atom_batch * self.atom_batch_size
×
265
                atom_stop = atom_start + self.atom_batch_size
×
266
                if atom_batch == self.n_atom_batches:
×
267
                    atom_stop = start + self.atom_remainder
×
268
                for batch in range(batch_number + int(remainder)):
×
269
                    start = int(batch * batch_size) + self.offset
×
270
                    stop = int(start + batch_size)
×
271
                    data_size = tf.cast(batch_size, dtype=tf.int32)
×
272
                    if batch == batch_number:
×
273
                        stop = int(start + self.remainder)
×
274
                        data_size = tf.cast(self.remainder, dtype=tf.int16)
×
275
                    select_slice = np.s_[int(atom_start) : int(atom_stop), start:stop]
×
276
                    yield database.load_data(
×
277
                        data_path,
278
                        select_slice=select_slice,
279
                        dictionary=dictionary,
280
                        d_size=data_size,
281
                    )
282

283
        if self.minibatch:
3✔
284
            return atom_generator, args
3✔
285
        else:
286
            return generator, args
3✔
287

288
    def ensemble_generator(self, system: bool = False, glob_data: dict = None) -> tuple:
3✔
289
        """
290
        Build a generator for the ensemble loop.
291

292
        Parameters
293
        ----------
294
        system : bool
295
                If true, the system generator is returned.
296
        glob_data : dict
297
                data to be loaded in ensembles from a tensorflow generator.
298
                e.g. {b'Na/Positions': tf.Tensor}.
299
                Will usually include a b'data_size' key which is checked in the
300
                loop and ignored. All keys are in byte arrays. This appears when you
301
                pass a dict to the tensorflow generator.
302

303
        Returns
304
        -------
305
        Ensemble loop generator
306
        """
307
        args = (self.ensemble_loop, self.correlation_time, self.data_range)
3✔
308

309
        def dictionary_generator(ensemble_loop, correlation_time, data_range):
3✔
310
            """
311
            Generator for the ensemble loop
312
            Parameters
313
            ----------
314
            ensemble_loop : int
315
                    Number of ensembles to loop over
316
            correlation_time : int
317
                    Distance between ensembles
318
            data_range : int
319
                    Size of each ensemble
320
            Returns
321
            -------
322
            None.
323
            """
324
            ensemble_loop = int(
×
325
                np.clip(
326
                    (glob_data[b"data_size"] - data_range) / correlation_time, 1, None
327
                )
328
            )
329
            for ensemble in range(ensemble_loop):
×
330
                start = ensemble * correlation_time
×
331
                stop = start + data_range
×
332
                output_dict = {}
×
333
                for item in glob_data:
×
334
                    if item == str.encode("data_size"):
×
335
                        pass
×
336
                    else:
337
                        output_dict[item] = glob_data[item][:, start:stop]
×
338

339
                yield output_dict
×
340

341
        return dictionary_generator, args
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc