• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 9587172660

19 Jun 2024 07:19PM UTC coverage: 92.009% (+0.05%) from 91.962%
9587172660

Pull #411

github

rettigl
rename group_name to dataset_key
Pull Request #411: Energy calibration bias shift

81 of 99 new or added lines in 3 files covered. (81.82%)

128 existing lines in 3 files now uncovered.

6494 of 7058 relevant lines covered (92.01%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.65
/sed/calibrator/energy.py
1
"""sed.calibrator.energy module. Code for energy calibration and
2
correction. Mostly ported from https://github.com/mpes-kit/mpes.
3
"""
4
import itertools as it
1✔
5
from copy import deepcopy
1✔
6
from datetime import datetime
1✔
7
from functools import partial
1✔
8
from typing import Any
1✔
9
from typing import cast
1✔
10
from typing import Dict
1✔
11
from typing import List
1✔
12
from typing import Literal
1✔
13
from typing import Sequence
1✔
14
from typing import Tuple
1✔
15
from typing import Union
1✔
16

17
import bokeh.plotting as pbk
1✔
18
import dask.dataframe
1✔
19
import h5py
1✔
20
import ipywidgets as ipw
1✔
21
import matplotlib
1✔
22
import matplotlib.pyplot as plt
1✔
23
import numpy as np
1✔
24
import pandas as pd
1✔
25
import psutil
1✔
26
import xarray as xr
1✔
27
from bokeh.io import output_notebook
1✔
28
from bokeh.palettes import Category10 as ColorCycle
1✔
29
from fastdtw import fastdtw
1✔
30
from IPython.display import display
1✔
31
from lmfit import Minimizer
1✔
32
from lmfit import Parameters
1✔
33
from lmfit.printfuncs import report_fit
1✔
34
from numpy.linalg import lstsq
1✔
35
from scipy.signal import savgol_filter
1✔
36
from scipy.sparse.linalg import lsqr
1✔
37

38
from sed.binning import bin_dataframe
1✔
39
from sed.core import dfops
1✔
40
from sed.loader.base.loader import BaseLoader
1✔
41

42

43
class EnergyCalibrator:
1✔
44
    """Electron binding energy calibration workflow.
45

46
    For the initialization of the EnergyCalibrator class an instance of a
47
    loader is required. The data can be loaded using the optional arguments,
48
    or using the load_data method or bin_data method.
49

50
    Args:
51
        loader (BaseLoader): Instance of a loader, subclassed from BaseLoader.
52
        biases (np.ndarray, optional): Bias voltages used. Defaults to None.
53
        traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
54
            values. Defaults to None.
55
        tof (np.ndarray, optional): TOF-values for the data traces.
56
            Defaults to None.
57
        config (dict, optional): Config dictionary. Defaults to None.
58
    """
59

60
    def __init__(
1✔
61
        self,
62
        loader: BaseLoader,
63
        biases: np.ndarray = None,
64
        traces: np.ndarray = None,
65
        tof: np.ndarray = None,
66
        config: dict = None,
67
    ):
68
        """For the initialization of the EnergyCalibrator class an instance of a
69
        loader is required. The data can be loaded using the optional arguments,
70
        or using the load_data method or bin_data method.
71

72
        Args:
73
            loader (BaseLoader): Instance of a loader, subclassed from BaseLoader.
74
            biases (np.ndarray, optional): Bias voltages used. Defaults to None.
75
            traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
76
                values. Defaults to None.
77
            tof (np.ndarray, optional): TOF-values for the data traces.
78
                Defaults to None.
79
            config (dict, optional): Config dictionary. Defaults to None.
80
        """
81
        self.loader = loader
1✔
82
        self.biases: np.ndarray = None
1✔
83
        self.traces: np.ndarray = None
1✔
84
        self.traces_normed: np.ndarray = None
1✔
85
        self.tof: np.ndarray = None
1✔
86

87
        if traces is not None and tof is not None and biases is not None:
1✔
88
            self.load_data(biases=biases, traces=traces, tof=tof)
×
89

90
        if config is None:
1✔
91
            config = {}
×
92

93
        self._config = config
1✔
94

95
        self.featranges: List[Tuple] = []  # Value ranges for feature detection
1✔
96
        self.peaks: np.ndarray = np.asarray([])
1✔
97
        self.calibration: Dict[str, Any] = self._config["energy"].get("calibration", {})
1✔
98

99
        self.tof_column = self._config["dataframe"]["tof_column"]
1✔
100
        self.tof_ns_column = self._config["dataframe"].get("tof_ns_column", None)
1✔
101
        self.corrected_tof_column = self._config["dataframe"]["corrected_tof_column"]
1✔
102
        self.energy_column = self._config["dataframe"]["energy_column"]
1✔
103
        self.x_column = self._config["dataframe"]["x_column"]
1✔
104
        self.y_column = self._config["dataframe"]["y_column"]
1✔
105
        self.binwidth: float = self._config["dataframe"]["tof_binwidth"]
1✔
106
        self.binning: int = self._config["dataframe"]["tof_binning"]
1✔
107
        self.x_width = self._config["energy"]["x_width"]
1✔
108
        self.y_width = self._config["energy"]["y_width"]
1✔
109
        self.tof_width = np.asarray(
1✔
110
            self._config["energy"]["tof_width"],
111
        ) / 2 ** (self.binning - 1)
112
        self.tof_fermi = self._config["energy"]["tof_fermi"] / 2 ** (self.binning - 1)
1✔
113
        self.color_clip = self._config["energy"]["color_clip"]
1✔
114
        self.sector_delays = self._config["dataframe"].get("sector_delays", None)
1✔
115
        self.sector_id_column = self._config["dataframe"].get("sector_id_column", None)
1✔
116
        self.offsets: Dict[str, Any] = self._config["energy"].get("offsets", {})
1✔
117
        self.correction: Dict[str, Any] = self._config["energy"].get("correction", {})
1✔
118

119
    @property
1✔
120
    def ntraces(self) -> int:
1✔
121
        """Property returning the number of traces.
122

123
        Returns:
124
            int: The number of loaded/calculated traces.
125
        """
126
        return len(self.traces)
1✔
127

128
    @property
1✔
129
    def nranges(self) -> int:
1✔
130
        """Property returning the number of specified feature ranges which Can be a
131
        multiple of ntraces.
132

133
        Returns:
134
            int: The number of specified feature ranges.
135
        """
136
        return len(self.featranges)
1✔
137

138
    @property
1✔
139
    def dup(self) -> int:
1✔
140
        """Property returning the duplication number, i.e. the number of feature
141
        ranges per trace.
142

143
        Returns:
144
            int: The duplication number.
145
        """
146
        return int(np.round(self.nranges / self.ntraces))
1✔
147

148
    def load_data(
1✔
149
        self,
150
        biases: np.ndarray = None,
151
        traces: np.ndarray = None,
152
        tof: np.ndarray = None,
153
    ):
154
        """Load data into the class. Not provided parameters will be overwritten by
155
        empty arrays.
156

157
        Args:
158
            biases (np.ndarray, optional): Bias voltages used. Defaults to None.
159
            traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
160
                values. Defaults to None.
161
            tof (np.ndarray, optional): TOF-values for the data traces.
162
                Defaults to None.
163
        """
164
        if biases is not None:
1✔
165
            self.biases = biases
1✔
166
        else:
167
            self.biases = np.asarray([])
×
168
        if tof is not None:
1✔
169
            self.tof = tof
1✔
170
        else:
171
            self.tof = np.asarray([])
×
172
        if traces is not None:
1✔
173
            self.traces = self.traces_normed = traces
1✔
174
        else:
175
            self.traces = self.traces_normed = np.asarray([])
×
176

177
    def bin_data(
1✔
178
        self,
179
        data_files: List[str],
180
        axes: List[str] = None,
181
        bins: List[int] = None,
182
        ranges: Sequence[Tuple[float, float]] = None,
183
        biases: np.ndarray = None,
184
        bias_key: str = None,
185
        **kwds,
186
    ):
187
        """Bin data from single-event files, and load into class.
188

189
        Args:
190
            data_files (List[str]): list of file names to bin
191
            axes (List[str], optional): bin axes. Defaults to
192
                config["dataframe"]["tof_column"].
193
            bins (List[int], optional): number of bins.
194
                Defaults to config["energy"]["bins"].
195
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
196
                Defaults to config["energy"]["ranges"].
197
            biases (np.ndarray, optional): Bias voltages used.
198
                If not provided, biases are extracted from the file meta data.
199
            bias_key (str, optional): hdf5 path where bias values are stored.
200
                Defaults to config["energy"]["bias_key"].
201
            **kwds: Keyword parameters for bin_dataframe
202
        """
203
        if axes is None:
1✔
204
            axes = [self.tof_column]
1✔
205
        if bins is None:
1✔
206
            bins = [self._config["energy"]["bins"]]
1✔
207
        if ranges is None:
1✔
208
            ranges_ = [
1✔
209
                np.array(self._config["energy"]["ranges"]) / 2 ** (self.binning - 1),
210
            ]
211
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
212
        # pylint: disable=duplicate-code
213
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
214
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
215
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
216
        try:
1✔
217
            num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
218
        except KeyError:
1✔
219
            num_cores = psutil.cpu_count() - 1
1✔
220
        threads_per_worker = kwds.pop(
1✔
221
            "threads_per_worker",
222
            self._config["binning"]["threads_per_worker"],
223
        )
224
        threadpool_api = kwds.pop(
1✔
225
            "threadpool_API",
226
            self._config["binning"]["threadpool_API"],
227
        )
228

229
        read_biases = False
1✔
230
        if biases is None:
1✔
231
            read_biases = True
1✔
232
            if bias_key is None:
1✔
233
                try:
1✔
234
                    bias_key = self._config["energy"]["bias_key"]
1✔
235
                except KeyError as exc:
1✔
236
                    raise ValueError(
1✔
237
                        "Either Bias Values or a valid bias_key has to be present!",
238
                    ) from exc
239

240
        dataframe, _, _ = self.loader.read_dataframe(
1✔
241
            files=data_files,
242
            collect_metadata=False,
243
        )
244
        traces = bin_dataframe(
1✔
245
            dataframe,
246
            bins=bins,
247
            axes=axes,
248
            ranges=ranges,
249
            hist_mode=hist_mode,
250
            mode=mode,
251
            pbar=pbar,
252
            n_cores=num_cores,
253
            threads_per_worker=threads_per_worker,
254
            threadpool_api=threadpool_api,
255
            return_partitions=True,
256
            **kwds,
257
        )
258
        if read_biases:
1✔
259
            if bias_key:
1✔
260
                try:
1✔
261
                    biases = extract_bias(data_files, bias_key)
1✔
262
                except KeyError as exc:
1✔
263
                    raise ValueError(
1✔
264
                        "Either Bias Values or a valid bias_key has to be present!",
265
                    ) from exc
266
        tof = traces.coords[(axes[0])]
1✔
267
        self.traces = self.traces_normed = np.asarray(traces.T)
1✔
268
        self.tof = np.asarray(tof)
1✔
269
        self.biases = np.asarray(biases)
1✔
270

271
    def normalize(self, smooth: bool = False, span: int = 7, order: int = 1):
1✔
272
        """Normalize the spectra along an axis.
273

274
        Args:
275
            smooth (bool, optional): Option to smooth the signals before normalization.
276
                Defaults to False.
277
            span (int, optional): span smoothing parameters of the LOESS method
278
                (see ``scipy.signal.savgol_filter()``). Defaults to 7.
279
            order (int, optional): order smoothing parameters of the LOESS method
280
                (see ``scipy.signal.savgol_filter()``). Defaults to 1.
281
        """
282
        self.traces_normed = normspec(
1✔
283
            self.traces,
284
            smooth=smooth,
285
            span=span,
286
            order=order,
287
        )
288

289
    def adjust_ranges(
1✔
290
        self,
291
        ranges: Tuple,
292
        ref_id: int = 0,
293
        traces: np.ndarray = None,
294
        peak_window: int = 7,
295
        apply: bool = False,
296
        **kwds,
297
    ):
298
        """Display a tool to select or extract the equivalent feature ranges
299
        (containing the peaks) among all traces.
300

301
        Args:
302
            ranges (Tuple):
303
                Collection of feature detection ranges, within which an algorithm
304
                (i.e. 1D peak detector) with look for the feature.
305
            ref_id (int, optional): Index of the reference trace. Defaults to 0.
306
            traces (np.ndarray, optional): Collection of energy dispersion curves.
307
                Defaults to self.traces_normed.
308
            peak_window (int, optional): area around a peak to check for other peaks.
309
                Defaults to 7.
310
            apply (bool, optional): Option to directly apply the provided parameters.
311
                Defaults to False.
312
            **kwds:
313
                keyword arguments for trace alignment (see ``find_correspondence()``).
314
        """
315
        if traces is None:
1✔
316
            traces = self.traces_normed
1✔
317

318
        self.add_ranges(
1✔
319
            ranges=ranges,
320
            ref_id=ref_id,
321
            traces=traces,
322
            infer_others=True,
323
            mode="replace",
324
        )
325
        self.feature_extract(peak_window=peak_window)
1✔
326

327
        # make plot
328
        labels = kwds.pop("labels", [str(b) + " V" for b in self.biases])
1✔
329
        figsize = kwds.pop("figsize", (8, 4))
1✔
330
        plot_segs = []
1✔
331
        plot_peaks = []
1✔
332
        fig, ax = plt.subplots(figsize=figsize)
1✔
333
        colors = plt.get_cmap("rainbow")(np.linspace(0, 1, len(traces)))
1✔
334
        for itr, color in zip(range(len(traces)), colors):
1✔
335
            trace = traces[itr, :]
1✔
336
            # main traces
337
            ax.plot(
1✔
338
                self.tof,
339
                trace,
340
                ls="-",
341
                color=color,
342
                linewidth=1,
343
                label=labels[itr],
344
            )
345
            # segments:
346
            seg = self.featranges[itr]
1✔
347
            cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
348
            tofseg, traceseg = self.tof[cond], trace[cond]
1✔
349
            (line,) = ax.plot(
1✔
350
                tofseg,
351
                traceseg,
352
                ls="-",
353
                color=color,
354
                linewidth=3,
355
            )
356
            plot_segs.append(line)
1✔
357
            # markers
358
            (scatt,) = ax.plot(
1✔
359
                self.peaks[itr, 0],
360
                self.peaks[itr, 1],
361
                ls="",
362
                marker=".",
363
                color="k",
364
                markersize=10,
365
            )
366
            plot_peaks.append(scatt)
1✔
367
        ax.legend(fontsize=8, loc="upper right")
1✔
368
        ax.set_title("")
1✔
369

370
        def update(refid, ranges):
1✔
371
            self.add_ranges(ranges, refid, traces=traces)
1✔
372
            self.feature_extract(peak_window=7)
1✔
373
            for itr, _ in enumerate(self.traces_normed):
1✔
374
                seg = self.featranges[itr]
1✔
375
                cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
376
                tofseg, traceseg = (
1✔
377
                    self.tof[cond],
378
                    self.traces_normed[itr][cond],
379
                )
380
                plot_segs[itr].set_ydata(traceseg)
1✔
381
                plot_segs[itr].set_xdata(tofseg)
1✔
382

383
                plot_peaks[itr].set_xdata(self.peaks[itr, 0])
1✔
384
                plot_peaks[itr].set_ydata(self.peaks[itr, 1])
1✔
385

386
            fig.canvas.draw_idle()
1✔
387

388
        refid_slider = ipw.IntSlider(
1✔
389
            value=ref_id,
390
            min=0,
391
            max=10,
392
            step=1,
393
        )
394

395
        ranges_slider = ipw.IntRangeSlider(
1✔
396
            value=list(ranges),
397
            min=min(self.tof),
398
            max=max(self.tof),
399
            step=1,
400
        )
401

402
        update(ranges=ranges, refid=ref_id)
1✔
403

404
        ipw.interact(
1✔
405
            update,
406
            refid=refid_slider,
407
            ranges=ranges_slider,
408
        )
409

410
        def apply_func(apply: bool):  # noqa: ARG001
1✔
411
            self.add_ranges(
1✔
412
                ranges_slider.value,
413
                refid_slider.value,
414
                traces=self.traces_normed,
415
            )
416
            self.feature_extract(peak_window=7)
1✔
417
            ranges_slider.close()
1✔
418
            refid_slider.close()
1✔
419
            apply_button.close()
1✔
420

421
        apply_button = ipw.Button(description="apply")
1✔
422
        display(apply_button)  # pylint: disable=duplicate-code
1✔
423
        apply_button.on_click(apply_func)
1✔
424
        plt.show()
1✔
425

426
        if apply:
1✔
427
            apply_func(True)
1✔
428

429
    def add_ranges(
1✔
430
        self,
431
        ranges: Union[List[Tuple], Tuple],
432
        ref_id: int = 0,
433
        traces: np.ndarray = None,
434
        infer_others: bool = True,
435
        mode: str = "replace",
436
        **kwds,
437
    ):
438
        """Select or extract the equivalent feature ranges (containing the peaks) among all traces.
439

440
        Args:
441
            ranges (Union[List[Tuple], Tuple]):
442
                Collection of feature detection ranges, within which an algorithm
443
                (i.e. 1D peak detector) with look for the feature.
444
            ref_id (int, optional): Index of the reference trace. Defaults to 0.
445
            traces (np.ndarray, optional): Collection of energy dispersion curves.
446
                Defaults to self.traces_normed.
447
            infer_others (bool, optional): Option to infer the feature detection range
448
                in other traces from a given one using a time warp algorthm.
449
                Defaults to True.
450
            mode (str, optional): Specification on how to change the feature ranges
451
                ('append' or 'replace'). Defaults to "replace".
452
            **kwds:
453
                keyword arguments for trace alignment (see ``find_correspondence()``).
454
        """
455
        if traces is None:
1✔
456
            traces = self.traces_normed
1✔
457

458
        # Infer the corresponding feature detection range of other traces by alignment
459
        if infer_others:
1✔
460
            assert isinstance(ranges, tuple)
1✔
461
            newranges: List[Tuple] = []
1✔
462

463
            for i in range(self.ntraces):
1✔
464
                pathcorr = find_correspondence(
1✔
465
                    traces[ref_id, :],
466
                    traces[i, :],
467
                    **kwds,
468
                )
469
                newranges.append(range_convert(self.tof, ranges, pathcorr))
1✔
470

471
        else:
472
            if isinstance(ranges, list):
1✔
473
                newranges = ranges
1✔
474
            else:
475
                newranges = [ranges]
×
476

477
        if mode == "append":
1✔
478
            self.featranges += newranges
×
479
        elif mode == "replace":
1✔
480
            self.featranges = newranges
1✔
481

482
    def feature_extract(
1✔
483
        self,
484
        ranges: List[Tuple] = None,
485
        traces: np.ndarray = None,
486
        peak_window: int = 7,
487
    ):
488
        """Select or extract the equivalent landmarks (e.g. peaks) among all traces.
489

490
        Args:
491
            ranges (List[Tuple], optional):  List of ranges in each trace to look for
492
                the peak feature, [start, end]. Defaults to self.featranges.
493
            traces (np.ndarray, optional): Collection of 1D spectra to use for
494
                calibration. Defaults to self.traces_normed.
495
            peak_window (int, optional): area around a peak to check for other peaks.
496
                Defaults to 7.
497
        """
498
        if ranges is None:
1✔
499
            ranges = self.featranges
1✔
500

501
        if traces is None:
1✔
502
            traces = self.traces_normed
1✔
503

504
        # Augment the content of the calibration data
505
        traces_aug = np.tile(traces, (self.dup, 1))
1✔
506
        # Run peak detection for each trace within the specified ranges
507
        self.peaks = peaksearch(
1✔
508
            traces_aug,
509
            self.tof,
510
            ranges=ranges,
511
            pkwindow=peak_window,
512
        )
513

514
    def calibrate(
1✔
515
        self,
516
        ref_energy: float = 0,
517
        method: str = "lmfit",
518
        energy_scale: str = "kinetic",
519
        landmarks: np.ndarray = None,
520
        biases: np.ndarray = None,
521
        t: np.ndarray = None,
522
        verbose: bool = True,
523
        **kwds,
524
    ) -> dict:
525
        """Calculate the functional mapping between time-of-flight and the energy
526
        scale using optimization methods.
527

528
        Args:
529
            ref_energy (float): Binding/kinetic energy of the detected feature.
530
            method (str, optional):  Method for determining the energy calibration.
531

532
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
533
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
534

535
                Defaults to 'lmfit'.
536
            energy_scale (str, optional): Direction of increasing energy scale.
537

538
                - **'kinetic'**: increasing energy with decreasing TOF.
539
                - **'binding'**: increasing energy with increasing TOF.
540

541
                Defaults to "kinetic".
542
            landmarks (np.ndarray, optional): Extracted peak positions (TOF) used for
543
                calibration. Defaults to self.peaks.
544
            biases (np.ndarray, optional): Bias values. Defaults to self.biases.
545
            t (np.ndarray, optional): TOF values. Defaults to self.tof.
546
            verbose (bool, optional): Option to print out diagnostic information.
547
                Defaults to True.
548
            **kwds: keyword arguments.
549
                See available keywords for ``poly_energy_calibration()`` and
550
                ``fit_energy_calibration()``
551

552
        Raises:
553
            ValueError: Raised if invalid 'energy_scale' is passed.
554
            NotImplementedError: Raised if invalid 'method' is passed.
555

556
        Returns:
557
            dict: Calibration dictionary with coefficients.
558
        """
559
        if landmarks is None:
1✔
560
            landmarks = self.peaks[:, 0]
1✔
561
        if biases is None:
1✔
562
            biases = self.biases
1✔
563
        if t is None:
1✔
564
            t = self.tof
1✔
565
        if energy_scale == "kinetic":
1✔
566
            sign = -1
1✔
567
        elif energy_scale == "binding":
1✔
568
            sign = 1
1✔
569
        else:
570
            raise ValueError(
1✔
571
                'energy_scale needs to be either "binding" or "kinetic"',
572
                f", got {energy_scale}.",
573
            )
574

575
        binwidth = kwds.pop("binwidth", self.binwidth)
1✔
576
        binning = kwds.pop("binning", self.binning)
1✔
577

578
        if method == "lmfit":
1✔
579
            self.calibration = fit_energy_calibration(
1✔
580
                landmarks,
581
                sign * biases,
582
                binwidth,
583
                binning,
584
                ref_energy=ref_energy,
585
                t=t,
586
                energy_scale=energy_scale,
587
                verbose=verbose,
588
                **kwds,
589
            )
590
        elif method in ("lstsq", "lsqr"):
1✔
591
            self.calibration = poly_energy_calibration(
1✔
592
                landmarks,
593
                sign * biases,
594
                ref_energy=ref_energy,
595
                aug=self.dup,
596
                method=method,
597
                t=t,
598
                energy_scale=energy_scale,
599
                **kwds,
600
            )
601
        else:
602
            raise NotImplementedError()
1✔
603

604
        self.calibration["creation_date"] = datetime.now().timestamp()
1✔
605
        return self.calibration
1✔
606

607
    def view(  # pylint: disable=dangerous-default-value
1✔
608
        self,
609
        traces: np.ndarray,
610
        segs: List[Tuple] = None,
611
        peaks: np.ndarray = None,
612
        show_legend: bool = True,
613
        backend: str = "matplotlib",
614
        linekwds: dict = {},
615
        linesegkwds: dict = {},
616
        scatterkwds: dict = {},
617
        legkwds: dict = {},
618
        **kwds,
619
    ):
620
        """Display a plot showing line traces with annotation.
621

622
        Args:
623
            traces (np.ndarray): Matrix of traces to visualize.
624
            segs (List[Tuple], optional): Segments to be highlighted in the
625
                visualization. Defaults to None.
626
            peaks (np.ndarray, optional): Peak positions for labelling the traces.
627
                Defaults to None.
628
            show_legend (bool, optional): Option to display bias voltages as legends.
629
                Defaults to True.
630
            backend (str, optional): Backend specification, choose between 'matplotlib'
631
                (static) or 'bokeh' (interactive). Defaults to "matplotlib".
632
            linekwds (dict, optional): Keyword arguments for line plotting
633
                (see ``matplotlib.pyplot.plot()``). Defaults to {}.
634
            linesegkwds (dict, optional): Keyword arguments for line segments plotting
635
                (see ``matplotlib.pyplot.plot()``). Defaults to {}.
636
            scatterkwds (dict, optional): Keyword arguments for scatter plot
637
                (see ``matplotlib.pyplot.scatter()``). Defaults to {}.
638
            legkwds (dict, optional): Keyword arguments for legend
639
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
640
            **kwds: keyword arguments:
641

642
                - **labels** (list): Labels for each curve
643
                - **xaxis** (np.ndarray): x (horizontal) axis values
644
                - **title** (str): Title of the plot
645
                - **legend_location** (str): Location of the plot legend
646
                - **align** (bool): Option to shift traces by bias voltage
647
        """
648
        lbs = kwds.pop("labels", [str(b) + " V" for b in self.biases])
1✔
649
        xaxis = kwds.pop("xaxis", self.tof)
1✔
650
        ttl = kwds.pop("title", "")
1✔
651
        align = kwds.pop("align", False)
1✔
652
        energy_scale = kwds.pop("energy_scale", "kinetic")
1✔
653

654
        sign = 1 if energy_scale == "kinetic" else -1
1✔
655

656
        if backend == "matplotlib":
1✔
657
            figsize = kwds.pop("figsize", (12, 4))
1✔
658
            fig, ax = plt.subplots(figsize=figsize)
1✔
659
            for itr, trace in enumerate(traces):
1✔
660
                if align:
1✔
661
                    ax.plot(
×
662
                        xaxis + sign * (self.biases[itr]),
663
                        trace,
664
                        ls="-",
665
                        linewidth=1,
666
                        label=lbs[itr],
667
                        **linekwds,
668
                    )
669
                else:
670
                    ax.plot(
1✔
671
                        xaxis,
672
                        trace,
673
                        ls="-",
674
                        linewidth=1,
675
                        label=lbs[itr],
676
                        **linekwds,
677
                    )
678

679
                # Emphasize selected EDC segments
680
                if segs is not None:
1✔
681
                    seg = segs[itr]
×
682
                    cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
×
683
                    tofseg, traceseg = self.tof[cond], trace[cond]
×
684
                    ax.plot(
×
685
                        tofseg,
686
                        traceseg,
687
                        ls="-",
688
                        linewidth=2,
689
                        **linesegkwds,
690
                    )
691
                # Emphasize extracted local maxima
692
                if peaks is not None:
1✔
693
                    ax.scatter(
×
694
                        peaks[itr, 0],
695
                        peaks[itr, 1],
696
                        s=30,
697
                        **scatterkwds,
698
                    )
699

700
            if show_legend:
1✔
701
                try:
×
702
                    ax.legend(fontsize=12, **legkwds)
×
703
                except TypeError:
×
704
                    pass
×
705

706
            ax.set_title(ttl)
1✔
707

708
        elif backend == "bokeh":
1✔
709
            output_notebook(hide_banner=True)
1✔
710
            colors = it.cycle(ColorCycle[10])
1✔
711
            ttp = [("(x, y)", "($x, $y)")]
1✔
712

713
            figsize = kwds.pop("figsize", (800, 300))
1✔
714
            fig = pbk.figure(
1✔
715
                title=ttl,
716
                width=figsize[0],
717
                height=figsize[1],
718
                tooltips=ttp,
719
            )
720
            # Plotting the main traces
721
            for itr, color in zip(range(len(traces)), colors):
1✔
722
                trace = traces[itr, :]
1✔
723
                if align:
1✔
724
                    fig.line(
1✔
725
                        xaxis + sign * (self.biases[itr]),
726
                        trace,
727
                        color=color,
728
                        line_dash="solid",
729
                        line_width=1,
730
                        line_alpha=1,
731
                        legend_label=lbs[itr],
732
                        **kwds,
733
                    )
734
                else:
735
                    fig.line(
1✔
736
                        xaxis,
737
                        trace,
738
                        color=color,
739
                        line_dash="solid",
740
                        line_width=1,
741
                        line_alpha=1,
742
                        legend_label=lbs[itr],
743
                        **kwds,
744
                    )
745

746
                # Emphasize selected EDC segments
747
                if segs is not None:
1✔
748
                    seg = segs[itr]
1✔
749
                    cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
750
                    tofseg, traceseg = self.tof[cond], trace[cond]
1✔
751
                    fig.line(
1✔
752
                        tofseg,
753
                        traceseg,
754
                        color=color,
755
                        line_width=3,
756
                        **linekwds,
757
                    )
758

759
                # Plot detected peaks
760
                if peaks is not None:
1✔
761
                    fig.scatter(
1✔
762
                        peaks[itr, 0],
763
                        peaks[itr, 1],
764
                        fill_color=color,
765
                        fill_alpha=0.8,
766
                        line_color=None,
767
                        size=5,
768
                        **scatterkwds,
769
                    )
770

771
            if show_legend:
1✔
772
                fig.legend.location = kwds.pop("legend_location", "top_right")
1✔
773
                fig.legend.spacing = 0
1✔
774
                fig.legend.padding = 2
1✔
775

776
            pbk.show(fig)
1✔
777

778
    def append_energy_axis(
1✔
779
        self,
780
        df: Union[pd.DataFrame, dask.dataframe.DataFrame],
781
        tof_column: str = None,
782
        energy_column: str = None,
783
        calibration: dict = None,
784
        verbose: bool = True,
785
        **kwds,
786
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
787
        """Calculate and append the energy axis to the events dataframe.
788

789
        Args:
790
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]):
791
                Dataframe to apply the energy axis calibration to.
792
            tof_column (str, optional): Label of the source column.
793
                Defaults to config["dataframe"]["tof_column"].
794
            energy_column (str, optional): Label of the destination column.
795
                Defaults to config["dataframe"]["energy_column"].
796
            calibration (dict, optional): Calibration dictionary. If provided,
797
                overrides calibration from class or config.
798
                Defaults to self.calibration or config["energy"]["calibration"].
799
            verbose (bool, optional): Option to print out diagnostic information.
800
                Defaults to True.
801
            **kwds: additional keyword arguments for the energy conversion. They are
802
                added to the calibration dictionary.
803

804
        Raises:
805
            ValueError: Raised if expected calibration parameters are missing.
806
            NotImplementedError: Raised if an invalid calib_type is found.
807

808
        Returns:
809
            Union[pd.DataFrame, dask.dataframe.DataFrame]: dataframe with added column
810
            and energy calibration metadata dictionary.
811
        """
812
        if tof_column is None:
1✔
813
            if self.corrected_tof_column in df.columns:
1✔
UNCOV
814
                tof_column = self.corrected_tof_column
×
815
            else:
816
                tof_column = self.tof_column
1✔
817

818
        if energy_column is None:
1✔
819
            energy_column = self.energy_column
1✔
820

821
        binwidth = kwds.pop("binwidth", self.binwidth)
1✔
822
        binning = kwds.pop("binning", self.binning)
1✔
823

824
        # pylint: disable=duplicate-code
825
        if calibration is None:
1✔
826
            calibration = deepcopy(self.calibration)
1✔
827

828
        if len(kwds) > 0:
1✔
829
            for key, value in kwds.items():
1✔
830
                calibration[key] = value
1✔
831
            calibration["creation_date"] = datetime.now().timestamp()
1✔
832

833
        elif "creation_date" in calibration and verbose:
1✔
834
            datestring = datetime.fromtimestamp(calibration["creation_date"]).strftime(
1✔
835
                "%m/%d/%Y, %H:%M:%S",
836
            )
837
            print(f"Using energy calibration parameters generated on {datestring}")
1✔
838

839
        # try to determine calibration type if not provided
840
        if "calib_type" not in calibration:
1✔
841
            if "t0" in calibration and "d" in calibration and "E0" in calibration:
1✔
842
                calibration["calib_type"] = "fit"
1✔
843
                if "energy_scale" not in calibration:
1✔
844
                    calibration["energy_scale"] = "kinetic"
1✔
845

846
            elif "coeffs" in calibration and "E0" in calibration:
1✔
847
                calibration["calib_type"] = "poly"
1✔
848
            else:
849
                raise ValueError("No valid calibration parameters provided!")
1✔
850

851
        if calibration["calib_type"] == "fit":
1✔
852
            # Fitting metadata for nexus
853
            calibration["fit_function"] = "(a0/(x0-a1))**2 + a2"
1✔
854
            calibration["coefficients"] = np.array(
1✔
855
                [
856
                    calibration["d"],
857
                    calibration["t0"],
858
                    calibration["E0"],
859
                ],
860
            )
861
            df[energy_column] = tof2ev(
1✔
862
                calibration["d"],
863
                calibration["t0"],
864
                binwidth,
865
                binning,
866
                calibration["energy_scale"],
867
                calibration["E0"],
868
                df[tof_column].astype("float64"),
869
            )
870
        elif calibration["calib_type"] == "poly":
1✔
871
            # Fitting metadata for nexus
872
            fit_function = "a0"
1✔
873
            for term in range(1, len(calibration["coeffs"]) + 1):
1✔
874
                fit_function += f" + a{term}*x0**{term}"
1✔
875
            calibration["fit_function"] = fit_function
1✔
876
            calibration["coefficients"] = np.concatenate(
1✔
877
                (calibration["coeffs"], [calibration["E0"]]),
878
            )[::-1]
879
            df[energy_column] = tof2evpoly(
1✔
880
                calibration["coeffs"],
881
                calibration["E0"],
882
                df[tof_column].astype("float64"),
883
            )
884
        else:
885
            raise NotImplementedError
1✔
886

887
        metadata = self.gather_calibration_metadata(calibration)
1✔
888

889
        return df, metadata
1✔
890

891
    def append_tof_ns_axis(
1✔
892
        self,
893
        df: Union[pd.DataFrame, dask.dataframe.DataFrame],
894
        tof_column: str = None,
895
        tof_ns_column: str = None,
896
        **kwds,
897
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
898
        """Converts the time-of-flight time from steps to time in ns.
899

900
        Args:
901
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to convert.
902
            tof_column (str, optional): Name of the column containing the
903
                time-of-flight steps. Defaults to config["dataframe"]["tof_column"].
904
            tof_ns_column (str, optional): Name of the column to store the
905
                time-of-flight in nanoseconds. Defaults to config["dataframe"]["tof_ns_column"].
906
            binwidth (float, optional): Time-of-flight binwidth in ns.
907
                Defaults to config["energy"]["tof_binwidth"].
908
            binning (int, optional): Time-of-flight binning factor.
909
                Defaults to config["energy"]["tof_binning"].
910

911
        Returns:
912
            dask.dataframe.DataFrame: Dataframe with the new columns.
913
            dict: Metadata dictionary.
914
        """
915
        binwidth = kwds.pop("binwidth", self.binwidth)
1✔
916
        binning = kwds.pop("binning", self.binning)
1✔
917
        if tof_column is None:
1✔
918
            if self.corrected_tof_column in df.columns:
1✔
UNCOV
919
                tof_column = self.corrected_tof_column
×
920
            else:
921
                tof_column = self.tof_column
1✔
922

923
        if tof_ns_column is None:
1✔
924
            tof_ns_column = self.tof_ns_column
1✔
925

926
        df[tof_ns_column] = tof2ns(
1✔
927
            binwidth,
928
            binning,
929
            df[tof_column].astype("float64"),
930
        )
931
        metadata: Dict[str, Any] = {
1✔
932
            "applied": True,
933
            "binwidth": binwidth,
934
            "binning": binning,
935
        }
936
        return df, metadata
1✔
937

938
    def gather_calibration_metadata(self, calibration: dict = None) -> dict:
1✔
939
        """Collects metadata from the energy calibration
940

941
        Args:
942
            calibration (dict, optional): Dictionary with energy calibration
943
                parameters. Defaults to None.
944

945
        Returns:
946
            dict: Generated metadata dictionary.
947
        """
948
        if calibration is None:
1✔
UNCOV
949
            calibration = self.calibration
×
950
        metadata: Dict[Any, Any] = {}
1✔
951
        metadata["applied"] = True
1✔
952
        metadata["calibration"] = deepcopy(calibration)
1✔
953
        metadata["tof"] = deepcopy(self.tof)
1✔
954
        # create empty calibrated axis entry, if it is not present.
955
        if "axis" not in metadata["calibration"]:
1✔
956
            metadata["calibration"]["axis"] = 0
1✔
957

958
        return metadata
1✔
959

960
    def adjust_energy_correction(
1✔
961
        self,
962
        image: xr.DataArray,
963
        correction_type: str = None,
964
        amplitude: float = None,
965
        center: Tuple[float, float] = None,
966
        correction: dict = None,
967
        apply: bool = False,
968
        **kwds,
969
    ):
970
        """Visualize the energy correction function on top of the TOF/X/Y graphs.
971

972
        Args:
973
            image (xr.DataArray): Image data cube (x, y, tof) of binned data to plot.
974
            correction_type (str, optional): Type of correction to apply to the TOF
975
                axis. Valid values are:
976

977
                - 'spherical'
978
                - 'Lorentzian'
979
                - 'Gaussian'
980
                - 'Lorentzian_asymmetric'
981

982
                Defaults to config["energy"]["correction_type"].
983
            amplitude (float, optional): Amplitude of the time-of-flight correction
984
                term. Defaults to config["energy"]["correction"]["correction_type"].
985
            center (Tuple[float, float], optional): Center (x/y) coordinates for the
986
                correction. Defaults to config["energy"]["correction"]["center"].
987
            correction (dict, optional): Correction dict. Defaults to the config values
988
                and is updated from provided and adjusted parameters.
989
            apply (bool, optional): whether to store the provided parameters within
990
                the class. Defaults to False.
991
            **kwds: Additional parameters to use for the adjustment plots:
992

993
                - **x_column** (str): Name of the x column.
994
                - **y_column** (str): Name of the y column.
995
                - **tof_column** (str): Name of the tog column to convert.
996
                - **x_width** (int, int): x range to integrate around the center
997
                - **y_width** (int, int): y range to integrate around the center
998
                - **tof_fermi** (int): TOF value of the Fermi level
999
                - **tof_width** (int, int): TOF range to plot around tof_fermi
1000
                - **color_clip** (int): highest value to plot in the color range
1001

1002
                Additional parameters for the correction functions:
1003

1004
                - **d** (float): Field-free drift distance.
1005
                - **gamma** (float): Linewidth value for correction using a 2D
1006
                  Lorentz profile.
1007
                - **sigma** (float): Standard deviation for correction using a 2D
1008
                  Gaussian profile.
1009
                - **gamma2** (float): Linewidth value for correction using an
1010
                  asymmetric 2D Lorentz profile, X-direction.
1011
                - **amplitude2** (float): Amplitude value for correction using an
1012
                  asymmetric 2D Lorentz profile, X-direction.
1013

1014
        Raises:
1015
            NotImplementedError: Raised for invalid correction_type.
1016
        """
1017
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
1018

1019
        if correction is None:
1✔
1020
            correction = deepcopy(self.correction)
1✔
1021

1022
        if correction_type is not None:
1✔
1023
            correction["correction_type"] = correction_type
1✔
1024

1025
        if amplitude is not None:
1✔
1026
            correction["amplitude"] = amplitude
1✔
1027

1028
        if center is not None:
1✔
1029
            correction["center"] = center
1✔
1030

1031
        x_column = kwds.pop("x_column", self.x_column)
1✔
1032
        y_column = kwds.pop("y_column", self.y_column)
1✔
1033
        tof_column = kwds.pop("tof_column", self.tof_column)
1✔
1034
        x_width = kwds.pop("x_width", self.x_width)
1✔
1035
        y_width = kwds.pop("y_width", self.y_width)
1✔
1036
        tof_fermi = kwds.pop("tof_fermi", self.tof_fermi)
1✔
1037
        tof_width = kwds.pop("tof_width", self.tof_width)
1✔
1038
        color_clip = kwds.pop("color_clip", self.color_clip)
1✔
1039

1040
        correction = {**correction, **kwds}
1✔
1041

1042
        if not {"correction_type", "amplitude", "center"}.issubset(set(correction.keys())):
1✔
1043
            raise ValueError(
1✔
1044
                "No valid energy correction found in config and required parameters missing!",
1045
            )
1046

1047
        if isinstance(correction["center"], list):
1✔
1048
            correction["center"] = tuple(correction["center"])
1✔
1049

1050
        x = image.coords[x_column].values
1✔
1051
        y = image.coords[y_column].values
1✔
1052

1053
        x_center = correction["center"][0]
1✔
1054
        y_center = correction["center"][1]
1✔
1055

1056
        correction_x = tof_fermi - correction_function(
1✔
1057
            x=x,
1058
            y=y_center,
1059
            **correction,
1060
        )
1061
        correction_y = tof_fermi - correction_function(
1✔
1062
            x=x_center,
1063
            y=y,
1064
            **correction,
1065
        )
1066
        fig, ax = plt.subplots(2, 1)
1✔
1067
        image.loc[
1✔
1068
            {
1069
                y_column: slice(y_center + y_width[0], y_center + y_width[1]),
1070
                tof_column: slice(
1071
                    tof_fermi + tof_width[0],
1072
                    tof_fermi + tof_width[1],
1073
                ),
1074
            }
1075
        ].sum(dim=y_column).T.plot(
1076
            ax=ax[0],
1077
            cmap="terrain_r",
1078
            vmax=color_clip,
1079
            yincrease=False,
1080
        )
1081
        image.loc[
1✔
1082
            {
1083
                x_column: slice(x_center + x_width[0], x_center + x_width[1]),
1084
                tof_column: slice(
1085
                    tof_fermi + tof_width[0],
1086
                    tof_fermi + tof_width[1],
1087
                ),
1088
            }
1089
        ].sum(dim=x_column).T.plot(
1090
            ax=ax[1],
1091
            cmap="terrain_r",
1092
            vmax=color_clip,
1093
            yincrease=False,
1094
        )
1095
        (trace1,) = ax[0].plot(x, correction_x)
1✔
1096
        line1 = ax[0].axvline(x=x_center)
1✔
1097
        (trace2,) = ax[1].plot(y, correction_y)
1✔
1098
        line2 = ax[1].axvline(x=y_center)
1✔
1099

1100
        amplitude_slider = ipw.FloatSlider(
1✔
1101
            value=correction["amplitude"],
1102
            min=0,
1103
            max=10,
1104
            step=0.1,
1105
        )
1106
        x_center_slider = ipw.FloatSlider(
1✔
1107
            value=x_center,
1108
            min=0,
1109
            max=self._config["momentum"]["detector_ranges"][0][1],
1110
            step=1,
1111
        )
1112
        y_center_slider = ipw.FloatSlider(
1✔
1113
            value=y_center,
1114
            min=0,
1115
            max=self._config["momentum"]["detector_ranges"][1][1],
1116
            step=1,
1117
        )
1118

1119
        def update(amplitude, x_center, y_center, **kwds):
1✔
1120
            nonlocal correction
1121
            correction["amplitude"] = amplitude
1✔
1122
            correction["center"] = (x_center, y_center)
1✔
1123
            correction = {**correction, **kwds}
1✔
1124
            correction_x = tof_fermi - correction_function(
1✔
1125
                x=x,
1126
                y=y_center,
1127
                **correction,
1128
            )
1129
            correction_y = tof_fermi - correction_function(
1✔
1130
                x=x_center,
1131
                y=y,
1132
                **correction,
1133
            )
1134

1135
            trace1.set_ydata(correction_x)
1✔
1136
            line1.set_xdata(x=x_center)
1✔
1137
            trace2.set_ydata(correction_y)
1✔
1138
            line2.set_xdata(x=y_center)
1✔
1139

1140
            fig.canvas.draw_idle()
1✔
1141

1142
        def common_apply_func(apply: bool):  # noqa: ARG001
1✔
1143
            self.correction = {}
1✔
1144
            self.correction["amplitude"] = correction["amplitude"]
1✔
1145
            self.correction["center"] = correction["center"]
1✔
1146
            self.correction["correction_type"] = correction["correction_type"]
1✔
1147
            self.correction["creation_date"] = datetime.now().timestamp()
1✔
1148
            amplitude_slider.close()
1✔
1149
            x_center_slider.close()
1✔
1150
            y_center_slider.close()
1✔
1151
            apply_button.close()
1✔
1152

1153
        if correction["correction_type"] == "spherical":
1✔
1154
            try:
1✔
1155
                update(correction["amplitude"], x_center, y_center, diameter=correction["diameter"])
1✔
UNCOV
1156
            except KeyError as exc:
×
UNCOV
1157
                raise ValueError(
×
1158
                    "Parameter 'diameter' required for correction type 'sperical', ",
1159
                    "but not present!",
1160
                ) from exc
1161

1162
            diameter_slider = ipw.FloatSlider(
1✔
1163
                value=correction["diameter"],
1164
                min=0,
1165
                max=10000,
1166
                step=100,
1167
            )
1168

1169
            ipw.interact(
1✔
1170
                update,
1171
                amplitude=amplitude_slider,
1172
                x_center=x_center_slider,
1173
                y_center=y_center_slider,
1174
                diameter=diameter_slider,
1175
            )
1176

1177
            def apply_func(apply: bool):
1✔
1178
                common_apply_func(apply)
1✔
1179
                self.correction["diameter"] = correction["diameter"]
1✔
1180
                diameter_slider.close()
1✔
1181

1182
        elif correction["correction_type"] == "Lorentzian":
1✔
1183
            try:
1✔
1184
                update(correction["amplitude"], x_center, y_center, gamma=correction["gamma"])
1✔
UNCOV
1185
            except KeyError as exc:
×
UNCOV
1186
                raise ValueError(
×
1187
                    "Parameter 'gamma' required for correction type 'Lorentzian', but not present!",
1188
                ) from exc
1189

1190
            gamma_slider = ipw.FloatSlider(
1✔
1191
                value=correction["gamma"],
1192
                min=0,
1193
                max=2000,
1194
                step=1,
1195
            )
1196

1197
            ipw.interact(
1✔
1198
                update,
1199
                amplitude=amplitude_slider,
1200
                x_center=x_center_slider,
1201
                y_center=y_center_slider,
1202
                gamma=gamma_slider,
1203
            )
1204

1205
            def apply_func(apply: bool):
1✔
1206
                common_apply_func(apply)
1✔
1207
                self.correction["gamma"] = correction["gamma"]
1✔
1208
                gamma_slider.close()
1✔
1209

1210
        elif correction["correction_type"] == "Gaussian":
1✔
1211
            try:
1✔
1212
                update(correction["amplitude"], x_center, y_center, sigma=correction["sigma"])
1✔
UNCOV
1213
            except KeyError as exc:
×
UNCOV
1214
                raise ValueError(
×
1215
                    "Parameter 'sigma' required for correction type 'Gaussian', but not present!",
1216
                ) from exc
1217

1218
            sigma_slider = ipw.FloatSlider(
1✔
1219
                value=correction["sigma"],
1220
                min=0,
1221
                max=1000,
1222
                step=1,
1223
            )
1224

1225
            ipw.interact(
1✔
1226
                update,
1227
                amplitude=amplitude_slider,
1228
                x_center=x_center_slider,
1229
                y_center=y_center_slider,
1230
                sigma=sigma_slider,
1231
            )
1232

1233
            def apply_func(apply: bool):
1✔
1234
                common_apply_func(apply)
1✔
1235
                self.correction["sigma"] = correction["sigma"]
1✔
1236
                sigma_slider.close()
1✔
1237

1238
        elif correction["correction_type"] == "Lorentzian_asymmetric":
1✔
1239
            try:
1✔
1240
                if "amplitude2" not in correction:
1✔
1241
                    correction["amplitude2"] = correction["amplitude"]
1✔
1242
                if "sigma2" not in correction:
1✔
1243
                    correction["gamma2"] = correction["gamma"]
1✔
1244
                update(
1✔
1245
                    correction["amplitude"],
1246
                    x_center,
1247
                    y_center,
1248
                    gamma=correction["gamma"],
1249
                    amplitude2=correction["amplitude2"],
1250
                    gamma2=correction["gamma2"],
1251
                )
UNCOV
1252
            except KeyError as exc:
×
UNCOV
1253
                raise ValueError(
×
1254
                    "Parameter 'gamma' required for correction type 'Lorentzian_asymmetric', ",
1255
                    "but not present!",
1256
                ) from exc
1257

1258
            gamma_slider = ipw.FloatSlider(
1✔
1259
                value=correction["gamma"],
1260
                min=0,
1261
                max=2000,
1262
                step=1,
1263
            )
1264

1265
            amplitude2_slider = ipw.FloatSlider(
1✔
1266
                value=correction["amplitude2"],
1267
                min=0,
1268
                max=10,
1269
                step=0.1,
1270
            )
1271

1272
            gamma2_slider = ipw.FloatSlider(
1✔
1273
                value=correction["gamma2"],
1274
                min=0,
1275
                max=2000,
1276
                step=1,
1277
            )
1278

1279
            ipw.interact(
1✔
1280
                update,
1281
                amplitude=amplitude_slider,
1282
                x_center=x_center_slider,
1283
                y_center=y_center_slider,
1284
                gamma=gamma_slider,
1285
                amplitude2=amplitude2_slider,
1286
                gamma2=gamma2_slider,
1287
            )
1288

1289
            def apply_func(apply: bool):
1✔
1290
                common_apply_func(apply)
1✔
1291
                self.correction["gamma"] = correction["gamma"]
1✔
1292
                self.correction["amplitude2"] = correction["amplitude2"]
1✔
1293
                self.correction["gamma2"] = correction["gamma2"]
1✔
1294
                gamma_slider.close()
1✔
1295
                amplitude2_slider.close()
1✔
1296
                gamma2_slider.close()
1✔
1297

1298
        else:
UNCOV
1299
            raise NotImplementedError
×
1300
        # pylint: disable=duplicate-code
1301
        apply_button = ipw.Button(description="apply")
1✔
1302
        display(apply_button)
1✔
1303
        apply_button.on_click(apply_func)
1✔
1304
        plt.show()
1✔
1305

1306
        if apply:
1✔
1307
            apply_func(True)
1✔
1308

1309
    def apply_energy_correction(
1✔
1310
        self,
1311
        df: Union[pd.DataFrame, dask.dataframe.DataFrame],
1312
        tof_column: str = None,
1313
        new_tof_column: str = None,
1314
        correction_type: str = None,
1315
        amplitude: float = None,
1316
        correction: dict = None,
1317
        verbose: bool = True,
1318
        **kwds,
1319
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
1320
        """Apply correction to the time-of-flight (TOF) axis of single-event data.
1321

1322
        Args:
1323
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]): The dataframe where
1324
                to apply the energy correction to.
1325
            tof_column (str, optional): Name of the source column to convert.
1326
                Defaults to config["dataframe"]["tof_column"].
1327
            new_tof_column (str, optional): Name of the destination column to convert.
1328
                Defaults to config["dataframe"]["corrected_tof_column"].
1329
            correction_type (str, optional): Type of correction to apply to the TOF
1330
                axis. Valid values are:
1331

1332
                - 'spherical'
1333
                - 'Lorentzian'
1334
                - 'Gaussian'
1335
                - 'Lorentzian_asymmetric'
1336

1337
                Defaults to config["energy"]["correction_type"].
1338
            amplitude (float, optional): Amplitude of the time-of-flight correction
1339
                term. Defaults to config["energy"]["correction"]["correction_type"].
1340
            correction (dict, optional): Correction dictionary containing paramters
1341
                for the correction. Defaults to self.correction or
1342
                config["energy"]["correction"].
1343
            verbose (bool, optional): Option to print out diagnostic information.
1344
                Defaults to True.
1345
            **kwds: Additional parameters to use for the correction:
1346

1347
                - **x_column** (str): Name of the x column.
1348
                - **y_column** (str): Name of the y column.
1349
                - **d** (float): Field-free drift distance.
1350
                - **gamma** (float): Linewidth value for correction using a 2D
1351
                  Lorentz profile.
1352
                - **sigma** (float): Standard deviation for correction using a 2D
1353
                  Gaussian profile.
1354
                - **gamma2** (float): Linewidth value for correction using an
1355
                  asymmetric 2D Lorentz profile, X-direction.
1356
                - **amplitude2** (float): Amplitude value for correction using an
1357
                  asymmetric 2D Lorentz profile, X-direction.
1358

1359
        Returns:
1360
            Union[pd.DataFrame, dask.dataframe.DataFrame]: dataframe with added column
1361
            and Energy correction metadata dictionary.
1362
        """
1363
        if correction is None:
1✔
1364
            correction = deepcopy(self.correction)
1✔
1365

1366
        x_column = kwds.pop("x_column", self.x_column)
1✔
1367
        y_column = kwds.pop("y_column", self.y_column)
1✔
1368

1369
        if tof_column is None:
1✔
1370
            tof_column = self.tof_column
1✔
1371

1372
        if new_tof_column is None:
1✔
1373
            new_tof_column = self.corrected_tof_column
1✔
1374

1375
        if correction_type is not None or amplitude is not None or len(kwds) > 0:
1✔
1376
            if correction_type is not None:
1✔
1377
                correction["correction_type"] = correction_type
1✔
1378

1379
            if amplitude is not None:
1✔
1380
                correction["amplitude"] = amplitude
1✔
1381

1382
            for key, value in kwds.items():
1✔
1383
                correction[key] = value
1✔
1384

1385
            correction["creation_date"] = datetime.now().timestamp()
1✔
1386

1387
        elif "creation_date" in correction and verbose:
1✔
1388
            datestring = datetime.fromtimestamp(correction["creation_date"]).strftime(
1✔
1389
                "%m/%d/%Y, %H:%M:%S",
1390
            )
1391
            print(f"Using energy correction parameters generated on {datestring}")
1✔
1392

1393
        missing_keys = {"correction_type", "center", "amplitude"} - set(correction.keys())
1✔
1394
        if missing_keys:
1✔
1395
            raise ValueError(f"Required correction parameters '{missing_keys}' missing!")
1✔
1396

1397
        df[new_tof_column] = df[tof_column] + correction_function(
1✔
1398
            x=df[x_column],
1399
            y=df[y_column],
1400
            **correction,
1401
        )
1402
        metadata = self.gather_correction_metadata(correction=correction)
1✔
1403

1404
        return df, metadata
1✔
1405

1406
    def gather_correction_metadata(self, correction: dict = None) -> dict:
1✔
1407
        """Collect meta data for energy correction
1408

1409
        Args:
1410
            correction (dict, optional): Dictionary with energy correction parameters.
1411
                Defaults to None.
1412

1413
        Returns:
1414
            dict: Generated metadata dictionary.
1415
        """
1416
        if correction is None:
1✔
UNCOV
1417
            correction = self.correction
×
1418
        metadata: Dict[Any, Any] = {}
1✔
1419
        metadata["applied"] = True
1✔
1420
        metadata["correction"] = deepcopy(correction)
1✔
1421

1422
        return metadata
1✔
1423

1424
    def align_dld_sectors(
1✔
1425
        self,
1426
        df: dask.dataframe.DataFrame,
1427
        tof_column: str = None,
1428
        sector_id_column: str = None,
1429
        sector_delays: np.ndarray = None,
1430
    ) -> Tuple[dask.dataframe.DataFrame, dict]:
1431
        """Aligns the time-of-flight axis of the different sections of a detector.
1432

1433
        Args:
1434
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
1435
            tof_column (str, optional): Name of the column containing the time-of-flight values.
1436
                Defaults to config["dataframe"]["tof_column"].
1437
            sector_id_column (str, optional): Name of the column containing the sector id values.
1438
                Defaults to config["dataframe"]["sector_id_column"].
1439
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1440
                config["dataframe"]["sector_delays"].
1441

1442
        Returns:
1443
            dask.dataframe.DataFrame: Dataframe with the new columns.
1444
            dict: Metadata dictionary.
1445
        """
1446
        if sector_delays is None:
1✔
1447
            sector_delays = self.sector_delays
1✔
1448
        if sector_id_column is None:
1✔
1449
            sector_id_column = self.sector_id_column
1✔
1450

1451
        if sector_delays is None or sector_id_column is None:
1✔
1452
            raise ValueError(
1✔
1453
                "No value for sector_delays or sector_id_column found in config."
1454
                "Config file is not properly configured for dld sector correction.",
1455
            )
1456
        tof_column = tof_column or self.tof_column
1✔
1457

1458
        # align the 8s sectors
1459
        sector_delays_arr = dask.array.from_array(sector_delays)
1✔
1460

1461
        def align_sector(x):
1✔
1462
            val = x[tof_column] - sector_delays_arr[x[sector_id_column].values.astype(int)]
1✔
1463
            return val.astype(np.float32)
1✔
1464

1465
        df[tof_column] = df.map_partitions(align_sector, meta=(tof_column, np.float32))
1✔
1466
        metadata: Dict[str, Any] = {
1✔
1467
            "applied": True,
1468
            "sector_delays": sector_delays,
1469
        }
1470
        return df, metadata
1✔
1471

1472
    def add_offsets(
1✔
1473
        self,
1474
        df: Union[pd.DataFrame, dask.dataframe.DataFrame] = None,
1475
        offsets: Dict[str, Any] = None,
1476
        constant: float = None,
1477
        columns: Union[str, Sequence[str]] = None,
1478
        weights: Union[float, Sequence[float]] = None,
1479
        preserve_mean: Union[bool, Sequence[bool]] = False,
1480
        reductions: Union[str, Sequence[str]] = None,
1481
        energy_column: str = None,
1482
        verbose: bool = True,
1483
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
1484
        """Apply an offset to the energy column by the values of the provided columns.
1485

1486
        If no parameter is passed to this function, the offset is applied as defined in the
1487
        config file. If parameters are passed, they are used to generate a new offset dictionary
1488
        and the offset is applied using the ``dfops.apply_offset_from_columns()`` function.
1489

1490
        Args:
1491
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]): Dataframe to use.
1492
            offsets (Dict, optional): Dictionary of energy offset parameters.
1493
            constant (float, optional): The constant to shift the energy axis by.
1494
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1495
            weights (Union[float, Sequence[float]]): weights to apply to the columns.
1496
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1497
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1498
                shift. Defaults to False.
1499
            reductions (str): The reduction to apply to the column. Should be an available method
1500
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1501
                to the column to generate a single value for the whole dataset. If None, the shift
1502
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1503
            energy_column (str, optional): Name of the column containing the energy values.
1504
            verbose (bool, optional): Option to print out diagnostic information.
1505
                Defaults to True.
1506

1507
        Returns:
1508
            dask.dataframe.DataFrame: Dataframe with the new columns.
1509
            dict: Metadata dictionary.
1510
        """
1511
        if offsets is None:
1✔
1512
            offsets = deepcopy(self.offsets)
1✔
1513

1514
        if energy_column is None:
1✔
1515
            energy_column = self.energy_column
1✔
1516

1517
        metadata: Dict[str, Any] = {
1✔
1518
            "applied": True,
1519
        }
1520

1521
        # flip sign for binding energy scale
1522
        energy_scale = self.calibration.get("energy_scale", None)
1✔
1523
        if energy_scale is None:
1✔
1524
            raise ValueError("Energy scale not set. Cannot interpret the sign of the offset.")
1✔
1525
        if energy_scale not in ["binding", "kinetic"]:
1✔
1526
            raise ValueError(f"Invalid energy scale: {energy_scale}")
1✔
1527
        scale_sign: Literal[-1, 1] = -1 if energy_scale == "binding" else 1
1✔
1528

1529
        if columns is not None or constant is not None:
1✔
1530
            # pylint:disable=duplicate-code
1531
            # use passed parameters, overwrite config
1532
            offsets = {}
1✔
1533
            offsets["creation_date"] = datetime.now().timestamp()
1✔
1534
            # column-based offsets
1535
            if columns is not None:
1✔
1536
                if isinstance(columns, str):
1✔
1537
                    columns = [columns]
1✔
1538

1539
                if weights is None:
1✔
1540
                    weights = 1
1✔
1541
                if isinstance(weights, (int, float, np.integer, np.floating)):
1✔
1542
                    weights = [weights]
1✔
1543
                if len(weights) == 1:
1✔
1544
                    weights = [weights[0]] * len(columns)
1✔
1545
                if not isinstance(weights, Sequence):
1✔
UNCOV
1546
                    raise TypeError(f"Invalid type for weights: {type(weights)}")
×
1547
                if not all(isinstance(s, (int, float, np.integer, np.floating)) for s in weights):
1✔
UNCOV
1548
                    raise TypeError(f"Invalid type for weights: {type(weights)}")
×
1549

1550
                if preserve_mean is None:
1✔
NEW
UNCOV
1551
                    preserve_mean = False
×
1552
                if not isinstance(preserve_mean, Sequence):
1✔
1553
                    preserve_mean = [preserve_mean]
1✔
1554
                if len(preserve_mean) == 1:
1✔
1555
                    preserve_mean = [preserve_mean[0]] * len(columns)
1✔
1556

1557
                if not isinstance(reductions, Sequence):
1✔
1558
                    reductions = [reductions]
1✔
1559
                if len(reductions) == 1:
1✔
1560
                    reductions = [reductions[0]] * len(columns)
1✔
1561

1562
                # store in offsets dictionary
1563
                for col, weight, pmean, red in zip(columns, weights, preserve_mean, reductions):
1✔
1564
                    offsets[col] = {
1✔
1565
                        "weight": weight,
1566
                        "preserve_mean": pmean,
1567
                        "reduction": red,
1568
                    }
1569

1570
            # constant offset
1571
            if isinstance(constant, (int, float, np.integer, np.floating)):
1✔
1572
                offsets["constant"] = constant
1✔
1573
            elif constant is not None:
1✔
1574
                raise TypeError(f"Invalid type for constant: {type(constant)}")
×
1575

1576
        elif "creation_date" in offsets and verbose:
1✔
UNCOV
1577
            datestring = datetime.fromtimestamp(offsets["creation_date"]).strftime(
×
1578
                "%m/%d/%Y, %H:%M:%S",
1579
            )
UNCOV
1580
            print(f"Using energy offset parameters generated on {datestring}")
×
1581

1582
        if len(offsets) > 0:
1✔
1583
            # unpack dictionary
1584
            # pylint: disable=duplicate-code
1585
            columns = []
1✔
1586
            weights = []
1✔
1587
            preserve_mean = []
1✔
1588
            reductions = []
1✔
1589
            if verbose:
1✔
1590
                print("Energy offset parameters:")
1✔
1591
            for k, v in offsets.items():
1✔
1592
                if k == "creation_date":
1✔
1593
                    continue
1✔
1594
                if k == "constant":
1✔
1595
                    # flip sign if binding energy scale
1596
                    constant = v * scale_sign
1✔
1597
                    if verbose:
1✔
1598
                        print(f"   Constant: {constant} ")
1✔
1599
                else:
1600
                    columns.append(k)
1✔
1601
                    try:
1✔
1602
                        weight = v["weight"]
1✔
UNCOV
1603
                    except KeyError:
×
UNCOV
1604
                        weight = 1
×
1605
                    if not isinstance(weight, (int, float, np.integer, np.floating)):
1✔
1606
                        raise TypeError(f"Invalid type for weight of column {k}: {type(weight)}")
1✔
1607
                    # flip sign if binding energy scale
1608
                    weight = weight * scale_sign
1✔
1609
                    weights.append(weight)
1✔
1610
                    pm = v.get("preserve_mean", False)
1✔
1611
                    if str(pm).lower() in ["false", "0", "no"]:
1✔
1612
                        pm = False
1✔
1613
                    elif str(pm).lower() in ["true", "1", "yes"]:
1✔
1614
                        pm = True
1✔
1615
                    preserve_mean.append(pm)
1✔
1616
                    red = v.get("reduction", None)
1✔
1617
                    if str(red).lower() in ["none", "null"]:
1✔
1618
                        red = None
1✔
1619
                    reductions.append(red)
1✔
1620
                    if verbose:
1✔
1621
                        print(
1✔
1622
                            f"   Column[{k}]: Weight={weight}, Preserve Mean: {pm}, ",
1623
                            f"Reductions: {red}.",
1624
                        )
1625

1626
            if len(columns) > 0:
1✔
1627
                df = dfops.offset_by_other_columns(
1✔
1628
                    df=df,
1629
                    target_column=energy_column,
1630
                    offset_columns=columns,
1631
                    weights=weights,
1632
                    preserve_mean=preserve_mean,
1633
                    reductions=reductions,
1634
                )
1635

1636
        # apply constant
1637
        if constant:
1✔
1638
            if not isinstance(constant, (int, float, np.integer, np.floating)):
1✔
1639
                raise TypeError(f"Invalid type for constant: {type(constant)}")
1✔
1640
            df[energy_column] = df.map_partitions(
1✔
1641
                lambda x: x[energy_column] + constant,
1642
                meta=(energy_column, np.float64),
1643
            )
1644

1645
        self.offsets = offsets
1✔
1646
        metadata["offsets"] = offsets
1✔
1647

1648
        return df, metadata
1✔
1649

1650

1651
def extract_bias(files: List[str], bias_key: str) -> np.ndarray:
1✔
1652
    """Read bias values from hdf5 files
1653

1654
    Args:
1655
        files (List[str]): List of filenames
1656
        bias_key (str): hdf5 path to the bias value
1657

1658
    Returns:
1659
        np.ndarray: Array of bias values.
1660
    """
1661
    bias_list: List[float] = []
1✔
1662
    for file in files:
1✔
1663
        with h5py.File(file, "r") as file_handle:
1✔
1664
            if bias_key[0] == "@":
1✔
1665
                bias_list.append(round(file_handle.attrs[bias_key[1:]], 2))
1✔
1666
            else:
UNCOV
1667
                bias_list.append(round(file_handle[bias_key], 2))
×
1668

1669
    return np.asarray(bias_list)
1✔
1670

1671

1672
def correction_function(
1✔
1673
    x: Union[float, np.ndarray],
1674
    y: Union[float, np.ndarray],
1675
    correction_type: str,
1676
    center: Tuple[float, float],
1677
    amplitude: float,
1678
    **kwds,
1679
) -> Union[float, np.ndarray]:
1680
    """Calculate the TOF correction based on the given X/Y coordinates and a model.
1681

1682
    Args:
1683
        x (float): x coordinate
1684
        y (float): y coordinate
1685
        correction_type (str): type of correction. One of
1686
            "spherical", "Lorentzian", "Gaussian", or "Lorentzian_asymmetric"
1687
        center (Tuple[int, int]): center position of the distribution (x,y)
1688
        amplitude (float): Amplitude of the correction
1689
        **kwds: Keyword arguments:
1690

1691
            - **diameter** (float): Field-free drift distance.
1692
            - **gamma** (float): Linewidth value for correction using a 2D
1693
              Lorentz profile.
1694
            - **sigma** (float): Standard deviation for correction using a 2D
1695
              Gaussian profile.
1696
            - **gamma2** (float): Linewidth value for correction using an
1697
              asymmetric 2D Lorentz profile, X-direction.
1698
            - **amplitude2** (float): Amplitude value for correction using an
1699
              asymmetric 2D Lorentz profile, X-direction.
1700

1701
    Returns:
1702
        float: calculated correction value
1703
    """
1704
    if correction_type == "spherical":
1✔
1705
        try:
1✔
1706
            diameter = kwds.pop("diameter")
1✔
1707
        except KeyError as exc:
1✔
1708
            raise ValueError(
1✔
1709
                f"Parameter 'diameter' required for correction type '{correction_type}' "
1710
                "but not provided!",
1711
            ) from exc
1712
        correction = -(
1✔
1713
            (
1714
                1
1715
                - np.sqrt(
1716
                    1 - ((x - center[0]) ** 2 + (y - center[1]) ** 2) / diameter**2,
1717
                )
1718
            )
1719
            * 100
1720
            * amplitude
1721
        )
1722

1723
    elif correction_type == "Lorentzian":
1✔
1724
        try:
1✔
1725
            gamma = kwds.pop("gamma")
1✔
1726
        except KeyError as exc:
1✔
1727
            raise ValueError(
1✔
1728
                f"Parameter 'gamma' required for correction type '{correction_type}' "
1729
                "but not provided!",
1730
            ) from exc
1731
        correction = (
1✔
1732
            100000
1733
            * amplitude
1734
            / (gamma * np.pi)
1735
            * (gamma**2 / ((x - center[0]) ** 2 + (y - center[1]) ** 2 + gamma**2) - 1)
1736
        )
1737

1738
    elif correction_type == "Gaussian":
1✔
1739
        try:
1✔
1740
            sigma = kwds.pop("sigma")
1✔
1741
        except KeyError as exc:
1✔
1742
            raise ValueError(
1✔
1743
                f"Parameter 'sigma' required for correction type '{correction_type}' "
1744
                "but not provided!",
1745
            ) from exc
1746
        correction = (
1✔
1747
            20000
1748
            * amplitude
1749
            / np.sqrt(2 * np.pi * sigma**2)
1750
            * (
1751
                np.exp(
1752
                    -((x - center[0]) ** 2 + (y - center[1]) ** 2) / (2 * sigma**2),
1753
                )
1754
                - 1
1755
            )
1756
        )
1757

1758
    elif correction_type == "Lorentzian_asymmetric":
1✔
1759
        try:
1✔
1760
            gamma = kwds.pop("gamma")
1✔
1761
        except KeyError as exc:
1✔
1762
            raise ValueError(
1✔
1763
                f"Parameter 'gamma' required for correction type '{correction_type}' "
1764
                "but not provided!",
1765
            ) from exc
1766
        gamma2 = kwds.pop("gamma2", gamma)
1✔
1767
        amplitude2 = kwds.pop("amplitude2", amplitude)
1✔
1768
        correction = (
1✔
1769
            100000
1770
            * amplitude
1771
            / (gamma * np.pi)
1772
            * (gamma**2 / ((y - center[1]) ** 2 + gamma**2) - 1)
1773
        )
1774
        correction += (
1✔
1775
            100000
1776
            * amplitude2
1777
            / (gamma2 * np.pi)
1778
            * (gamma2**2 / ((x - center[0]) ** 2 + gamma2**2) - 1)
1779
        )
1780

1781
    else:
UNCOV
1782
        raise NotImplementedError
×
1783

1784
    return correction
1✔
1785

1786

1787
def normspec(
1✔
1788
    specs: np.ndarray,
1789
    smooth: bool = False,
1790
    span: int = 7,
1791
    order: int = 1,
1792
) -> np.ndarray:
1793
    """Normalize a series of 1D signals.
1794

1795
    Args:
1796
        specs (np.ndarray): Collection of 1D signals.
1797
        smooth (bool, optional): Option to smooth the signals before normalization.
1798
            Defaults to False.
1799
        span (int, optional): Smoothing span parameters of the LOESS method
1800
            (see ``scipy.signal.savgol_filter()``). Defaults to 7.
1801
        order (int, optional): Smoothing order parameters of the LOESS method
1802
            (see ``scipy.signal.savgol_filter()``).. Defaults to 1.
1803

1804
    Returns:
1805
        np.ndarray: The matrix assembled from a list of maximum-normalized signals.
1806
    """
1807
    nspec = len(specs)
1✔
1808
    specnorm = []
1✔
1809

1810
    for i in range(nspec):
1✔
1811
        spec = specs[i]
1✔
1812

1813
        if smooth:
1✔
1814
            spec = savgol_filter(spec, span, order)
1✔
1815

1816
        if type(spec) in (list, tuple):
1✔
UNCOV
1817
            nsp = spec / max(spec)
×
1818
        else:
1819
            nsp = spec / spec.max()
1✔
1820
        specnorm.append(nsp)
1✔
1821

1822
        # Align 1D spectrum
1823
        normalized_specs = np.asarray(specnorm)
1✔
1824

1825
    return normalized_specs
1✔
1826

1827

1828
def find_correspondence(
1✔
1829
    sig_still: np.ndarray,
1830
    sig_mov: np.ndarray,
1831
    **kwds,
1832
) -> np.ndarray:
1833
    """Determine the correspondence between two 1D traces by alignment using a
1834
    time-warp algorithm.
1835

1836
    Args:
1837
        sig_still (np.ndarray): Reference 1D signals.
1838
        sig_mov (np.ndarray): 1D signal to be aligned.
1839
        **kwds: keyword arguments for ``fastdtw.fastdtw()``
1840

1841
    Returns:
1842
        np.ndarray: Pixel-wise path correspondences between two input 1D arrays
1843
        (sig_still, sig_mov).
1844
    """
1845
    dist = kwds.pop("dist_metric", None)
1✔
1846
    rad = kwds.pop("radius", 1)
1✔
1847
    _, pathcorr = fastdtw(sig_still, sig_mov, dist=dist, radius=rad)
1✔
1848
    return np.asarray(pathcorr)
1✔
1849

1850

1851
def range_convert(
1✔
1852
    x: np.ndarray,
1853
    xrng: Tuple,
1854
    pathcorr: np.ndarray,
1855
) -> Tuple:
1856
    """Convert value range using a pairwise path correspondence (e.g. obtained
1857
    from time warping algorithm).
1858

1859
    Args:
1860
        x (np.ndarray): Values of the x axis (e.g. time-of-flight values).
1861
        xrng (Tuple): Boundary value range on the x axis.
1862
        pathcorr (np.ndarray): Path correspondence between two 1D arrays in the
1863
            following form,
1864
            [(id_1_trace_1, id_1_trace_2), (id_2_trace_1, id_2_trace_2), ...]
1865

1866
    Returns:
1867
        Tuple: Transformed range according to the path correspondence.
1868
    """
1869
    pathcorr = np.asarray(pathcorr)
1✔
1870
    xrange_trans = []
1✔
1871

1872
    for xval in xrng:  # Transform each value in the range
1✔
1873
        xind = find_nearest(xval, x)
1✔
1874
        xind_alt = find_nearest(xind, pathcorr[:, 0])
1✔
1875
        xind_trans = pathcorr[xind_alt, 1]
1✔
1876
        xrange_trans.append(x[xind_trans])
1✔
1877

1878
    return tuple(xrange_trans)
1✔
1879

1880

1881
def find_nearest(val: float, narray: np.ndarray) -> int:
1✔
1882
    """Find the value closest to a given one in a 1D array.
1883

1884
    Args:
1885
        val (float): Value of interest.
1886
        narray (np.ndarray):  The array to look for the nearest value.
1887

1888
    Returns:
1889
        int: Array index of the value nearest to the given one.
1890
    """
1891
    return int(np.argmin(np.abs(narray - val)))
1✔
1892

1893

1894
def peaksearch(
1✔
1895
    traces: np.ndarray,
1896
    tof: np.ndarray,
1897
    ranges: List[Tuple] = None,
1898
    pkwindow: int = 3,
1899
    plot: bool = False,
1900
) -> np.ndarray:
1901
    """Detect a list of peaks in the corresponding regions of multiple spectra.
1902

1903
    Args:
1904
        traces (np.ndarray): Collection of 1D spectra.
1905
        tof (np.ndarray): Time-of-flight values.
1906
        ranges (List[Tuple], optional): List of ranges for peak detection in the format
1907
        [(LowerBound1, UpperBound1), (LowerBound2, UpperBound2), ....].
1908
            Defaults to None.
1909
        pkwindow (int, optional): Window width of a peak (amounts to lookahead in
1910
            ``peakdetect1d``). Defaults to 3.
1911
        plot (bool, optional): Specify whether to display a custom plot of the peak
1912
            search results. Defaults to False.
1913

1914
    Returns:
1915
        np.ndarray: Collection of peak positions.
1916
    """
1917
    pkmaxs = []
1✔
1918
    if plot:
1✔
UNCOV
1919
        plt.figure(figsize=(10, 4))
×
1920

1921
    for rng, trace in zip(ranges, traces.tolist()):
1✔
1922
        cond = (tof >= rng[0]) & (tof <= rng[1])
1✔
1923
        trace = np.array(trace).ravel()
1✔
1924
        tofseg, trseg = tof[cond], trace[cond]
1✔
1925
        maxs, _ = peakdetect1d(trseg, tofseg, lookahead=pkwindow)
1✔
1926
        try:
1✔
1927
            pkmaxs.append(maxs[0, :])
1✔
UNCOV
1928
        except IndexError:  # No peak found for this range
×
UNCOV
1929
            print(f"No peak detected in range {rng}.")
×
UNCOV
1930
            raise
×
1931

1932
        if plot:
1✔
UNCOV
1933
            plt.plot(tof, trace, "--k", linewidth=1)
×
UNCOV
1934
            plt.plot(tofseg, trseg, linewidth=2)
×
UNCOV
1935
            plt.scatter(maxs[0, 0], maxs[0, 1], s=30)
×
1936

1937
    return np.asarray(pkmaxs)
1✔
1938

1939

1940
# 1D peak detection algorithm adapted from Sixten Bergman
1941
# https://gist.github.com/sixtenbe/1178136#file-peakdetect-py
1942
def _datacheck_peakdetect(
1✔
1943
    x_axis: np.ndarray,
1944
    y_axis: np.ndarray,
1945
) -> Tuple[np.ndarray, np.ndarray]:
1946
    """Input format checking for 1D peakdtect algorithm
1947

1948
    Args:
1949
        x_axis (np.ndarray): x-axis array
1950
        y_axis (np.ndarray): y-axis array
1951

1952
    Raises:
1953
        ValueError: Raised if x and y values don't have the same length.
1954

1955
    Returns:
1956
        Tuple[np.ndarray, np.ndarray]: Tuple of checked (x/y) arrays.
1957
    """
1958

1959
    if x_axis is None:
1✔
UNCOV
1960
        x_axis = np.arange(len(y_axis))
×
1961

1962
    if len(y_axis) != len(x_axis):
1✔
UNCOV
1963
        raise ValueError(
×
1964
            "Input vectors y_axis and x_axis must have same length",
1965
        )
1966

1967
    # Needs to be a numpy array
1968
    y_axis = np.asarray(y_axis)
1✔
1969
    x_axis = np.asarray(x_axis)
1✔
1970

1971
    return x_axis, y_axis
1✔
1972

1973

1974
def peakdetect1d(
1✔
1975
    y_axis: np.ndarray,
1976
    x_axis: np.ndarray = None,
1977
    lookahead: int = 200,
1978
    delta: int = 0,
1979
) -> Tuple[np.ndarray, np.ndarray]:
1980
    """Function for detecting local maxima and minima in a signal.
1981
    Discovers peaks by searching for values which are surrounded by lower
1982
    or larger values for maxima and minima respectively
1983

1984
    Converted from/based on a MATLAB script at:
1985
    http://billauer.co.il/peakdet.html
1986

1987
    Args:
1988
        y_axis (np.ndarray): A list containing the signal over which to find peaks.
1989
        x_axis (np.ndarray, optional): A x-axis whose values correspond to the y_axis
1990
            list and is used in the return to specify the position of the peaks. If
1991
            omitted an index of the y_axis is used.
1992
        lookahead (int, optional): distance to look ahead from a peak candidate to
1993
            determine if it is the actual peak
1994
            '(samples / period) / f' where '4 >= f >= 1.25' might be a good value.
1995
            Defaults to 200.
1996
        delta (int, optional): this specifies a minimum difference between a peak and
1997
            the following points, before a peak may be considered a peak. Useful
1998
            to hinder the function from picking up false peaks towards to end of
1999
            the signal. To work well delta should be set to delta >= RMSnoise * 5.
2000
            Defaults to 0.
2001

2002
    Raises:
2003
        ValueError: Raised if lookahead and delta are out of range.
2004

2005
    Returns:
2006
        Tuple[np.ndarray, np.ndarray]: Tuple of positions of the positive peaks,
2007
        positions of the negative peaks
2008
    """
2009
    max_peaks = []
1✔
2010
    min_peaks = []
1✔
2011
    dump = []  # Used to pop the first hit which almost always is false
1✔
2012

2013
    # Check input data
2014
    x_axis, y_axis = _datacheck_peakdetect(x_axis, y_axis)
1✔
2015
    # Store data length for later use
2016
    length = len(y_axis)
1✔
2017

2018
    # Perform some checks
2019
    if lookahead < 1:
1✔
UNCOV
2020
        raise ValueError("Lookahead must be '1' or above in value")
×
2021

2022
    if not (np.ndim(delta) == 0 and delta >= 0):
1✔
UNCOV
2023
        raise ValueError("delta must be a positive number")
×
2024

2025
    # maxima and minima candidates are temporarily stored in
2026
    # mx and mn respectively
2027
    _min, _max = np.Inf, -np.Inf
1✔
2028

2029
    # Only detect peak if there is 'lookahead' amount of points after it
2030
    for index, (x, y) in enumerate(
1✔
2031
        zip(x_axis[:-lookahead], y_axis[:-lookahead]),
2032
    ):
2033
        if y > _max:
1✔
2034
            _max = y
1✔
2035
            _max_pos = x
1✔
2036

2037
        if y < _min:
1✔
2038
            _min = y
1✔
2039
            _min_pos = x
1✔
2040

2041
        # Find local maxima
2042
        if y < _max - delta and _max != np.Inf:
1✔
2043
            # Maxima peak candidate found
2044
            # look ahead in signal to ensure that this is a peak and not jitter
2045
            if y_axis[index : index + lookahead].max() < _max:
1✔
2046
                max_peaks.append([_max_pos, _max])
1✔
2047
                dump.append(True)
1✔
2048
                # Set algorithm to only find minima now
2049
                _max = np.Inf
1✔
2050
                _min = np.Inf
1✔
2051

2052
                if index + lookahead >= length:
1✔
2053
                    # The end is within lookahead no more peaks can be found
UNCOV
2054
                    break
×
UNCOV
2055
                continue
×
2056
            # else:
2057
            #    mx = ahead
2058
            #    mxpos = x_axis[np.where(y_axis[index:index+lookahead]==mx)]
2059

2060
        # Find local minima
2061
        if y > _min + delta and _min != -np.Inf:
1✔
2062
            # Minima peak candidate found
2063
            # look ahead in signal to ensure that this is a peak and not jitter
2064
            if y_axis[index : index + lookahead].min() > _min:
1✔
2065
                min_peaks.append([_min_pos, _min])
1✔
2066
                dump.append(False)
1✔
2067
                # Set algorithm to only find maxima now
2068
                _min = -np.Inf
1✔
2069
                _max = -np.Inf
1✔
2070

2071
                if index + lookahead >= length:
1✔
2072
                    # The end is within lookahead no more peaks can be found
UNCOV
2073
                    break
×
2074
            # else:
2075
            #    mn = ahead
2076
            #    mnpos = x_axis[np.where(y_axis[index:index+lookahead]==mn)]
2077

2078
    # Remove the false hit on the first value of the y_axis
2079
    try:
1✔
2080
        if dump[0]:
1✔
UNCOV
2081
            max_peaks.pop(0)
×
2082
        else:
2083
            min_peaks.pop(0)
1✔
2084
        del dump
1✔
2085

UNCOV
2086
    except IndexError:  # When no peaks have been found
×
UNCOV
2087
        pass
×
2088

2089
    return (np.asarray(max_peaks), np.asarray(min_peaks))
1✔
2090

2091

2092
def fit_energy_calibration(
1✔
2093
    pos: Union[List[float], np.ndarray],
2094
    vals: Union[List[float], np.ndarray],
2095
    binwidth: float,
2096
    binning: int,
2097
    ref_energy: float,
2098
    t: Union[List[float], np.ndarray] = None,
2099
    energy_scale: str = "kinetic",
2100
    verbose: bool = True,
2101
    **kwds,
2102
) -> dict:
2103
    """Energy calibration by nonlinear least squares fitting of spectral landmarks on
2104
    a set of (energy dispersion curves (EDCs). This is done here by fitting to the
2105
    function d/(t-t0)**2.
2106

2107
    Args:
2108
        pos (Union[List[float], np.ndarray]): Positions of the spectral landmarks
2109
            (e.g. peaks) in the EDCs.
2110
        vals (Union[List[float], np.ndarray]): Bias voltage value associated with
2111
            each EDC.
2112
        binwidth (float): Time width of each original TOF bin in ns.
2113
        binning (int): Binning factor of the TOF values.
2114
        ref_energy (float): Energy value of the feature in the refence
2115
            trace (eV).
2116
        t (Union[List[float], np.ndarray], optional): Array of TOF values. Required
2117
            to calculate calibration trace. Defaults to None.
2118
        energy_scale (str, optional): Direction of increasing energy scale.
2119

2120
            - **'kinetic'**: increasing energy with decreasing TOF.
2121
            - **'binding'**: increasing energy with increasing TOF.
2122
        verbose (bool, optional): Option to print out diagnostic information.
2123
            Defaults to True.
2124
        **kwds: keyword arguments:
2125

2126
            - **t0** (float): constrains and initial values for the fit parameter t0,
2127
              corresponding to the time of flight offset. Defaults to 1e-6.
2128
            - **E0** (float): constrains and initial values for the fit parameter E0,
2129
              corresponding to the energy offset. Defaults to min(vals).
2130
            - **d** (float): constrains and initial values for the fit parameter d,
2131
              corresponding to the drift distance. Defaults to 1.
2132

2133
    Returns:
2134
        dict: A dictionary of fitting parameters including the following,
2135

2136
        - "coeffs": Fitted function coefficents.
2137
        - "axis": Fitted energy axis.
2138
    """
2139
    vals = np.asarray(vals)
1✔
2140

2141
    def residual(pars, time, data, binwidth, binning, energy_scale):
1✔
2142
        model = tof2ev(
1✔
2143
            pars["d"],
2144
            pars["t0"],
2145
            binwidth,
2146
            binning,
2147
            energy_scale,
2148
            pars["E0"],
2149
            time,
2150
        )
2151
        if data is None:
1✔
UNCOV
2152
            return model
×
2153
        return model - data
1✔
2154

2155
    pars = Parameters()
1✔
2156
    d_pars = kwds.pop("d", {})
1✔
2157
    pars.add(
1✔
2158
        name="d",
2159
        value=d_pars.get("value", 1),
2160
        min=d_pars.get("min", -np.inf),
2161
        max=d_pars.get("max", np.inf),
2162
        vary=d_pars.get("vary", True),
2163
    )
2164
    t0_pars = kwds.pop("t0", {})
1✔
2165
    pars.add(
1✔
2166
        name="t0",
2167
        value=t0_pars.get("value", 1e-6),
2168
        min=t0_pars.get("min", -np.inf),
2169
        max=t0_pars.get(
2170
            "max",
2171
            (min(pos) - 1) * binwidth * 2**binning,
2172
        ),
2173
        vary=t0_pars.get("vary", True),
2174
    )
2175
    E0_pars = kwds.pop("E0", {})  # pylint: disable=invalid-name
1✔
2176
    pars.add(
1✔
2177
        name="E0",
2178
        value=E0_pars.get("value", min(vals)),
2179
        min=E0_pars.get("min", -np.inf),
2180
        max=E0_pars.get("max", np.inf),
2181
        vary=E0_pars.get("vary", True),
2182
    )
2183
    fit = Minimizer(
1✔
2184
        residual,
2185
        pars,
2186
        fcn_args=(pos, vals, binwidth, binning, energy_scale),
2187
    )
2188
    result = fit.leastsq()
1✔
2189
    if verbose:
1✔
2190
        report_fit(result)
1✔
2191

2192
    # Construct the calibrating function
2193
    pfunc = partial(
1✔
2194
        tof2ev,
2195
        result.params["d"].value,
2196
        result.params["t0"].value,
2197
        binwidth,
2198
        binning,
2199
        energy_scale,
2200
    )
2201

2202
    # Return results according to specification
2203
    ecalibdict = {}
1✔
2204
    ecalibdict["d"] = result.params["d"].value
1✔
2205
    ecalibdict["t0"] = result.params["t0"].value
1✔
2206
    ecalibdict["E0"] = result.params["E0"].value
1✔
2207
    ecalibdict["energy_scale"] = energy_scale
1✔
2208
    energy_offset = pfunc(-1 * ref_energy, pos[0])
1✔
2209
    ecalibdict["E0"] = -(energy_offset - vals[0])
1✔
2210

2211
    if t is not None:
1✔
2212
        ecalibdict["axis"] = pfunc(ecalibdict["E0"], t)
1✔
2213

2214
    return ecalibdict
1✔
2215

2216

2217
def poly_energy_calibration(
1✔
2218
    pos: Union[List[float], np.ndarray],
2219
    vals: Union[List[float], np.ndarray],
2220
    ref_energy: float,
2221
    order: int = 3,
2222
    t: Union[List[float], np.ndarray] = None,
2223
    aug: int = 1,
2224
    method: str = "lstsq",
2225
    energy_scale: str = "kinetic",
2226
    **kwds,
2227
) -> dict:
2228
    """Energy calibration by nonlinear least squares fitting of spectral landmarks on
2229
    a set of (energy dispersion curves (EDCs). This amounts to solving for the
2230
    coefficient vector, a, in the system of equations T.a = b. Here T is the
2231
    differential drift time matrix and b the differential bias vector, and
2232
    assuming that the energy-drift-time relationship can be written in the form,
2233
    E = sum_n (a_n * t**n) + E0
2234

2235

2236
    Args:
2237
        pos (Union[List[float], np.ndarray]): Positions of the spectral landmarks
2238
            (e.g. peaks) in the EDCs.
2239
        vals (Union[List[float], np.ndarray]): Bias voltage value associated with
2240
            each EDC.
2241
        ref_energy (float): Energy value of the feature in the refence
2242
            trace (eV).
2243
        order (int, optional): Polynomial order of the fitting function. Defaults to 3.
2244
        t (Union[List[float], np.ndarray], optional): Array of TOF values. Required
2245
            to calculate calibration trace. Defaults to None.
2246
        aug (int, optional): Fitting dimension augmentation
2247
            (1=no change, 2=double, etc). Defaults to 1.
2248
        method (str, optional): Method for determining the energy calibration.
2249

2250
            - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
2251
            - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form..
2252

2253
            Defaults to "lstsq".
2254
        energy_scale (str, optional): Direction of increasing energy scale.
2255

2256
            - **'kinetic'**: increasing energy with decreasing TOF.
2257
            - **'binding'**: increasing energy with increasing TOF.
2258

2259
    Returns:
2260
        dict: A dictionary of fitting parameters including the following,
2261

2262
        - "coeffs": Fitted polynomial coefficients (the a's).
2263
        - "offset": Minimum time-of-flight corresponding to a peak.
2264
        - "Tmat": the T matrix (differential time-of-flight) in the equation Ta=b.
2265
        - "bvec": the b vector (differential bias) in the fitting Ta=b.
2266
        - "axis": Fitted energy axis.
2267
    """
2268
    vals = np.asarray(vals)
1✔
2269
    nvals = vals.size
1✔
2270

2271
    # Top-to-bottom ordering of terms in the T matrix
2272
    termorder = np.delete(range(0, nvals, 1), 0)
1✔
2273
    termorder = np.tile(termorder, aug)
1✔
2274
    # Left-to-right ordering of polynomials in the T matrix
2275
    polyorder = np.linspace(order, 1, order, dtype="int")
1✔
2276

2277
    # Construct the T (differential drift time) matrix, Tmat = Tmain - Tsec
2278
    t_main = np.array([pos[0] ** p for p in polyorder])
1✔
2279
    # Duplicate to the same order as the polynomials
2280
    t_main = np.tile(t_main, (aug * (nvals - 1), 1))
1✔
2281

2282
    t_sec = []
1✔
2283

2284
    for term in termorder:
1✔
2285
        t_sec.append([pos[term] ** p for p in polyorder])
1✔
2286

2287
    t_mat = t_main - np.asarray(t_sec)
1✔
2288

2289
    # Construct the b vector (differential bias)
2290
    bvec = vals[0] - np.delete(vals, 0)
1✔
2291
    bvec = np.tile(bvec, aug)
1✔
2292

2293
    # Solve for the a vector (polynomial coefficients) using least squares
2294
    if method == "lstsq":
1✔
2295
        sol = lstsq(t_mat, bvec, rcond=None)
1✔
2296
    elif method == "lsqr":
1✔
2297
        sol = lsqr(t_mat, bvec, **kwds)
1✔
2298
    poly_a = sol[0]
1✔
2299

2300
    # Construct the calibrating function
2301
    pfunc = partial(tof2evpoly, poly_a)
1✔
2302

2303
    # Return results according to specification
2304
    ecalibdict = {}
1✔
2305
    ecalibdict["offset"] = np.asarray(pos).min()
1✔
2306
    ecalibdict["coeffs"] = poly_a
1✔
2307
    ecalibdict["Tmat"] = t_mat
1✔
2308
    ecalibdict["bvec"] = bvec
1✔
2309
    ecalibdict["energy_scale"] = energy_scale
1✔
2310
    ecalibdict["E0"] = -(pfunc(-1 * ref_energy, pos[0]) + vals[0])
1✔
2311

2312
    if t is not None:
1✔
2313
        ecalibdict["axis"] = pfunc(-ecalibdict["E0"], t)
1✔
2314

2315
    return ecalibdict
1✔
2316

2317

2318
def tof2ev(
1✔
2319
    tof_distance: float,
2320
    time_offset: float,
2321
    binwidth: float,
2322
    binning: int,
2323
    energy_scale: str,
2324
    energy_offset: float,
2325
    t: float,
2326
) -> float:
2327
    """(d/(t-t0))**2 expression of the time-of-flight to electron volt
2328
    conversion formula.
2329

2330
    Args:
2331
        tof_distance (float): Drift distance in meter.
2332
        time_offset (float): time offset in ns.
2333
        binwidth (float): Time width of each original TOF bin in ns.
2334
        binning (int): Binning factor of the TOF values.
2335
        energy_scale (str, optional): Direction of increasing energy scale.
2336

2337
            - **'kinetic'**: increasing energy with decreasing TOF.
2338
            - **'binding'**: increasing energy with increasing TOF.
2339

2340
        energy_offset (float): Energy offset in eV.
2341
        t (float): TOF value in bin number.
2342

2343
    Returns:
2344
        float: Converted energy in eV
2345
    """
2346
    sign = 1 if energy_scale == "kinetic" else -1
1✔
2347

2348
    #         m_e/2 [eV]                      bin width [s]
2349
    energy = (
1✔
2350
        2.84281e-12 * sign * (tof_distance / (t * binwidth * 2**binning - time_offset)) ** 2
2351
        + energy_offset
2352
    )
2353

2354
    return energy
1✔
2355

2356

2357
def tof2evpoly(
1✔
2358
    poly_a: Union[List[float], np.ndarray],
2359
    energy_offset: float,
2360
    t: float,
2361
) -> float:
2362
    """Polynomial approximation of the time-of-flight to electron volt
2363
    conversion formula.
2364

2365
    Args:
2366
        poly_a (Union[List[float], np.ndarray]): Polynomial coefficients.
2367
        energy_offset (float): Energy offset in eV.
2368
        t (float): TOF value in bin number.
2369

2370
    Returns:
2371
        float: Converted energy.
2372
    """
2373
    odr = len(poly_a)  # Polynomial order
1✔
2374
    poly_a = poly_a[::-1]
1✔
2375
    energy = 0.0
1✔
2376

2377
    for i, order in enumerate(range(1, odr + 1)):
1✔
2378
        energy += poly_a[i] * t**order
1✔
2379
    energy += energy_offset
1✔
2380

2381
    return energy
1✔
2382

2383

2384
def tof2ns(
1✔
2385
    binwidth: float,
2386
    binning: int,
2387
    t: float,
2388
) -> float:
2389
    """Converts the time-of-flight steps to time-of-flight in nanoseconds.
2390

2391
    designed for use with dask.dataframe.DataFrame.map_partitions.
2392

2393
    Args:
2394
        binwidth (float): Time step size in seconds.
2395
        binning (int): Binning of the time-of-flight steps.
2396
        t (float): TOF value in bin number.
2397
    Returns:
2398
        float: Converted time in nanoseconds.
2399
    """
2400
    val = t * 1e9 * binwidth * 2.0**binning
1✔
2401
    return val
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc