• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WashU-Astroparticle-Lab / straxion / 17967756706

24 Sep 2025 05:58AM UTC coverage: 57.59% (+0.9%) from 56.703%
17967756706

Pull #47

github

web-flow
Merge ce42fdf8e into 74115fa76
Pull Request #47: Simpler hit finding

3 of 7 new or added lines in 2 files covered. (42.86%)

3 existing lines in 2 files now uncovered.

626 of 1087 relevant lines covered (57.59%)

1.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

46.67
/straxion/plugins/hit_classification.py
1
import strax
2✔
2
import numpy as np
2✔
3
from straxion.utils import (
2✔
4
    TIME_DTYPE,
5
    CHANNEL_DTYPE,
6
    SECOND_TO_NANOSECOND,
7
    HIT_WINDOW_LENGTH_LEFT,
8
    DATA_DTYPE,
9
)
10

11
export, __all__ = strax.exporter()
2✔
12

13

14
@export
2✔
15
@strax.takes_config(
2✔
16
    strax.Option(
17
        "max_spike_coincidence",
18
        type=int,
19
        default=1,
20
        help=("Maximum number of spikes that can be coincident with a photon candidate hit."),
21
    ),
22
    strax.Option(
23
        "spike_coincidence_window",
24
        type=float,
25
        default=0.131e-3,
26
        help=("Window length for checking spike coincidence, in unit of seconds."),
27
    ),
28
    strax.Option(
29
        "spike_threshold_dx",
30
        default=None,
31
        track=True,
32
        type=float,
33
        help="Threshold for spike finding in units of dx=df/f0.",
34
    ),
35
    strax.Option(
36
        "spike_thresholds_sigma",
37
        default=[3.0 for _ in range(41)],
38
        track=True,
39
        type=list,
40
        help=(
41
            "Threshold for spike finding in units of sigma of standard deviation of the noise. "
42
            "If None, the spike threshold will be calculated based on the signal statistics."
43
        ),
44
    ),
45
    strax.Option(
46
        "fs",
47
        default=38_000,
48
        track=True,
49
        type=int,
50
        help="Sampling frequency (assumed the same for all channels) in unit of Hz",
51
    ),
52
    strax.Option(
53
        "symmetric_spike_inspection_window_length",
54
        type=int,
55
        default=25,
56
        help=(
57
            "Length of the inspection window for identifying symmetric spikes, "
58
            "in unit of samples."
59
        ),
60
    ),
61
    strax.Option(
62
        "symmetric_spike_min_slope",
63
        type=list,
64
        default=[0.0 for _ in range(41)],
65
        help=(
66
            "Minimum rise edge slope of the moving averaged signal for identifying a physical hit "
67
            "against symmetric spikes, in unit of dx/second."
68
        ),
69
    ),
70
)
71
class SpikeCoincidence(strax.Plugin):
2✔
72
    """Classify hits into different types based on their coincidence with spikes."""
73

74
    __version__ = "0.1.0"
2✔
75

76
    depends_on = ("hits", "records")
2✔
77
    provides = "hit_classification"
2✔
78
    data_kind = "hits"
2✔
79
    save_when = strax.SaveWhen.ALWAYS
2✔
80

81
    def infer_dtype(self):
2✔
82
        base_dtype = [
2✔
83
            (("Start time since unix epoch [ns]", "time"), TIME_DTYPE),
84
            (("Exclusive end time since unix epoch [ns]", "endtime"), TIME_DTYPE),
85
            (("Channel number defined by channel_map", "channel"), CHANNEL_DTYPE),
86
        ]
87

88
        hit_id_dtype = [
2✔
89
            (("Is in coincidence with spikes", "is_coincident_with_spikes"), bool),
90
            (("Is symmetric spike hit", "is_symmetric_spike"), bool),
91
            (("Photon candidate hit", "is_photon_candidate"), bool),
92
        ]
93

94
        hit_feature_dtype = [
2✔
95
            (
96
                (
97
                    "Rise edge slope of the hit waveform, in unit of dx/second",
98
                    "rise_edge_slope",
99
                ),
100
                DATA_DTYPE,
101
            ),
102
            (
103
                ("Number of channels with spikes coinciding with the hit", "n_spikes_coinciding"),
104
                int,
105
            ),
106
        ]
107

108
        return base_dtype + hit_id_dtype + hit_feature_dtype
2✔
109

110
    def setup(self):
2✔
111
        self.spike_coincidence_window = int(
2✔
112
            round(self.config["spike_coincidence_window"] * self.config["fs"])
113
        )
114
        self.spike_threshold_dx = self.config["spike_threshold_dx"]
2✔
115
        self.ss_min_slope = np.array(self.config["symmetric_spike_min_slope"])
2✔
116
        self.ss_window = self.config["symmetric_spike_inspection_window_length"]
2✔
117
        self.max_spike_coincidence = self.config["max_spike_coincidence"]
2✔
118
        self.dt_exact = 1 / self.config["fs"] * SECOND_TO_NANOSECOND
2✔
119

120
    @staticmethod
2✔
121
    def calculate_spike_threshold(signal, spike_threshold_sigma):
2✔
122
        """Calculate spike threshold based on signal statistics.
123

124
        Args:
125
            signal (np.ndarray): The signal array to analyze.
126
            spike_threshold_sigma (float): Threshold multiplier in units of sigma.
127

128
        Returns:
129
            float: The calculated spike threshold.
130

131
        """
132
        signal_mean = np.mean(signal, axis=1)
×
133
        signal_std = np.std(signal, axis=1)
×
134

135
        # The naive spike threshold is a multiple of the standard deviation of the signal.
136
        spike_threshold = signal_mean + spike_threshold_sigma * signal_std
×
137

UNCOV
138
        return spike_threshold
×
139

140
    def determine_spike_threshold(self, records):
2✔
141
        """Determine the spike threshold based on the provided configuration.
142
        You can either provide hit_threshold_dx or hit_thresholds_sigma.
143
        You cannot provide both.
144
        """
NEW
145
        if self.spike_threshold_dx is None and self.spike_thresholds_sigma is not None:
×
146
            # If spike_thresholds_sigma are single values,
147
            # we need to convert them to arrays.
148
            if isinstance(self.spike_thresholds_sigma, float):
×
149
                self.spike_thresholds_sigma = np.full(
×
150
                    len(records["channel"]), self.spike_thresholds_sigma
151
                )
152
            else:
153
                self.spike_thresholds_sigma = np.array(self.spike_thresholds_sigma)
×
154
            # Calculate spike threshold and find spike candidates
155
            self.spike_threshold_dx = self.calculate_spike_threshold(
×
156
                records["data_dx_convolved"],
157
                self.spike_thresholds_sigma[records["channel"]],
158
            )
NEW
159
        elif self.spike_threshold_dx is not None and self.spike_thresholds_sigma is None:
×
160
            # If spike_threshold_dx is a single value, we need to convert it to an array.
161
            if isinstance(self.spike_threshold_dx, float):
×
162
                self.spike_threshold_dx = np.full(len(records["channel"]), self.spike_threshold_dx)
×
163
            else:
164
                self.spike_threshold_dx = np.array(self.spike_threshold_dx)
×
165
        else:
166
            raise ValueError(
×
167
                "Either spike_threshold_dx or spike_thresholds_sigma "
168
                "must be provided. You cannot provide both."
169
            )
170

171
    def _get_ss_window(self, hits, window_start_offset, window_end_offset):
2✔
172
        """Extract windows from all hits using vectorized operations.
173

174
        Args:
175
            hits: Array of hits containing the data
176
            window_start_offset: Offset from climax_shift for window start
177
            window_end_offset: Offset from climax_shift for window end
178

179
        Returns:
180
            Array of extracted windows with shape (n_hits, window_length)
181
        """
182
        # The inspected window ends at the maximum of the moving averaged signal.
183
        climax_shift = (
×
184
            hits["amplitude_moving_average_max_record_i"] - hits["amplitude_convolved_max_record_i"]
185
        )
186

187
        # Calculate start indices for all hits at once
188
        start_indices = window_start_offset + climax_shift
×
189

190
        # Extract windows using vectorized operations
191
        # Create index arrays for all hits
192
        n_hits = len(hits)
×
193
        window_indices = np.arange(self.ss_window)[None, :]  # Shape: (1, ss_window)
×
194
        start_indices = start_indices[:, None]  # Shape: (n_hits, 1)
×
195

196
        # Broadcast to get all indices for all hits
197
        all_indices = start_indices + window_indices  # Shape: (n_hits, ss_window)
×
198

199
        # Use advanced indexing to extract all windows at once
200
        return hits["data_dx_moving_average"][np.arange(n_hits)[:, None], all_indices]
×
201

202
    def compute_rise_edge_slope(self, hits, hit_classification):
2✔
203
        """Compute the rise edge slope of the moving averaged signal."""
204

205
        # Temporary time stamps for the inspected window, in unit of seconds.
206
        dt = self.dt_exact
×
207
        times = np.arange(self.ss_window) * dt / SECOND_TO_NANOSECOND
×
208

209
        inspected_wfs = self._get_ss_window(
×
210
            hits, HIT_WINDOW_LENGTH_LEFT - self.ss_window, HIT_WINDOW_LENGTH_LEFT
211
        )
212
        # Fit a linear model to the inspected window.
213
        hit_classification["rise_edge_slope"] = np.polyfit(times, inspected_wfs.T, 1)[0]
×
214

215
    def is_symmetric_spike_hit(self, hits, hit_classification):
2✔
216
        """Identify symmetric spike hits."""
217
        hit_classification["is_symmetric_spike"] = (
×
218
            hit_classification["rise_edge_slope"] < self.ss_min_slope[hits["channel"]]
219
        )
220

221
    def find_spike_coincidence(self, hit_classification, hits, records):
2✔
222
        """Find the spike coincidence of the hit in the convolved signal."""
223
        spike_coincidence = np.zeros(len(hits))
×
224
        for i, hit in enumerate(hits):
×
225
            # Get the index of the hit maximum in the record
226
            hit_climax_i = hit["amplitude_convolved_max_record_i"]
×
227

228
            # Extract windows from all records at once
229
            inspected_wfs = records["data_dx_convolved"][
×
230
                :,
231
                hit_climax_i
232
                - self.spike_coincidence_window : hit_climax_i
233
                + self.spike_coincidence_window,
234
            ]
235

236
            # Count records with spikes above threshold
237
            spike_coincidence[i] = np.sum(
×
238
                np.max(inspected_wfs, axis=1) > self.spike_threshold_dx[records["channel"]]
239
            )
240
        hit_classification["n_spikes_coinciding"] = spike_coincidence
×
241

242
    def compute(self, hits, records):
2✔
243
        self.determine_spike_threshold(records)
×
244

245
        hit_classification = np.zeros(len(hits), dtype=self.infer_dtype())
×
246
        hit_classification["time"] = hits["time"]
×
247
        hit_classification["endtime"] = hits["endtime"]
×
248
        hit_classification["channel"] = hits["channel"]
×
249

250
        self.compute_rise_edge_slope(hits, hit_classification)
×
251
        self.find_spike_coincidence(hit_classification, hits, records)
×
252
        self.is_symmetric_spike_hit(hits, hit_classification)
×
253

254
        hit_classification["is_coincident_with_spikes"] = (
×
255
            hit_classification["n_spikes_coinciding"] > self.max_spike_coincidence
256
        )
257
        hit_classification["is_photon_candidate"] = ~(
×
258
            hit_classification["is_coincident_with_spikes"]
259
            | hit_classification["is_symmetric_spike"]
260
        )
261

262
        return hit_classification
×
263

264

265
@export
2✔
266
@strax.takes_config(
2✔
267
    strax.Option(
268
        "cr_ma_std_coeff",
269
        type=list,
270
        default=[20.0 for _ in range(41)],
271
        help=(
272
            "Coefficients applied to the moving averaged signal's "
273
            "standard deviation for identifying cosmic ray hits."
274
        ),
275
    ),
276
    strax.Option(
277
        "cr_convolved_std_coeff",
278
        type=list,
279
        default=[20.0 for _ in range(41)],
280
        help=(
281
            "Coefficients applied to the convolved signal's "
282
            "standard deviation for identifying cosmic ray hits."
283
        ),
284
    ),
285
    strax.Option(
286
        "cr_min_ma_amplitude",
287
        type=list,
288
        default=[1.0 for _ in range(41)],
289
        help=(
290
            "Minimum amplitude of the moving averaged signal's " "for identifying cosmic ray hits."
291
        ),
292
    ),
293
    strax.Option(
294
        "symmetric_spike_min_slope",
295
        type=list,
296
        default=[75.0 for _ in range(41)],
297
        help=(
298
            "Minimum slope for identifying a physical hit against symmetric spikes, "
299
            "in unit of rad/second."
300
        ),
301
    ),
302
    strax.Option(
303
        "symmetric_spike_inspection_window_length",
304
        type=int,
305
        default=25,
306
        help=(
307
            "Length of the inspection window for identifying symmetric spikes, "
308
            "in unit of samples."
309
        ),
310
    ),
311
)
312
class HitClassification(strax.Plugin):
2✔
313
    """Classify hits into different types based on their characteristics."""
314

315
    __version__ = "0.0.0"
2✔
316

317
    depends_on = "hits"
2✔
318
    provides = "hit_classification"
2✔
319
    data_kind = "hits"
2✔
320
    save_when = strax.SaveWhen.ALWAYS
2✔
321

322
    def setup(self):
2✔
323
        self.cr_ma_std_coeff = np.array(self.config["cr_ma_std_coeff"])
2✔
324
        self.cr_convolved_std_coeff = np.array(self.config["cr_convolved_std_coeff"])
2✔
325
        self.cr_min_ma_amplitude = np.array(self.config["cr_min_ma_amplitude"])
2✔
326
        self.ss_min_slope = np.array(self.config["symmetric_spike_min_slope"])
2✔
327
        self.ss_window = self.config["symmetric_spike_inspection_window_length"]
2✔
328

329
    def infer_dtype(self):
2✔
330
        base_dtype = [
2✔
331
            (("Start time since unix epoch [ns]", "time"), TIME_DTYPE),
332
            (("Exclusive end time since unix epoch [ns]", "endtime"), TIME_DTYPE),
333
            (("Channel number defined by channel_map", "channel"), CHANNEL_DTYPE),
334
        ]
335

336
        hit_id_dtype = [
2✔
337
            (("Is identified as cosmic ray hit", "is_cr"), bool),
338
            (("Is identified as symmetric spike hit", "is_symmetric_spike"), bool),
339
            (("Is unidentified hit", "is_unidentified"), bool),
340
        ]
341

342
        hit_feature_dtype = [
2✔
343
            (
344
                (
345
                    "Rise edge slope of the moving averaged signal, in unit of rad/second",
346
                    "ma_rise_edge_slope",
347
                ),
348
                DATA_DTYPE,
349
            ),
350
        ]
351

352
        return base_dtype + hit_id_dtype + hit_feature_dtype
2✔
353

354
    def compute_ma_rise_edge_slope(self, hits, hit_classification):
2✔
355
        """Compute the rise edge slope of the moving averaged signal."""
356
        assert (
×
357
            len(np.unique(hits["dt"])) == 1
358
        ), "The sampling frequency is not constant!? We found {} unique values: {}".format(
359
            len(np.unique(hits["dt"])), np.unique(hits["dt"])
360
        )
361
        # Temporary time stamps for the inspected window, in unit of seconds.
362
        dt = hits["dt"][0]
×
363
        times = np.arange(self.ss_window) * dt / SECOND_TO_NANOSECOND
×
364

365
        # Extract windows from all hits at once (fully vectorized)
366
        inspected_wfs = hits["data_theta_moving_average"][
×
367
            :, HIT_WINDOW_LENGTH_LEFT - self.ss_window : HIT_WINDOW_LENGTH_LEFT
368
        ]
369
        # Fit a linear model to the inspected window.
370
        hit_classification["ma_rise_edge_slope"] = np.polyfit(times, inspected_wfs.T, 1)[0]
×
371

372
    def is_unidentified_hit(self, hits, hit_classification):
2✔
373
        """Identify unidentified hits.
374

375
        The hit is identified as an unidentified hit if the amplitude of the hit is
376
        less than the threshold.
377

378
        Args:
379
            hits (np.ndarray): Hit array.
380

381
        """
382
        hit_classification["is_unidentified"] = (
×
383
            hits["amplitude_convolved_max"] < hits["hit_threshold"]
384
        )
385

386
    def is_cr_hit(self, hits, hit_classification):
2✔
387
        """Identify cosmic ray hits.
388

389
        The hit is identified as a cosmic ray hit if it satisfies either of the following
390
        conditions, for both convolved and moving averaged signals:
391
        1. The amplitude of the hit is greater than the threshold.
392
        2. The amplitude of the hit is greater than the standard deviation of the signal
393
        multiplied by the coefficient.
394

395
        Args:
396
            hits (np.ndarray): Hit array.
397

398
        Returns:
399
            np.ndarray: Hit array with `is_cr` field.
400

401
        """
402
        mask_convolved = hits["amplitude_convolved_max"] >= hits["hit_threshold"]
×
403
        mask_convolved &= hits["amplitude_convolved_max_ext"] >= (
×
404
            hits["record_convolved_std"] * self.cr_convolved_std_coeff[hits["channel"]]
405
        )
406
        mask_ma = hits["amplitude_ma_max_ext"] >= self.cr_min_ma_amplitude[hits["channel"]]
×
407
        mask_ma |= hits["amplitude_ma_max_ext"] >= (
×
408
            hits["record_ma_mean"] + hits["record_ma_std"] * self.cr_ma_std_coeff[hits["channel"]]
409
        )
410
        hit_classification["is_cr"] = mask_convolved | mask_ma
×
411

412
    def is_symmetric_spike_hit(self, hits, hit_classification):
2✔
413
        self.compute_ma_rise_edge_slope(hits, hit_classification)
×
414
        hit_classification["is_symmetric_spike"] = (
×
415
            hit_classification["ma_rise_edge_slope"] < self.ss_min_slope[hits["channel"]]
416
        )
417

418
    def compute(self, hits):
2✔
419
        hit_classification = np.zeros(len(hits), dtype=self.infer_dtype())
×
420
        hit_classification["time"] = hits["time"]
×
421
        hit_classification["endtime"] = hits["endtime"]
×
422
        hit_classification["channel"] = hits["channel"]
×
423

424
        self.is_unidentified_hit(hits, hit_classification)
×
425
        self.is_cr_hit(hits, hit_classification)
×
426
        self.is_symmetric_spike_hit(hits, hit_classification)
×
427

428
        return hit_classification
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc