• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / MDSuite / 3999396905

pending completion
3999396905

push

github-actions

GitHub
[merge before other PRs] ruff updates (#580)

960 of 1311 branches covered (73.23%)

Branch coverage included in aggregate %.

15 of 15 new or added lines in 11 files covered. (100.0%)

4034 of 4930 relevant lines covered (81.83%)

3.19 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.67
/mdsuite/calculators/einstein_diffusion_coefficients.py
1
"""
2
MDSuite: A Zincwarecode package.
3

4
License
5
-------
6
This program and the accompanying materials are made available under the terms
7
of the Eclipse Public License v2.0 which accompanies this distribution, and is
8
available at https://www.eclipse.org/legal/epl-v20.html
9

10
SPDX-License-Identifier: EPL-2.0
11

12
Copyright Contributors to the Zincwarecode Project.
13

14
Contact Information
15
-------------------
16
email: zincwarecode@gmail.com
17
github: https://github.com/zincware
18
web: https://zincwarecode.com/
19

20
Citation
21
--------
22
If you use this module please cite us with:
23

24
Summary
25
-------
26
Module for the computation of self-diffusion coefficients using the Einstein method.
27
"""
28
from __future__ import annotations
4✔
29

30
import logging
4✔
31
from abc import ABC
4✔
32
from dataclasses import dataclass
4✔
33
from typing import Any, List, Union
4✔
34

35
import numpy as np
4✔
36
import tensorflow as tf
4✔
37
from bokeh.models import HoverTool, LinearAxis, Span
4✔
38
from bokeh.models.ranges import Range1d
4✔
39
from bokeh.plotting import figure
4✔
40
from tqdm import tqdm
4✔
41

42
from mdsuite import utils
4✔
43
from mdsuite.calculators.calculator import call
4✔
44
from mdsuite.calculators.trajectory_calculator import TrajectoryCalculator
4✔
45
from mdsuite.database.mdsuite_properties import mdsuite_properties
4✔
46
from mdsuite.utils.calculator_helper_methods import fit_einstein_curve
4✔
47

48
log = logging.getLogger(__name__)
4✔
49

50

51
@dataclass
4✔
52
class Args:
3✔
53
    """Data class for the saved properties."""
54

55
    data_range: int
4✔
56
    correlation_time: int
4✔
57
    atom_selection: np.s_
4✔
58
    tau_values: np.s_
4✔
59
    molecules: bool
4✔
60
    species: list
4✔
61
    fit_range: int
4✔
62

63

64
class EinsteinDiffusionCoefficients(TrajectoryCalculator, ABC):
4✔
65
    """
66
    Class for the Einstein diffusion coefficient implementation.
67

68
    Attributes
69
    ----------
70
    msd_array : np.ndarray
71
            MSd data updated during each ensemble computation.
72

73
    See Also
74
    --------
75
    mdsuite.calculators.calculator.Calculator class
76

77
    Examples
78
    --------
79
    project.experiment.run.EinsteinDiffusionCoefficients(data_range=500,
80
                                                         plot=True,
81
                                                         correlation_time=10)
82
    """
83

84
    def __init__(self, **kwargs):
4✔
85
        """
86
        Parameters
87
        ----------
88
        experiment :  Experiment
89
                Experiment class to call from
90
        experiments :  Experiment
91
                Experiment classes to call from
92
        """
93
        super().__init__(**kwargs)
4✔
94
        self.scale_function = {"linear": {"scale_factor": 150}}
4✔
95
        self.loaded_property = mdsuite_properties.unwrapped_positions
4✔
96
        self.x_label = r"$$\text{Time} / s$$"
4✔
97
        self.y_label = r"$$\text{MSD} / m^{2}$$"
4✔
98
        self.result_keys = [
4✔
99
            "diffusion_coefficient",
100
            "uncertainty",
101
            "gradient",
102
            "intercept",
103
        ]
104
        self.result_series_keys = ["time", "msd", "gradients", "gradient_errors"]
4✔
105
        self.analysis_name = "Einstein Self-Diffusion Coefficients"
4✔
106
        self._dtype = tf.float64
4✔
107

108
        self.msd_array = None
4✔
109

110
        log.info("starting Einstein Diffusion Computation")
4✔
111

112
    @call
4✔
113
    def __call__(
4✔
114
        self,
115
        plot: bool = True,
116
        species: list = None,
117
        data_range: int = 100,
118
        correlation_time: int = 1,
119
        atom_selection: np.s_ = np.s_[:],
120
        molecules: bool = False,
121
        tau_values: Union[int, List, Any] = np.s_[:],
122
        fit_range: int = -1,
123
    ):
124
        """
125

126
        Parameters
127
        ----------
128
        plot : bool
129
                if true, plot the output.
130
        species : list
131
                List of species on which to operate.
132
        data_range : int
133
                Data range to use in the analysis.
134
        correlation_time : int
135
                Correlation time to use in the window sampling.
136
        atom_selection : np.s_
137
                Selection of atoms to use within the HDF5 database.
138
        molecules : bool
139
                If true, molecules are used instead of atoms.
140
        tau_values : Union[int, list, np.s_]
141
                Selection of tau values to use in the window sliding.
142

143
        Returns
144
        -------
145
        None
146
        """
147
        if species is None:
4!
148
            if molecules:
4✔
149
                species = list(self.experiment.molecules)
4✔
150
            else:
151
                species = list(self.experiment.species)
4✔
152

153
        if fit_range == -1:
4!
154
            fit_range = int(data_range - 1)
4✔
155
        # set args that will affect the computation result
156
        self.args = Args(
4✔
157
            data_range=data_range,
158
            correlation_time=correlation_time,
159
            atom_selection=atom_selection,
160
            tau_values=tau_values,
161
            molecules=molecules,
162
            species=species,
163
            fit_range=fit_range,
164
        )
165
        self.plot = plot
4✔
166
        self.system_property = False
4✔
167

168
    def ensemble_operation(self, ensemble):
4✔
169
        """
170
        Calculate and return the msd.
171

172
        Parameters
173
        ----------
174
        ensemble : tf.Tensor
175
                An ensemble of data to be operated on.
176

177
        Returns
178
        -------
179
        MSD of the tensor_values.
180
        """
181
        msd = tf.math.squared_difference(
4✔
182
            tf.gather(ensemble, self.args.tau_values, axis=1), ensemble[:, None, 0]
183
        )
184
        self.count += msd.shape[0]
4✔
185
        # average over particles, sum over dimensions
186
        # msd = tf.reduce_sum(tf.reduce_mean(msd, axis=0), axis=-1)
187
        msd = tf.reduce_sum(tf.reduce_sum(msd, axis=0), axis=-1)
4✔
188

189
        # sum up ensembles to average in post processing
190
        return np.array(msd)
4✔
191

192
    def fit_diff_coeff(self):
4✔
193
        """Apply unit conversion, fit line to the data, prepare for database storage."""
194
        # self.msd_array /= int(self.n_batches) * self.ensemble_loop
195
        self.msd_array /= self.count
4✔
196

197
        self.msd_array *= self.experiment.units.length**2
4✔
198
        self.time *= self.experiment.units.time
4✔
199

200
        fit_values, covariance, gradients, gradient_errors = fit_einstein_curve(
4✔
201
            x_data=self.time, y_data=self.msd_array, fit_max_index=self.args.fit_range
202
        )
203
        error = np.sqrt(np.diag(covariance))[0]
4✔
204

205
        data = {
4✔
206
            self.result_keys[0]: 1 / 6.0 * fit_values[0],
207
            self.result_keys[1]: 1 / 6.0 * error,
208
            self.result_keys[2]: fit_values[0],
209
            self.result_keys[3]: fit_values[1],
210
            self.result_series_keys[0]: self.time.tolist(),
211
            self.result_series_keys[1]: self.msd_array.tolist(),
212
            self.result_series_keys[2]: (np.array(gradients) / 6).tolist(),
213
            self.result_series_keys[3]: (np.array(gradient_errors) / 6).tolist(),
214
        }
215
        return data
4✔
216

217
    def run_calculator(self):
4✔
218
        """Run analysis."""
219
        self._run_dependency_check()
4✔
220
        for species in self.args.species:
4✔
221
            # Here for now to avoid issues. Should be moved out when calculators become
222
            # species-wise
223
            self.time = None
4✔
224
            self.time = self._handle_tau_values()
4✔
225
            dict_ref = str.encode("/".join([species, self.loaded_property.name]))
4✔
226
            batch_ds = self.get_batch_dataset([species])
4✔
227
            self.msd_array = np.zeros(self.data_resolution)
4✔
228
            self.count = 0
4✔
229
            # loop over batches to get MSD
230
            for i, batch in tqdm(
4✔
231
                enumerate(batch_ds),
232
                ncols=70,
233
                desc=species,
234
                total=self.n_batches,
235
                disable=self.memory_manager.minibatch,
236
            ):
237
                ensemble_ds = self.get_ensemble_dataset(batch, species)
4✔
238

239
                for ensemble in ensemble_ds:
4✔
240
                    if not ensemble[dict_ref].shape[1] == self.args.data_range:
4!
241
                        continue
×
242
                    else:
243
                        self.msd_array += self.ensemble_operation(ensemble[dict_ref])
4✔
244
                        self.count += 1
4✔
245

246
            # self.msd_array = np.array(tf.reduce_sum(self.msd_array, axis=0))
247
            fit_results = self.fit_diff_coeff()
4✔
248
            self.queue_data(data=fit_results, subjects=[species])
4✔
249

250
    def plot_data(self, data):
4✔
251
        """
252
        Plot the Einstein fits.
253

254
        Parameters
255
        ----------
256
        data
257

258
        Returns
259
        -------
260

261
        """
262
        for selected_species, val in data.items():
4✔
263
            fig = figure(x_axis_label=self.x_label, y_axis_label=self.y_label)
4✔
264

265
            gradients = np.array(val[self.result_series_keys[2]])
4✔
266
            gradient_errors = np.array(val[self.result_series_keys[3]])
4✔
267

268
            time = np.array(val[self.result_series_keys[0]])
4✔
269
            msd = np.array(val[self.result_series_keys[1]])
4✔
270

271
            fig.y_range = Range1d(-0.0, 1.1 * max(msd))
4✔
272

273
            # Compute the span
274
            span = Span(
4✔
275
                location=time[self.args.fit_range],
276
                dimension="height",
277
                line_dash="dashed",
278
            )
279
            # Compute msd and fit lines
280
            fig.line(
4✔
281
                time,
282
                msd,
283
                color=utils.Colour.ORANGE,
284
                legend_label=(
285
                    f"{selected_species}: {val[self.result_keys[0]]: 0.3E} +-"
286
                    f" {val[self.result_keys[1]]: 0.3E}"
287
                ),
288
            )
289
            fit_data = val[self.result_keys[2]] * time + val[self.result_keys[3]]
4✔
290
            fig.line(time, fit_data, color=utils.Colour.PAUA, legend_label="Curve fit.")
4✔
291
            fig.extra_y_ranges = {
4✔
292
                "Diff_range": Range1d(
293
                    start=0.999 * min(gradients), end=1.001 * max(gradients)
294
                )
295
            }
296

297
            fig.add_layout(
4✔
298
                LinearAxis(
299
                    y_range_name="Diff_range",
300
                    axis_label=r"$$\text{Diffusion Coefficient} / m^{2}s^{-1}$$",
301
                ),
302
                "right",
303
            )
304
            grad_time = time[-len(gradients) :]
4✔
305
            fig.line(
4✔
306
                grad_time,
307
                gradients,
308
                y_range_name="Diff_range",
309
                color=utils.Colour.MULBERRY,
310
            )
311
            fig.varea(
4✔
312
                grad_time,
313
                gradients - gradient_errors,
314
                gradients + gradient_errors,
315
                alpha=0.3,
316
                color=utils.Colour.ORANGE,
317
                y_range_name="Diff_range",
318
            )
319

320
            fig.add_tools(HoverTool())
4✔
321
            fig.add_layout(span)
4✔
322
            self.plot_array.append(fig)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc