• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

stfc / janus-core / 12712565578

10 Jan 2025 03:39PM UTC coverage: 95.166%. Remained the same
12712565578

Pull #369

github

web-flow
Merge 7ca6bd0ca into bf53f87b2
Pull Request #369: Comment correlation algorithm.

12 of 13 new or added lines in 2 files covered. (92.31%)

2 existing lines in 1 file now uncovered.

2264 of 2379 relevant lines covered (95.17%)

3.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.91
/janus_core/processing/correlator.py
1
"""Module to correlate scalar data on-the-fly."""
2

3
from __future__ import annotations
4✔
4

5
from collections.abc import Iterable
4✔
6

7
from ase import Atoms
4✔
8
import numpy as np
4✔
9

10
from janus_core.processing.observables import Observable
4✔
11

12

13
class Correlator:
4✔
14
    """
15
    Correlate scalar real values, <ab>.
16

17
    Implements the algorithm detailed in https://doi.org/10.1063/1.3491098.
18

19
    Data pairs are observed iteratively and stored in a set of rolling hierarchical data
20
    blocks.
21

22
    Once a block is filled, coarse graining may be applied to update coarser block
23
    levels by updating the coarser block with the average of values accumulated up to
24
    that point in the filled block.
25

26
    The correlation is continuously updated when any block is updated with new data.
27

28
    Parameters
29
    ----------
30
    blocks : int
31
        Number of correlation blocks.
32
    points : int
33
        Number of points per block.
34
    averaging : int
35
        Averaging window per block level.
36
    """
37

38
    def __init__(self, *, blocks: int, points: int, averaging: int) -> None:
4✔
39
        """
40
        Initialise an empty Correlator.
41

42
        Parameters
43
        ----------
44
        blocks : int
45
            Number of correlation blocks.
46
        points : int
47
            Number of points per block.
48
        averaging : int
49
            Averaging window per block level.
50
        """
51
        # Number of resoluion levels.
52
        self._blocks = blocks
4✔
53
        # Data points at each resolution.
54
        self._points = points
4✔
55
        # Coarse-graining between resolution levels.
56
        self._averaging = averaging
4✔
57
        # Which levels have been updated with data.
58
        self._max_block_used = 0
4✔
59
        # First point in coarse-grained block relevant for correlation updates.
60
        self._min_dist = self._points / self._averaging
4✔
61

62
        # Sum of data seen for calculating the average between blocks.
63
        self._accumulator = np.zeros((self._blocks, 2))
4✔
64
        # Data points accumulated at each block.
65
        self._count_accumulated = np.zeros(self._blocks, dtype=int)
4✔
66
        # Current position in each blocks data store.
67
        self._shift_index = np.zeros(self._blocks, dtype=int)
4✔
68
        # Rolling data store for each block.
69
        self._shift = np.zeros((self._blocks, self._points, 2))
4✔
70
        # If data is stored in this block's rolling data store, at a given index.
71
        self._shift_not_null = np.zeros((self._blocks, self._points), dtype=bool)
4✔
72
        # Running correlation values.
73
        self._correlation = np.zeros((self._blocks, self._points))
4✔
74
        # Count of correlation updates.
75
        self._count_correlated = np.zeros((self._blocks, self._points), dtype=int)
4✔
76

77
    def update(self, a: float, b: float) -> None:
4✔
78
        """
79
        Update the correlation, <ab>, with new values a and b.
80

81
        Parameters
82
        ----------
83
        a : float
84
            Newly observed value of left correland.
85
        b : float
86
            Newly observed value of right correland.
87
        """
88
        self._propagate(a, b, 0)
4✔
89

90
    def _propagate(self, a: float, b: float, block: int) -> None:
4✔
91
        """
92
        Propagate update down block hierarchy.
93

94
        Parameters
95
        ----------
96
        a : float
97
            Newly observed value of left correland/average.
98
        b : float
99
            Newly observed value of right correland/average.
100
        block : int
101
            Block in the hierachy being updated.
102
        """
103
        if block == self._blocks:
4✔
104
            # Hit the end of the data structure.
105
            return
4✔
106

107
        shift = self._shift_index[block]
4✔
108
        self._max_block_used = max(self._max_block_used, block)
4✔
109

110
        # Update the rolling data store, and accumulate
111
        self._shift[block, shift, :] = a, b
4✔
112
        self._accumulator[block, :] += a, b
4✔
113
        self._shift_not_null[block, shift] = True
4✔
114
        self._count_accumulated[block] += 1
4✔
115

116
        if self._count_accumulated[block] == self._averaging:
4✔
117
            # Hit the coarse graining threshold, the next block can be updated.
118
            self._propagate(
4✔
119
                self._accumulator[block, 0] / self._averaging,
120
                self._accumulator[block, 1] / self._averaging,
121
                block + 1,
122
            )
123
            # Reset the accumulator at this block level.
124
            self._accumulator[block, :] = 0.0
4✔
125
            self._count_accumulated[block] = 0
4✔
126

127
        # Update the correlation.
128
        i = self._shift_index[block]
4✔
129
        if block == 0:
4✔
130
            # Need to multiply by all in this block (full resolution).
131
            j = i
4✔
132
            for point in range(self._points):
4✔
133
                if self._shifts_valid(block, i, j):
4✔
134
                    # Correlate at this lag.
135
                    self._correlation[block, point] += (
4✔
136
                        self._shift[block, i, 0] * self._shift[block, j, 1]
137
                    )
138
                    self._count_correlated[block, point] += 1
4✔
139
                j -= 1
4✔
140
                if j < 0:
4✔
141
                    # Wrap to start of rolling data store.
142
                    j += self._points
4✔
143
        else:
144
            # Only need to update after points/averaging.
145
            # The previous block already accounts for those points.
146
            for point in range(self._min_dist, self._points):
×
147
                if j < 0:
×
148
                    j = j + self._points
×
149
                if self._shifts_valid(block, i, j):
×
150
                    # Correlate at this lag.
UNCOV
151
                    self._correlation[block, point] += (
×
152
                        self._shift[block, i, 0] * self._shift[block, j, 1]
153
                    )
NEW
154
                    self._count_correlated[block, point] += 1
×
155
                j = j - 1
×
156
        # Update the rolling data store.
157
        self._shift_index[block] = (self._shift_index[block] + 1) % self._points
4✔
158

159
    def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool:
4✔
160
        """
161
        Return True if the shift registers have data.
162

163
        Parameters
164
        ----------
165
        block : int
166
            Block to check the shift register of.
167
        p_i : int
168
            Index i in the shift (left correland).
169
        p_j : int
170
            Index j in the shift (right correland).
171

172
        Returns
173
        -------
174
        bool
175
            Whether the shift indices have data.
176
        """
177
        return self._shift_not_null[block, p_i] and self._shift_not_null[block, p_j]
4✔
178

179
    def get_lags(self) -> Iterable[float]:
4✔
180
        """
181
        Obtain the correlation lag times.
182

183
        Returns
184
        -------
185
        Iterable[float]
186
            The correlation lag times.
187
        """
188
        lags = np.zeros(self._points * self._blocks)
4✔
189

190
        lag = 0
4✔
191
        for i in range(self._points):
4✔
192
            if self._count_correlated[0, i] > 0:
4✔
193
                # Data has been correlated, at full resolution.
194
                lags[lag] = i
4✔
195
                lag += 1
4✔
196
        for k in range(1, self._max_block_used):
4✔
197
            for i in range(self._min_dist, self._points):
×
198
                if self._count_correlated[k, i] > 0:
×
199
                    # Data has been correlated at a coarse-grained level.
200
                    lags[lag] = float(i) * float(self._averaging) ** k
×
201
                    lag += 1
×
202
        return lags[0:lag]
4✔
203

204
    def get_value(self) -> Iterable[float]:
4✔
205
        """
206
        Obtain the correlation value.
207

208
        Returns
209
        -------
210
        Iterable[float]
211
            The correlation values <a(t)b(t+t')>.
212
        """
213
        correlation = np.zeros(self._points * self._blocks)
4✔
214

215
        lag = 0
4✔
216
        for i in range(self._points):
4✔
217
            if self._count_correlated[0, i] > 0:
4✔
218
                # Data has been correlated at full resolution.
219
                correlation[lag] = (
4✔
220
                    self._correlation[0, i] / self._count_correlated[0, i]
221
                )
222
                lag += 1
4✔
223
        for k in range(1, self._max_block_used):
4✔
224
            for i in range(self._min_dist, self._points):
×
225
                # Indices less than points/averaging accounted in the previous block.
226
                if self._count_correlated[k, i] > 0:
×
227
                    # Data has been correlated at a coarse-grained level.
UNCOV
228
                    correlation[lag] = (
×
229
                        self._correlation[k, i] / self._count_correlated[k, i]
230
                    )
231
                    lag += 1
×
232
        return correlation[0:lag]
4✔
233

234

235
class Correlation:
4✔
236
    """
237
    Represents a user correlation, <ab>.
238

239
    Parameters
240
    ----------
241
    a : Observable
242
        Observable for a.
243
    b : Observable
244
        Observable for b.
245
    name : str
246
        Name of correlation.
247
    blocks : int
248
        Number of correlation blocks.
249
    points : int
250
        Number of points per block.
251
    averaging : int
252
        Averaging window per block level.
253
    update_frequency : int
254
        Frequency to update the correlation, md steps.
255
    """
256

257
    def __init__(
4✔
258
        self,
259
        *,
260
        a: Observable,
261
        b: Observable,
262
        name: str,
263
        blocks: int,
264
        points: int,
265
        averaging: int,
266
        update_frequency: int,
267
    ) -> None:
268
        """
269
        Initialise a correlation.
270

271
        Parameters
272
        ----------
273
        a : Observable
274
            Observable for a.
275
        b : Observable
276
            Observable for b.
277
        name : str
278
            Name of correlation.
279
        blocks : int
280
            Number of correlation blocks.
281
        points : int
282
            Number of points per block.
283
        averaging : int
284
            Averaging window per block level.
285
        update_frequency : int
286
            Frequency to update the correlation, md steps.
287
        """
288
        self.name = name
4✔
289
        self.blocks = blocks
4✔
290
        self.points = points
4✔
291
        self.averaging = averaging
4✔
292
        self._get_a = a
4✔
293
        self._get_b = b
4✔
294

295
        self._correlators = None
4✔
296
        self._update_frequency = update_frequency
4✔
297

298
    @property
4✔
299
    def update_frequency(self) -> int:
4✔
300
        """
301
        Get update frequency.
302

303
        Returns
304
        -------
305
        int
306
            Correlation update frequency.
307
        """
308
        return self._update_frequency
4✔
309

310
    def update(self, atoms: Atoms) -> None:
4✔
311
        """
312
        Update a correlation.
313

314
        Parameters
315
        ----------
316
        atoms : Atoms
317
            Atoms object to observe values from.
318
        """
319
        # All pairs of data to be correlated.
320
        value_pairs = zip(self._get_a(atoms).flatten(), self._get_b(atoms).flatten())
4✔
321
        if self._correlators is None:
4✔
322
            # Initialise correlators automatically.
323
            self._correlators = [
4✔
324
                Correlator(
325
                    blocks=self.blocks, points=self.points, averaging=self.averaging
326
                )
327
                for _ in range(len(self._get_a(atoms).flatten()))
328
            ]
329
        for corr, values in zip(self._correlators, value_pairs):
4✔
330
            corr.update(*values)
4✔
331

332
    def get(self) -> tuple[Iterable[float], Iterable[float]]:
4✔
333
        """
334
        Get the correlation value and lags, averaging over atoms if applicable.
335

336
        Returns
337
        -------
338
        correlation : Iterable[float]
339
            The correlation values <a(t)b(t+t')>.
340
        lags : Iterable[float]]
341
            The correlation lag times t'.
342
        """
343
        if self._correlators:
4✔
344
            lags = self._correlators[0].get_lags()
4✔
345
            return np.mean([cor.get_value() for cor in self._correlators], axis=0), lags
4✔
346
        return [], []
×
347

348
    def __str__(self) -> str:
4✔
349
        """
350
        Return string representation of correlation.
351

352
        Returns
353
        -------
354
        str
355
            String representation.
356
        """
357
        return self.name
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc