• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6520232780

14 Oct 2023 10:12PM UTC coverage: 90.267% (-0.3%) from 90.603%
6520232780

Pull #181

github

rettigl
define jitter_amps as single amplitude in default config
Pull Request #181: define jitter_amps as single amplitude in default config

4229 of 4685 relevant lines covered (90.27%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.86
/sed/calibrator/energy.py
1
"""sed.calibrator.energy module. Code for energy calibration and
2
correction. Mostly ported from https://github.com/mpes-kit/mpes.
3
"""
4
import itertools as it
1✔
5
import warnings as wn
1✔
6
from copy import deepcopy
1✔
7
from functools import partial
1✔
8
from typing import Any
1✔
9
from typing import cast
1✔
10
from typing import Dict
1✔
11
from typing import List
1✔
12
from typing import Sequence
1✔
13
from typing import Tuple
1✔
14
from typing import Union
1✔
15

16
import bokeh.plotting as pbk
1✔
17
import dask.dataframe
1✔
18
import h5py
1✔
19
import ipywidgets as ipw
1✔
20
import matplotlib
1✔
21
import matplotlib.pyplot as plt
1✔
22
import numpy as np
1✔
23
import pandas as pd
1✔
24
import psutil
1✔
25
import xarray as xr
1✔
26
from bokeh.io import output_notebook
1✔
27
from bokeh.palettes import Category10 as ColorCycle
1✔
28
from fastdtw import fastdtw
1✔
29
from IPython.display import display
1✔
30
from lmfit import Minimizer
1✔
31
from lmfit import Parameters
1✔
32
from lmfit.printfuncs import report_fit
1✔
33
from numpy.linalg import lstsq
1✔
34
from scipy.signal import savgol_filter
1✔
35
from scipy.sparse.linalg import lsqr
1✔
36

37
from sed.binning import bin_dataframe
1✔
38
from sed.loader.base.loader import BaseLoader
1✔
39

40

41
class EnergyCalibrator:
1✔
42
    """Electron binding energy calibration workflow.
43

44
    For the initialization of the EnergyCalibrator class an instance of a
45
    loader is required. The data can be loaded using the optional arguments,
46
    or using the load_data method or bin_data method.
47

48
    Args:
49
        loader (BaseLoader): Instance of a loader, subclassed from BaseLoader.
50
        biases (np.ndarray, optional): Bias voltages used. Defaults to None.
51
        traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
52
            values. Defaults to None.
53
        tof (np.ndarray, optional): TOF-values for the data traces.
54
            Defaults to None.
55
        config (dict, optional): Config dictionary. Defaults to None.
56
    """
57

58
    def __init__(
1✔
59
        self,
60
        loader: BaseLoader,
61
        biases: np.ndarray = None,
62
        traces: np.ndarray = None,
63
        tof: np.ndarray = None,
64
        config: dict = None,
65
    ):
66
        """For the initialization of the EnergyCalibrator class an instance of a
67
        loader is required. The data can be loaded using the optional arguments,
68
        or using the load_data method or bin_data method.
69

70
        Args:
71
            loader (BaseLoader): Instance of a loader, subclassed from BaseLoader.
72
            biases (np.ndarray, optional): Bias voltages used. Defaults to None.
73
            traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
74
                values. Defaults to None.
75
            tof (np.ndarray, optional): TOF-values for the data traces.
76
                Defaults to None.
77
            config (dict, optional): Config dictionary. Defaults to None.
78
        """
79
        self.loader = loader
1✔
80
        self.biases: np.ndarray = None
1✔
81
        self.traces: np.ndarray = None
1✔
82
        self.traces_normed: np.ndarray = None
1✔
83
        self.tof: np.ndarray = None
1✔
84

85
        if traces is not None and tof is not None and biases is not None:
1✔
86
            self.load_data(biases=biases, traces=traces, tof=tof)
×
87

88
        if config is None:
1✔
89
            config = {}
×
90

91
        self._config = config
1✔
92

93
        self.featranges: List[Tuple] = []  # Value ranges for feature detection
1✔
94
        self.peaks: np.ndarray = np.asarray([])
1✔
95
        self.calibration: Dict[Any, Any] = {}
1✔
96

97
        self.tof_column = self._config["dataframe"]["tof_column"]
1✔
98
        self.corrected_tof_column = self._config["dataframe"]["corrected_tof_column"]
1✔
99
        self.energy_column = self._config["dataframe"]["energy_column"]
1✔
100
        self.x_column = self._config["dataframe"]["x_column"]
1✔
101
        self.y_column = self._config["dataframe"]["y_column"]
1✔
102
        self.binwidth: float = self._config["dataframe"]["tof_binwidth"]
1✔
103
        self.binning: int = self._config["dataframe"]["tof_binning"]
1✔
104
        self.x_width = self._config["energy"]["x_width"]
1✔
105
        self.y_width = self._config["energy"]["y_width"]
1✔
106
        self.tof_width = np.asarray(
1✔
107
            self._config["energy"]["tof_width"],
108
        ) / 2 ** (self.binning - 1)
109
        self.tof_fermi = self._config["energy"]["tof_fermi"] / 2 ** (self.binning - 1)
1✔
110
        self.color_clip = self._config["energy"]["color_clip"]
1✔
111

112
        self.correction: Dict[Any, Any] = {}
1✔
113

114
    @property
1✔
115
    def ntraces(self) -> int:
1✔
116
        """Property returning the number of traces.
117

118
        Returns:
119
            int: The number of loaded/calculated traces.
120
        """
121
        return len(self.traces)
1✔
122

123
    @property
1✔
124
    def nranges(self) -> int:
1✔
125
        """Property returning the number of specified feature ranges which Can be a
126
        multiple of ntraces.
127

128
        Returns:
129
            int: The number of specified feature ranges.
130
        """
131
        return len(self.featranges)
1✔
132

133
    @property
1✔
134
    def dup(self) -> int:
1✔
135
        """Property returning the duplication number, i.e. the number of feature
136
        ranges per trace.
137

138
        Returns:
139
            int: The duplication number.
140
        """
141
        return int(np.round(self.nranges / self.ntraces))
1✔
142

143
    def load_data(
1✔
144
        self,
145
        biases: np.ndarray = None,
146
        traces: np.ndarray = None,
147
        tof: np.ndarray = None,
148
    ):
149
        """Load data into the class. Not provided parameters will be overwritten by
150
        empty arrays.
151

152
        Args:
153
            biases (np.ndarray, optional): Bias voltages used. Defaults to None.
154
            traces (np.ndarray, optional): TOF-Data traces corresponding to the bias
155
                values. Defaults to None.
156
            tof (np.ndarray, optional): TOF-values for the data traces.
157
                Defaults to None.
158
        """
159
        if biases is not None:
1✔
160
            self.biases = biases
1✔
161
        else:
162
            self.biases = np.asarray([])
×
163
        if tof is not None:
1✔
164
            self.tof = tof
1✔
165
        else:
166
            self.tof = np.asarray([])
×
167
        if traces is not None:
1✔
168
            self.traces = self.traces_normed = traces
1✔
169
        else:
170
            self.traces = self.traces_normed = np.asarray([])
×
171

172
    def bin_data(
1✔
173
        self,
174
        data_files: List[str],
175
        axes: List[str] = None,
176
        bins: List[int] = None,
177
        ranges: Sequence[Tuple[float, float]] = None,
178
        biases: np.ndarray = None,
179
        bias_key: str = None,
180
        **kwds,
181
    ):
182
        """Bin data from single-event files, and load into class.
183

184
        Args:
185
            data_files (List[str]): list of file names to bin
186
            axes (List[str], optional): bin axes. Defaults to
187
                config["dataframe"]["tof_column"].
188
            bins (List[int], optional): number of bins.
189
                Defaults to config["energy"]["bins"].
190
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
191
                Defaults to config["energy"]["ranges"].
192
            biases (np.ndarray, optional): Bias voltages used.
193
                If not provided, biases are extracted from the file meta data.
194
            bias_key (str, optional): hdf5 path where bias values are stored.
195
                Defaults to config["energy"]["bias_key"].
196
            **kwds: Keyword parameters for bin_dataframe
197
        """
198
        if axes is None:
1✔
199
            axes = [self.tof_column]
1✔
200
        if bins is None:
1✔
201
            bins = [self._config["energy"]["bins"]]
1✔
202
        if ranges is None:
1✔
203
            ranges_ = [
1✔
204
                np.array(self._config["energy"]["ranges"]) / 2 ** (self.binning - 1),
205
            ]
206
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
207
        # pylint: disable=duplicate-code
208
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
209
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
210
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
211
        try:
1✔
212
            num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
213
        except KeyError:
1✔
214
            num_cores = psutil.cpu_count() - 1
1✔
215
        threads_per_worker = kwds.pop(
1✔
216
            "threads_per_worker",
217
            self._config["binning"]["threads_per_worker"],
218
        )
219
        threadpool_api = kwds.pop(
1✔
220
            "threadpool_API",
221
            self._config["binning"]["threadpool_API"],
222
        )
223

224
        read_biases = False
1✔
225
        if biases is None:
1✔
226
            read_biases = True
1✔
227
            if bias_key is None:
1✔
228
                try:
1✔
229
                    bias_key = self._config["energy"]["bias_key"]
1✔
230
                except KeyError as exc:
1✔
231
                    raise ValueError(
1✔
232
                        "Either Bias Values or a valid bias_key has to be present!",
233
                    ) from exc
234

235
        dataframe, _ = self.loader.read_dataframe(
1✔
236
            files=data_files,
237
            collect_metadata=False,
238
        )
239
        traces = bin_dataframe(
1✔
240
            dataframe,
241
            bins=bins,
242
            axes=axes,
243
            ranges=ranges,
244
            hist_mode=hist_mode,
245
            mode=mode,
246
            pbar=pbar,
247
            n_cores=num_cores,
248
            threads_per_worker=threads_per_worker,
249
            threadpool_api=threadpool_api,
250
            return_partitions=True,
251
            **kwds,
252
        )
253
        if read_biases:
1✔
254
            if bias_key:
1✔
255
                try:
1✔
256
                    biases = extract_bias(data_files, bias_key)
1✔
257
                except KeyError as exc:
1✔
258
                    raise ValueError(
1✔
259
                        "Either Bias Values or a valid bias_key has to be present!",
260
                    ) from exc
261
        tof = traces.coords[(axes[0])]
1✔
262
        self.traces = self.traces_normed = np.asarray(traces.T)
1✔
263
        self.tof = np.asarray(tof)
1✔
264
        self.biases = np.asarray(biases)
1✔
265

266
    def normalize(self, smooth: bool = False, span: int = 7, order: int = 1):
1✔
267
        """Normalize the spectra along an axis.
268

269
        Args:
270
            smooth (bool, optional): Option to smooth the signals before normalization.
271
                Defaults to False.
272
            span (int, optional): span smoothing parameters of the LOESS method
273
                (see ``scipy.signal.savgol_filter()``). Defaults to 7.
274
            order (int, optional): order smoothing parameters of the LOESS method
275
                (see ``scipy.signal.savgol_filter()``). Defaults to 1.
276
        """
277
        self.traces_normed = normspec(
1✔
278
            self.traces,
279
            smooth=smooth,
280
            span=span,
281
            order=order,
282
        )
283

284
    def adjust_ranges(
1✔
285
        self,
286
        ranges: Tuple,
287
        ref_id: int = 0,
288
        traces: np.ndarray = None,
289
        peak_window: int = 7,
290
        apply: bool = False,
291
        **kwds,
292
    ):
293
        """Display a tool to select or extract the equivalent feature ranges
294
        (containing the peaks) among all traces.
295

296
        Args:
297
            ranges (Tuple):
298
                Collection of feature detection ranges, within which an algorithm
299
                (i.e. 1D peak detector) with look for the feature.
300
            ref_id (int, optional): Index of the reference trace. Defaults to 0.
301
            traces (np.ndarray, optional): Collection of energy dispersion curves.
302
                Defaults to self.traces_normed.
303
            peak_window (int, optional): area around a peak to check for other peaks.
304
                Defaults to 7.
305
            apply (bool, optional): Option to directly apply the provided parameters.
306
                Defaults to False.
307
            **kwds:
308
                keyword arguments for trace alignment (see ``find_correspondence()``).
309
        """
310
        if traces is None:
1✔
311
            traces = self.traces_normed
1✔
312

313
        self.add_ranges(
1✔
314
            ranges=ranges,
315
            ref_id=ref_id,
316
            traces=traces,
317
            infer_others=True,
318
            mode="replace",
319
        )
320
        self.feature_extract(peak_window=peak_window)
1✔
321

322
        # make plot
323
        labels = kwds.pop("labels", [str(b) + " V" for b in self.biases])
1✔
324
        figsize = kwds.pop("figsize", (8, 4))
1✔
325
        plot_segs = []
1✔
326
        plot_peaks = []
1✔
327
        fig, ax = plt.subplots(figsize=figsize)
1✔
328
        colors = plt.get_cmap("rainbow")(np.linspace(0, 1, len(traces)))
1✔
329
        for itr, color in zip(range(len(traces)), colors):
1✔
330
            trace = traces[itr, :]
1✔
331
            # main traces
332
            ax.plot(
1✔
333
                self.tof,
334
                trace,
335
                ls="-",
336
                color=color,
337
                linewidth=1,
338
                label=labels[itr],
339
            )
340
            # segments:
341
            seg = self.featranges[itr]
1✔
342
            cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
343
            tofseg, traceseg = self.tof[cond], trace[cond]
1✔
344
            (line,) = ax.plot(
1✔
345
                tofseg,
346
                traceseg,
347
                ls="-",
348
                color=color,
349
                linewidth=3,
350
            )
351
            plot_segs.append(line)
1✔
352
            # markers
353
            (scatt,) = ax.plot(
1✔
354
                self.peaks[itr, 0],
355
                self.peaks[itr, 1],
356
                ls="",
357
                marker=".",
358
                color="k",
359
                markersize=10,
360
            )
361
            plot_peaks.append(scatt)
1✔
362
        ax.legend(fontsize=8, loc="upper right")
1✔
363
        ax.set_title("")
1✔
364

365
        def update(refid, ranges):
1✔
366
            self.add_ranges(ranges, refid, traces=traces)
1✔
367
            self.feature_extract(peak_window=7)
1✔
368
            for itr, _ in enumerate(self.traces_normed):
1✔
369
                seg = self.featranges[itr]
1✔
370
                cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
371
                tofseg, traceseg = (
1✔
372
                    self.tof[cond],
373
                    self.traces_normed[itr][cond],
374
                )
375
                plot_segs[itr].set_ydata(traceseg)
1✔
376
                plot_segs[itr].set_xdata(tofseg)
1✔
377

378
                plot_peaks[itr].set_xdata(self.peaks[itr, 0])
1✔
379
                plot_peaks[itr].set_ydata(self.peaks[itr, 1])
1✔
380

381
            fig.canvas.draw_idle()
1✔
382

383
        refid_slider = ipw.IntSlider(
1✔
384
            value=ref_id,
385
            min=0,
386
            max=10,
387
            step=1,
388
        )
389

390
        ranges_slider = ipw.IntRangeSlider(
1✔
391
            value=list(ranges),
392
            min=min(self.tof),
393
            max=max(self.tof),
394
            step=1,
395
        )
396

397
        update(ranges=ranges, refid=ref_id)
1✔
398

399
        ipw.interact(
1✔
400
            update,
401
            refid=refid_slider,
402
            ranges=ranges_slider,
403
        )
404

405
        def apply_func(apply: bool):  # pylint: disable=unused-argument
1✔
406
            self.add_ranges(
1✔
407
                ranges_slider.value,
408
                refid_slider.value,
409
                traces=self.traces_normed,
410
            )
411
            self.feature_extract(peak_window=7)
1✔
412
            ranges_slider.close()
1✔
413
            refid_slider.close()
1✔
414
            apply_button.close()
1✔
415

416
        apply_button = ipw.Button(description="apply")
1✔
417
        display(apply_button)  # pylint: disable=duplicate-code
1✔
418
        apply_button.on_click(apply_func)
1✔
419
        plt.show()
1✔
420

421
        if apply:
1✔
422
            apply_func(True)
1✔
423

424
    def add_ranges(
1✔
425
        self,
426
        ranges: Union[List[Tuple], Tuple],
427
        ref_id: int = 0,
428
        traces: np.ndarray = None,
429
        infer_others: bool = True,
430
        mode: str = "replace",
431
        **kwds,
432
    ):
433
        """Select or extract the equivalent feature ranges (containing the peaks) among all traces.
434

435
        Args:
436
            ranges (Union[List[Tuple], Tuple]):
437
                Collection of feature detection ranges, within which an algorithm
438
                (i.e. 1D peak detector) with look for the feature.
439
            ref_id (int, optional): Index of the reference trace. Defaults to 0.
440
            traces (np.ndarray, optional): Collection of energy dispersion curves.
441
                Defaults to self.traces_normed.
442
            infer_others (bool, optional): Option to infer the feature detection range
443
                in other traces from a given one using a time warp algorthm.
444
                Defaults to True.
445
            mode (str, optional): Specification on how to change the feature ranges
446
                ('append' or 'replace'). Defaults to "replace".
447
            **kwds:
448
                keyword arguments for trace alignment (see ``find_correspondence()``).
449
        """
450
        if traces is None:
1✔
451
            traces = self.traces_normed
1✔
452

453
        # Infer the corresponding feature detection range of other traces by alignment
454
        if infer_others:
1✔
455
            assert isinstance(ranges, tuple)
1✔
456
            newranges: List[Tuple] = []
1✔
457

458
            for i in range(self.ntraces):
1✔
459

460
                pathcorr = find_correspondence(
1✔
461
                    traces[ref_id, :],
462
                    traces[i, :],
463
                    **kwds,
464
                )
465
                newranges.append(range_convert(self.tof, ranges, pathcorr))
1✔
466

467
        else:
468
            if isinstance(ranges, list):
1✔
469
                newranges = ranges
1✔
470
            else:
471
                newranges = [ranges]
×
472

473
        if mode == "append":
1✔
474
            self.featranges += newranges
×
475
        elif mode == "replace":
1✔
476
            self.featranges = newranges
1✔
477

478
    def feature_extract(
1✔
479
        self,
480
        ranges: List[Tuple] = None,
481
        traces: np.ndarray = None,
482
        peak_window: int = 7,
483
    ):
484
        """Select or extract the equivalent landmarks (e.g. peaks) among all traces.
485

486
        Args:
487
            ranges (List[Tuple], optional):  List of ranges in each trace to look for
488
                the peak feature, [start, end]. Defaults to self.featranges.
489
            traces (np.ndarray, optional): Collection of 1D spectra to use for
490
                calibration. Defaults to self.traces_normed.
491
            peak_window (int, optional): area around a peak to check for other peaks.
492
                Defaults to 7.
493
        """
494
        if ranges is None:
1✔
495
            ranges = self.featranges
1✔
496

497
        if traces is None:
1✔
498
            traces = self.traces_normed
1✔
499

500
        # Augment the content of the calibration data
501
        traces_aug = np.tile(traces, (self.dup, 1))
1✔
502
        # Run peak detection for each trace within the specified ranges
503
        self.peaks = peaksearch(
1✔
504
            traces_aug,
505
            self.tof,
506
            ranges=ranges,
507
            pkwindow=peak_window,
508
        )
509

510
    def calibrate(
1✔
511
        self,
512
        ref_id: int = 0,
513
        method: str = "lmfit",
514
        energy_scale: str = "kinetic",
515
        landmarks: np.ndarray = None,
516
        biases: np.ndarray = None,
517
        t: np.ndarray = None,
518
        **kwds,
519
    ) -> dict:
520
        """Calculate the functional mapping between time-of-flight and the energy
521
        scale using optimization methods.
522

523
        Args:
524
            ref_id (int, optional): The reference trace index (an integer).
525
                Defaults to 0.
526
            method (str, optional):  Method for determining the energy calibration.
527

528
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
529
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
530

531
                Defaults to 'lmfit'.
532
            energy_scale (str, optional): Direction of increasing energy scale.
533

534
                - **'kinetic'**: increasing energy with decreasing TOF.
535
                - **'binding'**: increasing energy with increasing TOF.
536

537
                Defaults to "kinetic".
538
            landmarks (np.ndarray, optional): Extracted peak positions (TOF) used for
539
                calibration. Defaults to self.peaks.
540
            biases (np.ndarray, optional): Bias values. Defaults to self.biases.
541
            t (np.ndarray, optional): TOF values. Defaults to self.tof.
542
            **kwds: keyword arguments.
543
                See available keywords for ``poly_energy_calibration()`` and
544
                ``fit_energy_calibration()``
545

546
        Raises:
547
            ValueError: Raised if invalid 'energy_scale' is passed.
548
            NotImplementedError: Raised if invalid 'method' is passed.
549

550
        Returns:
551
            dict: Calibration dictionary with coefficients.
552
        """
553
        if landmarks is None:
1✔
554
            landmarks = self.peaks[:, 0]
1✔
555
        if biases is None:
1✔
556
            biases = self.biases
1✔
557
        if t is None:
1✔
558
            t = self.tof
1✔
559
        if energy_scale == "kinetic":
1✔
560
            sign = -1
1✔
561
        elif energy_scale == "binding":
1✔
562
            sign = 1
1✔
563
        else:
564
            raise ValueError(
1✔
565
                'energy_scale needs to be either "binding" or "kinetic"',
566
                f", got {energy_scale}.",
567
            )
568

569
        binwidth = kwds.pop("binwidth", self.binwidth)
1✔
570
        binning = kwds.pop("binning", self.binning)
1✔
571

572
        if method == "lmfit":
1✔
573
            self.calibration = fit_energy_calibation(
1✔
574
                landmarks,
575
                sign * biases,
576
                binwidth,
577
                binning,
578
                ref_id=ref_id,
579
                t=t,
580
                energy_scale=energy_scale,
581
                **kwds,
582
            )
583
        elif method in ("lstsq", "lsqr"):
1✔
584
            self.calibration = poly_energy_calibration(
1✔
585
                landmarks,
586
                sign * biases,
587
                ref_id=ref_id,
588
                aug=self.dup,
589
                method=method,
590
                t=t,
591
                energy_scale=energy_scale,
592
                **kwds,
593
            )
594
        else:
595
            raise NotImplementedError()
1✔
596

597
        return self.calibration
1✔
598

599
    def view(  # pylint: disable=dangerous-default-value
1✔
600
        self,
601
        traces: np.ndarray,
602
        segs: List[Tuple] = None,
603
        peaks: np.ndarray = None,
604
        show_legend: bool = True,
605
        backend: str = "matplotlib",
606
        linekwds: dict = {},
607
        linesegkwds: dict = {},
608
        scatterkwds: dict = {},
609
        legkwds: dict = {},
610
        **kwds,
611
    ):
612
        """Display a plot showing line traces with annotation.
613

614
        Args:
615
            traces (np.ndarray): Matrix of traces to visualize.
616
            segs (List[Tuple], optional): Segments to be highlighted in the
617
                visualization. Defaults to None.
618
            peaks (np.ndarray, optional): Peak positions for labelling the traces.
619
                Defaults to None.
620
            show_legend (bool, optional): Option to display bias voltages as legends.
621
                Defaults to True.
622
            backend (str, optional): Backend specification, choose between 'matplotlib'
623
                (static) or 'bokeh' (interactive). Defaults to "matplotlib".
624
            linekwds (dict, optional): Keyword arguments for line plotting
625
                (see ``matplotlib.pyplot.plot()``). Defaults to {}.
626
            linesegkwds (dict, optional): Keyword arguments for line segments plotting
627
                (see ``matplotlib.pyplot.plot()``). Defaults to {}.
628
            scatterkwds (dict, optional): Keyword arguments for scatter plot
629
                (see ``matplotlib.pyplot.scatter()``). Defaults to {}.
630
            legkwds (dict, optional): Keyword arguments for legend
631
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
632
            **kwds: keyword arguments:
633

634
                - **labels** (list): Labels for each curve
635
                - **xaxis** (np.ndarray): x (horizontal) axis values
636
                - **title** (str): Title of the plot
637
                - **legend_location** (str): Location of the plot legend
638
                - **align** (bool): Option to shift traces by bias voltage
639
        """
640
        lbs = kwds.pop("labels", [str(b) + " V" for b in self.biases])
1✔
641
        xaxis = kwds.pop("xaxis", self.tof)
1✔
642
        ttl = kwds.pop("title", "")
1✔
643
        align = kwds.pop("align", False)
1✔
644
        energy_scale = kwds.pop("energy_scale", "kinetic")
1✔
645

646
        sign = 1 if energy_scale == "kinetic" else -1
1✔
647

648
        if backend == "matplotlib":
1✔
649

650
            figsize = kwds.pop("figsize", (12, 4))
1✔
651
            fig, ax = plt.subplots(figsize=figsize)
1✔
652
            for itr, trace in enumerate(traces):
1✔
653
                if align:
1✔
654
                    ax.plot(
×
655
                        xaxis + sign * (self.biases[itr] - self.biases[self.calibration["refid"]]),
656
                        trace,
657
                        ls="-",
658
                        linewidth=1,
659
                        label=lbs[itr],
660
                        **linekwds,
661
                    )
662
                else:
663
                    ax.plot(
1✔
664
                        xaxis,
665
                        trace,
666
                        ls="-",
667
                        linewidth=1,
668
                        label=lbs[itr],
669
                        **linekwds,
670
                    )
671

672
                # Emphasize selected EDC segments
673
                if segs is not None:
1✔
674
                    seg = segs[itr]
×
675
                    cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
×
676
                    tofseg, traceseg = self.tof[cond], trace[cond]
×
677
                    ax.plot(
×
678
                        tofseg,
679
                        traceseg,
680
                        ls="-",
681
                        linewidth=2,
682
                        **linesegkwds,
683
                    )
684
                # Emphasize extracted local maxima
685
                if peaks is not None:
1✔
686
                    ax.scatter(
×
687
                        peaks[itr, 0],
688
                        peaks[itr, 1],
689
                        s=30,
690
                        **scatterkwds,
691
                    )
692

693
            if show_legend:
1✔
694
                try:
×
695
                    ax.legend(fontsize=12, **legkwds)
×
696
                except TypeError:
×
697
                    pass
×
698

699
            ax.set_title(ttl)
1✔
700

701
        elif backend == "bokeh":
1✔
702

703
            output_notebook(hide_banner=True)
1✔
704
            colors = it.cycle(ColorCycle[10])
1✔
705
            ttp = [("(x, y)", "($x, $y)")]
1✔
706

707
            figsize = kwds.pop("figsize", (800, 300))
1✔
708
            fig = pbk.figure(
1✔
709
                title=ttl,
710
                plot_width=figsize[0],
711
                plot_height=figsize[1],
712
                tooltips=ttp,
713
            )
714
            # Plotting the main traces
715
            for itr, color in zip(range(len(traces)), colors):
1✔
716
                trace = traces[itr, :]
1✔
717
                if align:
1✔
718
                    fig.line(
1✔
719
                        xaxis + sign * (self.biases[itr] - self.biases[self.calibration["refid"]]),
720
                        trace,
721
                        color=color,
722
                        line_dash="solid",
723
                        line_width=1,
724
                        line_alpha=1,
725
                        legend_label=lbs[itr],
726
                        **kwds,
727
                    )
728
                else:
729
                    fig.line(
1✔
730
                        xaxis,
731
                        trace,
732
                        color=color,
733
                        line_dash="solid",
734
                        line_width=1,
735
                        line_alpha=1,
736
                        legend_label=lbs[itr],
737
                        **kwds,
738
                    )
739

740
                # Emphasize selected EDC segments
741
                if segs is not None:
1✔
742
                    seg = segs[itr]
1✔
743
                    cond = (self.tof >= seg[0]) & (self.tof <= seg[1])
1✔
744
                    tofseg, traceseg = self.tof[cond], trace[cond]
1✔
745
                    fig.line(
1✔
746
                        tofseg,
747
                        traceseg,
748
                        color=color,
749
                        line_width=3,
750
                        **linekwds,
751
                    )
752

753
                # Plot detected peaks
754
                if peaks is not None:
1✔
755
                    fig.scatter(
1✔
756
                        peaks[itr, 0],
757
                        peaks[itr, 1],
758
                        fill_color=color,
759
                        fill_alpha=0.8,
760
                        line_color=None,
761
                        size=5,
762
                        **scatterkwds,
763
                    )
764

765
            if show_legend:
1✔
766
                fig.legend.location = kwds.pop("legend_location", "top_right")
1✔
767
                fig.legend.spacing = 0
1✔
768
                fig.legend.padding = 2
1✔
769

770
            pbk.show(fig)
1✔
771

772
    def append_energy_axis(
1✔
773
        self,
774
        df: Union[pd.DataFrame, dask.dataframe.DataFrame],
775
        tof_column: str = None,
776
        energy_column: str = None,
777
        calibration: dict = None,
778
        **kwds,
779
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
780
        """Calculate and append the energy axis to the events dataframe.
781

782
        Args:
783
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]):
784
                Dataframe to apply the energy axis calibration to.
785
            tof_column (str, optional): Label of the source column.
786
                Defaults to config["dataframe"]["tof_column"].
787
            energy_column (str, optional): Label of the destination column.
788
                Defaults to config["dataframe"]["energy_column"].
789
            calibration (dict, optional): Calibration dictionary. If provided,
790
                overrides calibration from class or config.
791
                Defaults to self.calibration or config["energy"]["calibration"].
792
            **kwds: additional keyword arguments for the energy conversion. They are
793
                added to the calibration dictionary.
794

795
        Raises:
796
            ValueError: Raised if expected calibration parameters are missing.
797
            NotImplementedError: Raised if an invalid calib_type is found.
798

799
        Returns:
800
            Union[pd.DataFrame, dask.dataframe.DataFrame]: dataframe with added column
801
            and energy calibration metadata dictionary.
802
        """
803
        if tof_column is None:
1✔
804
            if self.corrected_tof_column in df.columns:
1✔
805
                tof_column = self.corrected_tof_column
×
806
            else:
807
                tof_column = self.tof_column
1✔
808

809
        if energy_column is None:
1✔
810
            energy_column = self.energy_column
1✔
811

812
        binwidth = kwds.pop("binwidth", self.binwidth)
1✔
813
        binning = kwds.pop("binning", self.binning)
1✔
814

815
        # pylint: disable=duplicate-code
816
        if calibration is None:
1✔
817
            if self.calibration:
1✔
818
                calibration = deepcopy(self.calibration)
1✔
819
            else:
820
                calibration = deepcopy(
1✔
821
                    self._config["energy"].get(
822
                        "calibration",
823
                        {},
824
                    ),
825
                )
826

827
        for key, value in kwds.items():
1✔
828
            calibration[key] = value
1✔
829

830
        # try to determine calibration type if not provided
831
        if "calib_type" not in calibration:
1✔
832
            if "t0" in calibration and "d" in calibration and "E0" in calibration:
1✔
833
                calibration["calib_type"] = "fit"
1✔
834
                if "energy_scale" not in calibration:
1✔
835
                    calibration["energy_scale"] = "kinetic"
1✔
836

837
            elif "coeffs" in calibration and "E0" in calibration:
1✔
838
                calibration["calib_type"] = "poly"
1✔
839
            else:
840
                raise ValueError("No valid calibration parameters provided!")
1✔
841

842
        if calibration["calib_type"] == "fit":
1✔
843
            # Fitting metadata for nexus
844
            calibration["fit_function"] = "(a0/(x0-a1))**2 + a2"
1✔
845
            calibration["coefficients"] = np.array(
1✔
846
                [
847
                    calibration["d"],
848
                    calibration["t0"],
849
                    calibration["E0"],
850
                ],
851
            )
852
            df[energy_column] = tof2ev(
1✔
853
                calibration["d"],
854
                calibration["t0"],
855
                binwidth,
856
                binning,
857
                calibration["energy_scale"],
858
                calibration["E0"],
859
                df[tof_column].astype("float64"),
860
            )
861
        elif calibration["calib_type"] == "poly":
1✔
862
            # Fitting metadata for nexus
863
            fit_function = "a0"
1✔
864
            for term in range(1, len(calibration["coeffs"]) + 1):
1✔
865
                fit_function += f" + a{term}*x0**{term}"
1✔
866
            calibration["fit_function"] = fit_function
1✔
867
            calibration["coefficients"] = np.concatenate(
1✔
868
                (calibration["coeffs"], [calibration["E0"]]),
869
            )[::-1]
870
            df[energy_column] = tof2evpoly(
1✔
871
                calibration["coeffs"],
872
                calibration["E0"],
873
                df[tof_column].astype("float64"),
874
            )
875
        else:
876
            raise NotImplementedError
1✔
877

878
        metadata = self.gather_calibration_metadata(calibration)
1✔
879

880
        return df, metadata
1✔
881

882
    def gather_calibration_metadata(self, calibration: dict = None) -> dict:
1✔
883
        """Collects metadata from the energy calibration
884

885
        Args:
886
            calibration (dict, optional): Dictionary with energy calibration
887
                parameters. Defaults to None.
888

889
        Returns:
890
            dict: Generated metadata dictionary.
891
        """
892
        if calibration is None:
1✔
893
            calibration = self.calibration
×
894
        metadata: Dict[Any, Any] = {}
1✔
895
        metadata["applied"] = True
1✔
896
        metadata["calibration"] = deepcopy(calibration)
1✔
897
        metadata["tof"] = deepcopy(self.tof)
1✔
898
        # create empty calibrated axis entry, if it is not present.
899
        if "axis" not in metadata["calibration"]:
1✔
900
            metadata["calibration"]["axis"] = 0
1✔
901

902
        return metadata
1✔
903

904
    def adjust_energy_correction(
1✔
905
        self,
906
        image: xr.DataArray,
907
        correction_type: str = None,
908
        amplitude: float = None,
909
        center: Tuple[float, float] = None,
910
        correction: dict = None,
911
        apply: bool = False,
912
        **kwds,
913
    ):
914
        """Visualize the energy correction function on top of the TOF/X/Y graphs.
915

916
        Args:
917
            image (xr.DataArray): Image data cube (x, y, tof) of binned data to plot.
918
            correction_type (str, optional): Type of correction to apply to the TOF
919
                axis. Valid values are:
920

921
                - 'spherical'
922
                - 'Lorentzian'
923
                - 'Gaussian'
924
                - 'Lorentzian_asymmetric'
925

926
                Defaults to config["energy"]["correction_type"].
927
            amplitude (float, optional): Amplitude of the time-of-flight correction
928
                term. Defaults to config["energy"]["correction"]["correction_type"].
929
            center (Tuple[float, float], optional): Center (x/y) coordinates for the
930
                correction. Defaults to config["energy"]["correction"]["center"].
931
            correction (dict, optional): Correction dict. Defaults to the config values
932
                and is updated from provided and adjusted parameters.
933
            apply (bool, optional): whether to store the provided parameters within
934
                the class. Defaults to False.
935
            **kwds: Additional parameters to use for the adjustment plots:
936

937
                - **x_column** (str): Name of the x column.
938
                - **y_column** (str): Name of the y column.
939
                - **tof_column** (str): Name of the tog column to convert.
940
                - **x_width** (int, int): x range to integrate around the center
941
                - **y_width** (int, int): y range to integrate around the center
942
                - **tof_fermi** (int): TOF value of the Fermi level
943
                - **tof_width** (int, int): TOF range to plot around tof_fermi
944
                - **color_clip** (int): highest value to plot in the color range
945

946
                Additional parameters for the correction functions:
947

948
                - **d** (float): Field-free drift distance.
949
                - **gamma** (float): Linewidth value for correction using a 2D
950
                  Lorentz profile.
951
                - **sigma** (float): Standard deviation for correction using a 2D
952
                  Gaussian profile.
953
                - **gamma2** (float): Linewidth value for correction using an
954
                  asymmetric 2D Lorentz profile, X-direction.
955
                - **amplitude2** (float): Amplitude value for correction using an
956
                  asymmetric 2D Lorentz profile, X-direction.
957

958
        Raises:
959
            NotImplementedError: Raised for invalid correction_type.
960
        """
961
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
962

963
        if correction is None:
1✔
964
            if self.correction:
1✔
965
                correction = deepcopy(self.correction)
×
966
            else:
967
                correction = deepcopy(self._config["energy"].get("correction", {}))
1✔
968

969
        if correction_type is not None:
1✔
970
            correction["correction_type"] = correction_type
1✔
971

972
        if amplitude is not None:
1✔
973
            correction["amplitude"] = amplitude
1✔
974

975
        if center is not None:
1✔
976
            correction["center"] = center
1✔
977

978
        x_column = kwds.pop("x_column", self.x_column)
1✔
979
        y_column = kwds.pop("y_column", self.y_column)
1✔
980
        tof_column = kwds.pop("tof_column", self.tof_column)
1✔
981
        x_width = kwds.pop("x_width", self.x_width)
1✔
982
        y_width = kwds.pop("y_width", self.y_width)
1✔
983
        tof_fermi = kwds.pop("tof_fermi", self.tof_fermi)
1✔
984
        tof_width = kwds.pop("tof_width", self.tof_width)
1✔
985
        color_clip = kwds.pop("color_clip", self.color_clip)
1✔
986

987
        correction = {**correction, **kwds}
1✔
988

989
        if not {"correction_type", "amplitude", "center"}.issubset(set(correction.keys())):
1✔
990
            raise ValueError(
1✔
991
                "No valid energy correction found in config and required parameters missing!",
992
            )
993

994
        if isinstance(correction["center"], list):
1✔
995
            correction["center"] = tuple(correction["center"])
1✔
996

997
        x = image.coords[x_column].values
1✔
998
        y = image.coords[y_column].values
1✔
999

1000
        x_center = correction["center"][0]
1✔
1001
        y_center = correction["center"][1]
1✔
1002

1003
        correction_x = tof_fermi - correction_function(
1✔
1004
            x=x,
1005
            y=y_center,
1006
            **correction,
1007
        )
1008
        correction_y = tof_fermi - correction_function(
1✔
1009
            x=x_center,
1010
            y=y,
1011
            **correction,
1012
        )
1013
        fig, ax = plt.subplots(2, 1)
1✔
1014
        image.loc[
1✔
1015
            {
1016
                y_column: slice(y_center + y_width[0], y_center + y_width[1]),
1017
                tof_column: slice(
1018
                    tof_fermi + tof_width[0],
1019
                    tof_fermi + tof_width[1],
1020
                ),
1021
            }
1022
        ].sum(dim=y_column).T.plot(
1023
            ax=ax[0],
1024
            cmap="terrain_r",
1025
            vmax=color_clip,
1026
            yincrease=False,
1027
        )
1028
        image.loc[
1✔
1029
            {
1030
                x_column: slice(x_center + x_width[0], x_center + x_width[1]),
1031
                tof_column: slice(
1032
                    tof_fermi + tof_width[0],
1033
                    tof_fermi + tof_width[1],
1034
                ),
1035
            }
1036
        ].sum(dim=x_column).T.plot(
1037
            ax=ax[1],
1038
            cmap="terrain_r",
1039
            vmax=color_clip,
1040
            yincrease=False,
1041
        )
1042
        (trace1,) = ax[0].plot(x, correction_x)
1✔
1043
        line1 = ax[0].axvline(x=x_center)
1✔
1044
        (trace2,) = ax[1].plot(y, correction_y)
1✔
1045
        line2 = ax[1].axvline(x=y_center)
1✔
1046

1047
        amplitude_slider = ipw.FloatSlider(
1✔
1048
            value=correction["amplitude"],
1049
            min=0,
1050
            max=10,
1051
            step=0.1,
1052
        )
1053
        x_center_slider = ipw.FloatSlider(
1✔
1054
            value=x_center,
1055
            min=0,
1056
            max=self._config["momentum"]["detector_ranges"][0][1],
1057
            step=1,
1058
        )
1059
        y_center_slider = ipw.FloatSlider(
1✔
1060
            value=y_center,
1061
            min=0,
1062
            max=self._config["momentum"]["detector_ranges"][1][1],
1063
            step=1,
1064
        )
1065

1066
        def update(amplitude, x_center, y_center, **kwds):
1✔
1067
            nonlocal correction
1068
            correction["amplitude"] = amplitude
1✔
1069
            correction["center"] = (x_center, y_center)
1✔
1070
            correction = {**correction, **kwds}
1✔
1071
            correction_x = tof_fermi - correction_function(
1✔
1072
                x=x,
1073
                y=y_center,
1074
                **correction,
1075
            )
1076
            correction_y = tof_fermi - correction_function(
1✔
1077
                x=x_center,
1078
                y=y,
1079
                **correction,
1080
            )
1081

1082
            trace1.set_ydata(correction_x)
1✔
1083
            line1.set_xdata(x=x_center)
1✔
1084
            trace2.set_ydata(correction_y)
1✔
1085
            line2.set_xdata(x=y_center)
1✔
1086

1087
            fig.canvas.draw_idle()
1✔
1088

1089
        def common_apply_func(apply: bool):  # pylint: disable=unused-argument
1✔
1090
            self.correction = {}
1✔
1091
            self.correction["amplitude"] = correction["amplitude"]
1✔
1092
            self.correction["center"] = correction["center"]
1✔
1093
            self.correction["correction_type"] = correction["correction_type"]
1✔
1094
            amplitude_slider.close()
1✔
1095
            x_center_slider.close()
1✔
1096
            y_center_slider.close()
1✔
1097
            apply_button.close()
1✔
1098

1099
        if correction["correction_type"] == "spherical":
1✔
1100
            try:
1✔
1101
                update(correction["amplitude"], x_center, y_center, diameter=correction["diameter"])
1✔
1102
            except KeyError as exc:
×
1103
                raise ValueError(
×
1104
                    "Parameter 'diameter' required for correction type 'sperical', ",
1105
                    "but not present!",
1106
                ) from exc
1107

1108
            diameter_slider = ipw.FloatSlider(
1✔
1109
                value=correction["diameter"],
1110
                min=0,
1111
                max=10000,
1112
                step=100,
1113
            )
1114

1115
            ipw.interact(
1✔
1116
                update,
1117
                amplitude=amplitude_slider,
1118
                x_center=x_center_slider,
1119
                y_center=y_center_slider,
1120
                diameter=diameter_slider,
1121
            )
1122

1123
            def apply_func(apply: bool):
1✔
1124
                common_apply_func(apply)
1✔
1125
                self.correction["diameter"] = correction["diameter"]
1✔
1126
                diameter_slider.close()
1✔
1127

1128
        elif correction["correction_type"] == "Lorentzian":
1✔
1129
            try:
1✔
1130
                update(correction["amplitude"], x_center, y_center, gamma=correction["gamma"])
1✔
1131
            except KeyError as exc:
×
1132
                raise ValueError(
×
1133
                    "Parameter 'gamma' required for correction type 'Lorentzian', but not present!",
1134
                ) from exc
1135

1136
            gamma_slider = ipw.FloatSlider(
1✔
1137
                value=correction["gamma"],
1138
                min=0,
1139
                max=2000,
1140
                step=1,
1141
            )
1142

1143
            ipw.interact(
1✔
1144
                update,
1145
                amplitude=amplitude_slider,
1146
                x_center=x_center_slider,
1147
                y_center=y_center_slider,
1148
                gamma=gamma_slider,
1149
            )
1150

1151
            def apply_func(apply: bool):
1✔
1152
                common_apply_func(apply)
1✔
1153
                self.correction["gamma"] = correction["gamma"]
1✔
1154
                gamma_slider.close()
1✔
1155

1156
        elif correction["correction_type"] == "Gaussian":
1✔
1157
            try:
1✔
1158
                update(correction["amplitude"], x_center, y_center, sigma=correction["sigma"])
1✔
1159
            except KeyError as exc:
×
1160
                raise ValueError(
×
1161
                    "Parameter 'sigma' required for correction type 'Gaussian', but not present!",
1162
                ) from exc
1163

1164
            sigma_slider = ipw.FloatSlider(
1✔
1165
                value=correction["sigma"],
1166
                min=0,
1167
                max=1000,
1168
                step=1,
1169
            )
1170

1171
            ipw.interact(
1✔
1172
                update,
1173
                amplitude=amplitude_slider,
1174
                x_center=x_center_slider,
1175
                y_center=y_center_slider,
1176
                sigma=sigma_slider,
1177
            )
1178

1179
            def apply_func(apply: bool):
1✔
1180
                common_apply_func(apply)
1✔
1181
                self.correction["sigma"] = correction["sigma"]
1✔
1182
                sigma_slider.close()
1✔
1183

1184
        elif correction["correction_type"] == "Lorentzian_asymmetric":
1✔
1185
            try:
1✔
1186
                if "amplitude2" not in correction:
1✔
1187
                    correction["amplitude2"] = correction["amplitude"]
1✔
1188
                if "sigma2" not in correction:
1✔
1189
                    correction["gamma2"] = correction["gamma"]
1✔
1190
                update(
1✔
1191
                    correction["amplitude"],
1192
                    x_center,
1193
                    y_center,
1194
                    gamma=correction["gamma"],
1195
                    amplitude2=correction["amplitude2"],
1196
                    gamma2=correction["gamma2"],
1197
                )
1198
            except KeyError as exc:
×
1199
                raise ValueError(
×
1200
                    "Parameter 'gamma' required for correction type 'Lorentzian_asymmetric', ",
1201
                    "but not present!",
1202
                ) from exc
1203

1204
            gamma_slider = ipw.FloatSlider(
1✔
1205
                value=correction["gamma"],
1206
                min=0,
1207
                max=2000,
1208
                step=1,
1209
            )
1210

1211
            amplitude2_slider = ipw.FloatSlider(
1✔
1212
                value=correction["amplitude2"],
1213
                min=0,
1214
                max=10,
1215
                step=0.1,
1216
            )
1217

1218
            gamma2_slider = ipw.FloatSlider(
1✔
1219
                value=correction["gamma2"],
1220
                min=0,
1221
                max=2000,
1222
                step=1,
1223
            )
1224

1225
            ipw.interact(
1✔
1226
                update,
1227
                amplitude=amplitude_slider,
1228
                x_center=x_center_slider,
1229
                y_center=y_center_slider,
1230
                gamma=gamma_slider,
1231
                amplitude2=amplitude2_slider,
1232
                gamma2=gamma2_slider,
1233
            )
1234

1235
            def apply_func(apply: bool):
1✔
1236
                common_apply_func(apply)
1✔
1237
                self.correction["gamma"] = correction["gamma"]
1✔
1238
                self.correction["amplitude2"] = correction["amplitude2"]
1✔
1239
                self.correction["gamma2"] = correction["gamma2"]
1✔
1240
                gamma_slider.close()
1✔
1241
                amplitude2_slider.close()
1✔
1242
                gamma2_slider.close()
1✔
1243

1244
        else:
1245
            raise NotImplementedError
×
1246
        # pylint: disable=duplicate-code
1247
        apply_button = ipw.Button(description="apply")
1✔
1248
        display(apply_button)
1✔
1249
        apply_button.on_click(apply_func)
1✔
1250
        plt.show()
1✔
1251

1252
        if apply:
1✔
1253
            apply_func(True)
1✔
1254

1255
    def apply_energy_correction(
1✔
1256
        self,
1257
        df: Union[pd.DataFrame, dask.dataframe.DataFrame],
1258
        tof_column: str = None,
1259
        new_tof_column: str = None,
1260
        correction_type: str = None,
1261
        amplitude: float = None,
1262
        correction: dict = None,
1263
        **kwds,
1264
    ) -> Tuple[Union[pd.DataFrame, dask.dataframe.DataFrame], dict]:
1265
        """Apply correction to the time-of-flight (TOF) axis of single-event data.
1266

1267
        Args:
1268
            df (Union[pd.DataFrame, dask.dataframe.DataFrame]): The dataframe where
1269
                to apply the energy correction to.
1270
            tof_column (str, optional): Name of the source column to convert.
1271
                Defaults to config["dataframe"]["tof_column"].
1272
            new_tof_column (str, optional): Name of the destination column to convert.
1273
                Defaults to config["dataframe"]["corrected_tof_column"].
1274
            correction_type (str, optional): Type of correction to apply to the TOF
1275
                axis. Valid values are:
1276

1277
                - 'spherical'
1278
                - 'Lorentzian'
1279
                - 'Gaussian'
1280
                - 'Lorentzian_asymmetric'
1281

1282
                Defaults to config["energy"]["correction_type"].
1283
            amplitude (float, optional): Amplitude of the time-of-flight correction
1284
                term. Defaults to config["energy"]["correction"]["correction_type"].
1285
            correction (dict, optional): Correction dictionary containing paramters
1286
                for the correction. Defaults to self.correction or
1287
                config["energy"]["correction"].
1288
            **kwds: Additional parameters to use for the correction:
1289

1290
                - **x_column** (str): Name of the x column.
1291
                - **y_column** (str): Name of the y column.
1292
                - **d** (float): Field-free drift distance.
1293
                - **gamma** (float): Linewidth value for correction using a 2D
1294
                  Lorentz profile.
1295
                - **sigma** (float): Standard deviation for correction using a 2D
1296
                  Gaussian profile.
1297
                - **gamma2** (float): Linewidth value for correction using an
1298
                  asymmetric 2D Lorentz profile, X-direction.
1299
                - **amplitude2** (float): Amplitude value for correction using an
1300
                  asymmetric 2D Lorentz profile, X-direction.
1301

1302
        Returns:
1303
            Union[pd.DataFrame, dask.dataframe.DataFrame]: dataframe with added column
1304
            and Energy correction metadata dictionary.
1305
        """
1306
        if correction is None:
1✔
1307
            if self.correction:
1✔
1308
                correction = deepcopy(self.correction)
1✔
1309
            else:
1310
                correction = deepcopy(self._config["energy"].get("correction", {}))
1✔
1311

1312
        if correction_type is not None:
1✔
1313
            correction["correction_type"] = correction_type
1✔
1314

1315
        if amplitude is not None:
1✔
1316
            correction["amplitude"] = amplitude
1✔
1317

1318
        x_column = kwds.pop("x_column", self.x_column)
1✔
1319
        y_column = kwds.pop("y_column", self.y_column)
1✔
1320

1321
        for key, value in kwds.items():
1✔
1322
            correction[key] = value
1✔
1323

1324
        if tof_column is None:
1✔
1325
            tof_column = self.tof_column
1✔
1326

1327
        if new_tof_column is None:
1✔
1328
            new_tof_column = self.corrected_tof_column
1✔
1329

1330
        missing_keys = {"correction_type", "center", "amplitude"} - set(correction.keys())
1✔
1331
        if missing_keys:
1✔
1332
            raise ValueError(f"Required correction parameters '{missing_keys}' missing!")
1✔
1333

1334
        df[new_tof_column] = df[tof_column] + correction_function(
1✔
1335
            x=df[x_column],
1336
            y=df[y_column],
1337
            **correction,
1338
        )
1339
        metadata = self.gather_correction_metadata(correction=correction)
1✔
1340

1341
        return df, metadata
1✔
1342

1343
    def gather_correction_metadata(self, correction: dict = None) -> dict:
1✔
1344
        """Collect meta data for energy correction
1345

1346
        Args:
1347
            correction (dict, optional): Dictionary with energy correction parameters.
1348
                Defaults to None.
1349

1350
        Returns:
1351
            dict: Generated metadata dictionary.
1352
        """
1353
        if correction is None:
1✔
1354
            correction = self.correction
×
1355
        metadata: Dict[Any, Any] = {}
1✔
1356
        metadata["applied"] = True
1✔
1357
        metadata["correction"] = deepcopy(correction)
1✔
1358

1359
        return metadata
1✔
1360

1361

1362
def extract_bias(files: List[str], bias_key: str) -> np.ndarray:
1✔
1363
    """Read bias values from hdf5 files
1364

1365
    Args:
1366
        files (List[str]): List of filenames
1367
        bias_key (str): hdf5 path to the bias value
1368

1369
    Returns:
1370
        np.ndarray: Array of bias values.
1371
    """
1372
    bias_list: List[float] = []
1✔
1373
    for file in files:
1✔
1374
        with h5py.File(file, "r") as file_handle:
1✔
1375
            if bias_key[0] == "@":
1✔
1376
                bias_list.append(round(file_handle.attrs[bias_key[1:]], 2))
1✔
1377
            else:
1378
                bias_list.append(round(file_handle[bias_key], 2))
×
1379

1380
    return np.asarray(bias_list)
1✔
1381

1382

1383
def correction_function(
1✔
1384
    x: Union[float, np.ndarray],
1385
    y: Union[float, np.ndarray],
1386
    correction_type: str,
1387
    center: Tuple[float, float],
1388
    amplitude: float,
1389
    **kwds,
1390
) -> Union[float, np.ndarray]:
1391
    """Calculate the TOF correction based on the given X/Y coordinates and a model.
1392

1393
    Args:
1394
        x (float): x coordinate
1395
        y (float): y coordinate
1396
        correction_type (str): type of correction. One of
1397
            "spherical", "Lorentzian", "Gaussian", or "Lorentzian_asymmetric"
1398
        center (Tuple[int, int]): center position of the distribution (x,y)
1399
        amplitude (float): Amplitude of the correction
1400
        **kwds: Keyword arguments:
1401

1402
            - **diameter** (float): Field-free drift distance.
1403
            - **gamma** (float): Linewidth value for correction using a 2D
1404
              Lorentz profile.
1405
            - **sigma** (float): Standard deviation for correction using a 2D
1406
              Gaussian profile.
1407
            - **gamma2** (float): Linewidth value for correction using an
1408
              asymmetric 2D Lorentz profile, X-direction.
1409
            - **amplitude2** (float): Amplitude value for correction using an
1410
              asymmetric 2D Lorentz profile, X-direction.
1411

1412
    Returns:
1413
        float: calculated correction value
1414
    """
1415
    if correction_type == "spherical":
1✔
1416
        try:
1✔
1417
            diameter = kwds.pop("diameter")
1✔
1418
        except KeyError as exc:
1✔
1419
            raise ValueError(
1✔
1420
                f"Parameter 'diameter' required for correction type '{correction_type}' "
1421
                "but not provided!",
1422
            ) from exc
1423
        correction = -(
1✔
1424
            (
1425
                1
1426
                - np.sqrt(
1427
                    1 - ((x - center[0]) ** 2 + (y - center[1]) ** 2) / diameter**2,
1428
                )
1429
            )
1430
            * 100
1431
            * amplitude
1432
        )
1433

1434
    elif correction_type == "Lorentzian":
1✔
1435
        try:
1✔
1436
            gamma = kwds.pop("gamma")
1✔
1437
        except KeyError as exc:
1✔
1438
            raise ValueError(
1✔
1439
                f"Parameter 'gamma' required for correction type '{correction_type}' "
1440
                "but not provided!",
1441
            ) from exc
1442
        correction = (
1✔
1443
            100000
1444
            * amplitude
1445
            / (gamma * np.pi)
1446
            * (gamma**2 / ((x - center[0]) ** 2 + (y - center[1]) ** 2 + gamma**2) - 1)
1447
        )
1448

1449
    elif correction_type == "Gaussian":
1✔
1450
        try:
1✔
1451
            sigma = kwds.pop("sigma")
1✔
1452
        except KeyError as exc:
1✔
1453
            raise ValueError(
1✔
1454
                f"Parameter 'sigma' required for correction type '{correction_type}' "
1455
                "but not provided!",
1456
            ) from exc
1457
        correction = (
1✔
1458
            20000
1459
            * amplitude
1460
            / np.sqrt(2 * np.pi * sigma**2)
1461
            * (
1462
                np.exp(
1463
                    -((x - center[0]) ** 2 + (y - center[1]) ** 2) / (2 * sigma**2),
1464
                )
1465
                - 1
1466
            )
1467
        )
1468

1469
    elif correction_type == "Lorentzian_asymmetric":
1✔
1470
        try:
1✔
1471
            gamma = kwds.pop("gamma")
1✔
1472
        except KeyError as exc:
1✔
1473
            raise ValueError(
1✔
1474
                f"Parameter 'gamma' required for correction type '{correction_type}' "
1475
                "but not provided!",
1476
            ) from exc
1477
        gamma2 = kwds.pop("gamma2", gamma)
1✔
1478
        amplitude2 = kwds.pop("amplitude2", amplitude)
1✔
1479
        correction = (
1✔
1480
            100000
1481
            * amplitude
1482
            / (gamma * np.pi)
1483
            * (gamma**2 / ((y - center[1]) ** 2 + gamma**2) - 1)
1484
        )
1485
        correction += (
1✔
1486
            100000
1487
            * amplitude2
1488
            / (gamma2 * np.pi)
1489
            * (gamma2**2 / ((x - center[0]) ** 2 + gamma2**2) - 1)
1490
        )
1491

1492
    else:
1493
        raise NotImplementedError
×
1494

1495
    return correction
1✔
1496

1497

1498
def normspec(
1✔
1499
    specs: np.ndarray,
1500
    smooth: bool = False,
1501
    span: int = 7,
1502
    order: int = 1,
1503
) -> np.ndarray:
1504
    """Normalize a series of 1D signals.
1505

1506
    Args:
1507
        specs (np.ndarray): Collection of 1D signals.
1508
        smooth (bool, optional): Option to smooth the signals before normalization.
1509
            Defaults to False.
1510
        span (int, optional): Smoothing span parameters of the LOESS method
1511
            (see ``scipy.signal.savgol_filter()``). Defaults to 7.
1512
        order (int, optional): Smoothing order parameters of the LOESS method
1513
            (see ``scipy.signal.savgol_filter()``).. Defaults to 1.
1514

1515
    Returns:
1516
        np.ndarray: The matrix assembled from a list of maximum-normalized signals.
1517
    """
1518
    nspec = len(specs)
1✔
1519
    specnorm = []
1✔
1520

1521
    for i in range(nspec):
1✔
1522

1523
        spec = specs[i]
1✔
1524

1525
        if smooth:
1✔
1526
            spec = savgol_filter(spec, span, order)
1✔
1527

1528
        if type(spec) in (list, tuple):
1✔
1529
            nsp = spec / max(spec)
×
1530
        else:
1531
            nsp = spec / spec.max()
1✔
1532
        specnorm.append(nsp)
1✔
1533

1534
        # Align 1D spectrum
1535
        normalized_specs = np.asarray(specnorm)
1✔
1536

1537
    return normalized_specs
1✔
1538

1539

1540
def find_correspondence(
1✔
1541
    sig_still: np.ndarray,
1542
    sig_mov: np.ndarray,
1543
    **kwds,
1544
) -> np.ndarray:
1545
    """Determine the correspondence between two 1D traces by alignment using a
1546
    time-warp algorithm.
1547

1548
    Args:
1549
        sig_still (np.ndarray): Reference 1D signals.
1550
        sig_mov (np.ndarray): 1D signal to be aligned.
1551
        **kwds: keyword arguments for ``fastdtw.fastdtw()``
1552

1553
    Returns:
1554
        np.ndarray: Pixel-wise path correspondences between two input 1D arrays
1555
        (sig_still, sig_mov).
1556
    """
1557
    dist = kwds.pop("dist_metric", None)
1✔
1558
    rad = kwds.pop("radius", 1)
1✔
1559
    _, pathcorr = fastdtw(sig_still, sig_mov, dist=dist, radius=rad)
1✔
1560
    return np.asarray(pathcorr)
1✔
1561

1562

1563
def range_convert(
1✔
1564
    x: np.ndarray,
1565
    xrng: Tuple,
1566
    pathcorr: np.ndarray,
1567
) -> Tuple:
1568
    """Convert value range using a pairwise path correspondence (e.g. obtained
1569
    from time warping algorithm).
1570

1571
    Args:
1572
        x (np.ndarray): Values of the x axis (e.g. time-of-flight values).
1573
        xrng (Tuple): Boundary value range on the x axis.
1574
        pathcorr (np.ndarray): Path correspondence between two 1D arrays in the
1575
            following form,
1576
            [(id_1_trace_1, id_1_trace_2), (id_2_trace_1, id_2_trace_2), ...]
1577

1578
    Returns:
1579
        Tuple: Transformed range according to the path correspondence.
1580
    """
1581
    pathcorr = np.asarray(pathcorr)
1✔
1582
    xrange_trans = []
1✔
1583

1584
    for xval in xrng:  # Transform each value in the range
1✔
1585
        xind = find_nearest(xval, x)
1✔
1586
        xind_alt = find_nearest(xind, pathcorr[:, 0])
1✔
1587
        xind_trans = pathcorr[xind_alt, 1]
1✔
1588
        xrange_trans.append(x[xind_trans])
1✔
1589

1590
    return tuple(xrange_trans)
1✔
1591

1592

1593
def find_nearest(val: float, narray: np.ndarray) -> int:
1✔
1594
    """Find the value closest to a given one in a 1D array.
1595

1596
    Args:
1597
        val (float): Value of interest.
1598
        narray (np.ndarray):  The array to look for the nearest value.
1599

1600
    Returns:
1601
        int: Array index of the value nearest to the given one.
1602
    """
1603
    return int(np.argmin(np.abs(narray - val)))
1✔
1604

1605

1606
def peaksearch(
1✔
1607
    traces: np.ndarray,
1608
    tof: np.ndarray,
1609
    ranges: List[Tuple] = None,
1610
    pkwindow: int = 3,
1611
    plot: bool = False,
1612
) -> np.ndarray:
1613
    """Detect a list of peaks in the corresponding regions of multiple spectra.
1614

1615
    Args:
1616
        traces (np.ndarray): Collection of 1D spectra.
1617
        tof (np.ndarray): Time-of-flight values.
1618
        ranges (List[Tuple], optional): List of ranges for peak detection in the format
1619
        [(LowerBound1, UpperBound1), (LowerBound2, UpperBound2), ....].
1620
            Defaults to None.
1621
        pkwindow (int, optional): Window width of a peak (amounts to lookahead in
1622
            ``peakdetect1d``). Defaults to 3.
1623
        plot (bool, optional): Specify whether to display a custom plot of the peak
1624
            search results. Defaults to False.
1625

1626
    Returns:
1627
        np.ndarray: Collection of peak positions.
1628
    """
1629
    pkmaxs = []
1✔
1630
    if plot:
1✔
1631
        plt.figure(figsize=(10, 4))
×
1632

1633
    for rng, trace in zip(ranges, traces.tolist()):
1✔
1634

1635
        cond = (tof >= rng[0]) & (tof <= rng[1])
1✔
1636
        trace = np.array(trace).ravel()
1✔
1637
        tofseg, trseg = tof[cond], trace[cond]
1✔
1638
        maxs, _ = peakdetect1d(trseg, tofseg, lookahead=pkwindow)
1✔
1639
        try:
1✔
1640
            pkmaxs.append(maxs[0, :])
1✔
1641
        except IndexError:  # No peak found for this range
×
1642
            print(f"No peak detected in range {rng}.")
×
1643
            raise
×
1644

1645
        if plot:
1✔
1646
            plt.plot(tof, trace, "--k", linewidth=1)
×
1647
            plt.plot(tofseg, trseg, linewidth=2)
×
1648
            plt.scatter(maxs[0, 0], maxs[0, 1], s=30)
×
1649

1650
    return np.asarray(pkmaxs)
1✔
1651

1652

1653
# 1D peak detection algorithm adapted from Sixten Bergman
1654
# https://gist.github.com/sixtenbe/1178136#file-peakdetect-py
1655
def _datacheck_peakdetect(
1✔
1656
    x_axis: np.ndarray,
1657
    y_axis: np.ndarray,
1658
) -> Tuple[np.ndarray, np.ndarray]:
1659
    """Input format checking for 1D peakdtect algorithm
1660

1661
    Args:
1662
        x_axis (np.ndarray): x-axis array
1663
        y_axis (np.ndarray): y-axis array
1664

1665
    Raises:
1666
        ValueError: Raised if x and y values don't have the same length.
1667

1668
    Returns:
1669
        Tuple[np.ndarray, np.ndarray]: Tuple of checked (x/y) arrays.
1670
    """
1671

1672
    if x_axis is None:
1✔
1673
        x_axis = np.arange(len(y_axis))
×
1674

1675
    if len(y_axis) != len(x_axis):
1✔
1676
        raise ValueError(
×
1677
            "Input vectors y_axis and x_axis must have same length",
1678
        )
1679

1680
    # Needs to be a numpy array
1681
    y_axis = np.asarray(y_axis)
1✔
1682
    x_axis = np.asarray(x_axis)
1✔
1683

1684
    return x_axis, y_axis
1✔
1685

1686

1687
def peakdetect1d(
1✔
1688
    y_axis: np.ndarray,
1689
    x_axis: np.ndarray = None,
1690
    lookahead: int = 200,
1691
    delta: int = 0,
1692
) -> Tuple[np.ndarray, np.ndarray]:
1693
    """Function for detecting local maxima and minima in a signal.
1694
    Discovers peaks by searching for values which are surrounded by lower
1695
    or larger values for maxima and minima respectively
1696

1697
    Converted from/based on a MATLAB script at:
1698
    http://billauer.co.il/peakdet.html
1699

1700
    Args:
1701
        y_axis (np.ndarray): A list containing the signal over which to find peaks.
1702
        x_axis (np.ndarray, optional): A x-axis whose values correspond to the y_axis
1703
            list and is used in the return to specify the position of the peaks. If
1704
            omitted an index of the y_axis is used.
1705
        lookahead (int, optional): distance to look ahead from a peak candidate to
1706
            determine if it is the actual peak
1707
            '(samples / period) / f' where '4 >= f >= 1.25' might be a good value.
1708
            Defaults to 200.
1709
        delta (int, optional): this specifies a minimum difference between a peak and
1710
            the following points, before a peak may be considered a peak. Useful
1711
            to hinder the function from picking up false peaks towards to end of
1712
            the signal. To work well delta should be set to delta >= RMSnoise * 5.
1713
            Defaults to 0.
1714

1715
    Raises:
1716
        ValueError: Raised if lookahead and delta are out of range.
1717

1718
    Returns:
1719
        Tuple[np.ndarray, np.ndarray]: Tuple of positions of the positive peaks,
1720
        positions of the negative peaks
1721
    """
1722
    max_peaks = []
1✔
1723
    min_peaks = []
1✔
1724
    dump = []  # Used to pop the first hit which almost always is false
1✔
1725

1726
    # Check input data
1727
    x_axis, y_axis = _datacheck_peakdetect(x_axis, y_axis)
1✔
1728
    # Store data length for later use
1729
    length = len(y_axis)
1✔
1730

1731
    # Perform some checks
1732
    if lookahead < 1:
1✔
1733
        raise ValueError("Lookahead must be '1' or above in value")
×
1734

1735
    if not (np.isscalar(delta) and delta >= 0):
1✔
1736
        raise ValueError("delta must be a positive number")
×
1737

1738
    # maxima and minima candidates are temporarily stored in
1739
    # mx and mn respectively
1740
    _min, _max = np.Inf, -np.Inf
1✔
1741

1742
    # Only detect peak if there is 'lookahead' amount of points after it
1743
    for index, (x, y) in enumerate(
1✔
1744
        zip(x_axis[:-lookahead], y_axis[:-lookahead]),
1745
    ):
1746

1747
        if y > _max:
1✔
1748
            _max = y
1✔
1749
            _max_pos = x
1✔
1750

1751
        if y < _min:
1✔
1752
            _min = y
1✔
1753
            _min_pos = x
1✔
1754

1755
        # Find local maxima
1756
        if y < _max - delta and _max != np.Inf:
1✔
1757
            # Maxima peak candidate found
1758
            # look ahead in signal to ensure that this is a peak and not jitter
1759
            if y_axis[index : index + lookahead].max() < _max:
1✔
1760

1761
                max_peaks.append([_max_pos, _max])
1✔
1762
                dump.append(True)
1✔
1763
                # Set algorithm to only find minima now
1764
                _max = np.Inf
1✔
1765
                _min = np.Inf
1✔
1766

1767
                if index + lookahead >= length:
1✔
1768
                    # The end is within lookahead no more peaks can be found
1769
                    break
×
1770
                continue
×
1771
            # else:
1772
            #    mx = ahead
1773
            #    mxpos = x_axis[np.where(y_axis[index:index+lookahead]==mx)]
1774

1775
        # Find local minima
1776
        if y > _min + delta and _min != -np.Inf:
1✔
1777
            # Minima peak candidate found
1778
            # look ahead in signal to ensure that this is a peak and not jitter
1779
            if y_axis[index : index + lookahead].min() > _min:
1✔
1780

1781
                min_peaks.append([_min_pos, _min])
1✔
1782
                dump.append(False)
1✔
1783
                # Set algorithm to only find maxima now
1784
                _min = -np.Inf
1✔
1785
                _max = -np.Inf
1✔
1786

1787
                if index + lookahead >= length:
1✔
1788
                    # The end is within lookahead no more peaks can be found
1789
                    break
×
1790
            # else:
1791
            #    mn = ahead
1792
            #    mnpos = x_axis[np.where(y_axis[index:index+lookahead]==mn)]
1793

1794
    # Remove the false hit on the first value of the y_axis
1795
    try:
1✔
1796
        if dump[0]:
1✔
1797
            max_peaks.pop(0)
×
1798
        else:
1799
            min_peaks.pop(0)
1✔
1800
        del dump
1✔
1801

1802
    except IndexError:  # When no peaks have been found
×
1803
        pass
×
1804

1805
    return (np.asarray(max_peaks), np.asarray(min_peaks))
1✔
1806

1807

1808
def fit_energy_calibation(
1✔
1809
    pos: Union[List[float], np.ndarray],
1810
    vals: Union[List[float], np.ndarray],
1811
    binwidth: float,
1812
    binning: int,
1813
    ref_id: int = 0,
1814
    ref_energy: float = None,
1815
    t: Union[List[float], np.ndarray] = None,
1816
    energy_scale: str = "kinetic",
1817
    **kwds,
1818
) -> dict:
1819
    """Energy calibration by nonlinear least squares fitting of spectral landmarks on
1820
    a set of (energy dispersion curves (EDCs). This is done here by fitting to the
1821
    function d/(t-t0)**2.
1822

1823
    Args:
1824
        pos (Union[List[float], np.ndarray]): Positions of the spectral landmarks
1825
            (e.g. peaks) in the EDCs.
1826
        vals (Union[List[float], np.ndarray]): Bias voltage value associated with
1827
            each EDC.
1828
        binwidth (float): Time width of each original TOF bin in ns.
1829
        binning (int): Binning factor of the TOF values.
1830
        ref_id (int, optional): Reference dataset index. Defaults to 0.
1831
        ref_energy (float, optional): Energy value of the feature in the refence
1832
            trace (eV). required to output the calibration. Defaults to None.
1833
        t (Union[List[float], np.ndarray], optional): Array of TOF values. Required
1834
            to calculate calibration trace. Defaults to None.
1835
        energy_scale (str, optional): Direction of increasing energy scale.
1836

1837
            - **'kinetic'**: increasing energy with decreasing TOF.
1838
            - **'binding'**: increasing energy with increasing TOF.
1839

1840
    Returns:
1841
        dict: A dictionary of fitting parameters including the following,
1842

1843
        - "coeffs": Fitted function coefficents.
1844
        - "axis": Fitted energy axis.
1845
    """
1846
    vals = np.asarray(vals)
1✔
1847
    nvals = vals.size
1✔
1848

1849
    if ref_id >= nvals:
1✔
1850
        wn.warn(
×
1851
            "Reference index (refid) cannot be larger than the number of traces!\
1852
                Reset to the largest allowed number.",
1853
        )
1854
        ref_id = nvals - 1
×
1855

1856
    def residual(pars, time, data, binwidth, binning, energy_scale):
1✔
1857
        model = tof2ev(
1✔
1858
            pars["d"],
1859
            pars["t0"],
1860
            binwidth,
1861
            binning,
1862
            energy_scale,
1863
            pars["E0"],
1864
            time,
1865
        )
1866
        if data is None:
1✔
1867
            return model
×
1868
        return model - data
1✔
1869

1870
    pars = Parameters()
1✔
1871
    pars.add(name="d", value=kwds.pop("d_init", 1))
1✔
1872
    pars.add(
1✔
1873
        name="t0",
1874
        value=kwds.pop("t0_init", 1e-6),
1875
        max=(min(pos) - 1) * binwidth * 2**binning,
1876
    )
1877
    pars.add(name="E0", value=kwds.pop("E0_init", min(vals)))
1✔
1878
    fit = Minimizer(
1✔
1879
        residual,
1880
        pars,
1881
        fcn_args=(pos, vals, binwidth, binning, energy_scale),
1882
    )
1883
    result = fit.leastsq()
1✔
1884
    report_fit(result)
1✔
1885

1886
    # Construct the calibrating function
1887
    pfunc = partial(
1✔
1888
        tof2ev,
1889
        result.params["d"].value,
1890
        result.params["t0"].value,
1891
        binwidth,
1892
        binning,
1893
        energy_scale,
1894
    )
1895

1896
    # Return results according to specification
1897
    ecalibdict = {}
1✔
1898
    ecalibdict["d"] = result.params["d"].value
1✔
1899
    ecalibdict["t0"] = result.params["t0"].value
1✔
1900
    ecalibdict["E0"] = result.params["E0"].value
1✔
1901
    ecalibdict["energy_scale"] = energy_scale
1✔
1902

1903
    if (ref_energy is not None) and (t is not None):
1✔
1904
        energy_offset = pfunc(-1 * ref_energy, pos[ref_id])
1✔
1905
        ecalibdict["axis"] = pfunc(-energy_offset, t)
1✔
1906
        ecalibdict["E0"] = -energy_offset
1✔
1907
        ecalibdict["refid"] = ref_id
1✔
1908

1909
    return ecalibdict
1✔
1910

1911

1912
def poly_energy_calibration(
1✔
1913
    pos: Union[List[float], np.ndarray],
1914
    vals: Union[List[float], np.ndarray],
1915
    order: int = 3,
1916
    ref_id: int = 0,
1917
    ref_energy: float = None,
1918
    t: Union[List[float], np.ndarray] = None,
1919
    aug: int = 1,
1920
    method: str = "lstsq",
1921
    energy_scale: str = "kinetic",
1922
    **kwds,
1923
) -> dict:
1924
    """Energy calibration by nonlinear least squares fitting of spectral landmarks on
1925
    a set of (energy dispersion curves (EDCs). This amounts to solving for the
1926
    coefficient vector, a, in the system of equations T.a = b. Here T is the
1927
    differential drift time matrix and b the differential bias vector, and
1928
    assuming that the energy-drift-time relationship can be written in the form,
1929
    E = sum_n (a_n * t**n) + E0
1930

1931

1932
    Args:
1933
        pos (Union[List[float], np.ndarray]): Positions of the spectral landmarks
1934
            (e.g. peaks) in the EDCs.
1935
        vals (Union[List[float], np.ndarray]): Bias voltage value associated with
1936
            each EDC.
1937
        order (int, optional): Polynomial order of the fitting function. Defaults to 3.
1938
        ref_id (int, optional): Reference dataset index. Defaults to 0.
1939
        ref_energy (float, optional): Energy value of the feature in the refence
1940
            trace (eV). required to output the calibration. Defaults to None.
1941
        t (Union[List[float], np.ndarray], optional): Array of TOF values. Required
1942
            to calculate calibration trace. Defaults to None.
1943
        aug (int, optional): Fitting dimension augmentation
1944
            (1=no change, 2=double, etc). Defaults to 1.
1945
        method (str, optional): Method for determining the energy calibration.
1946

1947
            - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1948
            - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form..
1949

1950
            Defaults to "lstsq".
1951
        energy_scale (str, optional): Direction of increasing energy scale.
1952

1953
            - **'kinetic'**: increasing energy with decreasing TOF.
1954
            - **'binding'**: increasing energy with increasing TOF.
1955

1956
    Returns:
1957
        dict: A dictionary of fitting parameters including the following,
1958

1959
        - "coeffs": Fitted polynomial coefficients (the a's).
1960
        - "offset": Minimum time-of-flight corresponding to a peak.
1961
        - "Tmat": the T matrix (differential time-of-flight) in the equation Ta=b.
1962
        - "bvec": the b vector (differential bias) in the fitting Ta=b.
1963
        - "axis": Fitted energy axis.
1964
    """
1965
    vals = np.asarray(vals)
1✔
1966
    nvals = vals.size
1✔
1967

1968
    if ref_id >= nvals:
1✔
1969
        wn.warn(
×
1970
            "Reference index (refid) cannot be larger than the number of traces!\
1971
                Reset to the largest allowed number.",
1972
        )
1973
        ref_id = nvals - 1
×
1974

1975
    # Top-to-bottom ordering of terms in the T matrix
1976
    termorder = np.delete(range(0, nvals, 1), ref_id)
1✔
1977
    termorder = np.tile(termorder, aug)
1✔
1978
    # Left-to-right ordering of polynomials in the T matrix
1979
    polyorder = np.linspace(order, 1, order, dtype="int")
1✔
1980

1981
    # Construct the T (differential drift time) matrix, Tmat = Tmain - Tsec
1982
    t_main = np.array([pos[ref_id] ** p for p in polyorder])
1✔
1983
    # Duplicate to the same order as the polynomials
1984
    t_main = np.tile(t_main, (aug * (nvals - 1), 1))
1✔
1985

1986
    t_sec = []
1✔
1987

1988
    for term in termorder:
1✔
1989
        t_sec.append([pos[term] ** p for p in polyorder])
1✔
1990

1991
    t_mat = t_main - np.asarray(t_sec)
1✔
1992

1993
    # Construct the b vector (differential bias)
1994
    bvec = vals[ref_id] - np.delete(vals, ref_id)
1✔
1995
    bvec = np.tile(bvec, aug)
1✔
1996

1997
    # Solve for the a vector (polynomial coefficients) using least squares
1998
    if method == "lstsq":
1✔
1999
        sol = lstsq(t_mat, bvec, rcond=None)
1✔
2000
    elif method == "lsqr":
1✔
2001
        sol = lsqr(t_mat, bvec, **kwds)
1✔
2002
    poly_a = sol[0]
1✔
2003

2004
    # Construct the calibrating function
2005
    pfunc = partial(tof2evpoly, poly_a)
1✔
2006

2007
    # Return results according to specification
2008
    ecalibdict = {}
1✔
2009
    ecalibdict["offset"] = np.asarray(pos).min()
1✔
2010
    ecalibdict["coeffs"] = poly_a
1✔
2011
    ecalibdict["Tmat"] = t_mat
1✔
2012
    ecalibdict["bvec"] = bvec
1✔
2013
    ecalibdict["energy_scale"] = energy_scale
1✔
2014

2015
    if ref_energy is not None and t is not None:
1✔
2016
        energy_offset = pfunc(-1 * ref_energy, pos[ref_id])
1✔
2017
        ecalibdict["axis"] = pfunc(-energy_offset, t)
1✔
2018
        ecalibdict["E0"] = -energy_offset
1✔
2019
        ecalibdict["refid"] = ref_id
1✔
2020

2021
    return ecalibdict
1✔
2022

2023

2024
def tof2ev(
1✔
2025
    tof_distance: float,
2026
    time_offset: float,
2027
    binwidth: float,
2028
    binning: int,
2029
    energy_scale: str,
2030
    energy_offset: float,
2031
    t: float,
2032
) -> float:
2033
    """(d/(t-t0))**2 expression of the time-of-flight to electron volt
2034
    conversion formula.
2035

2036
    Args:
2037
        tof_distance (float): Drift distance in meter.
2038
        time_offset (float): time offset in ns.
2039
        binwidth (float): Time width of each original TOF bin in ns.
2040
        binning (int): Binning factor of the TOF values.
2041
        energy_scale (str, optional): Direction of increasing energy scale.
2042

2043
            - **'kinetic'**: increasing energy with decreasing TOF.
2044
            - **'binding'**: increasing energy with increasing TOF.
2045

2046
        energy_offset (float): Energy offset in eV.
2047
        t (float): TOF value in bin number.
2048

2049
    Returns:
2050
        float: Converted energy in eV
2051
    """
2052
    sign = 1 if energy_scale == "kinetic" else -1
1✔
2053

2054
    #         m_e/2 [eV]                      bin width [s]
2055
    energy = (
1✔
2056
        2.84281e-12 * sign * (tof_distance / (t * binwidth * 2**binning - time_offset)) ** 2
2057
        + energy_offset
2058
    )
2059

2060
    return energy
1✔
2061

2062

2063
def tof2evpoly(
1✔
2064
    poly_a: Union[List[float], np.ndarray],
2065
    energy_offset: float,
2066
    t: float,
2067
) -> float:
2068
    """Polynomial approximation of the time-of-flight to electron volt
2069
    conversion formula.
2070

2071
    Args:
2072
        poly_a (Union[List[float], np.ndarray]): Polynomial coefficients.
2073
        energy_offset (float): Energy offset in eV.
2074
        t (float): TOF value in bin number.
2075

2076
    Returns:
2077
        float: Converted energy.
2078
    """
2079
    odr = len(poly_a)  # Polynomial order
1✔
2080
    poly_a = poly_a[::-1]
1✔
2081
    energy = 0.0
1✔
2082

2083
    for i, order in enumerate(range(1, odr + 1)):
1✔
2084
        energy += poly_a[i] * t**order
1✔
2085
    energy += energy_offset
1✔
2086

2087
    return energy
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc