• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / MDSuite / 3999396905

pending completion
3999396905

push

github-actions

GitHub
[merge before other PRs] ruff updates (#580)

960 of 1311 branches covered (73.23%)

Branch coverage included in aggregate %.

15 of 15 new or added lines in 11 files covered. (100.0%)

4034 of 4930 relevant lines covered (81.83%)

3.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.58
/mdsuite/database/simulation_database.py
1
"""
2
MDSuite: A Zincwarecode package.
3

4
License
5
-------
6
This program and the accompanying materials are made available under the terms
7
of the Eclipse Public License v2.0 which accompanies this distribution, and is
8
available at https://www.eclipse.org/legal/epl-v20.html
9

10
SPDX-License-Identifier: EPL-2.0
11

12
Copyright Contributors to the Zincwarecode Project.
13

14
Contact Information
15
-------------------
16
email: zincwarecode@gmail.com
17
github: https://github.com/zincware
18
web: https://zincwarecode.com/
19

20
Citation
21
--------
22
If you use this module please cite us with:
23

24
Summary
25
-------
26
"""
27
import dataclasses
4✔
28
import logging
4✔
29
import pathlib
4✔
30
import time
4✔
31
import typing
4✔
32
from typing import List
4✔
33

34
import h5py as hf
4✔
35
import numpy as np
4✔
36
import tensorflow as tf
4✔
37

38
from mdsuite.utils.meta_functions import join_path
4✔
39

40
log = logging.getLogger(__name__)
4✔
41

42

43
@dataclasses.dataclass(frozen=True)
4✔
44
class PropertyInfo:
3✔
45
    """
46
    Information of a trajectory property.
47
    example:
48
    pos_info = PropertyInfo('Positions', 3)
49
    vel_info = PropertyInfo('Velocities', 3).
50

51
    Attributes
52
    ----------
53
    name:
54
        The name of the property
55
    n_dims:
56
        The dimensionality of the property
57
    """
58

59
    name: str
4✔
60
    n_dims: int
4✔
61

62

63
@dataclasses.dataclass
4✔
64
class SpeciesInfo:
3✔
65
    """
66
    Information of a species.
67

68
    Attributes
69
    ----------
70
    name
71
        Name of the species (e.g. 'Na')
72
    n_particles
73
        Number of particles of that species
74
    properties: list of PropertyInfo
75
        List of the properties that were recorded for the species
76
        mass and charge are optional
77
    """
78

79
    name: str
4✔
80
    n_particles: int
4✔
81
    properties: List[PropertyInfo]
4✔
82
    mass: float = None
4✔
83
    charge: float = 0
4✔
84

85
    def __eq__(self, other):
4✔
86
        same = (
4✔
87
            self.name == other.name
88
            and self.n_particles == other.n_particles
89
            and self.mass == other.mass
90
            and self.charge == other.charge
91
        )
92
        if len(self.properties) != len(other.properties):
4!
93
            return False
×
94

95
        for prop_s, prop_o in zip(self.properties, other.properties):
4✔
96
            same = same and prop_s == prop_o
4✔
97
        return same
4✔
98

99

100
@dataclasses.dataclass
4✔
101
class MoleculeInfo(SpeciesInfo):
4✔
102
    """Information about a Molecule.
103

104
    All the information of a species + groups
105

106
    Attributes
107
    ----------
108
    groups: dict
109
        A molecule specific dictionary for mapping the molecule to the
110
        particles. The keys of this dict are index references to a specific molecule,
111
        i.e. molecule 1 and the values are a dict of atom species and their indices
112
        belonging to that specific molecule.
113
        e.g
114
            water = {"groups": {"0": {"H": [0, 1], "O": [0]}}
115
        This tells us that the 0th water molecule consists of the 0th and 1st hydrogen
116
        atoms in the database as well as the 0th oxygen atom.
117
    """
118

119
    groups: dict = None
4✔
120

121
    def __eq__(self, other):
4✔
122
        """Add a check to see if the groups are identical."""
123
        if self.groups != other.groups:
4!
124
            return False
×
125
        return super(MoleculeInfo, self).__eq__(other)
4✔
126

127

128
@dataclasses.dataclass
4✔
129
class TrajectoryMetadata:
3✔
130
    """Trajectory Metadata container.
131

132
    This metadata must be extracted from trajectory files to build the database into
133
    which the trajectory will be stored.
134

135
    Attributes
136
    ----------
137
    n_configurations : int
138
        Number of configurations of the whole trajectory.
139
    species_list: list of SpeciesInfo
140
        The information about all species in the system.
141
    box_l: list of float
142
        The simulation box size in three dimensions
143
    sample_rate : int optional
144
        The number of timesteps between consecutive samples
145
        # todo remove in favour of sample_step
146
    sample_step : int optional
147
        The time between consecutive configurations.
148
        E.g. for a simulation with time step 0.1 where the trajectory is written
149
        every 5 steps: sample_step = 0.5. Does not have to be specified
150
        (e.g. configurations from Monte Carlo scheme), but is needed for all
151
        dynamic observables.
152
    temperature : float optional
153
        The set temperature of the system.
154
        Optional because only applicable for MD simulations with thermostat.
155
        Needed for certain observables.
156
    simulation_data : str|Path, optional
157
        All other simulation data that can be extracted from the trajectory metadata.
158
        E.g. software version, pressure in NPT simulations, time step, ...
159
    """
160

161
    n_configurations: int
4✔
162
    species_list: List[SpeciesInfo]
4✔
163
    box_l: list = None
4✔
164
    sample_rate: int = 1
4✔
165
    sample_step: float = None
4✔
166
    temperature: float = None
4✔
167
    simulation_data: dict = dataclasses.field(default_factory=dict)
4✔
168

169

170
class TrajectoryChunkData:
4✔
171
    """Class to specify the data format for transfer from the file to the database."""
172

173
    def __init__(self, species_list: List[SpeciesInfo], chunk_size: int):
4✔
174
        """
175

176
        Parameters
177
        ----------
178
        species_list : List[SpeciesInfo]
179
            List of SpeciesInfo.
180
            It contains the information which species are there and which properties
181
            are recorded for each
182
        chunk_size : int
183
            The number of configurations to be stored in this chunk
184
        """
185
        self.chunk_size = chunk_size
4✔
186
        self.species_list = species_list
4✔
187
        self._data = {}
4✔
188
        for sp_info in species_list:
4✔
189
            self._data[sp_info.name] = {}
4✔
190
            for prop_info in sp_info.properties:
4✔
191
                self._data[sp_info.name][prop_info.name] = np.zeros(
4✔
192
                    (chunk_size, sp_info.n_particles, prop_info.n_dims)
193
                )
194

195
    def add_data(self, data: np.ndarray, config_idx, species_name, property_name):
4✔
196
        """
197
        Add configuration data to the chunk
198
        Parameters
199
        ----------
200
        data:
201
            The data to be added, with shape (n_configs, n_particles, n_dims).
202
            n_particles and n_dims relates to the species and the property that is
203
            being added
204
        config_idx:
205
            Start index of the configs that are being added.
206
        species_name
207
            Name of the species to which the data belongs
208
        property_name
209
            Name of the property being added.
210

211
        Example:
212
        -------
213
        Storing velocity Information for 42 Na atoms in the 17th iteration of a loop
214
        that reads 5 configs per loop:
215
        add_data(vel_array, 16*5, 'Na', 'Velocities')
216
        where vel.data.shape == (5,42,3)
217

218
        """
219
        n_configs = len(data)
4✔
220
        self._data[species_name][property_name][
4✔
221
            config_idx : config_idx + n_configs, :, :
222
        ] = data
223

224
    def get_data(self):
4✔
225
        return self._data
4✔
226

227

228
class Database:
4✔
229
    """
230
    Database class.
231

232
    Databases make up a large part of the functionality of MDSuite and are kept
233
    fairly consistent in structure. Therefore, the database_path structure we
234
    are using has a separate class with commonly used methods which act as
235
    wrappers for the hdf5 database_path.
236

237
    Attributes
238
    ----------
239
    path : str|Path
240
            The name of the database_path in question.
241
    """
242

243
    def __init__(self, path: typing.Union[str, pathlib.Path] = "database"):
4✔
244
        """
245
        Constructor for the database_path class.
246

247
        Parameters
248
        ----------
249
        path : str|Path
250
                The name of the database_path in question.
251
        """
252
        if isinstance(path, pathlib.Path):
4✔
253
            self.path = path.as_posix()
4✔
254
        elif isinstance(path, str):
4!
255
            self.path = path  # name of the database_path
4✔
256
        else:
257
            # TODO fix this!
258
            log.debug(f"Expected str|Path but found {type(path)}")
×
259
            self.path = path
×
260

261
    @staticmethod
4✔
262
    def _update_indices(
4✔
263
        data: np.array, reference: np.array, batch_size: int, n_atoms: int
264
    ):
265
        """
266
        Update the indices key of the structure dictionary if the tensor_values must be
267
        sorted.
268

269
        Parameters
270
        ----------
271
        data : np.ndarray
272
        reference : np.ndarray
273
        batch_size : int
274
        n_atoms : int
275

276
        Returns
277
        -------
278

279
        """
280
        ids = np.reshape(np.array(data[:, 0]).astype(int), (-1, n_atoms))
×
281
        ref_ids = np.argsort(ids, axis=1)
×
282
        n_batches = ids.shape[0]
×
283

284
        return (
×
285
            ref_ids[:, reference - 1] + (np.arange(n_batches) * n_atoms)[None].T
286
        ).flatten()
287

288
    @staticmethod
4✔
289
    def _build_path_input(structure: dict) -> dict:
4✔
290
        """
291
        Build an input to a hdf5 database_path from a dictionary.
292

293
        In many cases, whilst a dict can be passed on to a method, it is not ideal for
294
        use in the hdf5 database_path. This method takes a dictionary and return a new
295
        dictionary with the relevant file path.
296

297

298
        Parameters
299
        ----------
300
        structure : dict
301
                General structure of the dictionary with relevant dataset sizes. e.g.
302
                {'Na': {'Forces': (200, 5000, 3)},
303
                'Pressure': (5000, 6), 'Temperature': (5000, 1)} In this case, the last
304
                 value in the tuple corresponds to the number of components that wil be
305
                 parsed to the database_path. The value in the dict can also be an
306
                 integer corresponding to a resize operation such as
307
                 {'Na': {'velocities' 100}}. In any case, the deepest portion of the
308
                 dict must be a non-dict object and will be returned as the value of the
309
                 path to it in the new dictionary.
310

311

312
        Returns
313
        -------
314
        architecture : dict
315
                Corrected path in the hdf5 database_path. e.g. {'/Na/Velocities': 100},
316
                or {'/Na/Forces': (200, 5000, 3)}
317

318
        """
319
        # Build file paths for the addition.
320
        architecture = {}
4✔
321
        for group in structure:
4✔
322
            if type(structure[group]) is not dict:
4✔
323
                architecture[group] = structure[group]
4✔
324
            else:
325
                for subgroup in structure[group]:
4✔
326
                    db_path = join_path(group, subgroup)
4✔
327
                    architecture[db_path] = structure[group][subgroup]
4✔
328

329
        return architecture
4✔
330

331
    def add_data(self, chunk: TrajectoryChunkData, start_idx: int):
4✔
332
        """
333
        Add new data to the dataset.
334

335
        Parameters
336
        ----------
337
        chunk:
338
            a data chunk
339
        start_idx:
340
            Configuration at which to start writing.
341
        """
342
        workaround_time_in_axis_1 = True
4✔
343

344
        chunk_data = chunk.get_data()
4✔
345

346
        with hf.File(self.path, "r+") as database:
4✔
347
            stop_index = start_idx + chunk.chunk_size
4✔
348

349
            for sp_info in chunk.species_list:
4✔
350
                for prop_info in sp_info.properties:
4✔
351
                    dataset_name = f"{sp_info.name}/{prop_info.name}"
4✔
352
                    write_data = chunk_data[sp_info.name][prop_info.name]
4✔
353

354
                    dataset_shape = database[dataset_name].shape
4✔
355
                    if len(dataset_shape) == 2:
4!
356
                        # only one particle
357
                        database[dataset_name][start_idx:stop_index, :] = write_data[
×
358
                            :, 0, :
359
                        ]
360

361
                    elif len(dataset_shape) == 3:
4!
362
                        if workaround_time_in_axis_1:
4!
363
                            database[dataset_name][
4✔
364
                                :, start_idx:stop_index, :
365
                            ] = np.swapaxes(write_data, 0, 1)
366
                        else:
367
                            database[dataset_name][start_idx:stop_index, ...] = write_data
×
368
                    else:
369
                        raise ValueError(
×
370
                            "dataset shape must be either (n_part,n_config,n_dim) or"
371
                            " (n_config, n_dim)"
372
                        )
373

374
    def resize_datasets(self, structure: dict):
4✔
375
        """
376
        Resize a dataset so more tensor_values can be added.
377

378
        Parameters
379
        ----------
380
        structure : dict
381
                path to the dataset that needs to be resized. e.g.
382
                {'Na': {'velocities': (32, 100, 3)}}
383
                will resize all 'x', 'y', and 'z' datasets by 100 entries.
384

385
        Returns
386
        -------
387

388
        """
389
        with hf.File(self.path, "r+") as db:
4✔
390
            # construct the architecture dict
391
            architecture = self._build_path_input(structure=structure)
4✔
392

393
            # Check for a type error in the dataset information
394
            for identifier in architecture:
4✔
395
                dataset_information = architecture[identifier]
4✔
396
                if not isinstance(dataset_information, tuple):
4!
397
                    raise TypeError("Invalid input for dataset generation")
×
398

399
                # get the correct maximum shape for the dataset -- changes if an
400
                # experiment property or an atomic property
401
                if len(dataset_information[:-1]) == 1:
4!
402
                    axis = 0
×
403
                    expansion = dataset_information[0] + db[identifier].shape[0]
×
404
                else:
405
                    axis = 1
4✔
406
                    expansion = dataset_information[1] + db[identifier].shape[1]
4✔
407

408
                db[identifier].resize(expansion, axis)
4✔
409

410
    def initialize_database(self, structure: dict):
4✔
411
        """
412
        Build a database_path with a general structure.
413

414
        Note, this method WILL overwrite a pre-existing database_path. This is because
415
        it is only to be called on the initial construction of an experiment class and
416
        the first addition of tensor_values to it.
417

418

419
        Parameters
420
        ----------
421
        structure : dict
422
                General structure of the dictionary with relevant dataset sizes.
423
                e.g. {'Na': {'Forces': (200, 5000, 3)}, 'Pressure': (5000, 6),
424
                'Temperature': (5000, 1)} In this case, the last value in the tuple
425
                corresponds to the number of components that wil be parsed to the
426
                database_path.
427

428
        Returns
429
        -------
430

431
        """
432
        self.add_dataset(structure)  # add a dataset to the groups
4✔
433

434
    def database_exists(self) -> bool:
4✔
435
        """Check if the database file already exists."""
436
        return pathlib.Path(self.path).exists()
4✔
437

438
    def add_dataset(self, structure: dict):
4✔
439
        """
440
        Add a dataset of the necessary size to the database_path.
441

442
        Just as a separate method exists for building the group structure of the hdf5
443
        database_path, so too do we include a separate method for adding a dataset.
444
        This is so datasets can be added not just upon the initial construction of the
445
        database_path, but also if tensor_values is added in the future that should
446
        also be stored. This method will assume that a group has already been built,
447
        although this is not necessary for HDF5, the separation of the actions is good
448
        practice.
449

450
        Parameters
451
        ----------
452
        structure : dict
453
                Structure of a single property to be added to the database_path.
454
                e.g. {'Na': {'Forces': (200, 5000, 3)}}
455

456
        Returns
457
        -------
458
        Updates the database_path directly.
459
        """
460
        with hf.File(self.path, "a") as database:
4✔
461
            architecture = self._build_path_input(structure)  # get the correct file path
4✔
462
            for item in architecture:
4✔
463
                dataset_information = architecture[item]  # get the tuple information
4✔
464
                dataset_path = item  # get the dataset path in the database_path
4✔
465

466
                # Check for a type error in the dataset information
467
                try:
4✔
468
                    if type(dataset_information) is not tuple:
4!
469
                        raise TypeError("Invalid input for dataset generation")
×
470
                except TypeError:
×
471
                    raise TypeError
×
472

473
                if len(dataset_information[:-1]) == 1:
4!
474
                    vector_length = dataset_information[-1]
×
475
                    max_shape = (None, vector_length)
×
476
                else:
477
                    max_shape = list(dataset_information)
4✔
478
                    max_shape[1] = None
4✔
479
                    max_shape = tuple(max_shape)
4✔
480

481
                database.create_dataset(
4✔
482
                    dataset_path,
483
                    dataset_information,
484
                    maxshape=max_shape,
485
                    compression="gzip",
486
                    chunks=True,
487
                )
488

489
    def _add_group_structure(self, structure: dict):
4✔
490
        """
491
        Add a simple group structure to a database_path.
492
        This method will take an input structure and build the required group structure
493
        in the hdf5 database_path. This will NOT however instantiate the dataset for the
494
        structure, only the group hierarchy.
495

496

497
        Parameters
498
        ----------
499
        structure : dict
500
                Structure of a single property to be added to the database_path.
501
                e.g. {'Na': {'Forces': (200, 5000, 3)}}
502

503
        Returns
504
        -------
505
        Updates the database_path directly.
506
        """
507
        with hf.File(self.path, "a") as database:
×
508
            # Build file paths for the addition.
509
            architecture = self._build_path_input(structure=structure)
×
510
            for item in list(architecture):
×
511
                if item in database:
×
512
                    log.info("Group structure already exists")
×
513
                else:
514
                    database.create_group(item)
×
515

516
    def get_memory_information(self) -> dict:
4✔
517
        """
518
        Get memory information from the database_path.
519

520
        Returns
521
        -------
522
        memory_database : dict
523
                A dictionary of the memory information of the groups in the
524
                database_path
525
        """
526
        with hf.File(self.path, "r") as database:
4✔
527
            memory_database = {}
4✔
528
            for item in database:
4✔
529
                for ds in database[item]:
4✔
530
                    memory_database[join_path(item, ds)] = database[item][ds].nbytes
4✔
531

532
        return memory_database
4✔
533

534
    def check_existence(self, path: str) -> bool:
4✔
535
        """
536
        Check to see if a dataset is in the database_path.
537

538
        Parameters
539
        ----------
540
        path : str
541
                Path to the desired dataset
542

543
        Returns
544
        -------
545
        response : bool
546
                If true, the path exists, else, it does not.
547
        """
548
        with hf.File(self.path, "r") as database_object:
4✔
549
            keys = []
4✔
550
            database_object.visit(
4✔
551
                lambda item: keys.append(database_object[item].name)
552
                if isinstance(database_object[item], hf.Dataset)
553
                else None
554
            )
555
            path = f"/{path}"  # add the / to avoid name overlapping
4✔
556

557
            response = any(list(item.endswith(path) for item in keys))
4✔
558
        return response
4✔
559

560
    def change_key_names(self, mapping: dict):
4✔
561
        """
562
        Change the name of database_path keys.
563

564
        Parameters
565
        ----------
566
        mapping : dict
567
                Mapping for the change of names
568

569
        Returns
570
        -------
571
        Updates the database_path
572
        """
573
        with hf.File(self.path, "r+") as db:
×
574
            groups = list(db.keys())
×
575

576
            for item in groups:
×
577
                if item in mapping:
×
578
                    db.move(item, mapping[item])
×
579

580
    def load_data(
4✔
581
        self,
582
        path_list: list = None,
583
        select_slice: np.s_ = np.s_[:],
584
        dictionary: bool = False,
585
        scaling: list = None,
586
        d_size: int = None,
587
    ):
588
        """
589
        Load tensor_values from the database_path for some operation.
590

591
        Should be called by the tensor_values fetch class as this will ensure
592
        correct loading and pre-loading.
593

594
        Returns
595
        -------
596

597
        """
598
        if scaling is None:
4!
599
            scaling = [1 for _ in range(len(path_list))]
4✔
600

601
        with hf.File(self.path, "r") as database:
4✔
602
            data = {}
4✔
603
            for i, item in enumerate(path_list):
4✔
604
                if type(select_slice) is dict:
4!
605
                    # index is the particle species name in the full item as a str.
606
                    slice_index = item.decode().split("/")[0]
×
607
                    my_slice = select_slice[slice_index]
×
608
                else:
609
                    my_slice = select_slice
4✔
610
                try:
4✔
611
                    data[item] = (
4✔
612
                        tf.convert_to_tensor(database[item][my_slice], dtype=tf.float64)
613
                        * scaling[i]
614
                    )
615
                except TypeError:
×
616
                    data[item] = (
×
617
                        tf.convert_to_tensor(
618
                            database[item][my_slice[0]][:, my_slice[1], :],
619
                            dtype=tf.float64,
620
                        )
621
                        * scaling[i]
622
                    )
623
            data[str.encode("data_size")] = d_size
4✔
624

625
        return data
4✔
626

627
    def get_load_time(self, database_path: str = None):
4✔
628
        """
629
        Calculate the open/close time of the database_path.
630

631
        Parameters
632
        ----------
633
        database_path : str
634
                Database path on which to test the time.
635

636
        Returns
637
        -------
638
        opening time : float
639
                Time taken to open and close the database_path
640
        """
641
        if database_path is None:
×
642
            start = time.time()
×
643
            database_path = hf.File(self.path, "r")
×
644
            database_path.close()
×
645
            stop = time.time()
×
646
        else:
647
            start = time.time()
×
648
            database_path = hf.File(database_path, "r")
×
649
            database_path.close()
×
650
            stop = time.time()
×
651

652
        return stop - start
×
653

654
    def get_data_size(self, data_path: str) -> tuple:
4✔
655
        """
656
        Return the size of a dataset as a tuple (n_rows, n_columns, n_bytes).
657

658
        Parameters
659
        ----------
660
        data_path : str
661
                path to the tensor_values in the hdf5 database_path.
662

663
        Returns
664
        -------
665
        dataset_properties : tuple
666
                Tuple of tensor_values about the dataset, e.g.
667
                (n_rows, n_columns, n_bytes)
668
        """
669
        with hf.File(self.path, "r") as db:
4✔
670
            data_tuple = (
4✔
671
                db[data_path].shape[0],
672
                db[data_path].shape[1],
673
                db[data_path].nbytes,
674
            )
675

676
        return data_tuple
4✔
677

678
    def get_database_summary(self):
4✔
679
        """
680
        Get a summary of the database properties.
681

682
        Returns
683
        -------
684
        summary : list
685
                A list of properties that are in the database.
686
        """
687
        with hf.File(self.path, "r") as db:
×
688
            return list(db.keys())
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc