• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

XENONnT / straxen / 10286081415

07 Aug 2024 02:17PM UTC coverage: 91.122% (-0.01%) from 91.135%
10286081415

Pull #1404

github

web-flow
Merge 68e283066 into 4d6c2b7ba
Pull Request #1404: Plugins for position reconstruction with conditional normalizing flow

170 of 190 new or added lines in 6 files covered. (89.47%)

3 existing lines in 1 file now uncovered.

9145 of 10036 relevant lines covered (91.12%)

1.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.83
/straxen/plugins/records/records.py
1
from typing import Tuple
2✔
2
from immutabledict import immutabledict
2✔
3
import numba
2✔
4
import numpy as np
2✔
5

6
import strax
2✔
7
import straxen
2✔
8

9
export, __all__ = strax.exporter()
2✔
10
__all__.extend(["NO_PULSE_COUNTS"])
2✔
11

12

13
@export
2✔
14
class PulseProcessing(strax.Plugin):
2✔
15
    """
16
    Split raw_records into:
17
     - (tpc) records
18
     - aqmon_records
19
     - pulse_counts
20

21
    For TPC records, apply basic processing:
22
        1. Flip, baseline, and integrate the waveform
23
        2. Apply software HE veto after high-energy peaks.
24
        3. Find hits, apply linear filter, and zero outside hits.
25

26
    pulse_counts holds some average information for the individual PMT
27
    channels for each chunk of raw_records. This includes e.g.
28
    number of recorded pulses, lone_pulses (pulses which do not
29
    overlap with any other pulse), or mean values of baseline and
30
    baseline rms channel.
31
    """
32

33
    __version__ = "0.2.3"
2✔
34

35
    parallel = "process"
2✔
36
    rechunk_on_save = immutabledict(records=False, veto_regions=True, pulse_counts=True)
2✔
37
    compressor = "zstd"
2✔
38

39
    depends_on = "raw_records"
2✔
40

41
    provides: Tuple[str, ...] = ("records", "veto_regions", "pulse_counts")
2✔
42
    data_kind = {k: k for k in provides}
2✔
43
    save_when = immutabledict(
2✔
44
        records=strax.SaveWhen.TARGET,
45
        veto_regions=strax.SaveWhen.TARGET,
46
        pulse_counts=strax.SaveWhen.ALWAYS,
47
    )
48

49
    hev_gain_model = straxen.URLConfig(
2✔
50
        default=None, infer_type=False, help="PMT gain model used in the software high-energy veto."
51
    )
52

53
    baseline_samples = straxen.URLConfig(
2✔
54
        default=40,
55
        infer_type=False,
56
        help="Number of samples to use at the start of the pulse to determine the baseline",
57
    )
58

59
    # Tail veto options
60
    tail_veto_threshold = straxen.URLConfig(
2✔
61
        default=0,
62
        infer_type=False,
63
        help="Minimum peakarea in PE to trigger tail veto.Set to None, 0 or False to disable veto.",
64
    )
65

66
    tail_veto_duration = straxen.URLConfig(
2✔
67
        default=int(3e6), infer_type=False, help="Time in ns to veto after large peaks"
68
    )
69

70
    tail_veto_resolution = straxen.URLConfig(
2✔
71
        default=int(1e3),
72
        infer_type=False,
73
        help="Time resolution in ns for pass-veto waveform summation",
74
    )
75

76
    tail_veto_pass_fraction = straxen.URLConfig(
2✔
77
        default=0.05, infer_type=False, help="Pass veto if maximum amplitude above max * fraction"
78
    )
79

80
    tail_veto_pass_extend = straxen.URLConfig(
2✔
81
        default=3,
82
        infer_type=False,
83
        help="Extend pass veto by this many samples (tail_veto_resolution!)",
84
    )
85

86
    max_veto_value = straxen.URLConfig(
2✔
87
        default=None,
88
        infer_type=False,
89
        help=(
90
            "Optionally pass a HE peak that exceeds this absolute area. "
91
            "(if performing a hard veto, can keep a few statistics.)"
92
        ),
93
    )
94

95
    # PMT pulse processing options
96
    pmt_pulse_filter = straxen.URLConfig(
2✔
97
        default=None, infer_type=False, help="Linear filter to apply to pulses, will be normalized."
98
    )
99

100
    save_outside_hits = straxen.URLConfig(
2✔
101
        default=(3, 20),
102
        infer_type=False,
103
        help="Save (left, right) samples besides hits; cut the rest",
104
    )
105

106
    n_tpc_pmts = straxen.URLConfig(type=int, help="Number of TPC PMTs")
2✔
107

108
    check_raw_record_overlaps = straxen.URLConfig(
2✔
109
        default=True,
110
        track=False,
111
        infer_type=False,
112
        help="Crash if any of the pulses in raw_records overlap with others in the same channel",
113
    )
114

115
    allow_sloppy_chunking = straxen.URLConfig(
2✔
116
        default=False,
117
        track=False,
118
        infer_type=False,
119
        help=(
120
            "Use a default baseline for incorrectly chunked fragments. "
121
            "This is a kludge for improperly converted XENON1T data."
122
        ),
123
    )
124

125
    hit_min_amplitude = straxen.URLConfig(
2✔
126
        track=True,
127
        infer_type=False,
128
        default="cmt://hit_thresholds_tpc?version=ONLINE&run_id=plugin.run_id",
129
        help=(
130
            "Minimum hit amplitude in ADC counts above baseline. "
131
            "Specify as a tuple of length n_tpc_pmts, or a number,"
132
            'or a string like "pmt_commissioning_initial" which means calling'
133
            "hitfinder_thresholds.py"
134
            "or a tuple like (correction=str, version=str, nT=boolean),"
135
            "which means we are using cmt."
136
        ),
137
    )
138

139
    def infer_dtype(self):
2✔
140
        # Get record_length from the plugin making raw_records
141
        self.record_length = strax.record_length_from_dtype(
2✔
142
            self.deps["raw_records"].dtype_for("raw_records")
143
        )
144

145
        dtype = dict()
2✔
146
        for p in self.provides:
2✔
147
            if "records" in p:
2✔
148
                dtype[p] = strax.record_dtype(self.record_length)
2✔
149
        dtype["veto_regions"] = strax.hit_dtype
2✔
150
        dtype["pulse_counts"] = pulse_count_dtype(self.n_tpc_pmts)
2✔
151

152
        return dtype
2✔
153

154
    def setup(self):
2✔
155
        self.hev_enabled = self.hev_gain_model is not None and self.tail_veto_threshold
2✔
156
        if self.hev_enabled:
2✔
157
            self.to_pe = self.hev_gain_model
2✔
158
        self.hit_thresholds = self.hit_min_amplitude
2✔
159

160
    def compute(self, raw_records, start, end):
2✔
161
        if self.check_raw_record_overlaps:
2✔
162
            check_overlaps(raw_records, n_channels=3000)
2✔
163

164
        # Throw away any non-TPC records; this should only happen for XENON1T
165
        # converted data
166
        raw_records = raw_records[raw_records["channel"] < self.n_tpc_pmts]
2✔
167

168
        # Convert everything to the records data type -- adds extra fields.
169
        r = strax.raw_to_records(raw_records)
2✔
170
        del raw_records
2✔
171

172
        # Do not trust in DAQ + strax.baseline to leave the
173
        # out-of-bounds samples to zero.
174
        strax.zero_out_of_bounds(r)
2✔
175

176
        strax.baseline(
2✔
177
            r,
178
            baseline_samples=self.baseline_samples,
179
            allow_sloppy_chunking=self.allow_sloppy_chunking,
180
            flip=True,
181
        )
182

183
        strax.integrate(r)
2✔
184

185
        pulse_counts = count_pulses(r, self.n_tpc_pmts)
2✔
186
        pulse_counts["time"] = start
2✔
187
        pulse_counts["endtime"] = end
2✔
188

189
        if len(r) and self.hev_enabled:
2✔
190
            r, r_vetoed, veto_regions = software_he_veto(
2✔
191
                r,
192
                self.to_pe,
193
                end,
194
                area_threshold=self.tail_veto_threshold,
195
                veto_length=self.tail_veto_duration,
196
                veto_res=self.tail_veto_resolution,
197
                pass_veto_extend=self.tail_veto_pass_extend,
198
                pass_veto_fraction=self.tail_veto_pass_fraction,
199
                max_veto_value=self.max_veto_value,
200
            )
201

202
            # In the future, we'll probably want to sum the waveforms
203
            # inside the vetoed regions, so we can still save the "peaks".
204
            del r_vetoed
2✔
205

206
        else:
207
            veto_regions = np.zeros(0, dtype=strax.hit_dtype)
2✔
208

209
        if len(r):
2✔
210
            # Find hits
211
            # -- before filtering,since this messes with the with the S/N
212
            hits = strax.find_hits(r, min_amplitude=self.hit_thresholds)
2✔
213

214
            if self.pmt_pulse_filter:
2✔
215
                # Filter to concentrate the PMT pulses
216
                strax.filter_records(r, np.array(self.pmt_pulse_filter))
2✔
217

218
            le, re = self.save_outside_hits
2✔
219
            r = strax.cut_outside_hits(r, hits, left_extension=le, right_extension=re)
2✔
220

221
            # Probably overkill, but just to be sure...
222
            strax.zero_out_of_bounds(r)
2✔
223

224
        return dict(records=r, pulse_counts=pulse_counts, veto_regions=veto_regions)
2✔
225

226

227
##
228
# Software HE Veto
229
##
230

231

232
@export
2✔
233
def software_he_veto(
2✔
234
    records,
235
    to_pe,
236
    chunk_end,
237
    area_threshold=int(1e5),
238
    veto_length=int(3e6),
239
    veto_res=int(1e3),
240
    pass_veto_fraction=0.01,
241
    pass_veto_extend=3,
242
    max_veto_value=None,
243
):
244
    """Veto veto_length (time in ns) after peaks larger than area_threshold (in PE).
245

246
    Further large peaks inside the veto regions are still passed:
247
    We sum the waveform inside the veto region (with time resolution
248
    veto_res in ns) and pass regions within pass_veto_extend samples
249
    of samples with amplitude above pass_veto_fraction times the maximum.
250

251
    :return: (preserved records, vetoed records, veto intervals).
252

253
    :param records: PMT records
254
    :param to_pe: ADC to PE conversion factors for the channels in records.
255
    :param chunk_end: Endtime of chunk to set as maximum ceiling for the veto period
256
    :param area_threshold: Minimum peak area to trigger the veto.
257
    Note we use a much rougher clustering than in later processing.
258
    :param veto_length: Time in ns to veto after the peak
259
    :param veto_res: Resolution of the sum waveform inside the veto region.
260
    Do not make too large without increasing integer type in some strax
261
    dtypes...
262
    :param pass_veto_fraction: fraction of maximum sum waveform amplitude to
263
    trigger veto passing of further peaks
264
    :param pass_veto_extend: samples to extend (left and right) the pass veto
265
    regions.
266
    :param max_veto_value: if not None, pass peaks that exceed this area
267
    no matter what.
268

269
    """
270
    veto_res = int(veto_res)
2✔
271
    if veto_res > np.iinfo(np.int16).max:
2✔
272
        raise ValueError("Veto resolution does not fit 16-bit int")
273
    veto_length = np.ceil(veto_length / veto_res).astype(np.int64) * veto_res
2✔
274
    veto_n = int(veto_length / veto_res) + 1
2✔
275

276
    # 1. Find large peaks in the data.
277
    # This will actually return big agglomerations of peaks and their tails
278
    peaks = strax.find_peaks(
2✔
279
        records,
280
        to_pe,
281
        gap_threshold=1,
282
        left_extension=0,
283
        right_extension=0,
284
        min_channels=100,
285
        min_area=area_threshold,
286
        result_dtype=strax.peak_dtype(n_channels=len(to_pe), n_sum_wv_samples=veto_n),
287
    )
288

289
    # 2a. Set 'candidate regions' at these peaks. These should:
290
    #  - Have a fixed maximum length (else we can't use the strax hitfinder on them)
291
    #  - Never extend beyond the current chunk
292
    #  - Do not overlap
293
    veto_start = peaks["time"]
2✔
294
    veto_end = np.clip(peaks["time"] + veto_length, None, chunk_end)
2✔
295
    veto_end[:-1] = np.clip(veto_end[:-1], None, veto_start[1:])
2✔
296

297
    # 2b. Convert these into strax record-like objects
298
    # Note the waveform is float32 though (it's a summed waveform)
299
    regions = np.zeros(
2✔
300
        len(veto_start),
301
        dtype=strax.interval_dtype
302
        + [
303
            ("data", (np.float32, veto_n)),
304
            ("baseline", np.float32),
305
            ("baseline_rms", np.float32),
306
            ("reduction_level", np.int64),
307
            ("record_i", np.int64),
308
            ("pulse_length", np.int64),
309
        ],
310
    )
311
    regions["time"] = veto_start
2✔
312
    regions["length"] = (veto_end - veto_start) // veto_n
2✔
313
    regions["pulse_length"] = veto_n
2✔
314
    regions["dt"] = veto_res
2✔
315

316
    if not len(regions):
2✔
317
        # No veto anywhere in this data
318
        return records, records[:0], np.zeros(0, strax.hit_dtype)
2✔
319

320
    # 3. Find pass_veto regios with big peaks inside the veto regions.
321
    # For this we compute a rough sum waveform (at low resolution,
322
    # without looping over the pulse data)
323
    rough_sum(regions, records, to_pe, veto_n, veto_res)
2✔
324
    if max_veto_value is not None:
2✔
325
        pass_veto = strax.find_hits(regions, min_amplitude=max_veto_value)
×
326
    else:
327
        regions["data"] /= np.max(regions["data"], axis=1)[:, np.newaxis]
2✔
328
        pass_veto = strax.find_hits(regions, min_amplitude=pass_veto_fraction)
2✔
329

330
    # 4. Extend these by a few samples and inverse to find veto regions
331
    regions["data"] = 1
2✔
332
    regions = strax.cut_outside_hits(
2✔
333
        regions, pass_veto, left_extension=pass_veto_extend, right_extension=pass_veto_extend
334
    )
335
    regions["data"] = 1 - regions["data"]
2✔
336
    veto = strax.find_hits(regions, min_amplitude=1)
2✔
337
    # Do not remove very tiny regions
338
    veto = veto[veto["length"] > 2 * pass_veto_extend]
2✔
339

340
    # 5. Apply the veto and return results
341
    veto_mask = strax.fully_contained_in(records, veto) == -1
2✔
342
    return tuple(list(mask_and_not(records, veto_mask)) + [veto])
2✔
343

344

345
@numba.njit(cache=True, nogil=True)
2✔
346
def rough_sum(regions, records, to_pe, n, dt):
2✔
347
    """Compute ultra-rough sum waveforms for regions, assuming:
348

349
     - every record is a single peak at its first sample
350
     - all regions have the same length and dt
351
    and probably not carying too much about boundaries
352

353
    """
354
    if not len(regions) or not len(records):
2✔
355
        return
×
356

357
    # dt and n are passed explicitly to avoid overflows/wraparounds
358
    # related to the small dt integer type
359

360
    peak_i = 0
2✔
361
    r_i = 0
2✔
362
    while (peak_i <= len(regions) - 1) and (r_i <= len(records) - 1):
2✔
363
        p = regions[peak_i]
2✔
364
        l = p["time"]  # noqa
2✔
365
        r = l + n * dt
2✔
366

367
        while True:
1✔
368
            if r_i > len(records) - 1:
2✔
369
                # Scan ahead until records contribute
370
                break
2✔
371
            t = records[r_i]["time"]
2✔
372
            if t >= r:
2✔
373
                break
2✔
374
            if t >= l:
2✔
375
                index = int((t - l) // dt)
2✔
376
                regions[peak_i]["data"][index] += (
2✔
377
                    records[r_i]["area"] * to_pe[records[r_i]["channel"]]
378
                )
379
            r_i += 1
2✔
380
        peak_i += 1
2✔
381

382

383
##
384
# Pulse counting
385
##
386
@export
2✔
387
def pulse_count_dtype(n_channels):
2✔
388
    # NB: don't use the dt/length interval dtype, integer types are too small
389
    # to contain these huge chunk-wide intervals
390
    return [
2✔
391
        (("Start time of the chunk", "time"), np.int64),
392
        (("End time of the chunk", "endtime"), np.int64),
393
        (("Number of pulses", "pulse_count"), (np.int64, n_channels)),
394
        (("Number of lone pulses", "lone_pulse_count"), (np.int64, n_channels)),
395
        (("Integral of all pulses in ADC_count x samples", "pulse_area"), (np.int64, n_channels)),
396
        (
397
            ("Integral of lone pulses in ADC_count x samples", "lone_pulse_area"),
398
            (np.int64, n_channels),
399
        ),
400
        (("Average baseline", "baseline_mean"), (np.int16, n_channels)),
401
        (("Average baseline rms", "baseline_rms_mean"), (np.float32, n_channels)),
402
    ]
403

404

405
def count_pulses(records, n_channels):
2✔
406
    """Return array with one element, with pulse count info from records."""
407
    if len(records):
2✔
408
        result = np.zeros(1, dtype=pulse_count_dtype(n_channels))
2✔
409
        _count_pulses(records, n_channels, result)
2✔
410
        return result
2✔
411
    return np.zeros(0, dtype=pulse_count_dtype(n_channels))
2✔
412

413

414
NO_PULSE_COUNTS = -9999  # Special value required by average_baseline in case counts = 0
2✔
415

416

417
@numba.njit(cache=True, nogil=True)
2✔
418
def _count_pulses(records, n_channels, result):
2✔
419
    count = np.zeros(n_channels, dtype=np.int64)
2✔
420
    lone_count = np.zeros(n_channels, dtype=np.int64)
2✔
421
    area = np.zeros(n_channels, dtype=np.int64)
2✔
422
    lone_area = np.zeros(n_channels, dtype=np.int64)
2✔
423

424
    last_end_seen = 0
2✔
425
    next_start = 0
2✔
426

427
    # Array of booleans to track whether we are currently in a lone pulse
428
    # in each channel
429
    in_lone_pulse = np.zeros(n_channels, dtype=np.bool_)
2✔
430
    baseline_buffer = np.zeros(n_channels, dtype=np.float64)
2✔
431
    baseline_rms_buffer = np.zeros(n_channels, dtype=np.float64)
2✔
432
    for r_i, r in enumerate(records):
2✔
433
        if r_i != len(records) - 1:
2✔
434
            next_start = records[r_i + 1]["time"]
2✔
435

436
        ch = r["channel"]
2✔
437
        if ch >= n_channels:
2✔
438
            print("Channel:", ch)
×
439
            raise RuntimeError("Out of bounds channel in get_counts!")
440

441
        area[ch] += r["area"]  # <-- Summing total area in channel
2✔
442

443
        if r["record_i"] == 0:
2✔
444
            count[ch] += 1
2✔
445
            baseline_buffer[ch] += r["baseline"]
2✔
446
            baseline_rms_buffer[ch] += r["baseline_rms"]
2✔
447

448
            if r["time"] > last_end_seen and r["time"] + r["pulse_length"] * r["dt"] < next_start:
2✔
449
                # This is a lone pulse
UNCOV
450
                lone_count[ch] += 1
1✔
UNCOV
451
                in_lone_pulse[ch] = True
1✔
UNCOV
452
                lone_area[ch] += r["area"]
1✔
453
            else:
454
                in_lone_pulse[ch] = False
2✔
455

456
            last_end_seen = max(last_end_seen, r["time"] + r["pulse_length"] * r["dt"])
2✔
457

458
        elif in_lone_pulse[ch]:
2✔
459
            # This is a subsequent fragment of a lone pulse
460
            lone_area[ch] += r["area"]
×
461

462
    res = result[0]
2✔
463
    res["pulse_count"][:] = count[:]
2✔
464
    res["lone_pulse_count"][:] = lone_count[:]
2✔
465
    res["pulse_area"][:] = area[:]
2✔
466
    res["lone_pulse_area"][:] = lone_area[:]
2✔
467
    means = baseline_buffer / count
2✔
468
    means[np.isnan(means)] = NO_PULSE_COUNTS
2✔
469
    res["baseline_mean"][:] = means[:]
2✔
470
    res["baseline_rms_mean"][:] = (baseline_rms_buffer / count)[:]
2✔
471

472

473
##
474
# Misc
475
##
476
@export
2✔
477
@numba.njit(cache=True, nogil=True)
2✔
478
def mask_and_not(x, mask):
2✔
479
    return x[mask], x[~mask]
2✔
480

481

482
@export
2✔
483
@numba.njit(cache=True, nogil=True)
2✔
484
def channel_split(rr, first_other_ch):
2✔
485
    """Return."""
486
    return mask_and_not(rr, rr["channel"] < first_other_ch)
×
487

488

489
@export
2✔
490
def check_overlaps(records, n_channels):
2✔
491
    """Raise a ValueError if any of the pulses in records overlap.
492

493
    Assumes records is already sorted by time.
494

495
    """
496
    last_end = np.zeros(n_channels, dtype=np.int64)
2✔
497
    channel, time = _check_overlaps(records, last_end)
2✔
498
    if channel != -9999:
2✔
499
        raise ValueError(
500
            f"Bad data! In channel {channel}, a pulse starts at {time}, "
501
            "BEFORE the previous pulse in that same channel ended "
502
            f"(at {last_end[channel]})"
503
        )
504

505

506
@numba.njit(cache=True, nogil=True)
2✔
507
def _check_overlaps(records, last_end):
2✔
508
    for r in records:
2✔
509
        if r["time"] < last_end[r["channel"]]:
2✔
510
            return r["channel"], r["time"]
×
511
        last_end[r["channel"]] = strax.endtime(r)
2✔
512
    return -9999, -9999
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc