• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / MDSuite / 3999396905

pending completion
3999396905

push

github-actions

GitHub
[merge before other PRs] ruff updates (#580)

960 of 1311 branches covered (73.23%)

Branch coverage included in aggregate %.

15 of 15 new or added lines in 11 files covered. (100.0%)

4034 of 4930 relevant lines covered (81.83%)

3.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.35
/mdsuite/utils/meta_functions.py
1
"""
2
MDSuite: A Zincwarecode package.
3

4
License
5
-------
6
This program and the accompanying materials are made available under the terms
7
of the Eclipse Public License v2.0 which accompanies this distribution, and is
8
available at https://www.eclipse.org/legal/epl-v20.html
9

10
SPDX-License-Identifier: EPL-2.0
11

12
Copyright Contributors to the Zincwarecode Project.
13

14
Contact Information
15
-------------------
16
email: zincwarecode@gmail.com
17
github: https://github.com/zincware
18
web: https://zincwarecode.com/
19

20
Citation
21
--------
22
If you use this module please cite us with:
23

24
Summary
25
-------
26
"""
27

28
import json
4✔
29
import logging
4✔
30
import os
4✔
31
import pathlib
4✔
32
import typing
4✔
33
from functools import wraps
4✔
34
from time import time
4✔
35
from typing import Callable
4✔
36

37
import GPUtil
4✔
38
import numpy as np
4✔
39
import psutil
4✔
40
import tensorflow as tf
4✔
41
from scipy.signal import savgol_filter
4✔
42

43
from mdsuite.utils.exceptions import NoGPUInSystem
4✔
44
from mdsuite.utils.units import golden_ratio
4✔
45

46
log = logging.getLogger(__name__)
4✔
47

48

49
def gpu_available() -> bool:
4✔
50
    """Check if TensorFlow has access to any GPU device."""
51
    return len(tf.config.list_physical_devices("GPU")) > 1
4✔
52

53

54
# https://stackoverflow.com/questions/42033142/is-there-an-easy-way-to-check-if-an-object-is-json-serializable-in-python
55
def is_jsonable(x: dict) -> bool:
4✔
56
    """
57
    Parameters
58
    ----------
59
    x: dict
60
        Dictionary to check, if it is json serializable.
61

62
    Returns
63
    -------
64
    bool: Whether the dict was serializable or not.
65
    """
66
    try:
4✔
67
        json.dumps(x)
4✔
68
        return True
4✔
69
    except (TypeError, OverflowError):
4✔
70
        return False
4✔
71

72

73
def join_path(a, b):
4✔
74
    """Join a and b and make sure to use forward slashes.
75

76
    Parameters
77
    ----------
78
    a: str
79
    b: str
80

81
    Returns
82
    -------
83
    str: joined path with forced forward slashes
84

85
    Notes
86
    -----
87
    h5py 3.1.0 on windows relies on forward slashes but os.path.join returns backward
88
    slashes. Here we replace them to enable MDSuite for Windows users. To be used ONLY
89
    for navigation within a database_path. For navigation through the file experiment
90
    in general one should use os.path.join.
91

92
    """
93
    return os.path.join(a, b).replace("\\", "/")
4✔
94

95

96
def get_dimensionality(box: list) -> int:
4✔
97
    """
98
    Calculate the dimensionality of the experiment box.
99

100
    Parameters
101
    ----------
102
    box : list
103
            box array of the experiment of the form [x, y, z]
104

105
    Returns
106
    -------
107
    dimensions : int
108
            dimension of the box i.e, 1 or 2 or 3 (Higher dimensions probably don't
109
            make sense just yet)
110
    """
111
    # Check if the x, y, or z entries are empty, i.e. 2 dimensions
112
    if box[0] == 0 or box[1] == 0 or box[2] == 0:
4✔
113
        if (
4✔
114
            box[0] == 0
115
            and box[1] == 0
116
            or box[0] == 0
117
            and box[2] == 0
118
            or box[1] == 0
119
            and box[2] == 0
120
        ):
121
            dimensions = 1
4✔
122
        else:
123
            dimensions = 2
4✔
124

125
    # Other option is 3 dimensions.
126
    else:
127
        dimensions = 3
4✔
128

129
    return dimensions
4✔
130

131

132
def get_machine_properties() -> dict:
4✔
133
    """
134
    Get the properties of the machine being used.
135

136
    Returns
137
    -------
138
    machine_properties : dict
139
            A dictionary containing information about the hardware being used.
140
    """
141
    machine_properties = {}
4✔
142
    available_memory = psutil.virtual_memory().available  # RAM available
4✔
143
    total_cpu_cores = psutil.cpu_count(logical=True)  # CPU cores available
4✔
144
    # Update the machine properties dictionary
145
    machine_properties["cpu"] = total_cpu_cores
4✔
146
    machine_properties["memory"] = available_memory
4✔
147
    machine_properties["gpu"] = {}
4✔
148

149
    try:
4✔
150
        total_gpu_devices = GPUtil.getGPUs()  # get information on all the gpu's
4✔
151
        for gpu in total_gpu_devices:
4!
152
            machine_properties["gpu"][gpu.id] = {}
×
153
            machine_properties["gpu"][gpu.id]["name"] = gpu.name
×
154
            machine_properties["gpu"][gpu.id]["memory"] = gpu.memoryTotal
×
155
    except (NoGPUInSystem, ValueError):
×
156
        log.warning("No GPUs detected, continuing without GPU support")
×
157

158
    return machine_properties
4✔
159

160

161
def line_counter(filename: str) -> int:
4✔
162
    """
163
    Count the number of lines in a file.
164

165
    This function used a memory safe method to count the number of lines in the file.
166
    Using the other tensor_values collected during the trajectory analysis, this is
167
    enough information to completely characterize the experiment.
168

169
    Parameters
170
    ----------
171
    filename : str
172
            Name of the file to be read in.
173

174
    Returns
175
    -------
176
    lines : int
177
            Number of lines in the file
178
    """
179
    f = open(filename, "rb")
4✔
180
    num_lines = sum(1 for _ in f)
4✔
181
    f.close()
4✔
182
    return num_lines
4✔
183

184

185
def optimize_batch_size(
4✔
186
    filepath: typing.Union[str, pathlib.Path],
187
    number_of_configurations: int,
188
    _file_size: int = None,
189
    _memory: int = None,
190
    test: bool = False,
191
) -> int:
192
    """
193
    Optimize the size of batches during initial processing.
194

195
    During the database_path construction a batch size must be chosen in order to
196
    process the trajectories with the least RAM but reasonable performance.
197

198
    Parameters
199
    ----------
200
    filepath : str
201
            Path to the file be read in. This is not opened during the process, it is
202
            simply needed to read the file size.
203
    number_of_configurations : int
204
            Number of configurations in the trajectory.
205
    _file_size : int
206
            Mock file size to use during tests.
207
    _memory : int
208
            Mock memory to use during tests.
209
    test : bool
210
            If true, mock variables are used.
211

212
    Returns
213
    -------
214
    batch size : int
215
            Number of configurations to load in each batch
216
    """
217
    if test:
4✔
218
        file_size = _file_size
4✔
219
        database_memory = _memory
4✔
220
    else:
221
        computer_statistics = get_machine_properties()  # Get computer statistics
4✔
222
        file_size = os.path.getsize(filepath)  # Get the size of the file
4✔
223
        database_memory = (
4✔
224
            0.1 * computer_statistics["memory"]
225
        )  # We take 10% of the available memory
226

227
    memory_per_configuration = (
4✔
228
        file_size / number_of_configurations
229
    )  # get the memory per configuration
230
    initial_batch_number = int(
4✔
231
        database_memory / (5 * memory_per_configuration)
232
    )  # trivial batch allocation
233

234
    # The database_path generation expands memory ~5x
235
    if 10 * file_size < database_memory:
4✔
236
        return int(number_of_configurations)
4✔
237
    else:
238
        return initial_batch_number
4✔
239

240

241
def linear_fitting_function(x: np.array, a: float, b: float) -> np.array:
4✔
242
    """
243
    Linear function for line fitting.
244

245
    In many cases, namely those involving an Einstein relation, a linear curve must be
246
    fit to some tensor_values. This function is called by the scipy curve_fit module as
247
    the model to fit to.
248

249
    Parameters
250
    ----------
251
    x : np.array
252
            x tensor_values for fitting
253
    a : float
254
            Fitting parameter of the gradient
255
    b : float
256
            Fitting parameter for the y intercept
257

258
    Returns
259
    -------
260
    a*x + b : float
261
            Returns the evaluation of a linear function.
262
    """
263
    return a * x + b
4✔
264

265

266
def simple_file_read(filename: str) -> list:
4✔
267
    """
268
    Trivially read a file and load it into an array.
269

270
    There are many occasions when a file simply must be read and dumped into a file. In
271
    these cases, we call this method and dump tensor_values into an array. This is NOT
272
    memory safe, and should not be used for processing large trajectory files.
273

274
    Parameters
275
    ----------
276
    filename : str
277
            Name of the file to be read in.
278

279
    Returns
280
    -------
281
    data_array: list
282
            Data read in by the function.
283
    """
284
    data_array = []  # define empty tensor_values array
4✔
285
    with open(filename, "r+") as f:  # Open the file for reading
4✔
286
        for line in f:  # Loop over the lines
4✔
287
            data_array.append(
4✔
288
                line.split()
289
            )  # Split the lines by whitespace and add to tensor_values array
290

291
    return data_array
4✔
292

293

294
def timeit(f: Callable) -> Callable:
4✔
295
    """
296
    Decorator to time the execution of a method.
297

298
    Parameters
299
    ----------
300
    f : Callable
301
            Function to be wrapped.
302

303
    Returns
304
    -------
305
    wrap : Callable
306
            Method wrapper for timing the method.
307

308
    Notes
309
    -----
310
    There is currently no test for this wrapper as there is no simple way of checking
311
    timing on a remote server.
312
    """
313

314
    @wraps(f)
×
315
    def wrap(*args, **kw):
316
        """Function to wrap a method and time its execution."""
317
        ts = time()  # get the initial time
×
318
        result = f(*args, **kw)  # run the function.
×
319
        te = time()  # get the time after the function as run.
×
320
        log.info(f"function '{f.__name__}' took {(te - ts)} s")
×
321

322
        return result
×
323

324
    return wrap
×
325

326

327
def apply_savgol_filter(
4✔
328
    data: np.ndarray, order: int = 2, window_length: int = 17
329
) -> np.ndarray:
330
    """
331
    Apply a savgol filter for function smoothing.
332

333
    This function will simply call the scipy SavGol implementation with preset
334
    parameters for the polynomial number and window size.
335

336
    Parameters
337
    ----------
338
    window_length : int
339
            Window length to use in the filtering.
340
    data : list
341
            Array of tensor_values to be analysed.
342
    order : int
343
            Order of polynomial to use in the smoothing.
344

345
    Returns
346
    -------
347
    filtered tensor_values : np.ndarray
348
            Returns the filtered tensor_values directly from the scipy SavGol filter.
349

350
    Notes
351
    -----
352
    There are no tests for this method as a test would simply be testing the scipy
353
    implementation which they have done.
354
    """
355
    return savgol_filter(data, window_length, order)
4✔
356

357

358
def closest_point(data: np.ndarray, value: float):
4✔
359
    """
360
    Find the value in the array closes to the value provided.
361

362
    Parameters
363
    ----------
364
    data : float
365
            Array to search.
366
    value : np.ndarray
367
            Value to look for.
368

369
    Returns
370
    -------
371

372
    """
373
    return min(data, key=lambda x: abs(x - value))
4✔
374

375

376
def golden_section_search(
4✔
377
    data: np.array,
378
    a: float,
379
    b: float,
380
    tol: float = 1e-5,
381
    h: float = None,
382
    c: float = None,
383
    d: float = None,
384
    fc: float = None,
385
    fd: float = None,
386
) -> tuple:
387
    """
388
    Perform a golden-section search for function minimums.
389

390
    The Golden-section search algorithm is one of the best min-finding algorithms
391
    available and is here used to the minimums of functions during analysis.
392
    For example, in the evaluation of coordination numbers the minimum values of the
393
    radial distribution functions must be calculated in order to define the
394
    coordination. This implementation will return an interval in which the minimum
395
    should exists, and does so for all of the minimums on the function.
396

397

398
    Parameters
399
    ----------
400
    data : np.array
401
            Data on which to find minimums.
402
    a : float
403
            upper bound on the min finding range.
404
    b : float
405
            lower bound on the min finding range.
406

407
    Returns
408
    -------
409
    minimum range : tuple
410
            Returns two radii values within which the minimum can be found.
411
    """
412
    # Define the golden ratio identities
413
    phi_a = 1 / golden_ratio
4✔
414
    phi_b = 1 / (golden_ratio**2)
4✔
415

416
    (a, b) = (min(a, b), max(a, b))  # check for a simple error
4✔
417

418
    if h is None:
4✔
419
        h = b - a
4✔
420
    if h <= tol:
4✔
421
        return a, b
4✔
422
    if c is None:
4✔
423
        c = closest_point(data[0], a + phi_b * h)
4✔
424
    if d is None:
4✔
425
        d = closest_point(data[0], a + phi_a * h)
4✔
426
    if fc is None:
4✔
427
        fc = data[1][np.where(data[0] == c)]
4✔
428
    if fd is None:
4✔
429
        fd = data[1][np.where(data[0] == d)]
4✔
430
    if fc < fd:
4✔
431
        return golden_section_search(
4✔
432
            data, a, d, tol, h * phi_a, c=None, fc=None, d=c, fd=fc
433
        )
434
    else:
435
        return golden_section_search(
4✔
436
            data, c, b, tol, h * phi_a, c=d, fc=fd, d=None, fd=None
437
        )
438

439

440
def get_nearest_divisor(a: int, b: int) -> int:
4✔
441
    """
442
    Function to get the nearest lower divisor.
443

444
    If b%a is not 0, this method may be called to get the nearest number to a that
445
    makes b%a zero.
446

447
    Parameters
448
    ----------
449
    a : int
450
            divisor
451
    b : int
452
            target number
453

454
    Returns
455
    -------
456
    divisor : int
457
            nearest number to a that divides into b evenly.
458
    """
459
    remainder = 1  # initialize a remainder
4✔
460
    a += 1
4✔
461
    while remainder != 0:
4✔
462
        a -= 1
4✔
463
        remainder = b % a
4✔
464

465
    return a
4✔
466

467

468
def split_array(data: np.array, condition: np.array) -> list:
4✔
469
    """
470
    split an array by a condition
471
    Parameters
472
    ----------
473
    data : np.array
474
            tensor_values to split
475
    condition : np.array
476
            condition on which to split by.
477

478
    Returns
479
    -------
480
    split_array : list
481
            A list of split up arrays.
482
    """
483
    initial_split = [data[condition], data[~condition]]  # attempt to split the array
4✔
484

485
    if (
4✔
486
        len(initial_split[1]) == 0
487
    ):  # if the condition is never met, return only the raw tensor_values
488
        return [data[condition]]
4✔
489
    else:  # else return the whole array
490
        return list(initial_split)
4✔
491

492

493
def find_item(obj, key):
4✔
494
    """
495
    Function to recursively retrieve values given a key for nested dictionaries.
496

497
    Parameters
498
    ----------
499
    obj: dict
500
        nested dictionary with results
501
    key: str, float or other
502
        to find in the dictionary
503

504
    Returns
505
    -------
506
    item: dict value.
507
        returns the value for the given key. Return type may change depending on the
508
        requested key
509
    """
510
    if key in obj:
4✔
511
        return obj[key]
4✔
512
    for k, v in obj.items():
4!
513
        if isinstance(v, dict):
4!
514
            item = find_item(v, key)
4✔
515
            if item is not None:
4!
516
                return item
4✔
517

518

519
def sort_array_by_column(array: np.ndarray, column_idx: int):
4✔
520
    # https://stackoverflow.com/questions/2828059/
521
    #   sorting-arrays-in-numpy-by-column/35624868
522
    # make sure that the column to sort by is number type
523
    # culprit: if we read in a lammps file, one line will be str, so the whole
524
    # array is str. sorting by id will invoke str sorting rules (i.e. '10' < '2'),
525
    # even though the id column could have number type.
526
    to_sort_by_column = np.asarray(array[:, column_idx], dtype=float)
4✔
527
    return array[to_sort_by_column.argsort()]
4✔
528

529

530
def check_a_in_b(a, b):
4✔
531
    """
532
    Check if any value of a is in b.
533

534
    Parameters
535
    ----------
536
    a: tf.Tensor
537
    b: tf.Tensor
538

539
    Returns
540
    -------
541
    bool
542

543
    """
544
    x = tf.unstack(a)
4✔
545
    for x1 in x:
4✔
546
        if tf.reduce_any(b == x1):
4✔
547
            return True
4✔
548
    return False
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc