• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

WashU-Astroparticle-Lab / straxion / 16531916059

25 Jul 2025 09:26PM UTC coverage: 57.563% (+15.2%) from 42.405%
16531916059

Pull #15

github

web-flow
Merge 7cbc6623d into f4922cd70
Pull Request #15: Technical robustness hits.py tests

3 of 5 new or added lines in 1 file covered. (60.0%)

1 existing line in 1 file now uncovered.

274 of 476 relevant lines covered (57.56%)

0.58 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.12
/straxion/plugins/hits.py
1
import strax
1✔
2
import numpy as np
1✔
3
import warnings
1✔
4
from straxion.utils import (
1✔
5
    DATA_DTYPE,
6
    INDEX_DTYPE,
7
    SECOND_TO_NANOSECOND,
8
    HIT_WINDOW_LENGTH_LEFT,
9
    HIT_WINDOW_LENGTH_RIGHT,
10
    base_waveform_dtype,
11
)
12

13
export, __all__ = strax.exporter()
1✔
14

15

16
@export
1✔
17
@strax.takes_config(
1✔
18
    strax.Option(
19
        "record_length",
20
        default=5_000_000,
21
        track=False,  # Not tracking record length, but we will have to check if it is as promised
22
        type=int,
23
        help=(
24
            "Number of samples in each dataset."
25
            "We assumed that each sample is equally spaced in time, with interval 1/fs."
26
            "It should not go beyond a billion so that numpy can still handle."
27
        ),
28
    ),
29
    strax.Option(
30
        "fs",
31
        default=50_000,
32
        track=True,
33
        type=int,
34
        help="Sampling frequency (assumed the same for all channels) in unit of Hz",
35
    ),
36
    strax.Option(
37
        "hit_thresholds_sigma",
38
        default=[3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
39
        track=True,
40
        type=list,
41
        help="Threshold for hit finding in units of sigma of standard deviation of the noise.",
42
    ),
43
    strax.Option(
44
        "noisy_channel_signal_std_multipliers",
45
        default=[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
46
        track=True,
47
        type=list,
48
        help=(
49
            "If the signal standard deviation above this threshold times of signal absolute "
50
            "mean, the signal is considered noisy and the hit threshold is increased."
51
        ),
52
    ),
53
    strax.Option(
54
        "min_pulse_widths",
55
        default=[20, 20, 20, 20, 20, 20, 20, 20, 20, 20],
56
        track=True,
57
        type=list,
58
        help=(
59
            "Minimum pulse width in unit of samples. If the pulse width is below this "
60
            "threshold, the hit is not considered a new hit."
61
        ),
62
    ),
63
    strax.Option(
64
        "hit_convolved_inspection_window_length",
65
        default=60,
66
        track=True,
67
        type=int,
68
        help=(
69
            "Length of the convolved hit inspection window (to find maximum and minimum) "
70
            "in unit of samples."
71
        ),
72
    ),
73
    strax.Option(
74
        "hit_extended_inspection_window_length",
75
        default=100,
76
        track=True,
77
        type=int,
78
        help=(
79
            "Length of the extended convolved hit inspection window (to find maximum and minimum) "
80
            "in unit of samples."
81
        ),
82
    ),
83
    strax.Option(
84
        "hit_moving_average_inspection_window_length",
85
        default=40,
86
        track=True,
87
        type=int,
88
        help=(
89
            "Length of the moving averaged hit inspection window (to find maximum and minimum) "
90
            "in unit of samples."
91
        ),
92
    ),
93
)
94
class Hits(strax.Plugin):
1✔
95
    """Find and characterize hits in processed phase angle data.
96

97
    This plugin identifies significant signal excursions (hits) in processed phase angle
98
    data and extracts their characteristics including amplitude, timing, and waveform
99
    data. The hit-finding algorithm uses adaptive thresholds based on signal statistics
100
    and applies various filtering criteria to distinguish real hits from noise.
101

102
    Processing workflow:
103
    1. Calculate adaptive hit thresholds based on signal statistics for each channel.
104
    2. Identify hit candidates using threshold crossing and minimum width criteria.
105
    3. Calculate hit characteristics (amplitude, timing, alignment point).
106
    4. Extract and align hit waveforms for further analysis.
107

108
    Provides:
109
    - hits: Characterized hits with waveform data and timing information.
110

111
    """
112

113
    __version__ = "0.0.0"
1✔
114

115
    # Inherited from straxen. Not optimized outside XENONnT.
116
    rechunk_on_save = False
1✔
117
    compressor = "zstd"
1✔
118
    chunk_target_size_mb = 2000
1✔
119
    rechunk_on_load = True
1✔
120
    chunk_source_size_mb = 100
1✔
121

122
    depends_on = ["records"]
1✔
123
    provides = "hits"
1✔
124
    data_kind = "hits"
1✔
125
    save_when = strax.SaveWhen.ALWAYS
1✔
126

127
    def setup(self):
1✔
128
        self.hit_waveform_length = HIT_WINDOW_LENGTH_LEFT + HIT_WINDOW_LENGTH_RIGHT
1✔
129
        self.hit_window_length_left = HIT_WINDOW_LENGTH_LEFT
1✔
130
        self.hit_window_length_right = HIT_WINDOW_LENGTH_RIGHT
1✔
131

132
        self.hit_thresholds_sigma = np.array(self.config["hit_thresholds_sigma"])
1✔
133
        self.noisy_channel_signal_std_multipliers = np.array(
1✔
134
            self.config["noisy_channel_signal_std_multipliers"]
135
        )
136
        self.hit_ma_inspection_window_length = self.config[
1✔
137
            "hit_moving_average_inspection_window_length"
138
        ]
139
        self.hit_convolved_inspection_window_length = self.config[
1✔
140
            "hit_convolved_inspection_window_length"
141
        ]
142
        self.hit_extended_inspection_window_length = self.config[
1✔
143
            "hit_extended_inspection_window_length"
144
        ]
145

146
        self.record_length = self.config["record_length"]
1✔
147
        self.dt = 1 / self.config["fs"] * SECOND_TO_NANOSECOND
1✔
148

149
        self._check_hit_parameters()
1✔
150

151
    def _check_hit_parameters(self):
1✔
152
        """Check for potentially problematic parameters and issue warnings."""
153
        if self.hit_ma_inspection_window_length > self.hit_waveform_length:
1✔
154
            warnings.warn(
×
155
                "The hit-waveform recording window might be too short to save enough information: "
156
                f"hit_ma_inspection_window_length={self.hit_ma_inspection_window_length} "
157
                f"is larger than hit_waveform_length={self.hit_waveform_length}."
158
            )
159
        if self.hit_convolved_inspection_window_length > self.hit_waveform_length:
1✔
160
            warnings.warn(
×
161
                "The hit-waveform recording window might be too short to save enough information: "
162
                "hit_convolved_inspection_window_length="
163
                f"{self.hit_convolved_inspection_window_length} "
164
                f"is larger than hit_waveform_length={self.hit_waveform_length}."
165
            )
166
        if self.hit_extended_inspection_window_length > self.hit_waveform_length:
1✔
167
            warnings.warn(
×
168
                "The hit-waveform recording window might be too short to save enough information: "
169
                "hit_extended_inspection_window_length="
170
                f"{self.hit_extended_inspection_window_length} "
171
                f"is larger than hit_waveform_length={self.hit_waveform_length}."
172
            )
173

174
    def infer_dtype(self):
1✔
175
        self.hit_waveform_length = HIT_WINDOW_LENGTH_LEFT + HIT_WINDOW_LENGTH_RIGHT
1✔
176

177
        dtype = base_waveform_dtype()
1✔
178
        dtype.append(
1✔
179
            (
180
                (
181
                    (
182
                        "Width of the hit waveform (length above the hit threshold) "
183
                        "in unit of samples.",
184
                    ),
185
                    "width",
186
                ),
187
                INDEX_DTYPE,
188
            )
189
        )
190
        dtype.append(
1✔
191
            (
192
                (
193
                    (
194
                        "Hit waveform of phase angle (theta) only after baseline corrections, "
195
                        "aligned at the maximum of the moving averaged waveform."
196
                    ),
197
                    "data_theta",
198
                ),
199
                DATA_DTYPE,
200
                self.hit_waveform_length,
201
            )
202
        )
203
        dtype.append(
1✔
204
            (
205
                (
206
                    (
207
                        "Hit waveform of phase angle (theta) further smoothed by moving average, "
208
                        "aligned at the maximum of the moving averaged waveform."
209
                    ),
210
                    "data_theta_moving_average",
211
                ),
212
                DATA_DTYPE,
213
                self.hit_waveform_length,
214
            )
215
        )
216
        dtype.append(
1✔
217
            (
218
                (
219
                    (
220
                        "Hit waveform of phase angle (theta) further smoothed by pulse kernel, "
221
                        "aligned at the maximum of the moving averaged waveform."
222
                    ),
223
                    "data_theta_convolved",
224
                ),
225
                DATA_DTYPE,
226
                self.hit_waveform_length,
227
            )
228
        )
229
        dtype.append(
1✔
230
            (
231
                (
232
                    "Hit finding threshold determined by signal statistics in unit of rad.",
233
                    "hit_threshold",
234
                ),
235
                DATA_DTYPE,
236
            )
237
        )
238
        dtype.append(
1✔
239
            (
240
                ("Index of alignment point (the maximum) in the records", "aligned_at_records_i"),
241
                INDEX_DTYPE,
242
            )
243
        )
244
        dtype.append(
1✔
245
            (
246
                (
247
                    (
248
                        "Maximum amplitude of the hit waveform (within the hit window) "
249
                        "in unit of rad.",
250
                    ),
251
                    "amplitude_max",
252
                ),
253
                DATA_DTYPE,
254
            )
255
        )
256
        dtype.append(
1✔
257
            (
258
                (
259
                    (
260
                        "Minimum amplitude of the hit waveform (within the hit window) "
261
                        "in unit of rad.",
262
                    ),
263
                    "amplitude_min",
264
                ),
265
                DATA_DTYPE,
266
            )
267
        )
268
        dtype.append(
1✔
269
            (
270
                (
271
                    (
272
                        "Maximum amplitude of the hit waveform (within the extended hit window) "
273
                        "in unit of rad.",
274
                    ),
275
                    "amplitude_max_ext",
276
                ),
277
                DATA_DTYPE,
278
            )
279
        )
280
        dtype.append(
1✔
281
            (
282
                (
283
                    (
284
                        "Minimum amplitude of the hit waveform (within the extended hit window) "
285
                        "in unit of rad.",
286
                    ),
287
                    "amplitude_min_ext",
288
                ),
289
                DATA_DTYPE,
290
            )
291
        )
292

293
        return dtype
1✔
294

295
    @staticmethod
1✔
296
    def calculate_hit_threshold(signal, hit_threshold_sigma, noisy_channel_signal_std_multiplier):
1✔
297
        """Calculate hit threshold based on signal statistics.
298

299
        Args:
300
            signal (np.ndarray): The signal array to analyze.
301
            hit_threshold_sigma (float): Threshold multiplier in units of sigma.
302
            noisy_channel_signal_std_multiplier (float): Multiplier to detect noisy channels.
303

304
        Returns:
305
            float: The calculated hit threshold.
306

307
        """
308
        signal_mean = np.mean(signal)
1✔
309
        signal_abs_mean = np.mean(np.abs(signal))
1✔
310
        signal_std = np.std(signal)
1✔
311

312
        # The naive hit threshold is a multiple of the standard deviation of the signal.
313
        hit_threshold = signal_mean + hit_threshold_sigma * signal_std
1✔
314

315
        # If the signal is noisy, the baseline might be too high.
316
        if signal_std > noisy_channel_signal_std_multiplier * signal_abs_mean:
1✔
317
            # We will use the quiet part of the signal to redefine a lowered hit threshold.
318
            quiet_mask = signal < hit_threshold
×
319
            hit_threshold = signal_mean + hit_threshold_sigma * np.std(signal[quiet_mask])
×
320

321
        return hit_threshold
1✔
322

323
    def compute(self, records):
1✔
324
        """Process records to find and characterize hits.
325

326
        Args:
327
            records: Array of processed records containing signal data.
328

329
        Returns:
330
            np.ndarray: Array of hits with waveform data and characteristics.
331

332
        """
333
        results = []
1✔
334

335
        for r in records:
1✔
336
            hits = self._process_single_record(r)
1✔
337
            if hits is not None and len(hits) > 0:
1✔
338
                results.append(hits)
×
339

340
        # Sort hits by time.
341
        if not results:
1✔
342
            return np.zeros(0, dtype=self.infer_dtype())
1✔
343

344
        results = np.concatenate(results)
×
345
        results = results[np.argsort(results["time"])]
×
346

347
        return results
×
348

349
    def _process_single_record(self, record):
1✔
350
        """Process a single record to find hits.
351

352
        Args:
353
            record: Single record containing signal data.
354

355
        Returns:
356
            np.ndarray or None: Array of hits found in the record, or None if no hits.
357

358
        """
359
        ch = int(record["channel"])
1✔
360
        signal = record["data_theta_convolved"]
1✔
361
        signal_ma = record["data_theta_moving_average"]
1✔
362
        signal_raw = record["data_theta"]
1✔
363
        min_pulse_width = self.config["min_pulse_widths"][ch]
1✔
364

365
        # Calculate hit threshold and find hit candidates
366
        hit_threshold = self.calculate_hit_threshold(
1✔
367
            signal, self.hit_thresholds_sigma[ch], self.noisy_channel_signal_std_multipliers[ch]
368
        )
369

370
        hit_candidates = self._find_hit_candidates(signal, hit_threshold, min_pulse_width)
1✔
371
        if len(hit_candidates) == 0:
1✔
372
            return None
×
373

374
        # Process each hit candidate
375
        hits = self._process_hit_candidates(
1✔
376
            hit_candidates, record, signal, signal_ma, signal_raw, hit_threshold, ch
377
        )
378

379
        return hits
1✔
380

381
    def _find_hit_candidates(self, signal, hit_threshold, min_pulse_width):
1✔
382
        """Find potential hit candidates based on threshold crossing.
383

384
        Args:
385
            signal: The convolved signal array.
386
            hit_threshold: Threshold value for hit detection.
387
            min_pulse_width: Minimum width required for a valid hit.
388

389
        Returns:
390
            tuple: (hit_start_indices, hit_widths) for valid hits.
391

392
        """
393
        below_threshold_indices = np.where(signal < hit_threshold)[0]
1✔
394
        if len(below_threshold_indices) == 0:
1✔
395
            return [], []
1✔
396

397
        # Find the start of the hits
398
        hits_width = np.diff(below_threshold_indices, prepend=1)
1✔
399

400
        # Filter by minimum pulse width
401
        valid_mask = hits_width >= min_pulse_width
1✔
402
        hit_end_indices = below_threshold_indices[valid_mask]
1✔
403
        hit_widths = hits_width[valid_mask]
1✔
404
        hit_start_indices = hit_end_indices - hit_widths
1✔
405

406
        return hit_start_indices, hit_widths
1✔
407

408
    def _process_hit_candidates(
1✔
409
        self, hit_candidates, record, signal, signal_ma, signal_raw, hit_threshold, channel
410
    ):
411
        """Process hit candidates to extract hit characteristics and waveforms.
412

413
        Args:
414
            hit_candidates: Tuple of (hit_start_indices, hit_widths).
415
            record: The original record.
416
            signal: The convolved signal array.
417
            signal_ma: The moving average signal array.
418
            signal_raw: The raw signal array.
419
            hit_threshold: The hit threshold value.
420
            channel: The channel number.
421

422
        Returns:
423
            np.ndarray: Array of processed hits.
424

425
        """
426
        hit_start_indices, hit_widths = hit_candidates
1✔
427

428
        hits = np.zeros(len(hit_start_indices), dtype=self.infer_dtype())
1✔
429
        hits["width"] = hit_widths
1✔
430

431
        for i, h_start_i in enumerate(hit_start_indices):
1✔
432
            self._process_single_hit(
×
433
                hits[i], h_start_i, record, signal, signal_ma, signal_raw, hit_threshold, channel
434
            )
435

436
        return hits
1✔
437

438
    def _process_single_hit(
1✔
439
        self, hit, hit_start_i, record, signal, signal_ma, signal_raw, hit_threshold, channel
440
    ):
441
        """Process a single hit to extract its characteristics and waveform.
442

443
        Args:
444
            hit: The hit array element to populate.
445
            hit_start_i: Start index of the hit.
446
            record: The original record.
447
            signal: The convolved signal array.
448
            signal_ma: The moving average signal array.
449
            signal_raw: The raw signal array.
450
            hit_threshold: The hit threshold value.
451
            channel: The channel number.
452

453
        """
454
        # Set basic hit properties
455
        hit["hit_threshold"] = hit_threshold
×
456
        hit["channel"] = channel
×
457
        hit["dt"] = self.dt
×
458

459
        # Calculate amplitude characteristics
460
        self._calculate_hit_amplitudes(hit, hit_start_i, signal)
×
461

462
        # Find alignment point and extract waveforms
463
        aligned_index = self._find_alignment_point(hit_start_i, signal, signal_ma)
×
464
        hit["aligned_at_records_i"] = aligned_index
×
465

466
        # Extract and align waveforms
467
        self._extract_hit_waveforms(hit, aligned_index, record, signal_raw, signal_ma, signal)
×
468

469
    def _calculate_hit_amplitudes(self, hit, hit_start_i, signal):
1✔
470
        """Calculate amplitude characteristics for a hit.
471

472
        Args:
473
            hit: The hit array element to populate.
474
            hit_start_i: Start index of the hit.
475
            signal: The convolved signal array.
476

477
        """
478
        # Find the maximum and minimum of the hit in the inspection windows
479
        hit_inspection_waveform = signal[
×
480
            hit_start_i : min(
481
                hit_start_i + self.hit_convolved_inspection_window_length,
482
                self.record_length,
483
            )
484
        ]
485
        hit_extended_inspection_waveform = signal[
×
486
            hit_start_i : min(
487
                hit_start_i + self.hit_extended_inspection_window_length,
488
                self.record_length,
489
            )
490
        ]
491

492
        hit["amplitude_max"] = np.max(hit_inspection_waveform)
×
493
        hit["amplitude_min"] = np.min(hit_inspection_waveform)
×
494
        hit["amplitude_max_ext"] = np.max(hit_extended_inspection_waveform)
×
495
        hit["amplitude_min_ext"] = np.min(hit_extended_inspection_waveform)
×
496

497
    def _find_alignment_point(self, hit_start_i, signal, signal_ma):
1✔
498
        """Find the alignment point for waveform extraction.
499

500
        Args:
501
            hit_start_i: Start index of the hit.
502
            signal: The convolved signal array.
503
            signal_ma: The moving average signal array.
504

505
        Returns:
506
            int: Index of the alignment point.
507

508
        """
509
        # Index of kernel-convolved signal in records
510
        hit_inspection_waveform = signal[
×
511
            hit_start_i : min(
512
                hit_start_i + self.hit_convolved_inspection_window_length,
513
                self.record_length,
514
            )
515
        ]
516
        hit_max_i = np.argmax(hit_inspection_waveform) + hit_start_i
×
517

518
        # Align waveforms at the maximum of the moving averaged signal
519
        # Search the maximum in the moving averaged signal within the inspection window
520
        search_start = max(hit_max_i - self.hit_ma_inspection_window_length, 0)
×
521
        search_end = min(hit_max_i + self.hit_ma_inspection_window_length, self.record_length)
×
522

523
        argmax_ma_i = np.argmax(signal_ma[search_start:search_end]) + search_start
×
524

525
        return argmax_ma_i
×
526

527
    def _extract_hit_waveforms(self, hit, aligned_index, record, signal_raw, signal_ma, signal):
1✔
528
        """Extract and align hit waveforms.
529

530
        Args:
531
            hit: The hit array element to populate.
532
            aligned_index: Index of the alignment point.
533
            record: The original record.
534
            signal_raw: The raw signal array.
535
            signal_ma: The moving average signal array.
536
            signal: The convolved signal array.
537

538
        """
539
        # Calculate valid sample ranges
540
        n_right_valid_samples = min(
×
541
            self.record_length - aligned_index, self.hit_window_length_right
542
        )
543
        n_left_valid_samples = min(aligned_index, self.hit_window_length_left)
×
544

545
        # Calculate waveform extraction boundaries
546
        hit_wf_start_i = max(aligned_index - self.hit_window_length_left, 0)
×
547
        hit_wf_end_i = min(aligned_index + self.hit_window_length_right, self.record_length)
×
548

549
        # Set timing information
NEW
550
        hit["time"] = record["time"] + hit_wf_start_i * self.dt
×
NEW
551
        hit["endtime"] = record["time"] + hit_wf_end_i * self.dt
×
UNCOV
552
        hit["length"] = hit_wf_end_i - hit_wf_start_i
×
553

554
        # Calculate target indices in the hit waveform arrays
555
        target_start = self.hit_window_length_left - n_left_valid_samples
×
556
        target_end = self.hit_window_length_left + n_right_valid_samples
×
557

558
        # Extract waveforms
559
        hit["data_theta"][target_start:target_end] = signal_raw[hit_wf_start_i:hit_wf_end_i]
×
560
        hit["data_theta_moving_average"][target_start:target_end] = signal_ma[
×
561
            hit_wf_start_i:hit_wf_end_i
562
        ]
563
        hit["data_theta_convolved"][target_start:target_end] = signal[hit_wf_start_i:hit_wf_end_i]
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc