• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 12744885267

13 Jan 2025 10:06AM UTC coverage: 95.159% (-0.007%) from 95.166%
12744885267

Pull #369

github

web-flow
Merge 592add3d6 into 1d7922781
Pull Request #369: Comment correlation algorithm, remove flattening of Velocity observable

2 of 2 new or added lines in 2 files covered. (100.0%)

46 existing lines in 10 files now uncovered.

2280 of 2396 relevant lines covered (95.16%)

3.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.91
/janus_core/processing/correlator.py
1
"""Module to correlate scalar data on-the-fly."""
2

3
from __future__ import annotations
4✔
4

5
from collections.abc import Iterable
4✔
6

7
from ase import Atoms
4✔
8
import numpy as np
4✔
9

10
from janus_core.processing.observables import Observable
4✔
11

12

13
class Correlator:
4✔
14
    """
15
    Correlate scalar real values, <ab>.
16

17
    Implements the algorithm detailed in https://doi.org/10.1063/1.3491098.
18

19
    Data pairs are observed iteratively and stored in a set of rolling hierarchical data
20
    blocks.
21

22
    Once a block is filled, coarse graining may be applied to update coarser block
23
    levels by updating the coarser block with the average of values accumulated up to
24
    that point in the filled block.
25

26
    The correlation is continuously updated when any block is updated with new data.
27

28
    Parameters
29
    ----------
30
    blocks : int
31
        Number of correlation blocks.
32
    points : int
33
        Number of points per block.
34
    averaging : int
35
        Averaging window per block level.
36

37
    Attributes
38
    ----------
39
    _max_block_used : int
40
        Which levels have been updated with data.
41
    _min_dist : int
42
        First point in coarse-grained block relevant for correlation updates.
43
    _accumulator : NDArray[float64]
44
        Sum of data seen for calculating the average between blocks.
45
    _count_accumulated : NDArray[int]
46
        Data points accumulated at this block.
47
    _shift_index : NDArray[int]
48
        Current position in each block's data store.
49
    _shift : NDArray[float64]
50
        Rolling data store for each block.
51
    _shift_not_null : NDArray[bool]
52
        If data is stored in this block's rolling data store, at a given index.
53
    _correlation : NDArray[float64]
54
        Running correlation values.
55
    _count_correlated : NDArray[int]
56
        Count of correlation updates for each block.
57
    """
58

59
    def __init__(self, *, blocks: int, points: int, averaging: int) -> None:
4✔
60
        """
61
        Initialise an empty Correlator.
62

63
        Parameters
64
        ----------
65
        blocks : int
66
            Number of resolution levels.
67
        points : int
68
            Data points at each resolution.
69
        averaging : int
70
            Coarse-graining between resolution levels.
71
        """
72
        self._blocks = blocks
4✔
73
        self._points = points
4✔
74
        self._averaging = averaging
4✔
75
        self._max_block_used = 0
4✔
76
        self._min_dist = self._points / self._averaging
4✔
77

78
        self._accumulator = np.zeros((self._blocks, 2))
4✔
79
        self._count_accumulated = np.zeros(self._blocks, dtype=int)
4✔
80
        self._shift_index = np.zeros(self._blocks, dtype=int)
4✔
81
        self._shift = np.zeros((self._blocks, self._points, 2))
4✔
82
        self._shift_not_null = np.zeros((self._blocks, self._points), dtype=bool)
4✔
83
        self._correlation = np.zeros((self._blocks, self._points))
4✔
84
        self._count_correlated = np.zeros((self._blocks, self._points), dtype=int)
4✔
85

86
    def update(self, a: float, b: float) -> None:
4✔
87
        """
88
        Update the correlation, <ab>, with new values a and b.
89

90
        Parameters
91
        ----------
92
        a : float
93
            Newly observed value of left correland.
94
        b : float
95
            Newly observed value of right correland.
96
        """
97
        self._propagate(a, b, 0)
4✔
98

99
    def _propagate(self, a: float, b: float, block: int) -> None:
4✔
100
        """
101
        Propagate update down block hierarchy.
102

103
        Parameters
104
        ----------
105
        a : float
106
            Newly observed value of left correland/average.
107
        b : float
108
            Newly observed value of right correland/average.
109
        block : int
110
            Block in the hierachy being updated.
111
        """
112
        if block == self._blocks:
4✔
113
            # Hit the end of the data structure.
114
            return
4✔
115

116
        shift = self._shift_index[block]
4✔
117
        self._max_block_used = max(self._max_block_used, block)
4✔
118

119
        # Update the rolling data store, and accumulate
120
        self._shift[block, shift, :] = a, b
4✔
121
        self._accumulator[block, :] += a, b
4✔
122
        self._shift_not_null[block, shift] = True
4✔
123
        self._count_accumulated[block] += 1
4✔
124

125
        if self._count_accumulated[block] == self._averaging:
4✔
126
            # Hit the coarse graining threshold, the next block can be updated.
127
            self._propagate(
4✔
128
                self._accumulator[block, 0] / self._averaging,
129
                self._accumulator[block, 1] / self._averaging,
130
                block + 1,
131
            )
132
            # Reset the accumulator at this block level.
133
            self._accumulator[block, :] = 0.0
4✔
134
            self._count_accumulated[block] = 0
4✔
135

136
        # Update the correlation.
137
        i = self._shift_index[block]
4✔
138
        if block == 0:
4✔
139
            # Need to multiply by all in this block (full resolution).
140
            j = i
4✔
141
            for point in range(self._points):
4✔
142
                if self._shifts_valid(block, i, j):
4✔
143
                    # Correlate at this lag.
144
                    self._correlation[block, point] += (
4✔
145
                        self._shift[block, i, 0] * self._shift[block, j, 1]
146
                    )
147
                    self._count_correlated[block, point] += 1
4✔
148
                j -= 1
4✔
149
                if j < 0:
4✔
150
                    # Wrap to start of rolling data store.
151
                    j += self._points
4✔
152
        else:
153
            # Only need to update after points/averaging.
154
            # The previous block already accounts for those points.
155
            for point in range(self._min_dist, self._points):
×
156
                if j < 0:
×
157
                    j = j + self._points
×
158
                if self._shifts_valid(block, i, j):
×
159
                    # Correlate at this lag.
UNCOV
160
                    self._correlation[block, point] += (
×
161
                        self._shift[block, i, 0] * self._shift[block, j, 1]
162
                    )
163
                    self._count_correlated[block, point] += 1
×
164
                j = j - 1
×
165
        # Update the rolling data store.
166
        self._shift_index[block] = (self._shift_index[block] + 1) % self._points
4✔
167

168
    def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool:
4✔
169
        """
170
        Return True if the shift registers have data.
171

172
        Parameters
173
        ----------
174
        block : int
175
            Block to check the shift register of.
176
        p_i : int
177
            Index i in the shift (left correland).
178
        p_j : int
179
            Index j in the shift (right correland).
180

181
        Returns
182
        -------
183
        bool
184
            Whether the shift indices have data.
185
        """
186
        return self._shift_not_null[block, p_i] and self._shift_not_null[block, p_j]
4✔
187

188
    def get_lags(self) -> Iterable[float]:
4✔
189
        """
190
        Obtain the correlation lag times.
191

192
        Returns
193
        -------
194
        Iterable[float]
195
            The correlation lag times.
196
        """
197
        lags = np.zeros(self._points * self._blocks)
4✔
198

199
        lag = 0
4✔
200
        for i in range(self._points):
4✔
201
            if self._count_correlated[0, i] > 0:
4✔
202
                # Data has been correlated, at full resolution.
203
                lags[lag] = i
4✔
204
                lag += 1
4✔
205
        for k in range(1, self._max_block_used):
4✔
206
            for i in range(self._min_dist, self._points):
×
207
                if self._count_correlated[k, i] > 0:
×
208
                    # Data has been correlated at a coarse-grained level.
209
                    lags[lag] = float(i) * float(self._averaging) ** k
×
210
                    lag += 1
×
211
        return lags[0:lag]
4✔
212

213
    def get_value(self) -> Iterable[float]:
4✔
214
        """
215
        Obtain the correlation value.
216

217
        Returns
218
        -------
219
        Iterable[float]
220
            The correlation values <a(t)b(t+t')>.
221
        """
222
        correlation = np.zeros(self._points * self._blocks)
4✔
223

224
        lag = 0
4✔
225
        for i in range(self._points):
4✔
226
            if self._count_correlated[0, i] > 0:
4✔
227
                # Data has been correlated at full resolution.
228
                correlation[lag] = (
4✔
229
                    self._correlation[0, i] / self._count_correlated[0, i]
230
                )
231
                lag += 1
4✔
232
        for k in range(1, self._max_block_used):
4✔
233
            for i in range(self._min_dist, self._points):
×
234
                # Indices less than points/averaging accounted in the previous block.
235
                if self._count_correlated[k, i] > 0:
×
236
                    # Data has been correlated at a coarse-grained level.
UNCOV
237
                    correlation[lag] = (
×
238
                        self._correlation[k, i] / self._count_correlated[k, i]
239
                    )
240
                    lag += 1
×
241
        return correlation[0:lag]
4✔
242

243

244
class Correlation:
4✔
245
    """
246
    Represents a user correlation, <ab>.
247

248
    Parameters
249
    ----------
250
    a : Observable
251
        Observable for a.
252
    b : Observable
253
        Observable for b.
254
    name : str
255
        Name of correlation.
256
    blocks : int
257
        Number of correlation blocks.
258
    points : int
259
        Number of points per block.
260
    averaging : int
261
        Averaging window per block level.
262
    update_frequency : int
263
        Frequency to update the correlation, md steps.
264
    """
265

266
    def __init__(
4✔
267
        self,
268
        *,
269
        a: Observable,
270
        b: Observable,
271
        name: str,
272
        blocks: int,
273
        points: int,
274
        averaging: int,
275
        update_frequency: int,
276
    ) -> None:
277
        """
278
        Initialise a correlation.
279

280
        Parameters
281
        ----------
282
        a : Observable
283
            Observable for a.
284
        b : Observable
285
            Observable for b.
286
        name : str
287
            Name of correlation.
288
        blocks : int
289
            Number of correlation blocks.
290
        points : int
291
            Number of points per block.
292
        averaging : int
293
            Averaging window per block level.
294
        update_frequency : int
295
            Frequency to update the correlation, md steps.
296
        """
297
        self.name = name
4✔
298
        self.blocks = blocks
4✔
299
        self.points = points
4✔
300
        self.averaging = averaging
4✔
301
        self._get_a = a
4✔
302
        self._get_b = b
4✔
303

304
        self._correlators = None
4✔
305
        self._update_frequency = update_frequency
4✔
306

307
    @property
4✔
308
    def update_frequency(self) -> int:
4✔
309
        """
310
        Get update frequency.
311

312
        Returns
313
        -------
314
        int
315
            Correlation update frequency.
316
        """
317
        return self._update_frequency
4✔
318

319
    def update(self, atoms: Atoms) -> None:
4✔
320
        """
321
        Update a correlation.
322

323
        Parameters
324
        ----------
325
        atoms : Atoms
326
            Atoms object to observe values from.
327
        """
328
        # All pairs of data to be correlated.
329
        value_pairs = zip(self._get_a(atoms).flatten(), self._get_b(atoms).flatten())
4✔
330
        if self._correlators is None:
4✔
331
            # Initialise correlators automatically.
332
            self._correlators = [
4✔
333
                Correlator(
334
                    blocks=self.blocks, points=self.points, averaging=self.averaging
335
                )
336
                for _ in range(len(self._get_a(atoms).flatten()))
337
            ]
338
        for corr, values in zip(self._correlators, value_pairs):
4✔
339
            corr.update(*values)
4✔
340

341
    def get(self) -> tuple[Iterable[float], Iterable[float]]:
4✔
342
        """
343
        Get the correlation value and lags, averaging over atoms if applicable.
344

345
        Returns
346
        -------
347
        correlation : Iterable[float]
348
            The correlation values <a(t)b(t+t')>.
349
        lags : Iterable[float]]
350
            The correlation lag times t'.
351
        """
352
        if self._correlators:
4✔
353
            lags = self._correlators[0].get_lags()
4✔
354
            return np.mean([cor.get_value() for cor in self._correlators], axis=0), lags
4✔
355
        return [], []
×
356

357
    def __str__(self) -> str:
4✔
358
        """
359
        Return string representation of correlation.
360

361
        Returns
362
        -------
363
        str
364
            String representation.
365
        """
366
        return self.name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc