• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 9799510297

04 Jul 2024 08:17PM UTC coverage: 92.349% (-0.2%) from 92.511%
9799510297

Pull #466

github

rettigl
add tests for illegal keyword errors
Pull Request #466: Catch illegal kwds

106 of 123 new or added lines in 17 files covered. (86.18%)

3 existing lines in 1 file now uncovered.

6940 of 7515 relevant lines covered (92.35%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.28
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from datetime import datetime
1✔
9
from typing import Any
1✔
10
from typing import cast
1✔
11

12
import dask.dataframe as ddf
1✔
13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
import pandas as pd
1✔
16
import psutil
1✔
17
import xarray as xr
1✔
18

19
from sed.binning import bin_dataframe
1✔
20
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
22
from sed.calibrator import DelayCalibrator
1✔
23
from sed.calibrator import EnergyCalibrator
1✔
24
from sed.calibrator import MomentumCorrector
1✔
25
from sed.core.config import parse_config
1✔
26
from sed.core.config import save_config
1✔
27
from sed.core.dfops import add_time_stamped_data
1✔
28
from sed.core.dfops import apply_filter
1✔
29
from sed.core.dfops import apply_jitter
1✔
30
from sed.core.metadata import MetaHandler
1✔
31
from sed.diagnostics import grid_histogram
1✔
32
from sed.io import to_h5
1✔
33
from sed.io import to_nexus
1✔
34
from sed.io import to_tiff
1✔
35
from sed.loader import CopyTool
1✔
36
from sed.loader import get_loader
1✔
37
from sed.loader.mpes.loader import get_archiver_data
1✔
38
from sed.loader.mpes.loader import MpesLoader
1✔
39

40
N_CPU = psutil.cpu_count()
1✔
41

42

43
class SedProcessor:
1✔
44
    """Processor class of sed. Contains wrapper functions defining a work flow for data
45
    correction, calibration and binning.
46

47
    Args:
48
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
49
        config (dict | str, optional): Config dictionary or config file name.
50
            Defaults to None.
51
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
52
            into the class. Defaults to None.
53
        files (list[str], optional): List of files to pass to the loader defined in
54
            the config. Defaults to None.
55
        folder (str, optional): Folder containing files to pass to the loader
56
            defined in the config. Defaults to None.
57
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
58
            defined in the config. Defaults to None.
59
        collect_metadata (bool): Option to collect metadata from files.
60
            Defaults to False.
61
        verbose (bool, optional): Option to print out diagnostic information.
62
            Defaults to config["core"]["verbose"] or False.
63
        **kwds: Keyword arguments passed to the reader.
64
    """
65

66
    def __init__(
1✔
67
        self,
68
        metadata: dict = None,
69
        config: dict | str = None,
70
        dataframe: pd.DataFrame | ddf.DataFrame = None,
71
        files: list[str] = None,
72
        folder: str = None,
73
        runs: Sequence[str] = None,
74
        collect_metadata: bool = False,
75
        verbose: bool = None,
76
        **kwds,
77
    ):
78
        """Processor class of sed. Contains wrapper functions defining a work flow
79
        for data correction, calibration, and binning.
80

81
        Args:
82
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
83
            config (dict | str, optional): Config dictionary or config file name.
84
                Defaults to None.
85
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
86
                into the class. Defaults to None.
87
            files (list[str], optional): List of files to pass to the loader defined in
88
                the config. Defaults to None.
89
            folder (str, optional): Folder containing files to pass to the loader
90
                defined in the config. Defaults to None.
91
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
92
                defined in the config. Defaults to None.
93
            collect_metadata (bool, optional): Option to collect metadata from files.
94
                Defaults to False.
95
            verbose (bool, optional): Option to print out diagnostic information.
96
                Defaults to config["core"]["verbose"] or False.
97
            **kwds: Keyword arguments passed to parse_config and to the reader.
98
        """
99
        # split off config keywords
100
        config_kwds = {
1✔
101
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
102
        }
103
        for key in config_kwds.keys():
1✔
104
            del kwds[key]
1✔
105
        self._config = parse_config(config, **config_kwds)
1✔
106
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
107
        if num_cores >= N_CPU:
1✔
108
            num_cores = N_CPU - 1
1✔
109
        self._config["binning"]["num_cores"] = num_cores
1✔
110

111
        if verbose is None:
1✔
112
            self.verbose = self._config["core"].get("verbose", False)
1✔
113
        else:
114
            self.verbose = verbose
1✔
115

116
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
117
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
118
        self._files: list[str] = []
1✔
119

120
        self._binned: xr.DataArray = None
1✔
121
        self._pre_binned: xr.DataArray = None
1✔
122
        self._normalization_histogram: xr.DataArray = None
1✔
123
        self._normalized: xr.DataArray = None
1✔
124

125
        self._attributes = MetaHandler(meta=metadata)
1✔
126

127
        loader_name = self._config["core"]["loader"]
1✔
128
        self.loader = get_loader(
1✔
129
            loader_name=loader_name,
130
            config=self._config,
131
        )
132

133
        self.ec = EnergyCalibrator(
1✔
134
            loader=get_loader(
135
                loader_name=loader_name,
136
                config=self._config,
137
            ),
138
            config=self._config,
139
        )
140

141
        self.mc = MomentumCorrector(
1✔
142
            config=self._config,
143
        )
144

145
        self.dc = DelayCalibrator(
1✔
146
            config=self._config,
147
        )
148

149
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
150
            "use_copy_tool",
151
            False,
152
        )
153
        if self.use_copy_tool:
1✔
154
            try:
1✔
155
                self.ct = CopyTool(
1✔
156
                    source=self._config["core"]["copy_tool_source"],
157
                    dest=self._config["core"]["copy_tool_dest"],
158
                    **self._config["core"].get("copy_tool_kwds", {}),
159
                )
160
            except KeyError:
1✔
161
                self.use_copy_tool = False
1✔
162

163
        # Load data if provided:
164
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
165
            self.load(
1✔
166
                dataframe=dataframe,
167
                metadata=metadata,
168
                files=files,
169
                folder=folder,
170
                runs=runs,
171
                collect_metadata=collect_metadata,
172
                **kwds,
173
            )
174

175
    def __repr__(self):
1✔
176
        if self._dataframe is None:
1✔
177
            df_str = "Dataframe: No Data loaded"
1✔
178
        else:
179
            df_str = self._dataframe.__repr__()
1✔
180
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
181
        return pretty_str
1✔
182

183
    def _repr_html_(self):
1✔
184
        html = "<div>"
×
185

186
        if self._dataframe is None:
×
187
            df_html = "Dataframe: No Data loaded"
×
188
        else:
189
            df_html = self._dataframe._repr_html_()
×
190

191
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
192

193
        # Add expandable section for attributes
194
        html += "<details><summary>Metadata</summary>"
×
195
        html += "<div style='padding-left: 10px;'>"
×
196
        html += self._attributes._repr_html_()
×
197
        html += "</div></details>"
×
198

199
        html += "</div>"
×
200

201
        return html
×
202

203
    ## Suggestion:
204
    # @property
205
    # def overview_panel(self):
206
    #     """Provides an overview panel with plots of different data attributes."""
207
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
208

209
    @property
1✔
210
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
211
        """Accessor to the underlying dataframe.
212

213
        Returns:
214
            pd.DataFrame | ddf.DataFrame: Dataframe object.
215
        """
216
        return self._dataframe
1✔
217

218
    @dataframe.setter
1✔
219
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
220
        """Setter for the underlying dataframe.
221

222
        Args:
223
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
224
        """
225
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
226
            dataframe,
227
            self._dataframe.__class__,
228
        ):
229
            raise ValueError(
1✔
230
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
231
                "as the dataframe loaded into the SedProcessor!.\n"
232
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
233
            )
234
        self._dataframe = dataframe
1✔
235

236
    @property
1✔
237
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
238
        """Accessor to the underlying timed_dataframe.
239

240
        Returns:
241
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
242
        """
243
        return self._timed_dataframe
1✔
244

245
    @timed_dataframe.setter
1✔
246
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
247
        """Setter for the underlying timed dataframe.
248

249
        Args:
250
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
251
        """
252
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
253
            timed_dataframe,
254
            self._timed_dataframe.__class__,
255
        ):
256
            raise ValueError(
×
257
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
258
                "kind as the dataframe loaded into the SedProcessor!.\n"
259
                f"Loaded type: {self._timed_dataframe.__class__}, "
260
                f"provided type: {timed_dataframe}.",
261
            )
262
        self._timed_dataframe = timed_dataframe
×
263

264
    @property
1✔
265
    def attributes(self) -> MetaHandler:
1✔
266
        """Accessor to the metadata dict.
267

268
        Returns:
269
            MetaHandler: The metadata object
270
        """
271
        return self._attributes
1✔
272

273
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
274
        """Function to add element to the attributes dict.
275

276
        Args:
277
            attributes (dict): The attributes dictionary object to add.
278
            name (str): Key under which to add the dictionary to the attributes.
279
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
280
        """
281
        self._attributes.add(
1✔
282
            entry=attributes,
283
            name=name,
284
            **kwds,
285
        )
286

287
    @property
1✔
288
    def config(self) -> dict[Any, Any]:
1✔
289
        """Getter attribute for the config dictionary
290

291
        Returns:
292
            dict: The config dictionary.
293
        """
294
        return self._config
1✔
295

296
    @property
1✔
297
    def files(self) -> list[str]:
1✔
298
        """Getter attribute for the list of files
299

300
        Returns:
301
            list[str]: The list of loaded files
302
        """
303
        return self._files
1✔
304

305
    @property
1✔
306
    def binned(self) -> xr.DataArray:
1✔
307
        """Getter attribute for the binned data array
308

309
        Returns:
310
            xr.DataArray: The binned data array
311
        """
312
        if self._binned is None:
1✔
313
            raise ValueError("No binned data available, need to compute histogram first!")
×
314
        return self._binned
1✔
315

316
    @property
1✔
317
    def normalized(self) -> xr.DataArray:
1✔
318
        """Getter attribute for the normalized data array
319

320
        Returns:
321
            xr.DataArray: The normalized data array
322
        """
323
        if self._normalized is None:
1✔
324
            raise ValueError(
×
325
                "No normalized data available, compute data with normalization enabled!",
326
            )
327
        return self._normalized
1✔
328

329
    @property
1✔
330
    def normalization_histogram(self) -> xr.DataArray:
1✔
331
        """Getter attribute for the normalization histogram
332

333
        Returns:
334
            xr.DataArray: The normalization histogram
335
        """
336
        if self._normalization_histogram is None:
1✔
337
            raise ValueError("No normalization histogram available, generate histogram first!")
×
338
        return self._normalization_histogram
1✔
339

340
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
341
        """Function to mirror a list of files or a folder from a network drive to a
342
        local storage. Returns either the original or the copied path to the given
343
        path. The option to use this functionality is set by
344
        config["core"]["use_copy_tool"].
345

346
        Args:
347
            path (str | list[str]): Source path or path list.
348

349
        Returns:
350
            str | list[str]: Source or destination path or path list.
351
        """
352
        if self.use_copy_tool:
1✔
353
            if isinstance(path, list):
1✔
354
                path_out = []
1✔
355
                for file in path:
1✔
356
                    path_out.append(self.ct.copy(file))
1✔
357
                return path_out
1✔
358

359
            return self.ct.copy(path)
×
360

361
        if isinstance(path, list):
1✔
362
            return path
1✔
363

364
        return path
1✔
365

366
    def load(
1✔
367
        self,
368
        dataframe: pd.DataFrame | ddf.DataFrame = None,
369
        metadata: dict = None,
370
        files: list[str] = None,
371
        folder: str = None,
372
        runs: Sequence[str] = None,
373
        collect_metadata: bool = False,
374
        **kwds,
375
    ):
376
        """Load tabular data of single events into the dataframe object in the class.
377

378
        Args:
379
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
380
                format. Accepts anything which can be interpreted by pd.DataFrame as
381
                an input. Defaults to None.
382
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
383
            files (list[str], optional): List of file paths to pass to the loader.
384
                Defaults to None.
385
            runs (Sequence[str], optional): List of run identifiers to pass to the
386
                loader. Defaults to None.
387
            folder (str, optional): Folder path to pass to the loader.
388
                Defaults to None.
389
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
390
            **kwds:
391
                - *timed_dataframe*: timed dataframe if dataframe is provided.
392

393
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
394

395
        Raises:
396
            ValueError: Raised if no valid input is provided.
397
        """
398
        if metadata is None:
1✔
399
            metadata = {}
1✔
400
        if dataframe is not None:
1✔
401
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
402
        elif runs is not None:
1✔
403
            # If runs are provided, we only use the copy tool if also folder is provided.
404
            # In that case, we copy the whole provided base folder tree, and pass the copied
405
            # version to the loader as base folder to look for the runs.
406
            if folder is not None:
1✔
407
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
408
                    folders=cast(str, self.cpy(folder)),
409
                    runs=runs,
410
                    metadata=metadata,
411
                    collect_metadata=collect_metadata,
412
                    **kwds,
413
                )
414
            else:
415
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
416
                    runs=runs,
417
                    metadata=metadata,
418
                    collect_metadata=collect_metadata,
419
                    **kwds,
420
                )
421

422
        elif folder is not None:
1✔
423
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
424
                folders=cast(str, self.cpy(folder)),
425
                metadata=metadata,
426
                collect_metadata=collect_metadata,
427
                **kwds,
428
            )
429
        elif files is not None:
1✔
430
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
431
                files=cast(list[str], self.cpy(files)),
432
                metadata=metadata,
433
                collect_metadata=collect_metadata,
434
                **kwds,
435
            )
436
        else:
437
            raise ValueError(
1✔
438
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
439
            )
440

441
        self._dataframe = dataframe
1✔
442
        self._timed_dataframe = timed_dataframe
1✔
443
        self._files = self.loader.files
1✔
444

445
        for key in metadata:
1✔
446
            self._attributes.add(
1✔
447
                entry=metadata[key],
448
                name=key,
449
                duplicate_policy="merge",
450
            )
451

452
    def filter_column(
1✔
453
        self,
454
        column: str,
455
        min_value: float = -np.inf,
456
        max_value: float = np.inf,
457
    ) -> None:
458
        """Filter values in a column which are outside of a given range
459

460
        Args:
461
            column (str): Name of the column to filter
462
            min_value (float, optional): Minimum value to keep. Defaults to None.
463
            max_value (float, optional): Maximum value to keep. Defaults to None.
464
        """
465
        if column != "index" and column not in self._dataframe.columns:
1✔
466
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
467
        if min_value >= max_value:
1✔
468
            raise ValueError("min_value has to be smaller than max_value!")
1✔
469
        if self._dataframe is not None:
1✔
470
            self._dataframe = apply_filter(
1✔
471
                self._dataframe,
472
                col=column,
473
                lower_bound=min_value,
474
                upper_bound=max_value,
475
            )
476
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
477
            self._timed_dataframe = apply_filter(
1✔
478
                self._timed_dataframe,
479
                column,
480
                lower_bound=min_value,
481
                upper_bound=max_value,
482
            )
483
        metadata = {
1✔
484
            "filter": {
485
                "column": column,
486
                "min_value": min_value,
487
                "max_value": max_value,
488
            },
489
        }
490
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
491

492
    # Momentum calibration workflow
493
    # 1. Bin raw detector data for distortion correction
494
    def bin_and_load_momentum_calibration(
1✔
495
        self,
496
        df_partitions: int | Sequence[int] = 100,
497
        axes: list[str] = None,
498
        bins: list[int] = None,
499
        ranges: Sequence[tuple[float, float]] = None,
500
        plane: int = 0,
501
        width: int = 5,
502
        apply: bool = False,
503
        **kwds,
504
    ):
505
        """1st step of momentum correction work flow. Function to do an initial binning
506
        of the dataframe loaded to the class, slice a plane from it using an
507
        interactive view, and load it into the momentum corrector class.
508

509
        Args:
510
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
511
                to use for the initial binning. Defaults to 100.
512
            axes (list[str], optional): Axes to bin.
513
                Defaults to config["momentum"]["axes"].
514
            bins (list[int], optional): Bin numbers to use for binning.
515
                Defaults to config["momentum"]["bins"].
516
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
517
                Defaults to config["momentum"]["ranges"].
518
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
519
            width (int, optional): Initial value for the width slider. Defaults to 5.
520
            apply (bool, optional): Option to directly apply the values and select the
521
                slice. Defaults to False.
522
            **kwds: Keyword argument passed to the pre_binning function.
523
        """
524
        self._pre_binned = self.pre_binning(
1✔
525
            df_partitions=df_partitions,
526
            axes=axes,
527
            bins=bins,
528
            ranges=ranges,
529
            **kwds,
530
        )
531

532
        self.mc.load_data(data=self._pre_binned)
1✔
533
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
534

535
    # 2. Generate the spline warp correction from momentum features.
536
    # Either autoselect features, or input features from view above.
537
    def define_features(
1✔
538
        self,
539
        features: np.ndarray = None,
540
        rotation_symmetry: int = 6,
541
        auto_detect: bool = False,
542
        include_center: bool = True,
543
        apply: bool = False,
544
        **kwds,
545
    ):
546
        """2. Step of the distortion correction workflow: Define feature points in
547
        momentum space. They can be either manually selected using a GUI tool, be
548
        provided as list of feature points, or auto-generated using a
549
        feature-detection algorithm.
550

551
        Args:
552
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
553
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
554
                Defaults to 6.
555
            auto_detect (bool, optional): Whether to auto-detect the features.
556
                Defaults to False.
557
            include_center (bool, optional): Option to include a point at the center
558
                in the feature list. Defaults to True.
559
            apply (bool, optional): Option to directly apply the values and select the
560
                slice. Defaults to False.
561
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
562
                ``MomentumCorrector.feature_select()``.
563
        """
564
        if auto_detect:  # automatic feature selection
1✔
565
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
566
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
567
            sigma_radius = kwds.pop(
×
568
                "sigma_radius",
569
                self._config["momentum"]["sigma_radius"],
570
            )
571
            self.mc.feature_extract(
×
572
                sigma=sigma,
573
                fwhm=fwhm,
574
                sigma_radius=sigma_radius,
575
                rotsym=rotation_symmetry,
576
                **kwds,
577
            )
578
            features = self.mc.peaks
×
579

580
        self.mc.feature_select(
1✔
581
            rotsym=rotation_symmetry,
582
            include_center=include_center,
583
            features=features,
584
            apply=apply,
585
            **kwds,
586
        )
587

588
    # 3. Generate the spline warp correction from momentum features.
589
    # If no features have been selected before, use class defaults.
590
    def generate_splinewarp(
1✔
591
        self,
592
        use_center: bool = None,
593
        verbose: bool = None,
594
        **kwds,
595
    ):
596
        """3. Step of the distortion correction workflow: Generate the correction
597
        function restoring the symmetry in the image using a splinewarp algorithm.
598

599
        Args:
600
            use_center (bool, optional): Option to use the position of the
601
                center point in the correction. Default is read from config, or set to True.
602
            verbose (bool, optional): Option to print out diagnostic information.
603
                Defaults to config["core"]["verbose"].
604
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
605
        """
606
        if verbose is None:
1✔
607
            verbose = self.verbose
1✔
608

609
        self.mc.spline_warp_estimate(use_center=use_center, verbose=verbose, **kwds)
1✔
610

611
        if self.mc.slice is not None and verbose:
1✔
612
            print("Original slice with reference features")
1✔
613
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
614

615
            print("Corrected slice with target features")
1✔
616
            self.mc.view(
1✔
617
                image=self.mc.slice_corrected,
618
                annotated=True,
619
                points={"feats": self.mc.ptargs},
620
                backend="bokeh",
621
                crosshair=True,
622
            )
623

624
            print("Original slice with target features")
1✔
625
            self.mc.view(
1✔
626
                image=self.mc.slice,
627
                points={"feats": self.mc.ptargs},
628
                annotated=True,
629
                backend="bokeh",
630
            )
631

632
    # 3a. Save spline-warp parameters to config file.
633
    def save_splinewarp(
1✔
634
        self,
635
        filename: str = None,
636
        overwrite: bool = False,
637
    ):
638
        """Save the generated spline-warp parameters to the folder config file.
639

640
        Args:
641
            filename (str, optional): Filename of the config dictionary to save to.
642
                Defaults to "sed_config.yaml" in the current folder.
643
            overwrite (bool, optional): Option to overwrite the present dictionary.
644
                Defaults to False.
645
        """
646
        if filename is None:
1✔
647
            filename = "sed_config.yaml"
×
648
        if len(self.mc.correction) == 0:
1✔
649
            raise ValueError("No momentum correction parameters to save!")
×
650
        correction = {}
1✔
651
        for key, value in self.mc.correction.items():
1✔
652
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
653
                continue
1✔
654
            if key in ["use_center", "rotation_symmetry"]:
1✔
655
                correction[key] = value
1✔
656
            elif key in ["center_point", "ascale"]:
1✔
657
                correction[key] = [float(i) for i in value]
1✔
658
            elif key in ["outer_points", "feature_points"]:
1✔
659
                correction[key] = []
1✔
660
                for point in value:
1✔
661
                    correction[key].append([float(i) for i in point])
1✔
662
            else:
663
                correction[key] = float(value)
1✔
664

665
        if "creation_date" not in correction:
1✔
666
            correction["creation_date"] = datetime.now().timestamp()
×
667

668
        config = {
1✔
669
            "momentum": {
670
                "correction": correction,
671
            },
672
        }
673
        save_config(config, filename, overwrite)
1✔
674
        print(f'Saved momentum correction parameters to "{filename}".')
1✔
675

676
    # 4. Pose corrections. Provide interactive interface for correcting
677
    # scaling, shift and rotation
678
    def pose_adjustment(
1✔
679
        self,
680
        transformations: dict[str, Any] = None,
681
        apply: bool = False,
682
        use_correction: bool = True,
683
        reset: bool = True,
684
        verbose: bool = None,
685
        **kwds,
686
    ):
687
        """3. step of the distortion correction workflow: Generate an interactive panel
688
        to adjust affine transformations that are applied to the image. Applies first
689
        a scaling, next an x/y translation, and last a rotation around the center of
690
        the image.
691

692
        Args:
693
            transformations (dict[str, Any], optional): Dictionary with transformations.
694
                Defaults to self.transformations or config["momentum"]["transformations"].
695
            apply (bool, optional): Option to directly apply the provided
696
                transformations. Defaults to False.
697
            use_correction (bool, option): Whether to use the spline warp correction
698
                or not. Defaults to True.
699
            reset (bool, optional): Option to reset the correction before transformation.
700
                Defaults to True.
701
            verbose (bool, optional): Option to print out diagnostic information.
702
                Defaults to config["core"]["verbose"].
703
            **kwds: Keyword parameters defining defaults for the transformations:
704

705
                - **scale** (float): Initial value of the scaling slider.
706
                - **xtrans** (float): Initial value of the xtrans slider.
707
                - **ytrans** (float): Initial value of the ytrans slider.
708
                - **angle** (float): Initial value of the angle slider.
709
        """
710
        if verbose is None:
1✔
711
            verbose = self.verbose
1✔
712

713
        # Generate homography as default if no distortion correction has been applied
714
        if self.mc.slice_corrected is None:
1✔
715
            if self.mc.slice is None:
1✔
716
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
717
            self.mc.slice_corrected = self.mc.slice
1✔
718

719
        if not use_correction:
1✔
720
            self.mc.reset_deformation()
1✔
721

722
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
723
            # Generate distortion correction from config values
724
            self.mc.spline_warp_estimate(verbose=verbose)
×
725

726
        self.mc.pose_adjustment(
1✔
727
            transformations=transformations,
728
            apply=apply,
729
            reset=reset,
730
            verbose=verbose,
731
            **kwds,
732
        )
733

734
    # 4a. Save pose adjustment parameters to config file.
735
    def save_transformations(
1✔
736
        self,
737
        filename: str = None,
738
        overwrite: bool = False,
739
    ):
740
        """Save the pose adjustment parameters to the folder config file.
741

742
        Args:
743
            filename (str, optional): Filename of the config dictionary to save to.
744
                Defaults to "sed_config.yaml" in the current folder.
745
            overwrite (bool, optional): Option to overwrite the present dictionary.
746
                Defaults to False.
747
        """
748
        if filename is None:
1✔
749
            filename = "sed_config.yaml"
×
750
        if len(self.mc.transformations) == 0:
1✔
751
            raise ValueError("No momentum transformation parameters to save!")
×
752
        transformations = {}
1✔
753
        for key, value in self.mc.transformations.items():
1✔
754
            transformations[key] = float(value)
1✔
755

756
        if "creation_date" not in transformations:
1✔
757
            transformations["creation_date"] = datetime.now().timestamp()
×
758

759
        config = {
1✔
760
            "momentum": {
761
                "transformations": transformations,
762
            },
763
        }
764
        save_config(config, filename, overwrite)
1✔
765
        print(f'Saved momentum transformation parameters to "{filename}".')
1✔
766

767
    # 5. Apply the momentum correction to the dataframe
768
    def apply_momentum_correction(
1✔
769
        self,
770
        preview: bool = False,
771
        verbose: bool = None,
772
        **kwds,
773
    ):
774
        """Applies the distortion correction and pose adjustment (optional)
775
        to the dataframe.
776

777
        Args:
778
            preview (bool, optional): Option to preview the first elements of the data frame.
779
                Defaults to False.
780
            verbose (bool, optional): Option to print out diagnostic information.
781
                Defaults to config["core"]["verbose"].
782
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
783

784
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
785
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
786
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
787

788
        """
789
        if verbose is None:
1✔
790
            verbose = self.verbose
1✔
791

792
        x_column = self._config["dataframe"]["x_column"]
1✔
793
        y_column = self._config["dataframe"]["y_column"]
1✔
794

795
        if self._dataframe is not None:
1✔
796
            if verbose:
1✔
797
                print("Adding corrected X/Y columns to dataframe:")
1✔
798
            df, metadata = self.mc.apply_corrections(
1✔
799
                df=self._dataframe,
800
                verbose=verbose,
801
                **kwds,
802
            )
803
            if (
1✔
804
                self._timed_dataframe is not None
805
                and x_column in self._timed_dataframe.columns
806
                and y_column in self._timed_dataframe.columns
807
            ):
808
                tdf, _ = self.mc.apply_corrections(
1✔
809
                    self._timed_dataframe,
810
                    verbose=False,
811
                    **kwds,
812
                )
813

814
            # Add Metadata
815
            self._attributes.add(
1✔
816
                metadata,
817
                "momentum_correction",
818
                duplicate_policy="merge",
819
            )
820
            self._dataframe = df
1✔
821
            if (
1✔
822
                self._timed_dataframe is not None
823
                and x_column in self._timed_dataframe.columns
824
                and y_column in self._timed_dataframe.columns
825
            ):
826
                self._timed_dataframe = tdf
1✔
827
        else:
828
            raise ValueError("No dataframe loaded!")
×
829
        if preview:
1✔
830
            print(self._dataframe.head(10))
×
831
        else:
832
            if self.verbose:
1✔
833
                print(self._dataframe)
1✔
834

835
    # Momentum calibration work flow
836
    # 1. Calculate momentum calibration
837
    def calibrate_momentum_axes(
1✔
838
        self,
839
        point_a: np.ndarray | list[int] = None,
840
        point_b: np.ndarray | list[int] = None,
841
        k_distance: float = None,
842
        k_coord_a: np.ndarray | list[float] = None,
843
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
844
        equiscale: bool = True,
845
        apply=False,
846
    ):
847
        """1. step of the momentum calibration workflow. Calibrate momentum
848
        axes using either provided pixel coordinates of a high-symmetry point and its
849
        distance to the BZ center, or the k-coordinates of two points in the BZ
850
        (depending on the equiscale option). Opens an interactive panel for selecting
851
        the points.
852

853
        Args:
854
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
855
                point used for momentum calibration.
856
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
857
                second point used for momentum calibration.
858
                Defaults to config["momentum"]["center_pixel"].
859
            k_distance (float, optional): Momentum distance between point a and b.
860
                Needs to be provided if no specific k-coordinates for the two points
861
                are given. Defaults to None.
862
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
863
                of the first point used for calibration. Used if equiscale is False.
864
                Defaults to None.
865
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
866
                of the second point used for calibration. Defaults to [0.0, 0.0].
867
            equiscale (bool, optional): Option to apply different scales to kx and ky.
868
                If True, the distance between points a and b, and the absolute
869
                position of point a are used for defining the scale. If False, the
870
                scale is calculated from the k-positions of both points a and b.
871
                Defaults to True.
872
            apply (bool, optional): Option to directly store the momentum calibration
873
                in the class. Defaults to False.
874
        """
875
        if point_b is None:
1✔
876
            point_b = self._config["momentum"]["center_pixel"]
1✔
877

878
        self.mc.select_k_range(
1✔
879
            point_a=point_a,
880
            point_b=point_b,
881
            k_distance=k_distance,
882
            k_coord_a=k_coord_a,
883
            k_coord_b=k_coord_b,
884
            equiscale=equiscale,
885
            apply=apply,
886
        )
887

888
    # 1a. Save momentum calibration parameters to config file.
889
    def save_momentum_calibration(
1✔
890
        self,
891
        filename: str = None,
892
        overwrite: bool = False,
893
    ):
894
        """Save the generated momentum calibration parameters to the folder config file.
895

896
        Args:
897
            filename (str, optional): Filename of the config dictionary to save to.
898
                Defaults to "sed_config.yaml" in the current folder.
899
            overwrite (bool, optional): Option to overwrite the present dictionary.
900
                Defaults to False.
901
        """
902
        if filename is None:
1✔
903
            filename = "sed_config.yaml"
×
904
        if len(self.mc.calibration) == 0:
1✔
905
            raise ValueError("No momentum calibration parameters to save!")
×
906
        calibration = {}
1✔
907
        for key, value in self.mc.calibration.items():
1✔
908
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
909
                continue
1✔
910

911
            calibration[key] = float(value)
1✔
912

913
        if "creation_date" not in calibration:
1✔
914
            calibration["creation_date"] = datetime.now().timestamp()
×
915

916
        config = {"momentum": {"calibration": calibration}}
1✔
917
        save_config(config, filename, overwrite)
1✔
918
        print(f"Saved momentum calibration parameters to {filename}")
1✔
919

920
    # 2. Apply correction and calibration to the dataframe
921
    def apply_momentum_calibration(
1✔
922
        self,
923
        calibration: dict = None,
924
        preview: bool = False,
925
        verbose: bool = None,
926
        **kwds,
927
    ):
928
        """2. step of the momentum calibration work flow: Apply the momentum
929
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
930
        these are used.
931

932
        Args:
933
            calibration (dict, optional): Optional dictionary with calibration data to
934
                use. Defaults to None.
935
            preview (bool, optional): Option to preview the first elements of the data frame.
936
                Defaults to False.
937
            verbose (bool, optional): Option to print out diagnostic information.
938
                Defaults to config["core"]["verbose"].
939
            **kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
940
        """
941
        if verbose is None:
1✔
942
            verbose = self.verbose
1✔
943

944
        x_column = self._config["dataframe"]["x_column"]
1✔
945
        y_column = self._config["dataframe"]["y_column"]
1✔
946

947
        if self._dataframe is not None:
1✔
948
            if verbose:
1✔
949
                print("Adding kx/ky columns to dataframe:")
1✔
950
            df, metadata = self.mc.append_k_axis(
1✔
951
                df=self._dataframe,
952
                calibration=calibration,
953
                **kwds,
954
            )
955
            if (
1✔
956
                self._timed_dataframe is not None
957
                and x_column in self._timed_dataframe.columns
958
                and y_column in self._timed_dataframe.columns
959
            ):
960
                tdf, _ = self.mc.append_k_axis(
1✔
961
                    df=self._timed_dataframe,
962
                    calibration=calibration,
963
                    **kwds,
964
                )
965

966
            # Add Metadata
967
            self._attributes.add(
1✔
968
                metadata,
969
                "momentum_calibration",
970
                duplicate_policy="merge",
971
            )
972
            self._dataframe = df
1✔
973
            if (
1✔
974
                self._timed_dataframe is not None
975
                and x_column in self._timed_dataframe.columns
976
                and y_column in self._timed_dataframe.columns
977
            ):
978
                self._timed_dataframe = tdf
1✔
979
        else:
980
            raise ValueError("No dataframe loaded!")
×
981
        if preview:
1✔
982
            print(self._dataframe.head(10))
×
983
        else:
984
            if self.verbose:
1✔
985
                print(self._dataframe)
1✔
986

987
    # Energy correction workflow
988
    # 1. Adjust the energy correction parameters
989
    def adjust_energy_correction(
1✔
990
        self,
991
        correction_type: str = None,
992
        amplitude: float = None,
993
        center: tuple[float, float] = None,
994
        apply=False,
995
        **kwds,
996
    ):
997
        """1. step of the energy correction workflow: Opens an interactive plot to
998
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
999
        they are not present yet.
1000

1001
        Args:
1002
            correction_type (str, optional): Type of correction to apply to the TOF
1003
                axis. Valid values are:
1004

1005
                - 'spherical'
1006
                - 'Lorentzian'
1007
                - 'Gaussian'
1008
                - 'Lorentzian_asymmetric'
1009

1010
                Defaults to config["energy"]["correction_type"].
1011
            amplitude (float, optional): Amplitude of the correction.
1012
                Defaults to config["energy"]["correction"]["amplitude"].
1013
            center (tuple[float, float], optional): Center X/Y coordinates for the
1014
                correction. Defaults to config["energy"]["correction"]["center"].
1015
            apply (bool, optional): Option to directly apply the provided or default
1016
                correction parameters. Defaults to False.
1017
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1018
        """
1019
        if self._pre_binned is None:
1✔
1020
            print(
1✔
1021
                "Pre-binned data not present, binning using defaults from config...",
1022
            )
1023
            self._pre_binned = self.pre_binning()
1✔
1024

1025
        self.ec.adjust_energy_correction(
1✔
1026
            self._pre_binned,
1027
            correction_type=correction_type,
1028
            amplitude=amplitude,
1029
            center=center,
1030
            apply=apply,
1031
            **kwds,
1032
        )
1033

1034
    # 1a. Save energy correction parameters to config file.
1035
    def save_energy_correction(
1✔
1036
        self,
1037
        filename: str = None,
1038
        overwrite: bool = False,
1039
    ):
1040
        """Save the generated energy correction parameters to the folder config file.
1041

1042
        Args:
1043
            filename (str, optional): Filename of the config dictionary to save to.
1044
                Defaults to "sed_config.yaml" in the current folder.
1045
            overwrite (bool, optional): Option to overwrite the present dictionary.
1046
                Defaults to False.
1047
        """
1048
        if filename is None:
1✔
1049
            filename = "sed_config.yaml"
1✔
1050
        if len(self.ec.correction) == 0:
1✔
1051
            raise ValueError("No energy correction parameters to save!")
×
1052
        correction = {}
1✔
1053
        for key, val in self.ec.correction.items():
1✔
1054
            if key == "correction_type":
1✔
1055
                correction[key] = val
1✔
1056
            elif key == "center":
1✔
1057
                correction[key] = [float(i) for i in val]
1✔
1058
            else:
1059
                correction[key] = float(val)
1✔
1060

1061
        if "creation_date" not in correction:
1✔
1062
            correction["creation_date"] = datetime.now().timestamp()
×
1063

1064
        config = {"energy": {"correction": correction}}
1✔
1065
        save_config(config, filename, overwrite)
1✔
1066
        print(f"Saved energy correction parameters to {filename}")
1✔
1067

1068
    # 2. Apply energy correction to dataframe
1069
    def apply_energy_correction(
1✔
1070
        self,
1071
        correction: dict = None,
1072
        preview: bool = False,
1073
        verbose: bool = None,
1074
        **kwds,
1075
    ):
1076
        """2. step of the energy correction workflow: Apply the energy correction
1077
        parameters stored in the class to the dataframe.
1078

1079
        Args:
1080
            correction (dict, optional): Dictionary containing the correction
1081
                parameters. Defaults to config["energy"]["calibration"].
1082
            preview (bool, optional): Option to preview the first elements of the data frame.
1083
                Defaults to False.
1084
            verbose (bool, optional): Option to print out diagnostic information.
1085
                Defaults to config["core"]["verbose"].
1086
            **kwds:
1087
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1088
        """
1089
        if verbose is None:
1✔
1090
            verbose = self.verbose
1✔
1091

1092
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1093

1094
        if self._dataframe is not None:
1✔
1095
            if verbose:
1✔
1096
                print("Applying energy correction to dataframe...")
1✔
1097
            df, metadata = self.ec.apply_energy_correction(
1✔
1098
                df=self._dataframe,
1099
                correction=correction,
1100
                verbose=verbose,
1101
                **kwds,
1102
            )
1103
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1104
                tdf, _ = self.ec.apply_energy_correction(
1✔
1105
                    df=self._timed_dataframe,
1106
                    correction=correction,
1107
                    verbose=False,
1108
                    **kwds,
1109
                )
1110

1111
            # Add Metadata
1112
            self._attributes.add(
1✔
1113
                metadata,
1114
                "energy_correction",
1115
            )
1116
            self._dataframe = df
1✔
1117
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1118
                self._timed_dataframe = tdf
1✔
1119
        else:
1120
            raise ValueError("No dataframe loaded!")
×
1121
        if preview:
1✔
1122
            print(self._dataframe.head(10))
×
1123
        else:
1124
            if verbose:
1✔
1125
                print(self._dataframe)
×
1126

1127
    # Energy calibrator workflow
1128
    # 1. Load and normalize data
1129
    def load_bias_series(
1✔
1130
        self,
1131
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1132
        data_files: list[str] = None,
1133
        axes: list[str] = None,
1134
        bins: list = None,
1135
        ranges: Sequence[tuple[float, float]] = None,
1136
        biases: np.ndarray = None,
1137
        bias_key: str = None,
1138
        normalize: bool = None,
1139
        span: int = None,
1140
        order: int = None,
1141
    ):
1142
        """1. step of the energy calibration workflow: Load and bin data from
1143
        single-event files, or load binned bias/TOF traces.
1144

1145
        Args:
1146
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1147
                Binned data If provided as DataArray, Needs to contain dimensions
1148
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1149
                provided as tuple, needs to contain elements tof, biases, traces.
1150
            data_files (list[str], optional): list of file paths to bin
1151
            axes (list[str], optional): bin axes.
1152
                Defaults to config["dataframe"]["tof_column"].
1153
            bins (list, optional): number of bins.
1154
                Defaults to config["energy"]["bins"].
1155
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1156
                Defaults to config["energy"]["ranges"].
1157
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1158
                voltages are extracted from the data files.
1159
            bias_key (str, optional): hdf5 path where bias values are stored.
1160
                Defaults to config["energy"]["bias_key"].
1161
            normalize (bool, optional): Option to normalize traces.
1162
                Defaults to config["energy"]["normalize"].
1163
            span (int, optional): span smoothing parameters of the LOESS method
1164
                (see ``scipy.signal.savgol_filter()``).
1165
                Defaults to config["energy"]["normalize_span"].
1166
            order (int, optional): order smoothing parameters of the LOESS method
1167
                (see ``scipy.signal.savgol_filter()``).
1168
                Defaults to config["energy"]["normalize_order"].
1169
        """
1170
        if binned_data is not None:
1✔
1171
            if isinstance(binned_data, xr.DataArray):
1✔
1172
                if (
1✔
1173
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
1174
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
1175
                ):
1176
                    raise ValueError(
1✔
1177
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1178
                        f"'{self._config['dataframe']['tof_column']}' and "
1179
                        f"'{self._config['dataframe']['bias_column']}'!.",
1180
                    )
1181
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
1182
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
1183
                traces = binned_data.values[:, :]
1✔
1184
            else:
1185
                try:
1✔
1186
                    (tof, biases, traces) = binned_data
1✔
1187
                except ValueError as exc:
1✔
1188
                    raise ValueError(
1✔
1189
                        "If binned_data is provided as tuple, it needs to contain "
1190
                        "(tof, biases, traces)!",
1191
                    ) from exc
1192
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1193

1194
        elif data_files is not None:
1✔
1195
            self.ec.bin_data(
1✔
1196
                data_files=cast(list[str], self.cpy(data_files)),
1197
                axes=axes,
1198
                bins=bins,
1199
                ranges=ranges,
1200
                biases=biases,
1201
                bias_key=bias_key,
1202
            )
1203

1204
        else:
1205
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1206

1207
        if (normalize is not None and normalize is True) or (
1✔
1208
            normalize is None and self._config["energy"]["normalize"]
1209
        ):
1210
            if span is None:
1✔
1211
                span = self._config["energy"]["normalize_span"]
1✔
1212
            if order is None:
1✔
1213
                order = self._config["energy"]["normalize_order"]
1✔
1214
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1215
        self.ec.view(
1✔
1216
            traces=self.ec.traces_normed,
1217
            xaxis=self.ec.tof,
1218
            backend="bokeh",
1219
        )
1220

1221
    # 2. extract ranges and get peak positions
1222
    def find_bias_peaks(
1✔
1223
        self,
1224
        ranges: list[tuple] | tuple,
1225
        ref_id: int = 0,
1226
        infer_others: bool = True,
1227
        mode: str = "replace",
1228
        radius: int = None,
1229
        peak_window: int = None,
1230
        apply: bool = False,
1231
    ):
1232
        """2. step of the energy calibration workflow: Find a peak within a given range
1233
        for the indicated reference trace, and tries to find the same peak for all
1234
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1235
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1236
        middle of the set, and don't choose the range too narrow around the peak.
1237
        Alternatively, a list of ranges for all traces can be provided.
1238

1239
        Args:
1240
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1241
                Alternatively, a list of ranges for all traces can be given.
1242
            ref_id (int, optional): The id of the trace the range refers to.
1243
                Defaults to 0.
1244
            infer_others (bool, optional): Whether to determine the range for the other
1245
                traces. Defaults to True.
1246
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1247
                Defaults to "replace".
1248
            radius (int, optional): Radius parameter for fast_dtw.
1249
                Defaults to config["energy"]["fastdtw_radius"].
1250
            peak_window (int, optional): Peak_window parameter for the peak detection
1251
                algorithm. amount of points that have to have to behave monotonously
1252
                around a peak. Defaults to config["energy"]["peak_window"].
1253
            apply (bool, optional): Option to directly apply the provided parameters.
1254
                Defaults to False.
1255
        """
1256
        if radius is None:
1✔
1257
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1258
        if peak_window is None:
1✔
1259
            peak_window = self._config["energy"]["peak_window"]
1✔
1260
        if not infer_others:
1✔
1261
            self.ec.add_ranges(
1✔
1262
                ranges=ranges,
1263
                ref_id=ref_id,
1264
                infer_others=infer_others,
1265
                mode=mode,
1266
                radius=radius,
1267
            )
1268
            print(self.ec.featranges)
1✔
1269
            try:
1✔
1270
                self.ec.feature_extract(peak_window=peak_window)
1✔
1271
                self.ec.view(
1✔
1272
                    traces=self.ec.traces_normed,
1273
                    segs=self.ec.featranges,
1274
                    xaxis=self.ec.tof,
1275
                    peaks=self.ec.peaks,
1276
                    backend="bokeh",
1277
                )
1278
            except IndexError:
×
1279
                print("Could not determine all peaks!")
×
1280
                raise
×
1281
        else:
1282
            # New adjustment tool
1283
            assert isinstance(ranges, tuple)
1✔
1284
            self.ec.adjust_ranges(
1✔
1285
                ranges=ranges,
1286
                ref_id=ref_id,
1287
                traces=self.ec.traces_normed,
1288
                infer_others=infer_others,
1289
                radius=radius,
1290
                peak_window=peak_window,
1291
                apply=apply,
1292
            )
1293

1294
    # 3. Fit the energy calibration relation
1295
    def calibrate_energy_axis(
1✔
1296
        self,
1297
        ref_energy: float,
1298
        method: str = None,
1299
        energy_scale: str = None,
1300
        verbose: bool = None,
1301
        **kwds,
1302
    ):
1303
        """3. Step of the energy calibration workflow: Calculate the calibration
1304
        function for the energy axis, and apply it to the dataframe. Two
1305
        approximations are implemented, a (normally 3rd order) polynomial
1306
        approximation, and a d^2/(t-t0)^2 relation.
1307

1308
        Args:
1309
            ref_energy (float): Binding/kinetic energy of the detected feature.
1310
            method (str, optional): Method for determining the energy calibration.
1311

1312
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1313
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1314

1315
                Defaults to config["energy"]["calibration_method"]
1316
            energy_scale (str, optional): Direction of increasing energy scale.
1317

1318
                - **'kinetic'**: increasing energy with decreasing TOF.
1319
                - **'binding'**: increasing energy with increasing TOF.
1320

1321
                Defaults to config["energy"]["energy_scale"]
1322
            verbose (bool, optional): Option to print out diagnostic information.
1323
                Defaults to config["core"]["verbose"].
1324
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1325
        """
1326
        if verbose is None:
1✔
1327
            verbose = self.verbose
1✔
1328

1329
        if method is None:
1✔
1330
            method = self._config["energy"]["calibration_method"]
1✔
1331

1332
        if energy_scale is None:
1✔
1333
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1334

1335
        self.ec.calibrate(
1✔
1336
            ref_energy=ref_energy,
1337
            method=method,
1338
            energy_scale=energy_scale,
1339
            verbose=verbose,
1340
            **kwds,
1341
        )
1342
        if verbose:
1✔
1343
            print("Quality of Calibration:")
1✔
1344
            self.ec.view(
1✔
1345
                traces=self.ec.traces_normed,
1346
                xaxis=self.ec.calibration["axis"],
1347
                align=True,
1348
                energy_scale=energy_scale,
1349
                backend="bokeh",
1350
            )
1351
            print("E/TOF relationship:")
1✔
1352
            self.ec.view(
1✔
1353
                traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1354
                xaxis=self.ec.tof,
1355
                backend="matplotlib",
1356
                show_legend=False,
1357
            )
1358
            if energy_scale == "kinetic":
1✔
1359
                plt.scatter(
1✔
1360
                    self.ec.peaks[:, 0],
1361
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1362
                    s=50,
1363
                    c="k",
1364
                )
1365
            elif energy_scale == "binding":
1✔
1366
                plt.scatter(
1✔
1367
                    self.ec.peaks[:, 0],
1368
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1369
                    s=50,
1370
                    c="k",
1371
                )
1372
            else:
1373
                raise ValueError(
×
1374
                    'energy_scale needs to be either "binding" or "kinetic"',
1375
                    f", got {energy_scale}.",
1376
                )
1377
            plt.xlabel("Time-of-flight", fontsize=15)
1✔
1378
            plt.ylabel("Energy (eV)", fontsize=15)
1✔
1379
            plt.show()
1✔
1380

1381
    # 3a. Save energy calibration parameters to config file.
1382
    def save_energy_calibration(
1✔
1383
        self,
1384
        filename: str = None,
1385
        overwrite: bool = False,
1386
    ):
1387
        """Save the generated energy calibration parameters to the folder config file.
1388

1389
        Args:
1390
            filename (str, optional): Filename of the config dictionary to save to.
1391
                Defaults to "sed_config.yaml" in the current folder.
1392
            overwrite (bool, optional): Option to overwrite the present dictionary.
1393
                Defaults to False.
1394
        """
1395
        if filename is None:
1✔
1396
            filename = "sed_config.yaml"
×
1397
        if len(self.ec.calibration) == 0:
1✔
1398
            raise ValueError("No energy calibration parameters to save!")
×
1399
        calibration = {}
1✔
1400
        for key, value in self.ec.calibration.items():
1✔
1401
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1402
                continue
1✔
1403
            if key == "energy_scale":
1✔
1404
                calibration[key] = value
1✔
1405
            elif key == "coeffs":
1✔
1406
                calibration[key] = [float(i) for i in value]
1✔
1407
            else:
1408
                calibration[key] = float(value)
1✔
1409

1410
        if "creation_date" not in calibration:
1✔
1411
            calibration["creation_date"] = datetime.now().timestamp()
×
1412

1413
        config = {"energy": {"calibration": calibration}}
1✔
1414
        save_config(config, filename, overwrite)
1✔
1415
        print(f'Saved energy calibration parameters to "{filename}".')
1✔
1416

1417
    # 4. Apply energy calibration to the dataframe
1418
    def append_energy_axis(
1✔
1419
        self,
1420
        calibration: dict = None,
1421
        bias_voltage: float = None,
1422
        preview: bool = False,
1423
        verbose: bool = None,
1424
        **kwds,
1425
    ):
1426
        """4. step of the energy calibration workflow: Apply the calibration function
1427
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1428
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1429
        can be provided.
1430

1431
        Args:
1432
            calibration (dict, optional): Calibration dict containing calibration
1433
                parameters. Overrides calibration from class or config.
1434
                Defaults to None.
1435
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1436
                the bias voltage is being read from the dataframe. If it is not found there,
1437
                a warning is printed and the calibrated data might have an offset.
1438
            preview (bool): Option to preview the first elements of the data frame.
1439
            verbose (bool, optional): Option to print out diagnostic information.
1440
                Defaults to config["core"]["verbose"].
1441
            **kwds:
1442
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1443
        """
1444
        if verbose is None:
1✔
1445
            verbose = self.verbose
1✔
1446

1447
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1448

1449
        if self._dataframe is not None:
1✔
1450
            if verbose:
1✔
1451
                print("Adding energy column to dataframe:")
1✔
1452
            df, metadata = self.ec.append_energy_axis(
1✔
1453
                df=self._dataframe,
1454
                calibration=calibration,
1455
                bias_voltage=bias_voltage,
1456
                verbose=verbose,
1457
                **kwds,
1458
            )
1459
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1460
                tdf, _ = self.ec.append_energy_axis(
1✔
1461
                    df=self._timed_dataframe,
1462
                    calibration=calibration,
1463
                    bias_voltage=bias_voltage,
1464
                    verbose=False,
1465
                    **kwds,
1466
                )
1467

1468
            # Add Metadata
1469
            self._attributes.add(
1✔
1470
                metadata,
1471
                "energy_calibration",
1472
                duplicate_policy="merge",
1473
            )
1474
            self._dataframe = df
1✔
1475
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1476
                self._timed_dataframe = tdf
1✔
1477

1478
        else:
1479
            raise ValueError("No dataframe loaded!")
×
1480
        if preview:
1✔
1481
            print(self._dataframe.head(10))
×
1482
        else:
1483
            if verbose:
1✔
1484
                print(self._dataframe)
1✔
1485

1486
    def add_energy_offset(
1✔
1487
        self,
1488
        constant: float = None,
1489
        columns: str | Sequence[str] = None,
1490
        weights: float | Sequence[float] = None,
1491
        reductions: str | Sequence[str] = None,
1492
        preserve_mean: bool | Sequence[bool] = None,
1493
        preview: bool = False,
1494
        verbose: bool = None,
1495
    ) -> None:
1496
        """Shift the energy axis of the dataframe by a given amount.
1497

1498
        Args:
1499
            constant (float, optional): The constant to shift the energy axis by.
1500
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1501
            weights (float | Sequence[float], optional): weights to apply to the columns.
1502
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1503
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1504
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1505
                case the function is applied to the column to generate a single value for the whole
1506
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1507
                Currently only "mean" is supported.
1508
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1509
                column before applying the shift. Defaults to False.
1510
            preview (bool, optional): Option to preview the first elements of the data frame.
1511
                Defaults to False.
1512
            verbose (bool, optional): Option to print out diagnostic information.
1513
                Defaults to config["core"]["verbose"].
1514

1515
        Raises:
1516
            ValueError: If the energy column is not in the dataframe.
1517
        """
1518
        if verbose is None:
1✔
1519
            verbose = self.verbose
1✔
1520

1521
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1522
        if energy_column not in self._dataframe.columns:
1✔
1523
            raise ValueError(
1✔
1524
                f"Energy column {energy_column} not found in dataframe! "
1525
                "Run `append_energy_axis()` first.",
1526
            )
1527
        if self.dataframe is not None:
1✔
1528
            if verbose:
1✔
1529
                print("Adding energy offset to dataframe:")
1✔
1530
            df, metadata = self.ec.add_offsets(
1✔
1531
                df=self._dataframe,
1532
                constant=constant,
1533
                columns=columns,
1534
                energy_column=energy_column,
1535
                weights=weights,
1536
                reductions=reductions,
1537
                preserve_mean=preserve_mean,
1538
                verbose=verbose,
1539
            )
1540
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1541
                tdf, _ = self.ec.add_offsets(
1✔
1542
                    df=self._timed_dataframe,
1543
                    constant=constant,
1544
                    columns=columns,
1545
                    energy_column=energy_column,
1546
                    weights=weights,
1547
                    reductions=reductions,
1548
                    preserve_mean=preserve_mean,
1549
                )
1550

1551
            self._attributes.add(
1✔
1552
                metadata,
1553
                "add_energy_offset",
1554
                # TODO: allow only appending when no offset along this column(s) was applied
1555
                # TODO: clear memory of modifications if the energy axis is recalculated
1556
                duplicate_policy="append",
1557
            )
1558
            self._dataframe = df
1✔
1559
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1560
                self._timed_dataframe = tdf
1✔
1561
        else:
1562
            raise ValueError("No dataframe loaded!")
×
1563
        if preview:
1✔
1564
            print(self._dataframe.head(10))
×
1565
        elif verbose:
1✔
1566
            print(self._dataframe)
1✔
1567

1568
    def save_energy_offset(
1✔
1569
        self,
1570
        filename: str = None,
1571
        overwrite: bool = False,
1572
    ):
1573
        """Save the generated energy calibration parameters to the folder config file.
1574

1575
        Args:
1576
            filename (str, optional): Filename of the config dictionary to save to.
1577
                Defaults to "sed_config.yaml" in the current folder.
1578
            overwrite (bool, optional): Option to overwrite the present dictionary.
1579
                Defaults to False.
1580
        """
1581
        if filename is None:
×
1582
            filename = "sed_config.yaml"
×
1583
        if len(self.ec.offsets) == 0:
×
1584
            raise ValueError("No energy offset parameters to save!")
×
1585

1586
        if "creation_date" not in self.ec.offsets.keys():
×
1587
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
×
1588

1589
        config = {"energy": {"offsets": self.ec.offsets}}
×
1590
        save_config(config, filename, overwrite)
×
1591
        print(f'Saved energy offset parameters to "{filename}".')
×
1592

1593
    def append_tof_ns_axis(
1✔
1594
        self,
1595
        preview: bool = False,
1596
        verbose: bool = None,
1597
        **kwds,
1598
    ):
1599
        """Convert time-of-flight channel steps to nanoseconds.
1600

1601
        Args:
1602
            tof_ns_column (str, optional): Name of the generated column containing the
1603
                time-of-flight in nanosecond.
1604
                Defaults to config["dataframe"]["tof_ns_column"].
1605
            preview (bool, optional): Option to preview the first elements of the data frame.
1606
                Defaults to False.
1607
            verbose (bool, optional): Option to print out diagnostic information.
1608
                Defaults to config["core"]["verbose"].
1609
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1610

1611
        """
1612
        if verbose is None:
1✔
1613
            verbose = self.verbose
1✔
1614

1615
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1616

1617
        if self._dataframe is not None:
1✔
1618
            if verbose:
1✔
1619
                print("Adding time-of-flight column in nanoseconds to dataframe:")
1✔
1620
            # TODO assert order of execution through metadata
1621

1622
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1623
                df=self._dataframe,
1624
                **kwds,
1625
            )
1626
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1627
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1628
                    df=self._timed_dataframe,
1629
                    **kwds,
1630
                )
1631

1632
            self._attributes.add(
1✔
1633
                metadata,
1634
                "tof_ns_conversion",
1635
                duplicate_policy="overwrite",
1636
            )
1637
            self._dataframe = df
1✔
1638
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1639
                self._timed_dataframe = tdf
1✔
1640
        else:
1641
            raise ValueError("No dataframe loaded!")
×
1642
        if preview:
1✔
1643
            print(self._dataframe.head(10))
×
1644
        else:
1645
            if verbose:
1✔
1646
                print(self._dataframe)
1✔
1647

1648
    def align_dld_sectors(
1✔
1649
        self,
1650
        sector_delays: np.ndarray = None,
1651
        preview: bool = False,
1652
        verbose: bool = None,
1653
        **kwds,
1654
    ):
1655
        """Align the 8s sectors of the HEXTOF endstation.
1656

1657
        Args:
1658
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1659
                config["dataframe"]["sector_delays"].
1660
            preview (bool, optional): Option to preview the first elements of the data frame.
1661
                Defaults to False.
1662
            verbose (bool, optional): Option to print out diagnostic information.
1663
                Defaults to config["core"]["verbose"].
1664
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1665
        """
1666
        if verbose is None:
1✔
1667
            verbose = self.verbose
1✔
1668

1669
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1670

1671
        if self._dataframe is not None:
1✔
1672
            if verbose:
1✔
1673
                print("Aligning 8s sectors of dataframe")
1✔
1674
            # TODO assert order of execution through metadata
1675

1676
            df, metadata = self.ec.align_dld_sectors(
1✔
1677
                df=self._dataframe,
1678
                sector_delays=sector_delays,
1679
                **kwds,
1680
            )
1681
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1682
                tdf, _ = self.ec.align_dld_sectors(
×
1683
                    df=self._timed_dataframe,
1684
                    sector_delays=sector_delays,
1685
                    **kwds,
1686
                )
1687

1688
            self._attributes.add(
1✔
1689
                metadata,
1690
                "dld_sector_alignment",
1691
                duplicate_policy="raise",
1692
            )
1693
            self._dataframe = df
1✔
1694
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1695
                self._timed_dataframe = tdf
×
1696
        else:
1697
            raise ValueError("No dataframe loaded!")
×
1698
        if preview:
1✔
1699
            print(self._dataframe.head(10))
×
1700
        else:
1701
            if verbose:
1✔
1702
                print(self._dataframe)
1✔
1703

1704
    # Delay calibration function
1705
    def calibrate_delay_axis(
1✔
1706
        self,
1707
        delay_range: tuple[float, float] = None,
1708
        datafile: str = None,
1709
        preview: bool = False,
1710
        verbose: bool = None,
1711
        **kwds,
1712
    ):
1713
        """Append delay column to dataframe. Either provide delay ranges, or read
1714
        them from a file.
1715

1716
        Args:
1717
            delay_range (tuple[float, float], optional): The scanned delay range in
1718
                picoseconds. Defaults to None.
1719
            datafile (str, optional): The file from which to read the delay ranges.
1720
                Defaults to None.
1721
            preview (bool, optional): Option to preview the first elements of the data frame.
1722
                Defaults to False.
1723
            verbose (bool, optional): Option to print out diagnostic information.
1724
                Defaults to config["core"]["verbose"].
1725
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1726
        """
1727
        if verbose is None:
1✔
1728
            verbose = self.verbose
1✔
1729

1730
        adc_column = self._config["dataframe"]["adc_column"]
1✔
1731
        if adc_column not in self._dataframe.columns:
1✔
1732
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1733

1734
        if self._dataframe is not None:
1✔
1735
            if verbose:
1✔
1736
                print("Adding delay column to dataframe:")
1✔
1737

1738
            if delay_range is None and datafile is None:
1✔
1739
                if len(self.dc.calibration) == 0:
1✔
1740
                    try:
1✔
1741
                        datafile = self._files[0]
1✔
NEW
1742
                    except IndexError as exc:
×
NEW
1743
                        raise IndexError(
×
1744
                            "No datafile available, specify either 'datafile' or 'delay_range'",
1745
                        ) from exc
1746

1747
            df, metadata = self.dc.append_delay_axis(
1✔
1748
                self._dataframe,
1749
                delay_range=delay_range,
1750
                datafile=datafile,
1751
                verbose=verbose,
1752
                **kwds,
1753
            )
1754
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1755
                tdf, _ = self.dc.append_delay_axis(
1✔
1756
                    self._timed_dataframe,
1757
                    delay_range=delay_range,
1758
                    datafile=datafile,
1759
                    verbose=False,
1760
                    **kwds,
1761
                )
1762

1763
            # Add Metadata
1764
            self._attributes.add(
1✔
1765
                metadata,
1766
                "delay_calibration",
1767
                duplicate_policy="overwrite",
1768
            )
1769
            self._dataframe = df
1✔
1770
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1771
                self._timed_dataframe = tdf
1✔
1772
        else:
1773
            raise ValueError("No dataframe loaded!")
×
1774
        if preview:
1✔
1775
            print(self._dataframe.head(10))
1✔
1776
        else:
1777
            if self.verbose:
1✔
1778
                print(self._dataframe)
1✔
1779

1780
    def save_delay_calibration(
1✔
1781
        self,
1782
        filename: str = None,
1783
        overwrite: bool = False,
1784
    ) -> None:
1785
        """Save the generated delay calibration parameters to the folder config file.
1786

1787
        Args:
1788
            filename (str, optional): Filename of the config dictionary to save to.
1789
                Defaults to "sed_config.yaml" in the current folder.
1790
            overwrite (bool, optional): Option to overwrite the present dictionary.
1791
                Defaults to False.
1792
        """
1793
        if filename is None:
1✔
1794
            filename = "sed_config.yaml"
×
1795

1796
        if len(self.dc.calibration) == 0:
1✔
1797
            raise ValueError("No delay calibration parameters to save!")
×
1798
        calibration = {}
1✔
1799
        for key, value in self.dc.calibration.items():
1✔
1800
            if key == "datafile":
1✔
1801
                calibration[key] = value
1✔
1802
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1803
                calibration[key] = [float(i) for i in value]
1✔
1804
            else:
1805
                calibration[key] = float(value)
1✔
1806

1807
        if "creation_date" not in calibration:
1✔
1808
            calibration["creation_date"] = datetime.now().timestamp()
×
1809

1810
        config = {
1✔
1811
            "delay": {
1812
                "calibration": calibration,
1813
            },
1814
        }
1815
        save_config(config, filename, overwrite)
1✔
1816

1817
    def add_delay_offset(
1✔
1818
        self,
1819
        constant: float = None,
1820
        flip_delay_axis: bool = None,
1821
        columns: str | Sequence[str] = None,
1822
        weights: float | Sequence[float] = 1.0,
1823
        reductions: str | Sequence[str] = None,
1824
        preserve_mean: bool | Sequence[bool] = False,
1825
        preview: bool = False,
1826
        verbose: bool = None,
1827
    ) -> None:
1828
        """Shift the delay axis of the dataframe by a constant or other columns.
1829

1830
        Args:
1831
            constant (float, optional): The constant to shift the delay axis by.
1832
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1833
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1834
            weights (float | Sequence[float], optional): weights to apply to the columns.
1835
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1836
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1837
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1838
                case the function is applied to the column to generate a single value for the whole
1839
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1840
                Currently only "mean" is supported.
1841
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1842
                column before applying the shift. Defaults to False.
1843
            preview (bool, optional): Option to preview the first elements of the data frame.
1844
                Defaults to False.
1845
            verbose (bool, optional): Option to print out diagnostic information.
1846
                Defaults to config["core"]["verbose"].
1847

1848
        Raises:
1849
            ValueError: If the delay column is not in the dataframe.
1850
        """
1851
        if verbose is None:
1✔
1852
            verbose = self.verbose
1✔
1853

1854
        delay_column = self._config["dataframe"]["delay_column"]
1✔
1855
        if delay_column not in self._dataframe.columns:
1✔
1856
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1857

1858
        if self.dataframe is not None:
1✔
1859
            if verbose:
1✔
1860
                print("Adding delay offset to dataframe:")
1✔
1861
            df, metadata = self.dc.add_offsets(
1✔
1862
                df=self._dataframe,
1863
                constant=constant,
1864
                flip_delay_axis=flip_delay_axis,
1865
                columns=columns,
1866
                delay_column=delay_column,
1867
                weights=weights,
1868
                reductions=reductions,
1869
                preserve_mean=preserve_mean,
1870
                verbose=verbose,
1871
            )
1872
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1873
                tdf, _ = self.dc.add_offsets(
1✔
1874
                    df=self._timed_dataframe,
1875
                    constant=constant,
1876
                    flip_delay_axis=flip_delay_axis,
1877
                    columns=columns,
1878
                    delay_column=delay_column,
1879
                    weights=weights,
1880
                    reductions=reductions,
1881
                    preserve_mean=preserve_mean,
1882
                    verbose=False,
1883
                )
1884

1885
            self._attributes.add(
1✔
1886
                metadata,
1887
                "delay_offset",
1888
                duplicate_policy="append",
1889
            )
1890
            self._dataframe = df
1✔
1891
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1892
                self._timed_dataframe = tdf
1✔
1893
        else:
1894
            raise ValueError("No dataframe loaded!")
×
1895
        if preview:
1✔
1896
            print(self._dataframe.head(10))
1✔
1897
        else:
1898
            if verbose:
1✔
1899
                print(self._dataframe)
1✔
1900

1901
    def save_delay_offsets(
1✔
1902
        self,
1903
        filename: str = None,
1904
        overwrite: bool = False,
1905
    ) -> None:
1906
        """Save the generated delay calibration parameters to the folder config file.
1907

1908
        Args:
1909
            filename (str, optional): Filename of the config dictionary to save to.
1910
                Defaults to "sed_config.yaml" in the current folder.
1911
            overwrite (bool, optional): Option to overwrite the present dictionary.
1912
                Defaults to False.
1913
        """
1914
        if filename is None:
1✔
1915
            filename = "sed_config.yaml"
×
1916
        if len(self.dc.offsets) == 0:
1✔
1917
            raise ValueError("No delay offset parameters to save!")
×
1918

1919
        if "creation_date" not in self.ec.offsets.keys():
1✔
1920
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
1✔
1921

1922
        config = {
1✔
1923
            "delay": {
1924
                "offsets": self.dc.offsets,
1925
            },
1926
        }
1927
        save_config(config, filename, overwrite)
1✔
1928
        print(f'Saved delay offset parameters to "{filename}".')
1✔
1929

1930
    def save_workflow_params(
1✔
1931
        self,
1932
        filename: str = None,
1933
        overwrite: bool = False,
1934
    ) -> None:
1935
        """run all save calibration parameter methods
1936

1937
        Args:
1938
            filename (str, optional): Filename of the config dictionary to save to.
1939
                Defaults to "sed_config.yaml" in the current folder.
1940
            overwrite (bool, optional): Option to overwrite the present dictionary.
1941
                Defaults to False.
1942
        """
1943
        for method in [
×
1944
            self.save_splinewarp,
1945
            self.save_transformations,
1946
            self.save_momentum_calibration,
1947
            self.save_energy_correction,
1948
            self.save_energy_calibration,
1949
            self.save_energy_offset,
1950
            self.save_delay_calibration,
1951
            self.save_delay_offsets,
1952
        ]:
1953
            try:
×
1954
                method(filename, overwrite)
×
1955
            except (ValueError, AttributeError, KeyError):
×
1956
                pass
×
1957

1958
    def add_jitter(
1✔
1959
        self,
1960
        cols: list[str] = None,
1961
        amps: float | Sequence[float] = None,
1962
        **kwds,
1963
    ):
1964
        """Add jitter to the selected dataframe columns.
1965

1966
        Args:
1967
            cols (list[str], optional): The columns onto which to apply jitter.
1968
                Defaults to config["dataframe"]["jitter_cols"].
1969
            amps (float | Sequence[float], optional): Amplitude scalings for the
1970
                jittering noise. If one number is given, the same is used for all axes.
1971
                For uniform noise (default) it will cover the interval [-amp, +amp].
1972
                Defaults to config["dataframe"]["jitter_amps"].
1973
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1974
        """
1975
        if cols is None:
1✔
1976
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1977
        for loc, col in enumerate(cols):
1✔
1978
            if col.startswith("@"):
1✔
1979
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1980

1981
        if amps is None:
1✔
1982
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1983

1984
        self._dataframe = self._dataframe.map_partitions(
1✔
1985
            apply_jitter,
1986
            cols=cols,
1987
            cols_jittered=cols,
1988
            amps=amps,
1989
            **kwds,
1990
        )
1991
        if self._timed_dataframe is not None:
1✔
1992
            cols_timed = cols.copy()
1✔
1993
            for col in cols:
1✔
1994
                if col not in self._timed_dataframe.columns:
1✔
1995
                    cols_timed.remove(col)
×
1996

1997
            if cols_timed:
1✔
1998
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1999
                    apply_jitter,
2000
                    cols=cols_timed,
2001
                    cols_jittered=cols_timed,
2002
                )
2003
        metadata = []
1✔
2004
        for col in cols:
1✔
2005
            metadata.append(col)
1✔
2006
        # TODO: allow only appending if columns are not jittered yet
2007
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
2008

2009
    def add_time_stamped_data(
1✔
2010
        self,
2011
        dest_column: str,
2012
        time_stamps: np.ndarray = None,
2013
        data: np.ndarray = None,
2014
        archiver_channel: str = None,
2015
        **kwds,
2016
    ):
2017
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
2018
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
2019
        an EPICS archiver instance.
2020

2021
        Args:
2022
            dest_column (str): destination column name
2023
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
2024
                time stamps are retrieved from the epics archiver
2025
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
2026
                If omitted, data are retrieved from the epics archiver.
2027
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
2028
                Either this or data and time_stamps have to be present.
2029
            **kwds:
2030

2031
                - **time_stamp_column**: Dataframe column containing time-stamp data
2032

2033
                Additional keyword arguments passed to ``add_time_stamped_data``.
2034
        """
2035
        time_stamp_column = kwds.pop(
1✔
2036
            "time_stamp_column",
2037
            self._config["dataframe"].get("time_stamp_alias", ""),
2038
        )
2039

2040
        if time_stamps is None and data is None:
1✔
2041
            if archiver_channel is None:
×
2042
                raise ValueError(
×
2043
                    "Either archiver_channel or both time_stamps and data have to be present!",
2044
                )
2045
            if self.loader.__name__ != "mpes":
×
2046
                raise NotImplementedError(
×
2047
                    "This function is currently only implemented for the mpes loader!",
2048
                )
2049
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2050
            # get channel data with +-5 seconds safety margin
2051
            time_stamps, data = get_archiver_data(
×
2052
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2053
                archiver_channel=archiver_channel,
2054
                ts_from=ts_from - 5,
2055
                ts_to=ts_to + 5,
2056
            )
2057

2058
        self._dataframe = add_time_stamped_data(
1✔
2059
            self._dataframe,
2060
            time_stamps=time_stamps,
2061
            data=data,
2062
            dest_column=dest_column,
2063
            time_stamp_column=time_stamp_column,
2064
            **kwds,
2065
        )
2066
        if self._timed_dataframe is not None:
1✔
2067
            if time_stamp_column in self._timed_dataframe:
1✔
2068
                self._timed_dataframe = add_time_stamped_data(
1✔
2069
                    self._timed_dataframe,
2070
                    time_stamps=time_stamps,
2071
                    data=data,
2072
                    dest_column=dest_column,
2073
                    time_stamp_column=time_stamp_column,
2074
                    **kwds,
2075
                )
2076
        metadata: list[Any] = []
1✔
2077
        metadata.append(dest_column)
1✔
2078
        metadata.append(time_stamps)
1✔
2079
        metadata.append(data)
1✔
2080
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2081

2082
    def pre_binning(
1✔
2083
        self,
2084
        df_partitions: int | Sequence[int] = 100,
2085
        axes: list[str] = None,
2086
        bins: list[int] = None,
2087
        ranges: Sequence[tuple[float, float]] = None,
2088
        **kwds,
2089
    ) -> xr.DataArray:
2090
        """Function to do an initial binning of the dataframe loaded to the class.
2091

2092
        Args:
2093
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2094
                use for the initial binning. Defaults to 100.
2095
            axes (list[str], optional): Axes to bin.
2096
                Defaults to config["momentum"]["axes"].
2097
            bins (list[int], optional): Bin numbers to use for binning.
2098
                Defaults to config["momentum"]["bins"].
2099
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2100
                Defaults to config["momentum"]["ranges"].
2101
            **kwds: Keyword argument passed to ``compute``.
2102

2103
        Returns:
2104
            xr.DataArray: pre-binned data-array.
2105
        """
2106
        if axes is None:
1✔
2107
            axes = self._config["momentum"]["axes"]
1✔
2108
        for loc, axis in enumerate(axes):
1✔
2109
            if axis.startswith("@"):
1✔
2110
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2111

2112
        if bins is None:
1✔
2113
            bins = self._config["momentum"]["bins"]
1✔
2114
        if ranges is None:
1✔
2115
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2116
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
2117
                self._config["dataframe"]["tof_binning"] - 1
2118
            )
2119
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2120

2121
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2122

2123
        return self.compute(
1✔
2124
            bins=bins,
2125
            axes=axes,
2126
            ranges=ranges,
2127
            df_partitions=df_partitions,
2128
            **kwds,
2129
        )
2130

2131
    def compute(
1✔
2132
        self,
2133
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2134
        axes: str | Sequence[str] = None,
2135
        ranges: Sequence[tuple[float, float]] = None,
2136
        normalize_to_acquisition_time: bool | str = False,
2137
        **kwds,
2138
    ) -> xr.DataArray:
2139
        """Compute the histogram along the given dimensions.
2140

2141
        Args:
2142
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2143
                Definition of the bins. Can be any of the following cases:
2144

2145
                - an integer describing the number of bins in on all dimensions
2146
                - a tuple of 3 numbers describing start, end and step of the binning
2147
                  range
2148
                - a np.arrays defining the binning edges
2149
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2150
                - a dictionary made of the axes as keys and any of the above as values.
2151

2152
                This takes priority over the axes and range arguments. Defaults to 100.
2153
            axes (str | Sequence[str], optional): The names of the axes (columns)
2154
                on which to calculate the histogram. The order will be the order of the
2155
                dimensions in the resulting array. Defaults to None.
2156
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2157
                the start and end point of the binning range. Defaults to None.
2158
            normalize_to_acquisition_time (bool | str): Option to normalize the
2159
                result to the acquisition time. If a "slow" axis was scanned, providing
2160
                the name of the scanned axis will compute and apply the corresponding
2161
                normalization histogram. Defaults to False.
2162
            **kwds: Keyword arguments:
2163

2164
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2165
                  ``bin_dataframe`` for details. Defaults to
2166
                  config["binning"]["hist_mode"].
2167
                - **mode**: Defines how the results from each partition are combined.
2168
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2169
                  Defaults to config["binning"]["mode"].
2170
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2171
                  config["binning"]["pbar"].
2172
                - **n_cores**: Number of CPU cores to use for parallelization.
2173
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
2174
                - **threads_per_worker**: Limit the number of threads that
2175
                  multiprocessing can spawn per binning thread. Defaults to
2176
                  config["binning"]["threads_per_worker"].
2177
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2178
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2179
                  config["binning"]["threadpool_API"].
2180
                - **df_partitions**: A sequence of dataframe partitions, or the
2181
                  number of the dataframe partitions to use. Defaults to all partitions.
2182
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2183
                  "upper_bound" to apply as filter to the dataframe before binning. The
2184
                  dataframe in the class remains unmodified by this.
2185

2186
                Additional kwds are passed to ``bin_dataframe``.
2187

2188
        Raises:
2189
            AssertError: Rises when no dataframe has been loaded.
2190

2191
        Returns:
2192
            xr.DataArray: The result of the n-dimensional binning represented in an
2193
            xarray object, combining the data with the axes.
2194
        """
2195
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2196

2197
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2198
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2199
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2200
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
2201
        threads_per_worker = kwds.pop(
1✔
2202
            "threads_per_worker",
2203
            self._config["binning"]["threads_per_worker"],
2204
        )
2205
        threadpool_api = kwds.pop(
1✔
2206
            "threadpool_API",
2207
            self._config["binning"]["threadpool_API"],
2208
        )
2209
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2210
        if isinstance(df_partitions, int):
1✔
2211
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2212
        if df_partitions is not None:
1✔
2213
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2214
        else:
2215
            dataframe = self._dataframe
1✔
2216

2217
        filter_params = kwds.pop("filter", None)
1✔
2218
        if filter_params is not None:
1✔
2219
            try:
1✔
2220
                for param in filter_params:
1✔
2221
                    if "col" not in param:
1✔
2222
                        raise ValueError(
1✔
2223
                            "'col' needs to be defined for each filter entry! ",
2224
                            f"Not present in {param}.",
2225
                        )
2226
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2227
                    dataframe = apply_filter(dataframe, **param)
1✔
2228
            except AssertionError as exc:
1✔
2229
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2230
                raise ValueError(
1✔
2231
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2232
                    f"Parameters {invalid_keys} not valid in {param}.",
2233
                ) from exc
2234

2235
        self._binned = bin_dataframe(
1✔
2236
            df=dataframe,
2237
            bins=bins,
2238
            axes=axes,
2239
            ranges=ranges,
2240
            hist_mode=hist_mode,
2241
            mode=mode,
2242
            pbar=pbar,
2243
            n_cores=num_cores,
2244
            threads_per_worker=threads_per_worker,
2245
            threadpool_api=threadpool_api,
2246
            **kwds,
2247
        )
2248

2249
        for dim in self._binned.dims:
1✔
2250
            try:
1✔
2251
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2252
            except KeyError:
1✔
2253
                pass
1✔
2254

2255
        self._binned.attrs["units"] = "counts"
1✔
2256
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2257
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2258

2259
        if normalize_to_acquisition_time:
1✔
2260
            if isinstance(normalize_to_acquisition_time, str):
1✔
2261
                axis = normalize_to_acquisition_time
1✔
2262
                print(
1✔
2263
                    f"Calculate normalization histogram for axis '{axis}'...",
2264
                )
2265
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2266
                    axis=axis,
2267
                    df_partitions=df_partitions,
2268
                )
2269
                # if the axes are named correctly, xarray figures out the normalization correctly
2270
                self._normalized = self._binned / self._normalization_histogram
1✔
2271
                self._attributes.add(
1✔
2272
                    self._normalization_histogram.values,
2273
                    name="normalization_histogram",
2274
                    duplicate_policy="overwrite",
2275
                )
2276
            else:
2277
                acquisition_time = self.loader.get_elapsed_time(
×
2278
                    fids=df_partitions,
2279
                )
2280
                if acquisition_time > 0:
×
2281
                    self._normalized = self._binned / acquisition_time
×
2282
                self._attributes.add(
×
2283
                    acquisition_time,
2284
                    name="normalization_histogram",
2285
                    duplicate_policy="overwrite",
2286
                )
2287

2288
            self._normalized.attrs["units"] = "counts/second"
1✔
2289
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2290
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2291

2292
            return self._normalized
1✔
2293

2294
        return self._binned
1✔
2295

2296
    def get_normalization_histogram(
1✔
2297
        self,
2298
        axis: str = "delay",
2299
        use_time_stamps: bool = False,
2300
        **kwds,
2301
    ) -> xr.DataArray:
2302
        """Generates a normalization histogram from the timed dataframe. Optionally,
2303
        use the TimeStamps column instead.
2304

2305
        Args:
2306
            axis (str, optional): The axis for which to compute histogram.
2307
                Defaults to "delay".
2308
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2309
                dataframe, rather than the timed dataframe. Defaults to False.
2310
            **kwds: Keyword arguments:
2311

2312
                - **df_partitions**: A sequence of dataframe partitions, or the
2313
                  number of the dataframe partitions to use. Defaults to all partitions.
2314

2315
        Raises:
2316
            ValueError: Raised if no data are binned.
2317
            ValueError: Raised if 'axis' not in binned coordinates.
2318
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2319
                in Dataframe.
2320

2321
        Returns:
2322
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2323
            per bin).
2324
        """
2325

2326
        if self._binned is None:
1✔
2327
            raise ValueError("Need to bin data first!")
1✔
2328
        if axis not in self._binned.coords:
1✔
2329
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2330

2331
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2332

2333
        if len(kwds) > 0:
1✔
NEW
2334
            raise TypeError(
×
2335
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2336
            )
2337

2338
        if isinstance(df_partitions, int):
1✔
2339
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2340
        if use_time_stamps or self._timed_dataframe is None:
1✔
2341
            if df_partitions is not None:
1✔
2342
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2343
                    self._dataframe.partitions[df_partitions],
2344
                    axis,
2345
                    self._binned.coords[axis].values,
2346
                    self._config["dataframe"]["time_stamp_alias"],
2347
                )
2348
            else:
2349
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2350
                    self._dataframe,
2351
                    axis,
2352
                    self._binned.coords[axis].values,
2353
                    self._config["dataframe"]["time_stamp_alias"],
2354
                )
2355
        else:
2356
            if df_partitions is not None:
1✔
2357
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2358
                    self._timed_dataframe.partitions[df_partitions],
2359
                    axis,
2360
                    self._binned.coords[axis].values,
2361
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2362
                )
2363
            else:
2364
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2365
                    self._timed_dataframe,
2366
                    axis,
2367
                    self._binned.coords[axis].values,
2368
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2369
                )
2370

2371
        return self._normalization_histogram
1✔
2372

2373
    def view_event_histogram(
1✔
2374
        self,
2375
        dfpid: int,
2376
        ncol: int = 2,
2377
        bins: Sequence[int] = None,
2378
        axes: Sequence[str] = None,
2379
        ranges: Sequence[tuple[float, float]] = None,
2380
        backend: str = "bokeh",
2381
        legend: bool = True,
2382
        histkwds: dict = None,
2383
        legkwds: dict = None,
2384
        **kwds,
2385
    ):
2386
        """Plot individual histograms of specified dimensions (axes) from a substituent
2387
        dataframe partition.
2388

2389
        Args:
2390
            dfpid (int): Number of the data frame partition to look at.
2391
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2392
            bins (Sequence[int], optional): Number of bins to use for the specified
2393
                axes. Defaults to config["histogram"]["bins"].
2394
            axes (Sequence[str], optional): Names of the axes to display.
2395
                Defaults to config["histogram"]["axes"].
2396
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2397
                specified axes. Defaults to config["histogram"]["ranges"].
2398
            backend (str, optional): Backend of the plotting library
2399
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2400
            legend (bool, optional): Option to include a legend in the histogram plots.
2401
                Defaults to True.
2402
            histkwds (dict, optional): Keyword arguments for histograms
2403
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2404
            legkwds (dict, optional): Keyword arguments for legend
2405
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2406
            **kwds: Extra keyword arguments passed to
2407
                ``sed.diagnostics.grid_histogram()``.
2408

2409
        Raises:
2410
            TypeError: Raises when the input values are not of the correct type.
2411
        """
2412
        if bins is None:
1✔
2413
            bins = self._config["histogram"]["bins"]
1✔
2414
        if axes is None:
1✔
2415
            axes = self._config["histogram"]["axes"]
1✔
2416
        axes = list(axes)
1✔
2417
        for loc, axis in enumerate(axes):
1✔
2418
            if axis.startswith("@"):
1✔
2419
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2420
        if ranges is None:
1✔
2421
            ranges = list(self._config["histogram"]["ranges"])
1✔
2422
            for loc, axis in enumerate(axes):
1✔
2423
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2424
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
2425
                        self._config["dataframe"]["tof_binning"] - 1
2426
                    )
2427
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
2428
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
2429
                        self._config["dataframe"]["adc_binning"] - 1
2430
                    )
2431

2432
        input_types = map(type, [axes, bins, ranges])
1✔
2433
        allowed_types = [list, tuple]
1✔
2434

2435
        df = self._dataframe
1✔
2436

2437
        if not set(input_types).issubset(allowed_types):
1✔
2438
            raise TypeError(
×
2439
                "Inputs of axes, bins, ranges need to be list or tuple!",
2440
            )
2441

2442
        # Read out the values for the specified groups
2443
        group_dict_dd = {}
1✔
2444
        dfpart = df.get_partition(dfpid)
1✔
2445
        cols = dfpart.columns
1✔
2446
        for ax in axes:
1✔
2447
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2448
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2449

2450
        # Plot multiple histograms in a grid
2451
        grid_histogram(
1✔
2452
            group_dict,
2453
            ncol=ncol,
2454
            rvs=axes,
2455
            rvbins=bins,
2456
            rvranges=ranges,
2457
            backend=backend,
2458
            legend=legend,
2459
            histkwds=histkwds,
2460
            legkwds=legkwds,
2461
            **kwds,
2462
        )
2463

2464
    def save(
1✔
2465
        self,
2466
        faddr: str,
2467
        **kwds,
2468
    ):
2469
        """Saves the binned data to the provided path and filename.
2470

2471
        Args:
2472
            faddr (str): Path and name of the file to write. Its extension determines
2473
                the file type to write. Valid file types are:
2474

2475
                - "*.tiff", "*.tif": Saves a TIFF stack.
2476
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2477
                - "*.nxs", "*.nexus": Saves a NeXus file.
2478

2479
            **kwds: Keyword arguments, which are passed to the writer functions:
2480
                For TIFF writing:
2481

2482
                - **alias_dict**: Dictionary of dimension aliases to use.
2483

2484
                For HDF5 writing:
2485

2486
                - **mode**: hdf5 read/write mode. Defaults to "w".
2487

2488
                For NeXus:
2489

2490
                - **reader**: Name of the pynxtools reader to use.
2491
                  Defaults to config["nexus"]["reader"]
2492
                - **definition**: NeXus application definition to use for saving.
2493
                  Must be supported by the used ``reader``. Defaults to
2494
                  config["nexus"]["definition"]
2495
                - **input_files**: A list of input files to pass to the reader.
2496
                  Defaults to config["nexus"]["input_files"]
2497
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2498
                  to add to the list of files to pass to the reader.
2499
        """
2500
        if self._binned is None:
1✔
2501
            raise NameError("Need to bin data first!")
1✔
2502

2503
        if self._normalized is not None:
1✔
2504
            data = self._normalized
×
2505
        else:
2506
            data = self._binned
1✔
2507

2508
        extension = pathlib.Path(faddr).suffix
1✔
2509

2510
        if extension in (".tif", ".tiff"):
1✔
2511
            to_tiff(
1✔
2512
                data=data,
2513
                faddr=faddr,
2514
                **kwds,
2515
            )
2516
        elif extension in (".h5", ".hdf5"):
1✔
2517
            to_h5(
1✔
2518
                data=data,
2519
                faddr=faddr,
2520
                **kwds,
2521
            )
2522
        elif extension in (".nxs", ".nexus"):
1✔
2523
            try:
1✔
2524
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2525
                definition = kwds.pop(
1✔
2526
                    "definition",
2527
                    self._config["nexus"]["definition"],
2528
                )
2529
                input_files = kwds.pop(
1✔
2530
                    "input_files",
2531
                    self._config["nexus"]["input_files"],
2532
                )
2533
            except KeyError as exc:
×
2534
                raise ValueError(
×
2535
                    "The nexus reader, definition and input files need to be provide!",
2536
                ) from exc
2537

2538
            if isinstance(input_files, str):
1✔
2539
                input_files = [input_files]
1✔
2540

2541
            if "eln_data" in kwds:
1✔
2542
                input_files.append(kwds.pop("eln_data"))
×
2543

2544
            to_nexus(
1✔
2545
                data=data,
2546
                faddr=faddr,
2547
                reader=reader,
2548
                definition=definition,
2549
                input_files=input_files,
2550
                **kwds,
2551
            )
2552

2553
        else:
2554
            raise NotImplementedError(
1✔
2555
                f"Unrecognized file format: {extension}.",
2556
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc