• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 10064604320

23 Jul 2024 06:45PM UTC coverage: 92.688% (+0.8%) from 91.936%
10064604320

Pull #482

github

github-actions[bot]
Update benchmark targets
Pull Request #482: Update benchmark targets

1205 of 1250 new or added lines in 50 files covered. (96.4%)

2 existing lines in 2 files now uncovered.

7073 of 7631 relevant lines covered (92.69%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.56
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from datetime import datetime
1✔
9
from typing import Any
1✔
10
from typing import cast
1✔
11

12
import dask.dataframe as ddf
1✔
13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
import pandas as pd
1✔
16
import psutil
1✔
17
import xarray as xr
1✔
18

19
from sed.binning import bin_dataframe
1✔
20
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
22
from sed.calibrator import DelayCalibrator
1✔
23
from sed.calibrator import EnergyCalibrator
1✔
24
from sed.calibrator import MomentumCorrector
1✔
25
from sed.core.config import parse_config
1✔
26
from sed.core.config import save_config
1✔
27
from sed.core.dfops import add_time_stamped_data
1✔
28
from sed.core.dfops import apply_filter
1✔
29
from sed.core.dfops import apply_jitter
1✔
30
from sed.core.metadata import MetaHandler
1✔
31
from sed.diagnostics import grid_histogram
1✔
32
from sed.io import to_h5
1✔
33
from sed.io import to_nexus
1✔
34
from sed.io import to_tiff
1✔
35
from sed.loader import CopyTool
1✔
36
from sed.loader import get_loader
1✔
37
from sed.loader.mpes.loader import get_archiver_data
1✔
38
from sed.loader.mpes.loader import MpesLoader
1✔
39

40
N_CPU = psutil.cpu_count()
1✔
41

42

43
class SedProcessor:
1✔
44
    """Processor class of sed. Contains wrapper functions defining a work flow for data
45
    correction, calibration and binning.
46

47
    Args:
48
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
49
        config (dict | str, optional): Config dictionary or config file name.
50
            Defaults to None.
51
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
52
            into the class. Defaults to None.
53
        files (list[str], optional): List of files to pass to the loader defined in
54
            the config. Defaults to None.
55
        folder (str, optional): Folder containing files to pass to the loader
56
            defined in the config. Defaults to None.
57
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
58
            defined in the config. Defaults to None.
59
        collect_metadata (bool): Option to collect metadata from files.
60
            Defaults to False.
61
        verbose (bool, optional): Option to print out diagnostic information.
62
            Defaults to config["core"]["verbose"] or False.
63
        **kwds: Keyword arguments passed to the reader.
64
    """
65

66
    def __init__(
1✔
67
        self,
68
        metadata: dict = None,
69
        config: dict | str = None,
70
        dataframe: pd.DataFrame | ddf.DataFrame = None,
71
        files: list[str] = None,
72
        folder: str = None,
73
        runs: Sequence[str] = None,
74
        collect_metadata: bool = False,
75
        verbose: bool = None,
76
        **kwds,
77
    ):
78
        """Processor class of sed. Contains wrapper functions defining a work flow
79
        for data correction, calibration, and binning.
80

81
        Args:
82
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
83
            config (dict | str, optional): Config dictionary or config file name.
84
                Defaults to None.
85
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
86
                into the class. Defaults to None.
87
            files (list[str], optional): List of files to pass to the loader defined in
88
                the config. Defaults to None.
89
            folder (str, optional): Folder containing files to pass to the loader
90
                defined in the config. Defaults to None.
91
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
92
                defined in the config. Defaults to None.
93
            collect_metadata (bool, optional): Option to collect metadata from files.
94
                Defaults to False.
95
            verbose (bool, optional): Option to print out diagnostic information.
96
                Defaults to config["core"]["verbose"] or False.
97
            **kwds: Keyword arguments passed to parse_config and to the reader.
98
        """
99
        # split off config keywords
100
        config_kwds = {
1✔
101
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
102
        }
103
        for key in config_kwds.keys():
1✔
104
            del kwds[key]
1✔
105
        self._config = parse_config(config, **config_kwds)
1✔
106
        num_cores = self._config["core"].get("num_cores", N_CPU - 1)
1✔
107
        if num_cores >= N_CPU:
1✔
108
            num_cores = N_CPU - 1
1✔
109
        self._config["core"]["num_cores"] = num_cores
1✔
110

111
        if verbose is None:
1✔
112
            self.verbose = self._config["core"].get("verbose", False)
1✔
113
        else:
114
            self.verbose = verbose
1✔
115

116
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
117
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
118
        self._files: list[str] = []
1✔
119

120
        self._binned: xr.DataArray = None
1✔
121
        self._pre_binned: xr.DataArray = None
1✔
122
        self._normalization_histogram: xr.DataArray = None
1✔
123
        self._normalized: xr.DataArray = None
1✔
124

125
        self._attributes = MetaHandler(meta=metadata)
1✔
126

127
        loader_name = self._config["core"]["loader"]
1✔
128
        self.loader = get_loader(
1✔
129
            loader_name=loader_name,
130
            config=self._config,
131
        )
132

133
        self.ec = EnergyCalibrator(
1✔
134
            loader=get_loader(
135
                loader_name=loader_name,
136
                config=self._config,
137
            ),
138
            config=self._config,
139
        )
140

141
        self.mc = MomentumCorrector(
1✔
142
            config=self._config,
143
        )
144

145
        self.dc = DelayCalibrator(
1✔
146
            config=self._config,
147
        )
148

149
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
150
            "use_copy_tool",
151
            False,
152
        )
153
        if self.use_copy_tool:
1✔
154
            try:
1✔
155
                self.ct = CopyTool(
1✔
156
                    source=self._config["core"]["copy_tool_source"],
157
                    dest=self._config["core"]["copy_tool_dest"],
158
                    num_cores=self._config["core"]["num_cores"],
159
                    **self._config["core"].get("copy_tool_kwds", {}),
160
                )
161
            except KeyError:
1✔
162
                self.use_copy_tool = False
1✔
163

164
        # Load data if provided:
165
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
166
            self.load(
1✔
167
                dataframe=dataframe,
168
                metadata=metadata,
169
                files=files,
170
                folder=folder,
171
                runs=runs,
172
                collect_metadata=collect_metadata,
173
                **kwds,
174
            )
175

176
    def __repr__(self):
1✔
177
        if self._dataframe is None:
1✔
178
            df_str = "Dataframe: No Data loaded"
1✔
179
        else:
180
            df_str = self._dataframe.__repr__()
1✔
181
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
182
        return pretty_str
1✔
183

184
    def _repr_html_(self):
1✔
185
        html = "<div>"
×
186

187
        if self._dataframe is None:
×
188
            df_html = "Dataframe: No Data loaded"
×
189
        else:
190
            df_html = self._dataframe._repr_html_()
×
191

192
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
193

194
        # Add expandable section for attributes
195
        html += "<details><summary>Metadata</summary>"
×
196
        html += "<div style='padding-left: 10px;'>"
×
197
        html += self._attributes._repr_html_()
×
198
        html += "</div></details>"
×
199

200
        html += "</div>"
×
201

202
        return html
×
203

204
    ## Suggestion:
205
    # @property
206
    # def overview_panel(self):
207
    #     """Provides an overview panel with plots of different data attributes."""
208
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
209

210
    @property
1✔
211
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
212
        """Accessor to the underlying dataframe.
213

214
        Returns:
215
            pd.DataFrame | ddf.DataFrame: Dataframe object.
216
        """
217
        return self._dataframe
1✔
218

219
    @dataframe.setter
1✔
220
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
221
        """Setter for the underlying dataframe.
222

223
        Args:
224
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
225
        """
226
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
227
            dataframe,
228
            self._dataframe.__class__,
229
        ):
230
            raise ValueError(
1✔
231
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
232
                "as the dataframe loaded into the SedProcessor!.\n"
233
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
234
            )
235
        self._dataframe = dataframe
1✔
236

237
    @property
1✔
238
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
239
        """Accessor to the underlying timed_dataframe.
240

241
        Returns:
242
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
243
        """
244
        return self._timed_dataframe
1✔
245

246
    @timed_dataframe.setter
1✔
247
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
248
        """Setter for the underlying timed dataframe.
249

250
        Args:
251
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
252
        """
253
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
254
            timed_dataframe,
255
            self._timed_dataframe.__class__,
256
        ):
257
            raise ValueError(
×
258
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
259
                "kind as the dataframe loaded into the SedProcessor!.\n"
260
                f"Loaded type: {self._timed_dataframe.__class__}, "
261
                f"provided type: {timed_dataframe}.",
262
            )
263
        self._timed_dataframe = timed_dataframe
×
264

265
    @property
1✔
266
    def attributes(self) -> MetaHandler:
1✔
267
        """Accessor to the metadata dict.
268

269
        Returns:
270
            MetaHandler: The metadata object
271
        """
272
        return self._attributes
1✔
273

274
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
275
        """Function to add element to the attributes dict.
276

277
        Args:
278
            attributes (dict): The attributes dictionary object to add.
279
            name (str): Key under which to add the dictionary to the attributes.
280
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
281
        """
282
        self._attributes.add(
1✔
283
            entry=attributes,
284
            name=name,
285
            **kwds,
286
        )
287

288
    @property
1✔
289
    def config(self) -> dict[Any, Any]:
1✔
290
        """Getter attribute for the config dictionary
291

292
        Returns:
293
            dict: The config dictionary.
294
        """
295
        return self._config
1✔
296

297
    @property
1✔
298
    def files(self) -> list[str]:
1✔
299
        """Getter attribute for the list of files
300

301
        Returns:
302
            list[str]: The list of loaded files
303
        """
304
        return self._files
1✔
305

306
    @property
1✔
307
    def binned(self) -> xr.DataArray:
1✔
308
        """Getter attribute for the binned data array
309

310
        Returns:
311
            xr.DataArray: The binned data array
312
        """
313
        if self._binned is None:
1✔
314
            raise ValueError("No binned data available, need to compute histogram first!")
×
315
        return self._binned
1✔
316

317
    @property
1✔
318
    def normalized(self) -> xr.DataArray:
1✔
319
        """Getter attribute for the normalized data array
320

321
        Returns:
322
            xr.DataArray: The normalized data array
323
        """
324
        if self._normalized is None:
1✔
325
            raise ValueError(
×
326
                "No normalized data available, compute data with normalization enabled!",
327
            )
328
        return self._normalized
1✔
329

330
    @property
1✔
331
    def normalization_histogram(self) -> xr.DataArray:
1✔
332
        """Getter attribute for the normalization histogram
333

334
        Returns:
335
            xr.DataArray: The normalization histogram
336
        """
337
        if self._normalization_histogram is None:
1✔
338
            raise ValueError("No normalization histogram available, generate histogram first!")
×
339
        return self._normalization_histogram
1✔
340

341
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
342
        """Function to mirror a list of files or a folder from a network drive to a
343
        local storage. Returns either the original or the copied path to the given
344
        path. The option to use this functionality is set by
345
        config["core"]["use_copy_tool"].
346

347
        Args:
348
            path (str | list[str]): Source path or path list.
349

350
        Returns:
351
            str | list[str]: Source or destination path or path list.
352
        """
353
        if self.use_copy_tool:
1✔
354
            if isinstance(path, list):
1✔
355
                path_out = []
1✔
356
                for file in path:
1✔
357
                    path_out.append(self.ct.copy(file))
1✔
358
                return path_out
1✔
359

360
            return self.ct.copy(path)
×
361

362
        if isinstance(path, list):
1✔
363
            return path
1✔
364

365
        return path
1✔
366

367
    def load(
1✔
368
        self,
369
        dataframe: pd.DataFrame | ddf.DataFrame = None,
370
        metadata: dict = None,
371
        files: list[str] = None,
372
        folder: str = None,
373
        runs: Sequence[str] = None,
374
        collect_metadata: bool = False,
375
        **kwds,
376
    ):
377
        """Load tabular data of single events into the dataframe object in the class.
378

379
        Args:
380
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
381
                format. Accepts anything which can be interpreted by pd.DataFrame as
382
                an input. Defaults to None.
383
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
384
            files (list[str], optional): List of file paths to pass to the loader.
385
                Defaults to None.
386
            runs (Sequence[str], optional): List of run identifiers to pass to the
387
                loader. Defaults to None.
388
            folder (str, optional): Folder path to pass to the loader.
389
                Defaults to None.
390
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
391
            **kwds:
392
                - *timed_dataframe*: timed dataframe if dataframe is provided.
393

394
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
395

396
        Raises:
397
            ValueError: Raised if no valid input is provided.
398
        """
399
        if metadata is None:
1✔
400
            metadata = {}
1✔
401
        if dataframe is not None:
1✔
402
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
403
        elif runs is not None:
1✔
404
            # If runs are provided, we only use the copy tool if also folder is provided.
405
            # In that case, we copy the whole provided base folder tree, and pass the copied
406
            # version to the loader as base folder to look for the runs.
407
            if folder is not None:
1✔
408
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
409
                    folders=cast(str, self.cpy(folder)),
410
                    runs=runs,
411
                    metadata=metadata,
412
                    collect_metadata=collect_metadata,
413
                    **kwds,
414
                )
415
            else:
416
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
417
                    runs=runs,
418
                    metadata=metadata,
419
                    collect_metadata=collect_metadata,
420
                    **kwds,
421
                )
422

423
        elif folder is not None:
1✔
424
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
425
                folders=cast(str, self.cpy(folder)),
426
                metadata=metadata,
427
                collect_metadata=collect_metadata,
428
                **kwds,
429
            )
430
        elif files is not None:
1✔
431
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
432
                files=cast(list[str], self.cpy(files)),
433
                metadata=metadata,
434
                collect_metadata=collect_metadata,
435
                **kwds,
436
            )
437
        else:
438
            raise ValueError(
1✔
439
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
440
            )
441

442
        self._dataframe = dataframe
1✔
443
        self._timed_dataframe = timed_dataframe
1✔
444
        self._files = self.loader.files
1✔
445

446
        for key in metadata:
1✔
447
            self._attributes.add(
1✔
448
                entry=metadata[key],
449
                name=key,
450
                duplicate_policy="merge",
451
            )
452

453
    def filter_column(
1✔
454
        self,
455
        column: str,
456
        min_value: float = -np.inf,
457
        max_value: float = np.inf,
458
    ) -> None:
459
        """Filter values in a column which are outside of a given range
460

461
        Args:
462
            column (str): Name of the column to filter
463
            min_value (float, optional): Minimum value to keep. Defaults to None.
464
            max_value (float, optional): Maximum value to keep. Defaults to None.
465
        """
466
        if column != "index" and column not in self._dataframe.columns:
1✔
467
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
468
        if min_value >= max_value:
1✔
469
            raise ValueError("min_value has to be smaller than max_value!")
1✔
470
        if self._dataframe is not None:
1✔
471
            self._dataframe = apply_filter(
1✔
472
                self._dataframe,
473
                col=column,
474
                lower_bound=min_value,
475
                upper_bound=max_value,
476
            )
477
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
478
            self._timed_dataframe = apply_filter(
1✔
479
                self._timed_dataframe,
480
                column,
481
                lower_bound=min_value,
482
                upper_bound=max_value,
483
            )
484
        metadata = {
1✔
485
            "filter": {
486
                "column": column,
487
                "min_value": min_value,
488
                "max_value": max_value,
489
            },
490
        }
491
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
492

493
    # Momentum calibration workflow
494
    # 1. Bin raw detector data for distortion correction
495
    def bin_and_load_momentum_calibration(
1✔
496
        self,
497
        df_partitions: int | Sequence[int] = 100,
498
        axes: list[str] = None,
499
        bins: list[int] = None,
500
        ranges: Sequence[tuple[float, float]] = None,
501
        plane: int = 0,
502
        width: int = 5,
503
        apply: bool = False,
504
        **kwds,
505
    ):
506
        """1st step of momentum correction work flow. Function to do an initial binning
507
        of the dataframe loaded to the class, slice a plane from it using an
508
        interactive view, and load it into the momentum corrector class.
509

510
        Args:
511
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
512
                to use for the initial binning. Defaults to 100.
513
            axes (list[str], optional): Axes to bin.
514
                Defaults to config["momentum"]["axes"].
515
            bins (list[int], optional): Bin numbers to use for binning.
516
                Defaults to config["momentum"]["bins"].
517
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
518
                Defaults to config["momentum"]["ranges"].
519
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
520
            width (int, optional): Initial value for the width slider. Defaults to 5.
521
            apply (bool, optional): Option to directly apply the values and select the
522
                slice. Defaults to False.
523
            **kwds: Keyword argument passed to the pre_binning function.
524
        """
525
        self._pre_binned = self.pre_binning(
1✔
526
            df_partitions=df_partitions,
527
            axes=axes,
528
            bins=bins,
529
            ranges=ranges,
530
            **kwds,
531
        )
532

533
        self.mc.load_data(data=self._pre_binned)
1✔
534
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
535

536
    # 2. Generate the spline warp correction from momentum features.
537
    # Either autoselect features, or input features from view above.
538
    def define_features(
1✔
539
        self,
540
        features: np.ndarray = None,
541
        rotation_symmetry: int = 6,
542
        auto_detect: bool = False,
543
        include_center: bool = True,
544
        apply: bool = False,
545
        **kwds,
546
    ):
547
        """2. Step of the distortion correction workflow: Define feature points in
548
        momentum space. They can be either manually selected using a GUI tool, be
549
        provided as list of feature points, or auto-generated using a
550
        feature-detection algorithm.
551

552
        Args:
553
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
554
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
555
                Defaults to 6.
556
            auto_detect (bool, optional): Whether to auto-detect the features.
557
                Defaults to False.
558
            include_center (bool, optional): Option to include a point at the center
559
                in the feature list. Defaults to True.
560
            apply (bool, optional): Option to directly apply the values and select the
561
                slice. Defaults to False.
562
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
563
                ``MomentumCorrector.feature_select()``.
564
        """
565
        if auto_detect:  # automatic feature selection
1✔
566
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
567
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
568
            sigma_radius = kwds.pop(
×
569
                "sigma_radius",
570
                self._config["momentum"]["sigma_radius"],
571
            )
572
            self.mc.feature_extract(
×
573
                sigma=sigma,
574
                fwhm=fwhm,
575
                sigma_radius=sigma_radius,
576
                rotsym=rotation_symmetry,
577
                **kwds,
578
            )
579
            features = self.mc.peaks
×
580

581
        self.mc.feature_select(
1✔
582
            rotsym=rotation_symmetry,
583
            include_center=include_center,
584
            features=features,
585
            apply=apply,
586
            **kwds,
587
        )
588

589
    # 3. Generate the spline warp correction from momentum features.
590
    # If no features have been selected before, use class defaults.
591
    def generate_splinewarp(
1✔
592
        self,
593
        use_center: bool = None,
594
        verbose: bool = None,
595
        **kwds,
596
    ):
597
        """3. Step of the distortion correction workflow: Generate the correction
598
        function restoring the symmetry in the image using a splinewarp algorithm.
599

600
        Args:
601
            use_center (bool, optional): Option to use the position of the
602
                center point in the correction. Default is read from config, or set to True.
603
            verbose (bool, optional): Option to print out diagnostic information.
604
                Defaults to config["core"]["verbose"].
605
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
606
        """
607
        if verbose is None:
1✔
608
            verbose = self.verbose
1✔
609

610
        self.mc.spline_warp_estimate(use_center=use_center, verbose=verbose, **kwds)
1✔
611

612
        if self.mc.slice is not None and verbose:
1✔
613
            print("Original slice with reference features")
1✔
614
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
615

616
            print("Corrected slice with target features")
1✔
617
            self.mc.view(
1✔
618
                image=self.mc.slice_corrected,
619
                annotated=True,
620
                points={"feats": self.mc.ptargs},
621
                backend="bokeh",
622
                crosshair=True,
623
            )
624

625
            print("Original slice with target features")
1✔
626
            self.mc.view(
1✔
627
                image=self.mc.slice,
628
                points={"feats": self.mc.ptargs},
629
                annotated=True,
630
                backend="bokeh",
631
            )
632

633
    # 3a. Save spline-warp parameters to config file.
634
    def save_splinewarp(
1✔
635
        self,
636
        filename: str = None,
637
        overwrite: bool = False,
638
    ):
639
        """Save the generated spline-warp parameters to the folder config file.
640

641
        Args:
642
            filename (str, optional): Filename of the config dictionary to save to.
643
                Defaults to "sed_config.yaml" in the current folder.
644
            overwrite (bool, optional): Option to overwrite the present dictionary.
645
                Defaults to False.
646
        """
647
        if filename is None:
1✔
648
            filename = "sed_config.yaml"
×
649
        if len(self.mc.correction) == 0:
1✔
650
            raise ValueError("No momentum correction parameters to save!")
×
651
        correction = {}
1✔
652
        for key, value in self.mc.correction.items():
1✔
653
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
654
                continue
1✔
655
            if key in ["use_center", "rotation_symmetry"]:
1✔
656
                correction[key] = value
1✔
657
            elif key in ["center_point", "ascale"]:
1✔
658
                correction[key] = [float(i) for i in value]
1✔
659
            elif key in ["outer_points", "feature_points"]:
1✔
660
                correction[key] = []
1✔
661
                for point in value:
1✔
662
                    correction[key].append([float(i) for i in point])
1✔
663
            else:
664
                correction[key] = float(value)
1✔
665

666
        if "creation_date" not in correction:
1✔
667
            correction["creation_date"] = datetime.now().timestamp()
×
668

669
        config = {
1✔
670
            "momentum": {
671
                "correction": correction,
672
            },
673
        }
674
        save_config(config, filename, overwrite)
1✔
675
        print(f'Saved momentum correction parameters to "{filename}".')
1✔
676

677
    # 4. Pose corrections. Provide interactive interface for correcting
678
    # scaling, shift and rotation
679
    def pose_adjustment(
1✔
680
        self,
681
        transformations: dict[str, Any] = None,
682
        apply: bool = False,
683
        use_correction: bool = True,
684
        reset: bool = True,
685
        verbose: bool = None,
686
        **kwds,
687
    ):
688
        """3. step of the distortion correction workflow: Generate an interactive panel
689
        to adjust affine transformations that are applied to the image. Applies first
690
        a scaling, next an x/y translation, and last a rotation around the center of
691
        the image.
692

693
        Args:
694
            transformations (dict[str, Any], optional): Dictionary with transformations.
695
                Defaults to self.transformations or config["momentum"]["transformations"].
696
            apply (bool, optional): Option to directly apply the provided
697
                transformations. Defaults to False.
698
            use_correction (bool, option): Whether to use the spline warp correction
699
                or not. Defaults to True.
700
            reset (bool, optional): Option to reset the correction before transformation.
701
                Defaults to True.
702
            verbose (bool, optional): Option to print out diagnostic information.
703
                Defaults to config["core"]["verbose"].
704
            **kwds: Keyword parameters defining defaults for the transformations:
705

706
                - **scale** (float): Initial value of the scaling slider.
707
                - **xtrans** (float): Initial value of the xtrans slider.
708
                - **ytrans** (float): Initial value of the ytrans slider.
709
                - **angle** (float): Initial value of the angle slider.
710
        """
711
        if verbose is None:
1✔
712
            verbose = self.verbose
1✔
713

714
        # Generate homography as default if no distortion correction has been applied
715
        if self.mc.slice_corrected is None:
1✔
716
            if self.mc.slice is None:
1✔
717
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
718
            self.mc.slice_corrected = self.mc.slice
1✔
719

720
        if not use_correction:
1✔
721
            self.mc.reset_deformation()
1✔
722

723
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
724
            # Generate distortion correction from config values
725
            self.mc.spline_warp_estimate(verbose=verbose)
×
726

727
        self.mc.pose_adjustment(
1✔
728
            transformations=transformations,
729
            apply=apply,
730
            reset=reset,
731
            verbose=verbose,
732
            **kwds,
733
        )
734

735
    # 4a. Save pose adjustment parameters to config file.
736
    def save_transformations(
1✔
737
        self,
738
        filename: str = None,
739
        overwrite: bool = False,
740
    ):
741
        """Save the pose adjustment parameters to the folder config file.
742

743
        Args:
744
            filename (str, optional): Filename of the config dictionary to save to.
745
                Defaults to "sed_config.yaml" in the current folder.
746
            overwrite (bool, optional): Option to overwrite the present dictionary.
747
                Defaults to False.
748
        """
749
        if filename is None:
1✔
750
            filename = "sed_config.yaml"
×
751
        if len(self.mc.transformations) == 0:
1✔
752
            raise ValueError("No momentum transformation parameters to save!")
×
753
        transformations = {}
1✔
754
        for key, value in self.mc.transformations.items():
1✔
755
            transformations[key] = float(value)
1✔
756

757
        if "creation_date" not in transformations:
1✔
758
            transformations["creation_date"] = datetime.now().timestamp()
×
759

760
        config = {
1✔
761
            "momentum": {
762
                "transformations": transformations,
763
            },
764
        }
765
        save_config(config, filename, overwrite)
1✔
766
        print(f'Saved momentum transformation parameters to "{filename}".')
1✔
767

768
    # 5. Apply the momentum correction to the dataframe
769
    def apply_momentum_correction(
1✔
770
        self,
771
        preview: bool = False,
772
        verbose: bool = None,
773
        **kwds,
774
    ):
775
        """Applies the distortion correction and pose adjustment (optional)
776
        to the dataframe.
777

778
        Args:
779
            preview (bool, optional): Option to preview the first elements of the data frame.
780
                Defaults to False.
781
            verbose (bool, optional): Option to print out diagnostic information.
782
                Defaults to config["core"]["verbose"].
783
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
784

785
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
786
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
787
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
788

789
        """
790
        if verbose is None:
1✔
791
            verbose = self.verbose
1✔
792

793
        x_column = self._config["dataframe"]["x_column"]
1✔
794
        y_column = self._config["dataframe"]["y_column"]
1✔
795

796
        if self._dataframe is not None:
1✔
797
            if verbose:
1✔
798
                print("Adding corrected X/Y columns to dataframe:")
1✔
799
            df, metadata = self.mc.apply_corrections(
1✔
800
                df=self._dataframe,
801
                verbose=verbose,
802
                **kwds,
803
            )
804
            if (
1✔
805
                self._timed_dataframe is not None
806
                and x_column in self._timed_dataframe.columns
807
                and y_column in self._timed_dataframe.columns
808
            ):
809
                tdf, _ = self.mc.apply_corrections(
1✔
810
                    self._timed_dataframe,
811
                    verbose=False,
812
                    **kwds,
813
                )
814

815
            # Add Metadata
816
            self._attributes.add(
1✔
817
                metadata,
818
                "momentum_correction",
819
                duplicate_policy="merge",
820
            )
821
            self._dataframe = df
1✔
822
            if (
1✔
823
                self._timed_dataframe is not None
824
                and x_column in self._timed_dataframe.columns
825
                and y_column in self._timed_dataframe.columns
826
            ):
827
                self._timed_dataframe = tdf
1✔
828
        else:
829
            raise ValueError("No dataframe loaded!")
×
830
        if preview:
1✔
831
            print(self._dataframe.head(10))
×
832
        else:
833
            if self.verbose:
1✔
834
                print(self._dataframe)
1✔
835

836
    # Momentum calibration work flow
837
    # 1. Calculate momentum calibration
838
    def calibrate_momentum_axes(
1✔
839
        self,
840
        point_a: np.ndarray | list[int] = None,
841
        point_b: np.ndarray | list[int] = None,
842
        k_distance: float = None,
843
        k_coord_a: np.ndarray | list[float] = None,
844
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
845
        equiscale: bool = True,
846
        apply=False,
847
    ):
848
        """1. step of the momentum calibration workflow. Calibrate momentum
849
        axes using either provided pixel coordinates of a high-symmetry point and its
850
        distance to the BZ center, or the k-coordinates of two points in the BZ
851
        (depending on the equiscale option). Opens an interactive panel for selecting
852
        the points.
853

854
        Args:
855
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
856
                point used for momentum calibration.
857
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
858
                second point used for momentum calibration.
859
                Defaults to config["momentum"]["center_pixel"].
860
            k_distance (float, optional): Momentum distance between point a and b.
861
                Needs to be provided if no specific k-coordinates for the two points
862
                are given. Defaults to None.
863
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
864
                of the first point used for calibration. Used if equiscale is False.
865
                Defaults to None.
866
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
867
                of the second point used for calibration. Defaults to [0.0, 0.0].
868
            equiscale (bool, optional): Option to apply different scales to kx and ky.
869
                If True, the distance between points a and b, and the absolute
870
                position of point a are used for defining the scale. If False, the
871
                scale is calculated from the k-positions of both points a and b.
872
                Defaults to True.
873
            apply (bool, optional): Option to directly store the momentum calibration
874
                in the class. Defaults to False.
875
        """
876
        if point_b is None:
1✔
877
            point_b = self._config["momentum"]["center_pixel"]
1✔
878

879
        self.mc.select_k_range(
1✔
880
            point_a=point_a,
881
            point_b=point_b,
882
            k_distance=k_distance,
883
            k_coord_a=k_coord_a,
884
            k_coord_b=k_coord_b,
885
            equiscale=equiscale,
886
            apply=apply,
887
        )
888

889
    # 1a. Save momentum calibration parameters to config file.
890
    def save_momentum_calibration(
1✔
891
        self,
892
        filename: str = None,
893
        overwrite: bool = False,
894
    ):
895
        """Save the generated momentum calibration parameters to the folder config file.
896

897
        Args:
898
            filename (str, optional): Filename of the config dictionary to save to.
899
                Defaults to "sed_config.yaml" in the current folder.
900
            overwrite (bool, optional): Option to overwrite the present dictionary.
901
                Defaults to False.
902
        """
903
        if filename is None:
1✔
904
            filename = "sed_config.yaml"
×
905
        if len(self.mc.calibration) == 0:
1✔
906
            raise ValueError("No momentum calibration parameters to save!")
×
907
        calibration = {}
1✔
908
        for key, value in self.mc.calibration.items():
1✔
909
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
910
                continue
1✔
911

912
            calibration[key] = float(value)
1✔
913

914
        if "creation_date" not in calibration:
1✔
915
            calibration["creation_date"] = datetime.now().timestamp()
×
916

917
        config = {"momentum": {"calibration": calibration}}
1✔
918
        save_config(config, filename, overwrite)
1✔
919
        print(f"Saved momentum calibration parameters to {filename}")
1✔
920

921
    # 2. Apply correction and calibration to the dataframe
922
    def apply_momentum_calibration(
1✔
923
        self,
924
        calibration: dict = None,
925
        preview: bool = False,
926
        verbose: bool = None,
927
        **kwds,
928
    ):
929
        """2. step of the momentum calibration work flow: Apply the momentum
930
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
931
        these are used.
932

933
        Args:
934
            calibration (dict, optional): Optional dictionary with calibration data to
935
                use. Defaults to None.
936
            preview (bool, optional): Option to preview the first elements of the data frame.
937
                Defaults to False.
938
            verbose (bool, optional): Option to print out diagnostic information.
939
                Defaults to config["core"]["verbose"].
940
            **kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
941
        """
942
        if verbose is None:
1✔
943
            verbose = self.verbose
1✔
944

945
        x_column = self._config["dataframe"]["x_column"]
1✔
946
        y_column = self._config["dataframe"]["y_column"]
1✔
947

948
        if self._dataframe is not None:
1✔
949
            if verbose:
1✔
950
                print("Adding kx/ky columns to dataframe:")
1✔
951
            df, metadata = self.mc.append_k_axis(
1✔
952
                df=self._dataframe,
953
                calibration=calibration,
954
                **kwds,
955
            )
956
            if (
1✔
957
                self._timed_dataframe is not None
958
                and x_column in self._timed_dataframe.columns
959
                and y_column in self._timed_dataframe.columns
960
            ):
961
                tdf, _ = self.mc.append_k_axis(
1✔
962
                    df=self._timed_dataframe,
963
                    calibration=calibration,
964
                    **kwds,
965
                )
966

967
            # Add Metadata
968
            self._attributes.add(
1✔
969
                metadata,
970
                "momentum_calibration",
971
                duplicate_policy="merge",
972
            )
973
            self._dataframe = df
1✔
974
            if (
1✔
975
                self._timed_dataframe is not None
976
                and x_column in self._timed_dataframe.columns
977
                and y_column in self._timed_dataframe.columns
978
            ):
979
                self._timed_dataframe = tdf
1✔
980
        else:
981
            raise ValueError("No dataframe loaded!")
×
982
        if preview:
1✔
983
            print(self._dataframe.head(10))
×
984
        else:
985
            if self.verbose:
1✔
986
                print(self._dataframe)
1✔
987

988
    # Energy correction workflow
989
    # 1. Adjust the energy correction parameters
990
    def adjust_energy_correction(
1✔
991
        self,
992
        correction_type: str = None,
993
        amplitude: float = None,
994
        center: tuple[float, float] = None,
995
        apply=False,
996
        **kwds,
997
    ):
998
        """1. step of the energy correction workflow: Opens an interactive plot to
999
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
1000
        they are not present yet.
1001

1002
        Args:
1003
            correction_type (str, optional): Type of correction to apply to the TOF
1004
                axis. Valid values are:
1005

1006
                - 'spherical'
1007
                - 'Lorentzian'
1008
                - 'Gaussian'
1009
                - 'Lorentzian_asymmetric'
1010

1011
                Defaults to config["energy"]["correction_type"].
1012
            amplitude (float, optional): Amplitude of the correction.
1013
                Defaults to config["energy"]["correction"]["amplitude"].
1014
            center (tuple[float, float], optional): Center X/Y coordinates for the
1015
                correction. Defaults to config["energy"]["correction"]["center"].
1016
            apply (bool, optional): Option to directly apply the provided or default
1017
                correction parameters. Defaults to False.
1018
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1019
        """
1020
        if self._pre_binned is None:
1✔
1021
            print(
1✔
1022
                "Pre-binned data not present, binning using defaults from config...",
1023
            )
1024
            self._pre_binned = self.pre_binning()
1✔
1025

1026
        self.ec.adjust_energy_correction(
1✔
1027
            self._pre_binned,
1028
            correction_type=correction_type,
1029
            amplitude=amplitude,
1030
            center=center,
1031
            apply=apply,
1032
            **kwds,
1033
        )
1034

1035
    # 1a. Save energy correction parameters to config file.
1036
    def save_energy_correction(
1✔
1037
        self,
1038
        filename: str = None,
1039
        overwrite: bool = False,
1040
    ):
1041
        """Save the generated energy correction parameters to the folder config file.
1042

1043
        Args:
1044
            filename (str, optional): Filename of the config dictionary to save to.
1045
                Defaults to "sed_config.yaml" in the current folder.
1046
            overwrite (bool, optional): Option to overwrite the present dictionary.
1047
                Defaults to False.
1048
        """
1049
        if filename is None:
1✔
1050
            filename = "sed_config.yaml"
1✔
1051
        if len(self.ec.correction) == 0:
1✔
1052
            raise ValueError("No energy correction parameters to save!")
×
1053
        correction = {}
1✔
1054
        for key, val in self.ec.correction.items():
1✔
1055
            if key == "correction_type":
1✔
1056
                correction[key] = val
1✔
1057
            elif key == "center":
1✔
1058
                correction[key] = [float(i) for i in val]
1✔
1059
            else:
1060
                correction[key] = float(val)
1✔
1061

1062
        if "creation_date" not in correction:
1✔
1063
            correction["creation_date"] = datetime.now().timestamp()
×
1064

1065
        config = {"energy": {"correction": correction}}
1✔
1066
        save_config(config, filename, overwrite)
1✔
1067
        print(f"Saved energy correction parameters to {filename}")
1✔
1068

1069
    # 2. Apply energy correction to dataframe
1070
    def apply_energy_correction(
1✔
1071
        self,
1072
        correction: dict = None,
1073
        preview: bool = False,
1074
        verbose: bool = None,
1075
        **kwds,
1076
    ):
1077
        """2. step of the energy correction workflow: Apply the energy correction
1078
        parameters stored in the class to the dataframe.
1079

1080
        Args:
1081
            correction (dict, optional): Dictionary containing the correction
1082
                parameters. Defaults to config["energy"]["calibration"].
1083
            preview (bool, optional): Option to preview the first elements of the data frame.
1084
                Defaults to False.
1085
            verbose (bool, optional): Option to print out diagnostic information.
1086
                Defaults to config["core"]["verbose"].
1087
            **kwds:
1088
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1089
        """
1090
        if verbose is None:
1✔
1091
            verbose = self.verbose
1✔
1092

1093
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1094

1095
        if self._dataframe is not None:
1✔
1096
            if verbose:
1✔
1097
                print("Applying energy correction to dataframe...")
1✔
1098
            df, metadata = self.ec.apply_energy_correction(
1✔
1099
                df=self._dataframe,
1100
                correction=correction,
1101
                verbose=verbose,
1102
                **kwds,
1103
            )
1104
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1105
                tdf, _ = self.ec.apply_energy_correction(
1✔
1106
                    df=self._timed_dataframe,
1107
                    correction=correction,
1108
                    verbose=False,
1109
                    **kwds,
1110
                )
1111

1112
            # Add Metadata
1113
            self._attributes.add(
1✔
1114
                metadata,
1115
                "energy_correction",
1116
            )
1117
            self._dataframe = df
1✔
1118
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1119
                self._timed_dataframe = tdf
1✔
1120
        else:
1121
            raise ValueError("No dataframe loaded!")
×
1122
        if preview:
1✔
1123
            print(self._dataframe.head(10))
×
1124
        else:
1125
            if verbose:
1✔
1126
                print(self._dataframe)
×
1127

1128
    # Energy calibrator workflow
1129
    # 1. Load and normalize data
1130
    def load_bias_series(
1✔
1131
        self,
1132
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1133
        data_files: list[str] = None,
1134
        axes: list[str] = None,
1135
        bins: list = None,
1136
        ranges: Sequence[tuple[float, float]] = None,
1137
        biases: np.ndarray = None,
1138
        bias_key: str = None,
1139
        normalize: bool = None,
1140
        span: int = None,
1141
        order: int = None,
1142
    ):
1143
        """1. step of the energy calibration workflow: Load and bin data from
1144
        single-event files, or load binned bias/TOF traces.
1145

1146
        Args:
1147
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1148
                Binned data If provided as DataArray, Needs to contain dimensions
1149
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1150
                provided as tuple, needs to contain elements tof, biases, traces.
1151
            data_files (list[str], optional): list of file paths to bin
1152
            axes (list[str], optional): bin axes.
1153
                Defaults to config["dataframe"]["tof_column"].
1154
            bins (list, optional): number of bins.
1155
                Defaults to config["energy"]["bins"].
1156
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1157
                Defaults to config["energy"]["ranges"].
1158
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1159
                voltages are extracted from the data files.
1160
            bias_key (str, optional): hdf5 path where bias values are stored.
1161
                Defaults to config["energy"]["bias_key"].
1162
            normalize (bool, optional): Option to normalize traces.
1163
                Defaults to config["energy"]["normalize"].
1164
            span (int, optional): span smoothing parameters of the LOESS method
1165
                (see ``scipy.signal.savgol_filter()``).
1166
                Defaults to config["energy"]["normalize_span"].
1167
            order (int, optional): order smoothing parameters of the LOESS method
1168
                (see ``scipy.signal.savgol_filter()``).
1169
                Defaults to config["energy"]["normalize_order"].
1170
        """
1171
        if binned_data is not None:
1✔
1172
            if isinstance(binned_data, xr.DataArray):
1✔
1173
                if (
1✔
1174
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
1175
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
1176
                ):
1177
                    raise ValueError(
1✔
1178
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1179
                        f"'{self._config['dataframe']['tof_column']}' and "
1180
                        f"'{self._config['dataframe']['bias_column']}'!.",
1181
                    )
1182
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
1183
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
1184
                traces = binned_data.values[:, :]
1✔
1185
            else:
1186
                try:
1✔
1187
                    (tof, biases, traces) = binned_data
1✔
1188
                except ValueError as exc:
1✔
1189
                    raise ValueError(
1✔
1190
                        "If binned_data is provided as tuple, it needs to contain "
1191
                        "(tof, biases, traces)!",
1192
                    ) from exc
1193
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1194

1195
        elif data_files is not None:
1✔
1196
            self.ec.bin_data(
1✔
1197
                data_files=cast(list[str], self.cpy(data_files)),
1198
                axes=axes,
1199
                bins=bins,
1200
                ranges=ranges,
1201
                biases=biases,
1202
                bias_key=bias_key,
1203
            )
1204

1205
        else:
1206
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1207

1208
        if (normalize is not None and normalize is True) or (
1✔
1209
            normalize is None and self._config["energy"]["normalize"]
1210
        ):
1211
            if span is None:
1✔
1212
                span = self._config["energy"]["normalize_span"]
1✔
1213
            if order is None:
1✔
1214
                order = self._config["energy"]["normalize_order"]
1✔
1215
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1216
        self.ec.view(
1✔
1217
            traces=self.ec.traces_normed,
1218
            xaxis=self.ec.tof,
1219
            backend="bokeh",
1220
        )
1221

1222
    # 2. extract ranges and get peak positions
1223
    def find_bias_peaks(
1✔
1224
        self,
1225
        ranges: list[tuple] | tuple,
1226
        ref_id: int = 0,
1227
        infer_others: bool = True,
1228
        mode: str = "replace",
1229
        radius: int = None,
1230
        peak_window: int = None,
1231
        apply: bool = False,
1232
    ):
1233
        """2. step of the energy calibration workflow: Find a peak within a given range
1234
        for the indicated reference trace, and tries to find the same peak for all
1235
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1236
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1237
        middle of the set, and don't choose the range too narrow around the peak.
1238
        Alternatively, a list of ranges for all traces can be provided.
1239

1240
        Args:
1241
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1242
                Alternatively, a list of ranges for all traces can be given.
1243
            ref_id (int, optional): The id of the trace the range refers to.
1244
                Defaults to 0.
1245
            infer_others (bool, optional): Whether to determine the range for the other
1246
                traces. Defaults to True.
1247
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1248
                Defaults to "replace".
1249
            radius (int, optional): Radius parameter for fast_dtw.
1250
                Defaults to config["energy"]["fastdtw_radius"].
1251
            peak_window (int, optional): Peak_window parameter for the peak detection
1252
                algorithm. amount of points that have to have to behave monotonously
1253
                around a peak. Defaults to config["energy"]["peak_window"].
1254
            apply (bool, optional): Option to directly apply the provided parameters.
1255
                Defaults to False.
1256
        """
1257
        if radius is None:
1✔
1258
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1259
        if peak_window is None:
1✔
1260
            peak_window = self._config["energy"]["peak_window"]
1✔
1261
        if not infer_others:
1✔
1262
            self.ec.add_ranges(
1✔
1263
                ranges=ranges,
1264
                ref_id=ref_id,
1265
                infer_others=infer_others,
1266
                mode=mode,
1267
                radius=radius,
1268
            )
1269
            print(self.ec.featranges)
1✔
1270
            try:
1✔
1271
                self.ec.feature_extract(peak_window=peak_window)
1✔
1272
                self.ec.view(
1✔
1273
                    traces=self.ec.traces_normed,
1274
                    segs=self.ec.featranges,
1275
                    xaxis=self.ec.tof,
1276
                    peaks=self.ec.peaks,
1277
                    backend="bokeh",
1278
                )
1279
            except IndexError:
×
1280
                print("Could not determine all peaks!")
×
1281
                raise
×
1282
        else:
1283
            # New adjustment tool
1284
            assert isinstance(ranges, tuple)
1✔
1285
            self.ec.adjust_ranges(
1✔
1286
                ranges=ranges,
1287
                ref_id=ref_id,
1288
                traces=self.ec.traces_normed,
1289
                infer_others=infer_others,
1290
                radius=radius,
1291
                peak_window=peak_window,
1292
                apply=apply,
1293
            )
1294

1295
    # 3. Fit the energy calibration relation
1296
    def calibrate_energy_axis(
1✔
1297
        self,
1298
        ref_energy: float,
1299
        method: str = None,
1300
        energy_scale: str = None,
1301
        verbose: bool = None,
1302
        **kwds,
1303
    ):
1304
        """3. Step of the energy calibration workflow: Calculate the calibration
1305
        function for the energy axis, and apply it to the dataframe. Two
1306
        approximations are implemented, a (normally 3rd order) polynomial
1307
        approximation, and a d^2/(t-t0)^2 relation.
1308

1309
        Args:
1310
            ref_energy (float): Binding/kinetic energy of the detected feature.
1311
            method (str, optional): Method for determining the energy calibration.
1312

1313
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1314
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1315

1316
                Defaults to config["energy"]["calibration_method"]
1317
            energy_scale (str, optional): Direction of increasing energy scale.
1318

1319
                - **'kinetic'**: increasing energy with decreasing TOF.
1320
                - **'binding'**: increasing energy with increasing TOF.
1321

1322
                Defaults to config["energy"]["energy_scale"]
1323
            verbose (bool, optional): Option to print out diagnostic information.
1324
                Defaults to config["core"]["verbose"].
1325
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1326
        """
1327
        if verbose is None:
1✔
1328
            verbose = self.verbose
1✔
1329

1330
        if method is None:
1✔
1331
            method = self._config["energy"]["calibration_method"]
1✔
1332

1333
        if energy_scale is None:
1✔
1334
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1335

1336
        self.ec.calibrate(
1✔
1337
            ref_energy=ref_energy,
1338
            method=method,
1339
            energy_scale=energy_scale,
1340
            verbose=verbose,
1341
            **kwds,
1342
        )
1343
        if verbose:
1✔
1344
            print("Quality of Calibration:")
1✔
1345
            self.ec.view(
1✔
1346
                traces=self.ec.traces_normed,
1347
                xaxis=self.ec.calibration["axis"],
1348
                align=True,
1349
                energy_scale=energy_scale,
1350
                backend="bokeh",
1351
            )
1352
            print("E/TOF relationship:")
1✔
1353
            if energy_scale == "kinetic":
1✔
1354
                self.ec.view(
1✔
1355
                    traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1356
                    xaxis=self.ec.tof,
1357
                    backend="matplotlib",
1358
                    show_legend=False,
1359
                )
1360
                plt.scatter(
1✔
1361
                    self.ec.peaks[:, 0],
1362
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1363
                    s=50,
1364
                    c="k",
1365
                )
1366
            elif energy_scale == "binding":
1✔
1367
                self.ec.view(
1✔
1368
                    traces=self.ec.calibration["axis"][None, :] - self.ec.biases[0],
1369
                    xaxis=self.ec.tof,
1370
                    backend="matplotlib",
1371
                    show_legend=False,
1372
                )
1373
                plt.scatter(
1✔
1374
                    self.ec.peaks[:, 0],
1375
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1376
                    s=50,
1377
                    c="k",
1378
                )
1379
            else:
1380
                raise ValueError(
×
1381
                    'energy_scale needs to be either "binding" or "kinetic"',
1382
                    f", got {energy_scale}.",
1383
                )
1384
            plt.xlabel("Time-of-flight", fontsize=15)
1✔
1385
            plt.ylabel("Energy (eV)", fontsize=15)
1✔
1386
            plt.show()
1✔
1387

1388
    # 3a. Save energy calibration parameters to config file.
1389
    def save_energy_calibration(
1✔
1390
        self,
1391
        filename: str = None,
1392
        overwrite: bool = False,
1393
    ):
1394
        """Save the generated energy calibration parameters to the folder config file.
1395

1396
        Args:
1397
            filename (str, optional): Filename of the config dictionary to save to.
1398
                Defaults to "sed_config.yaml" in the current folder.
1399
            overwrite (bool, optional): Option to overwrite the present dictionary.
1400
                Defaults to False.
1401
        """
1402
        if filename is None:
1✔
1403
            filename = "sed_config.yaml"
×
1404
        if len(self.ec.calibration) == 0:
1✔
1405
            raise ValueError("No energy calibration parameters to save!")
×
1406
        calibration = {}
1✔
1407
        for key, value in self.ec.calibration.items():
1✔
1408
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1409
                continue
1✔
1410
            if key == "energy_scale":
1✔
1411
                calibration[key] = value
1✔
1412
            elif key == "coeffs":
1✔
1413
                calibration[key] = [float(i) for i in value]
1✔
1414
            else:
1415
                calibration[key] = float(value)
1✔
1416

1417
        if "creation_date" not in calibration:
1✔
1418
            calibration["creation_date"] = datetime.now().timestamp()
×
1419

1420
        config = {"energy": {"calibration": calibration}}
1✔
1421
        save_config(config, filename, overwrite)
1✔
1422
        print(f'Saved energy calibration parameters to "{filename}".')
1✔
1423

1424
    # 4. Apply energy calibration to the dataframe
1425
    def append_energy_axis(
1✔
1426
        self,
1427
        calibration: dict = None,
1428
        bias_voltage: float = None,
1429
        preview: bool = False,
1430
        verbose: bool = None,
1431
        **kwds,
1432
    ):
1433
        """4. step of the energy calibration workflow: Apply the calibration function
1434
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1435
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1436
        can be provided.
1437

1438
        Args:
1439
            calibration (dict, optional): Calibration dict containing calibration
1440
                parameters. Overrides calibration from class or config.
1441
                Defaults to None.
1442
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1443
                the bias voltage is being read from the dataframe. If it is not found there,
1444
                a warning is printed and the calibrated data might have an offset.
1445
            preview (bool): Option to preview the first elements of the data frame.
1446
            verbose (bool, optional): Option to print out diagnostic information.
1447
                Defaults to config["core"]["verbose"].
1448
            **kwds:
1449
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1450
        """
1451
        if verbose is None:
1✔
1452
            verbose = self.verbose
1✔
1453

1454
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1455

1456
        if self._dataframe is not None:
1✔
1457
            if verbose:
1✔
1458
                print("Adding energy column to dataframe:")
1✔
1459
            df, metadata = self.ec.append_energy_axis(
1✔
1460
                df=self._dataframe,
1461
                calibration=calibration,
1462
                bias_voltage=bias_voltage,
1463
                verbose=verbose,
1464
                **kwds,
1465
            )
1466
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1467
                tdf, _ = self.ec.append_energy_axis(
1✔
1468
                    df=self._timed_dataframe,
1469
                    calibration=calibration,
1470
                    bias_voltage=bias_voltage,
1471
                    verbose=False,
1472
                    **kwds,
1473
                )
1474

1475
            # Add Metadata
1476
            self._attributes.add(
1✔
1477
                metadata,
1478
                "energy_calibration",
1479
                duplicate_policy="merge",
1480
            )
1481
            self._dataframe = df
1✔
1482
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1483
                self._timed_dataframe = tdf
1✔
1484

1485
        else:
1486
            raise ValueError("No dataframe loaded!")
×
1487
        if preview:
1✔
1488
            print(self._dataframe.head(10))
×
1489
        else:
1490
            if verbose:
1✔
1491
                print(self._dataframe)
1✔
1492

1493
    def add_energy_offset(
1✔
1494
        self,
1495
        constant: float = None,
1496
        columns: str | Sequence[str] = None,
1497
        weights: float | Sequence[float] = None,
1498
        reductions: str | Sequence[str] = None,
1499
        preserve_mean: bool | Sequence[bool] = None,
1500
        preview: bool = False,
1501
        verbose: bool = None,
1502
    ) -> None:
1503
        """Shift the energy axis of the dataframe by a given amount.
1504

1505
        Args:
1506
            constant (float, optional): The constant to shift the energy axis by.
1507
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1508
            weights (float | Sequence[float], optional): weights to apply to the columns.
1509
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1510
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1511
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1512
                case the function is applied to the column to generate a single value for the whole
1513
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1514
                Currently only "mean" is supported.
1515
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1516
                column before applying the shift. Defaults to False.
1517
            preview (bool, optional): Option to preview the first elements of the data frame.
1518
                Defaults to False.
1519
            verbose (bool, optional): Option to print out diagnostic information.
1520
                Defaults to config["core"]["verbose"].
1521

1522
        Raises:
1523
            ValueError: If the energy column is not in the dataframe.
1524
        """
1525
        if verbose is None:
1✔
1526
            verbose = self.verbose
1✔
1527

1528
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1529
        if energy_column not in self._dataframe.columns:
1✔
1530
            raise ValueError(
1✔
1531
                f"Energy column {energy_column} not found in dataframe! "
1532
                "Run `append_energy_axis()` first.",
1533
            )
1534
        if self.dataframe is not None:
1✔
1535
            if verbose:
1✔
1536
                print("Adding energy offset to dataframe:")
1✔
1537
            df, metadata = self.ec.add_offsets(
1✔
1538
                df=self._dataframe,
1539
                constant=constant,
1540
                columns=columns,
1541
                energy_column=energy_column,
1542
                weights=weights,
1543
                reductions=reductions,
1544
                preserve_mean=preserve_mean,
1545
                verbose=verbose,
1546
            )
1547
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1548
                tdf, _ = self.ec.add_offsets(
1✔
1549
                    df=self._timed_dataframe,
1550
                    constant=constant,
1551
                    columns=columns,
1552
                    energy_column=energy_column,
1553
                    weights=weights,
1554
                    reductions=reductions,
1555
                    preserve_mean=preserve_mean,
1556
                )
1557

1558
            self._attributes.add(
1✔
1559
                metadata,
1560
                "add_energy_offset",
1561
                # TODO: allow only appending when no offset along this column(s) was applied
1562
                # TODO: clear memory of modifications if the energy axis is recalculated
1563
                duplicate_policy="append",
1564
            )
1565
            self._dataframe = df
1✔
1566
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1567
                self._timed_dataframe = tdf
1✔
1568
        else:
1569
            raise ValueError("No dataframe loaded!")
×
1570
        if preview:
1✔
1571
            print(self._dataframe.head(10))
×
1572
        elif verbose:
1✔
1573
            print(self._dataframe)
1✔
1574

1575
    def save_energy_offset(
1✔
1576
        self,
1577
        filename: str = None,
1578
        overwrite: bool = False,
1579
    ):
1580
        """Save the generated energy calibration parameters to the folder config file.
1581

1582
        Args:
1583
            filename (str, optional): Filename of the config dictionary to save to.
1584
                Defaults to "sed_config.yaml" in the current folder.
1585
            overwrite (bool, optional): Option to overwrite the present dictionary.
1586
                Defaults to False.
1587
        """
1588
        if filename is None:
×
1589
            filename = "sed_config.yaml"
×
1590
        if len(self.ec.offsets) == 0:
×
1591
            raise ValueError("No energy offset parameters to save!")
×
1592

1593
        if "creation_date" not in self.ec.offsets.keys():
×
1594
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
×
1595

1596
        config = {"energy": {"offsets": self.ec.offsets}}
×
1597
        save_config(config, filename, overwrite)
×
1598
        print(f'Saved energy offset parameters to "{filename}".')
×
1599

1600
    def append_tof_ns_axis(
1✔
1601
        self,
1602
        preview: bool = False,
1603
        verbose: bool = None,
1604
        **kwds,
1605
    ):
1606
        """Convert time-of-flight channel steps to nanoseconds.
1607

1608
        Args:
1609
            tof_ns_column (str, optional): Name of the generated column containing the
1610
                time-of-flight in nanosecond.
1611
                Defaults to config["dataframe"]["tof_ns_column"].
1612
            preview (bool, optional): Option to preview the first elements of the data frame.
1613
                Defaults to False.
1614
            verbose (bool, optional): Option to print out diagnostic information.
1615
                Defaults to config["core"]["verbose"].
1616
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1617

1618
        """
1619
        if verbose is None:
1✔
1620
            verbose = self.verbose
1✔
1621

1622
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1623

1624
        if self._dataframe is not None:
1✔
1625
            if verbose:
1✔
1626
                print("Adding time-of-flight column in nanoseconds to dataframe:")
1✔
1627
            # TODO assert order of execution through metadata
1628

1629
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1630
                df=self._dataframe,
1631
                **kwds,
1632
            )
1633
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1634
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1635
                    df=self._timed_dataframe,
1636
                    **kwds,
1637
                )
1638

1639
            self._attributes.add(
1✔
1640
                metadata,
1641
                "tof_ns_conversion",
1642
                duplicate_policy="overwrite",
1643
            )
1644
            self._dataframe = df
1✔
1645
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1646
                self._timed_dataframe = tdf
1✔
1647
        else:
1648
            raise ValueError("No dataframe loaded!")
×
1649
        if preview:
1✔
1650
            print(self._dataframe.head(10))
×
1651
        else:
1652
            if verbose:
1✔
1653
                print(self._dataframe)
1✔
1654

1655
    def align_dld_sectors(
1✔
1656
        self,
1657
        sector_delays: np.ndarray = None,
1658
        preview: bool = False,
1659
        verbose: bool = None,
1660
        **kwds,
1661
    ):
1662
        """Align the 8s sectors of the HEXTOF endstation.
1663

1664
        Args:
1665
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1666
                config["dataframe"]["sector_delays"].
1667
            preview (bool, optional): Option to preview the first elements of the data frame.
1668
                Defaults to False.
1669
            verbose (bool, optional): Option to print out diagnostic information.
1670
                Defaults to config["core"]["verbose"].
1671
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1672
        """
1673
        if verbose is None:
1✔
1674
            verbose = self.verbose
1✔
1675

1676
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1677

1678
        if self._dataframe is not None:
1✔
1679
            if verbose:
1✔
1680
                print("Aligning 8s sectors of dataframe")
1✔
1681
            # TODO assert order of execution through metadata
1682

1683
            df, metadata = self.ec.align_dld_sectors(
1✔
1684
                df=self._dataframe,
1685
                sector_delays=sector_delays,
1686
                **kwds,
1687
            )
1688
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1689
                tdf, _ = self.ec.align_dld_sectors(
×
1690
                    df=self._timed_dataframe,
1691
                    sector_delays=sector_delays,
1692
                    **kwds,
1693
                )
1694

1695
            self._attributes.add(
1✔
1696
                metadata,
1697
                "dld_sector_alignment",
1698
                duplicate_policy="raise",
1699
            )
1700
            self._dataframe = df
1✔
1701
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1702
                self._timed_dataframe = tdf
×
1703
        else:
1704
            raise ValueError("No dataframe loaded!")
×
1705
        if preview:
1✔
1706
            print(self._dataframe.head(10))
×
1707
        else:
1708
            if verbose:
1✔
1709
                print(self._dataframe)
1✔
1710

1711
    # Delay calibration function
1712
    def calibrate_delay_axis(
1✔
1713
        self,
1714
        delay_range: tuple[float, float] = None,
1715
        datafile: str = None,
1716
        preview: bool = False,
1717
        verbose: bool = None,
1718
        **kwds,
1719
    ):
1720
        """Append delay column to dataframe. Either provide delay ranges, or read
1721
        them from a file.
1722

1723
        Args:
1724
            delay_range (tuple[float, float], optional): The scanned delay range in
1725
                picoseconds. Defaults to None.
1726
            datafile (str, optional): The file from which to read the delay ranges.
1727
                Defaults to None.
1728
            preview (bool, optional): Option to preview the first elements of the data frame.
1729
                Defaults to False.
1730
            verbose (bool, optional): Option to print out diagnostic information.
1731
                Defaults to config["core"]["verbose"].
1732
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1733
        """
1734
        if verbose is None:
1✔
1735
            verbose = self.verbose
1✔
1736

1737
        adc_column = self._config["dataframe"]["adc_column"]
1✔
1738
        if adc_column not in self._dataframe.columns:
1✔
1739
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1740

1741
        if self._dataframe is not None:
1✔
1742
            if verbose:
1✔
1743
                print("Adding delay column to dataframe:")
1✔
1744

1745
            if delay_range is None and datafile is None:
1✔
1746
                if len(self.dc.calibration) == 0:
1✔
1747
                    try:
1✔
1748
                        datafile = self._files[0]
1✔
NEW
1749
                    except IndexError as exc:
×
NEW
1750
                        raise IndexError(
×
1751
                            "No datafile available, specify either 'datafile' or 'delay_range'",
1752
                        ) from exc
1753

1754
            df, metadata = self.dc.append_delay_axis(
1✔
1755
                self._dataframe,
1756
                delay_range=delay_range,
1757
                datafile=datafile,
1758
                verbose=verbose,
1759
                **kwds,
1760
            )
1761
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1762
                tdf, _ = self.dc.append_delay_axis(
1✔
1763
                    self._timed_dataframe,
1764
                    delay_range=delay_range,
1765
                    datafile=datafile,
1766
                    verbose=False,
1767
                    **kwds,
1768
                )
1769

1770
            # Add Metadata
1771
            self._attributes.add(
1✔
1772
                metadata,
1773
                "delay_calibration",
1774
                duplicate_policy="overwrite",
1775
            )
1776
            self._dataframe = df
1✔
1777
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1778
                self._timed_dataframe = tdf
1✔
1779
        else:
1780
            raise ValueError("No dataframe loaded!")
×
1781
        if preview:
1✔
1782
            print(self._dataframe.head(10))
1✔
1783
        else:
1784
            if self.verbose:
1✔
1785
                print(self._dataframe)
1✔
1786

1787
    def save_delay_calibration(
1✔
1788
        self,
1789
        filename: str = None,
1790
        overwrite: bool = False,
1791
    ) -> None:
1792
        """Save the generated delay calibration parameters to the folder config file.
1793

1794
        Args:
1795
            filename (str, optional): Filename of the config dictionary to save to.
1796
                Defaults to "sed_config.yaml" in the current folder.
1797
            overwrite (bool, optional): Option to overwrite the present dictionary.
1798
                Defaults to False.
1799
        """
1800
        if filename is None:
1✔
1801
            filename = "sed_config.yaml"
×
1802

1803
        if len(self.dc.calibration) == 0:
1✔
1804
            raise ValueError("No delay calibration parameters to save!")
×
1805
        calibration = {}
1✔
1806
        for key, value in self.dc.calibration.items():
1✔
1807
            if key == "datafile":
1✔
1808
                calibration[key] = value
1✔
1809
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1810
                calibration[key] = [float(i) for i in value]
1✔
1811
            else:
1812
                calibration[key] = float(value)
1✔
1813

1814
        if "creation_date" not in calibration:
1✔
1815
            calibration["creation_date"] = datetime.now().timestamp()
×
1816

1817
        config = {
1✔
1818
            "delay": {
1819
                "calibration": calibration,
1820
            },
1821
        }
1822
        save_config(config, filename, overwrite)
1✔
1823

1824
    def add_delay_offset(
1✔
1825
        self,
1826
        constant: float = None,
1827
        flip_delay_axis: bool = None,
1828
        columns: str | Sequence[str] = None,
1829
        weights: float | Sequence[float] = 1.0,
1830
        reductions: str | Sequence[str] = None,
1831
        preserve_mean: bool | Sequence[bool] = False,
1832
        preview: bool = False,
1833
        verbose: bool = None,
1834
    ) -> None:
1835
        """Shift the delay axis of the dataframe by a constant or other columns.
1836

1837
        Args:
1838
            constant (float, optional): The constant to shift the delay axis by.
1839
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1840
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1841
            weights (float | Sequence[float], optional): weights to apply to the columns.
1842
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1843
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1844
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1845
                case the function is applied to the column to generate a single value for the whole
1846
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1847
                Currently only "mean" is supported.
1848
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1849
                column before applying the shift. Defaults to False.
1850
            preview (bool, optional): Option to preview the first elements of the data frame.
1851
                Defaults to False.
1852
            verbose (bool, optional): Option to print out diagnostic information.
1853
                Defaults to config["core"]["verbose"].
1854

1855
        Raises:
1856
            ValueError: If the delay column is not in the dataframe.
1857
        """
1858
        if verbose is None:
1✔
1859
            verbose = self.verbose
1✔
1860

1861
        delay_column = self._config["dataframe"]["delay_column"]
1✔
1862
        if delay_column not in self._dataframe.columns:
1✔
1863
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1864

1865
        if self.dataframe is not None:
1✔
1866
            if verbose:
1✔
1867
                print("Adding delay offset to dataframe:")
1✔
1868
            df, metadata = self.dc.add_offsets(
1✔
1869
                df=self._dataframe,
1870
                constant=constant,
1871
                flip_delay_axis=flip_delay_axis,
1872
                columns=columns,
1873
                delay_column=delay_column,
1874
                weights=weights,
1875
                reductions=reductions,
1876
                preserve_mean=preserve_mean,
1877
                verbose=verbose,
1878
            )
1879
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1880
                tdf, _ = self.dc.add_offsets(
1✔
1881
                    df=self._timed_dataframe,
1882
                    constant=constant,
1883
                    flip_delay_axis=flip_delay_axis,
1884
                    columns=columns,
1885
                    delay_column=delay_column,
1886
                    weights=weights,
1887
                    reductions=reductions,
1888
                    preserve_mean=preserve_mean,
1889
                    verbose=False,
1890
                )
1891

1892
            self._attributes.add(
1✔
1893
                metadata,
1894
                "delay_offset",
1895
                duplicate_policy="append",
1896
            )
1897
            self._dataframe = df
1✔
1898
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1899
                self._timed_dataframe = tdf
1✔
1900
        else:
1901
            raise ValueError("No dataframe loaded!")
×
1902
        if preview:
1✔
1903
            print(self._dataframe.head(10))
1✔
1904
        else:
1905
            if verbose:
1✔
1906
                print(self._dataframe)
1✔
1907

1908
    def save_delay_offsets(
1✔
1909
        self,
1910
        filename: str = None,
1911
        overwrite: bool = False,
1912
    ) -> None:
1913
        """Save the generated delay calibration parameters to the folder config file.
1914

1915
        Args:
1916
            filename (str, optional): Filename of the config dictionary to save to.
1917
                Defaults to "sed_config.yaml" in the current folder.
1918
            overwrite (bool, optional): Option to overwrite the present dictionary.
1919
                Defaults to False.
1920
        """
1921
        if filename is None:
1✔
1922
            filename = "sed_config.yaml"
×
1923
        if len(self.dc.offsets) == 0:
1✔
1924
            raise ValueError("No delay offset parameters to save!")
×
1925

1926
        if "creation_date" not in self.ec.offsets.keys():
1✔
1927
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
1✔
1928

1929
        config = {
1✔
1930
            "delay": {
1931
                "offsets": self.dc.offsets,
1932
            },
1933
        }
1934
        save_config(config, filename, overwrite)
1✔
1935
        print(f'Saved delay offset parameters to "{filename}".')
1✔
1936

1937
    def save_workflow_params(
1✔
1938
        self,
1939
        filename: str = None,
1940
        overwrite: bool = False,
1941
    ) -> None:
1942
        """run all save calibration parameter methods
1943

1944
        Args:
1945
            filename (str, optional): Filename of the config dictionary to save to.
1946
                Defaults to "sed_config.yaml" in the current folder.
1947
            overwrite (bool, optional): Option to overwrite the present dictionary.
1948
                Defaults to False.
1949
        """
1950
        for method in [
×
1951
            self.save_splinewarp,
1952
            self.save_transformations,
1953
            self.save_momentum_calibration,
1954
            self.save_energy_correction,
1955
            self.save_energy_calibration,
1956
            self.save_energy_offset,
1957
            self.save_delay_calibration,
1958
            self.save_delay_offsets,
1959
        ]:
1960
            try:
×
1961
                method(filename, overwrite)
×
1962
            except (ValueError, AttributeError, KeyError):
×
1963
                pass
×
1964

1965
    def add_jitter(
1✔
1966
        self,
1967
        cols: list[str] = None,
1968
        amps: float | Sequence[float] = None,
1969
        **kwds,
1970
    ):
1971
        """Add jitter to the selected dataframe columns.
1972

1973
        Args:
1974
            cols (list[str], optional): The columns onto which to apply jitter.
1975
                Defaults to config["dataframe"]["jitter_cols"].
1976
            amps (float | Sequence[float], optional): Amplitude scalings for the
1977
                jittering noise. If one number is given, the same is used for all axes.
1978
                For uniform noise (default) it will cover the interval [-amp, +amp].
1979
                Defaults to config["dataframe"]["jitter_amps"].
1980
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1981
        """
1982
        if cols is None:
1✔
1983
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1984
        for loc, col in enumerate(cols):
1✔
1985
            if col.startswith("@"):
1✔
1986
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1987

1988
        if amps is None:
1✔
1989
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1990

1991
        self._dataframe = self._dataframe.map_partitions(
1✔
1992
            apply_jitter,
1993
            cols=cols,
1994
            cols_jittered=cols,
1995
            amps=amps,
1996
            **kwds,
1997
        )
1998
        if self._timed_dataframe is not None:
1✔
1999
            cols_timed = cols.copy()
1✔
2000
            for col in cols:
1✔
2001
                if col not in self._timed_dataframe.columns:
1✔
2002
                    cols_timed.remove(col)
×
2003

2004
            if cols_timed:
1✔
2005
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
2006
                    apply_jitter,
2007
                    cols=cols_timed,
2008
                    cols_jittered=cols_timed,
2009
                )
2010
        metadata = []
1✔
2011
        for col in cols:
1✔
2012
            metadata.append(col)
1✔
2013
        # TODO: allow only appending if columns are not jittered yet
2014
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
2015

2016
    def add_time_stamped_data(
1✔
2017
        self,
2018
        dest_column: str,
2019
        time_stamps: np.ndarray = None,
2020
        data: np.ndarray = None,
2021
        archiver_channel: str = None,
2022
        **kwds,
2023
    ):
2024
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
2025
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
2026
        an EPICS archiver instance.
2027

2028
        Args:
2029
            dest_column (str): destination column name
2030
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
2031
                time stamps are retrieved from the epics archiver
2032
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
2033
                If omitted, data are retrieved from the epics archiver.
2034
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
2035
                Either this or data and time_stamps have to be present.
2036
            **kwds:
2037

2038
                - **time_stamp_column**: Dataframe column containing time-stamp data
2039

2040
                Additional keyword arguments passed to ``add_time_stamped_data``.
2041
        """
2042
        time_stamp_column = kwds.pop(
1✔
2043
            "time_stamp_column",
2044
            self._config["dataframe"].get("time_stamp_alias", ""),
2045
        )
2046

2047
        if time_stamps is None and data is None:
1✔
2048
            if archiver_channel is None:
×
2049
                raise ValueError(
×
2050
                    "Either archiver_channel or both time_stamps and data have to be present!",
2051
                )
2052
            if self.loader.__name__ != "mpes":
×
2053
                raise NotImplementedError(
×
2054
                    "This function is currently only implemented for the mpes loader!",
2055
                )
2056
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2057
            # get channel data with +-5 seconds safety margin
2058
            time_stamps, data = get_archiver_data(
×
2059
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2060
                archiver_channel=archiver_channel,
2061
                ts_from=ts_from - 5,
2062
                ts_to=ts_to + 5,
2063
            )
2064

2065
        self._dataframe = add_time_stamped_data(
1✔
2066
            self._dataframe,
2067
            time_stamps=time_stamps,
2068
            data=data,
2069
            dest_column=dest_column,
2070
            time_stamp_column=time_stamp_column,
2071
            **kwds,
2072
        )
2073
        if self._timed_dataframe is not None:
1✔
2074
            if time_stamp_column in self._timed_dataframe:
1✔
2075
                self._timed_dataframe = add_time_stamped_data(
1✔
2076
                    self._timed_dataframe,
2077
                    time_stamps=time_stamps,
2078
                    data=data,
2079
                    dest_column=dest_column,
2080
                    time_stamp_column=time_stamp_column,
2081
                    **kwds,
2082
                )
2083
        metadata: list[Any] = []
1✔
2084
        metadata.append(dest_column)
1✔
2085
        metadata.append(time_stamps)
1✔
2086
        metadata.append(data)
1✔
2087
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2088

2089
    def pre_binning(
1✔
2090
        self,
2091
        df_partitions: int | Sequence[int] = 100,
2092
        axes: list[str] = None,
2093
        bins: list[int] = None,
2094
        ranges: Sequence[tuple[float, float]] = None,
2095
        **kwds,
2096
    ) -> xr.DataArray:
2097
        """Function to do an initial binning of the dataframe loaded to the class.
2098

2099
        Args:
2100
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2101
                use for the initial binning. Defaults to 100.
2102
            axes (list[str], optional): Axes to bin.
2103
                Defaults to config["momentum"]["axes"].
2104
            bins (list[int], optional): Bin numbers to use for binning.
2105
                Defaults to config["momentum"]["bins"].
2106
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2107
                Defaults to config["momentum"]["ranges"].
2108
            **kwds: Keyword argument passed to ``compute``.
2109

2110
        Returns:
2111
            xr.DataArray: pre-binned data-array.
2112
        """
2113
        if axes is None:
1✔
2114
            axes = self._config["momentum"]["axes"]
1✔
2115
        for loc, axis in enumerate(axes):
1✔
2116
            if axis.startswith("@"):
1✔
2117
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2118

2119
        if bins is None:
1✔
2120
            bins = self._config["momentum"]["bins"]
1✔
2121
        if ranges is None:
1✔
2122
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2123
            ranges_[2] = np.asarray(ranges_[2]) / self._config["dataframe"]["tof_binning"]
1✔
2124
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2125

2126
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2127

2128
        return self.compute(
1✔
2129
            bins=bins,
2130
            axes=axes,
2131
            ranges=ranges,
2132
            df_partitions=df_partitions,
2133
            **kwds,
2134
        )
2135

2136
    def compute(
1✔
2137
        self,
2138
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2139
        axes: str | Sequence[str] = None,
2140
        ranges: Sequence[tuple[float, float]] = None,
2141
        normalize_to_acquisition_time: bool | str = False,
2142
        **kwds,
2143
    ) -> xr.DataArray:
2144
        """Compute the histogram along the given dimensions.
2145

2146
        Args:
2147
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2148
                Definition of the bins. Can be any of the following cases:
2149

2150
                - an integer describing the number of bins in on all dimensions
2151
                - a tuple of 3 numbers describing start, end and step of the binning
2152
                  range
2153
                - a np.arrays defining the binning edges
2154
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2155
                - a dictionary made of the axes as keys and any of the above as values.
2156

2157
                This takes priority over the axes and range arguments. Defaults to 100.
2158
            axes (str | Sequence[str], optional): The names of the axes (columns)
2159
                on which to calculate the histogram. The order will be the order of the
2160
                dimensions in the resulting array. Defaults to None.
2161
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2162
                the start and end point of the binning range. Defaults to None.
2163
            normalize_to_acquisition_time (bool | str): Option to normalize the
2164
                result to the acquisition time. If a "slow" axis was scanned, providing
2165
                the name of the scanned axis will compute and apply the corresponding
2166
                normalization histogram. Defaults to False.
2167
            **kwds: Keyword arguments:
2168

2169
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2170
                  ``bin_dataframe`` for details. Defaults to
2171
                  config["binning"]["hist_mode"].
2172
                - **mode**: Defines how the results from each partition are combined.
2173
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2174
                  Defaults to config["binning"]["mode"].
2175
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2176
                  config["binning"]["pbar"].
2177
                - **n_cores**: Number of CPU cores to use for parallelization.
2178
                  Defaults to config["core"]["num_cores"] or N_CPU-1.
2179
                - **threads_per_worker**: Limit the number of threads that
2180
                  multiprocessing can spawn per binning thread. Defaults to
2181
                  config["binning"]["threads_per_worker"].
2182
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2183
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2184
                  config["binning"]["threadpool_API"].
2185
                - **df_partitions**: A sequence of dataframe partitions, or the
2186
                  number of the dataframe partitions to use. Defaults to all partitions.
2187
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2188
                  "upper_bound" to apply as filter to the dataframe before binning. The
2189
                  dataframe in the class remains unmodified by this.
2190

2191
                Additional kwds are passed to ``bin_dataframe``.
2192

2193
        Raises:
2194
            AssertError: Rises when no dataframe has been loaded.
2195

2196
        Returns:
2197
            xr.DataArray: The result of the n-dimensional binning represented in an
2198
            xarray object, combining the data with the axes.
2199
        """
2200
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2201

2202
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2203
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2204
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2205
        num_cores = kwds.pop("num_cores", self._config["core"]["num_cores"])
1✔
2206
        threads_per_worker = kwds.pop(
1✔
2207
            "threads_per_worker",
2208
            self._config["binning"]["threads_per_worker"],
2209
        )
2210
        threadpool_api = kwds.pop(
1✔
2211
            "threadpool_API",
2212
            self._config["binning"]["threadpool_API"],
2213
        )
2214
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2215
        if isinstance(df_partitions, int):
1✔
2216
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2217
        if df_partitions is not None:
1✔
2218
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2219
        else:
2220
            dataframe = self._dataframe
1✔
2221

2222
        filter_params = kwds.pop("filter", None)
1✔
2223
        if filter_params is not None:
1✔
2224
            try:
1✔
2225
                for param in filter_params:
1✔
2226
                    if "col" not in param:
1✔
2227
                        raise ValueError(
1✔
2228
                            "'col' needs to be defined for each filter entry! ",
2229
                            f"Not present in {param}.",
2230
                        )
2231
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2232
                    dataframe = apply_filter(dataframe, **param)
1✔
2233
            except AssertionError as exc:
1✔
2234
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2235
                raise ValueError(
1✔
2236
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2237
                    f"Parameters {invalid_keys} not valid in {param}.",
2238
                ) from exc
2239

2240
        self._binned = bin_dataframe(
1✔
2241
            df=dataframe,
2242
            bins=bins,
2243
            axes=axes,
2244
            ranges=ranges,
2245
            hist_mode=hist_mode,
2246
            mode=mode,
2247
            pbar=pbar,
2248
            n_cores=num_cores,
2249
            threads_per_worker=threads_per_worker,
2250
            threadpool_api=threadpool_api,
2251
            **kwds,
2252
        )
2253

2254
        for dim in self._binned.dims:
1✔
2255
            try:
1✔
2256
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2257
            except KeyError:
1✔
2258
                pass
1✔
2259

2260
        self._binned.attrs["units"] = "counts"
1✔
2261
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2262
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2263

2264
        if normalize_to_acquisition_time:
1✔
2265
            if isinstance(normalize_to_acquisition_time, str):
1✔
2266
                axis = normalize_to_acquisition_time
1✔
2267
                print(
1✔
2268
                    f"Calculate normalization histogram for axis '{axis}'...",
2269
                )
2270
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2271
                    axis=axis,
2272
                    df_partitions=df_partitions,
2273
                )
2274
                # if the axes are named correctly, xarray figures out the normalization correctly
2275
                self._normalized = self._binned / self._normalization_histogram
1✔
2276
                self._attributes.add(
1✔
2277
                    self._normalization_histogram.values,
2278
                    name="normalization_histogram",
2279
                    duplicate_policy="overwrite",
2280
                )
2281
            else:
2282
                acquisition_time = self.loader.get_elapsed_time(
×
2283
                    fids=df_partitions,
2284
                )
2285
                if acquisition_time > 0:
×
2286
                    self._normalized = self._binned / acquisition_time
×
2287
                self._attributes.add(
×
2288
                    acquisition_time,
2289
                    name="normalization_histogram",
2290
                    duplicate_policy="overwrite",
2291
                )
2292

2293
            self._normalized.attrs["units"] = "counts/second"
1✔
2294
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2295
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2296

2297
            return self._normalized
1✔
2298

2299
        return self._binned
1✔
2300

2301
    def get_normalization_histogram(
1✔
2302
        self,
2303
        axis: str = "delay",
2304
        use_time_stamps: bool = False,
2305
        **kwds,
2306
    ) -> xr.DataArray:
2307
        """Generates a normalization histogram from the timed dataframe. Optionally,
2308
        use the TimeStamps column instead.
2309

2310
        Args:
2311
            axis (str, optional): The axis for which to compute histogram.
2312
                Defaults to "delay".
2313
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2314
                dataframe, rather than the timed dataframe. Defaults to False.
2315
            **kwds: Keyword arguments:
2316

2317
                - **df_partitions**: A sequence of dataframe partitions, or the
2318
                  number of the dataframe partitions to use. Defaults to all partitions.
2319

2320
        Raises:
2321
            ValueError: Raised if no data are binned.
2322
            ValueError: Raised if 'axis' not in binned coordinates.
2323
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2324
                in Dataframe.
2325

2326
        Returns:
2327
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2328
            per bin).
2329
        """
2330

2331
        if self._binned is None:
1✔
2332
            raise ValueError("Need to bin data first!")
1✔
2333
        if axis not in self._binned.coords:
1✔
2334
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2335

2336
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2337

2338
        if len(kwds) > 0:
1✔
2339
            raise TypeError(
1✔
2340
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2341
            )
2342

2343
        if isinstance(df_partitions, int):
1✔
2344
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2345
        if use_time_stamps or self._timed_dataframe is None:
1✔
2346
            if df_partitions is not None:
1✔
2347
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2348
                    self._dataframe.partitions[df_partitions],
2349
                    axis,
2350
                    self._binned.coords[axis].values,
2351
                    self._config["dataframe"]["time_stamp_alias"],
2352
                )
2353
            else:
2354
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2355
                    self._dataframe,
2356
                    axis,
2357
                    self._binned.coords[axis].values,
2358
                    self._config["dataframe"]["time_stamp_alias"],
2359
                )
2360
        else:
2361
            if df_partitions is not None:
1✔
2362
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2363
                    self._timed_dataframe.partitions[df_partitions],
2364
                    axis,
2365
                    self._binned.coords[axis].values,
2366
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2367
                )
2368
            else:
2369
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2370
                    self._timed_dataframe,
2371
                    axis,
2372
                    self._binned.coords[axis].values,
2373
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2374
                )
2375

2376
        return self._normalization_histogram
1✔
2377

2378
    def view_event_histogram(
1✔
2379
        self,
2380
        dfpid: int,
2381
        ncol: int = 2,
2382
        bins: Sequence[int] = None,
2383
        axes: Sequence[str] = None,
2384
        ranges: Sequence[tuple[float, float]] = None,
2385
        backend: str = "bokeh",
2386
        legend: bool = True,
2387
        histkwds: dict = None,
2388
        legkwds: dict = None,
2389
        **kwds,
2390
    ):
2391
        """Plot individual histograms of specified dimensions (axes) from a substituent
2392
        dataframe partition.
2393

2394
        Args:
2395
            dfpid (int): Number of the data frame partition to look at.
2396
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2397
            bins (Sequence[int], optional): Number of bins to use for the specified
2398
                axes. Defaults to config["histogram"]["bins"].
2399
            axes (Sequence[str], optional): Names of the axes to display.
2400
                Defaults to config["histogram"]["axes"].
2401
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2402
                specified axes. Defaults to config["histogram"]["ranges"].
2403
            backend (str, optional): Backend of the plotting library
2404
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2405
            legend (bool, optional): Option to include a legend in the histogram plots.
2406
                Defaults to True.
2407
            histkwds (dict, optional): Keyword arguments for histograms
2408
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2409
            legkwds (dict, optional): Keyword arguments for legend
2410
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2411
            **kwds: Extra keyword arguments passed to
2412
                ``sed.diagnostics.grid_histogram()``.
2413

2414
        Raises:
2415
            TypeError: Raises when the input values are not of the correct type.
2416
        """
2417
        if bins is None:
1✔
2418
            bins = self._config["histogram"]["bins"]
1✔
2419
        if axes is None:
1✔
2420
            axes = self._config["histogram"]["axes"]
1✔
2421
        axes = list(axes)
1✔
2422
        for loc, axis in enumerate(axes):
1✔
2423
            if axis.startswith("@"):
1✔
2424
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2425
        if ranges is None:
1✔
2426
            ranges = list(self._config["histogram"]["ranges"])
1✔
2427
            for loc, axis in enumerate(axes):
1✔
2428
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2429
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
1✔
2430
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
NEW
2431
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]
×
2432

2433
        input_types = map(type, [axes, bins, ranges])
1✔
2434
        allowed_types = [list, tuple]
1✔
2435

2436
        df = self._dataframe
1✔
2437

2438
        if not set(input_types).issubset(allowed_types):
1✔
2439
            raise TypeError(
×
2440
                "Inputs of axes, bins, ranges need to be list or tuple!",
2441
            )
2442

2443
        # Read out the values for the specified groups
2444
        group_dict_dd = {}
1✔
2445
        dfpart = df.get_partition(dfpid)
1✔
2446
        cols = dfpart.columns
1✔
2447
        for ax in axes:
1✔
2448
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2449
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2450

2451
        # Plot multiple histograms in a grid
2452
        grid_histogram(
1✔
2453
            group_dict,
2454
            ncol=ncol,
2455
            rvs=axes,
2456
            rvbins=bins,
2457
            rvranges=ranges,
2458
            backend=backend,
2459
            legend=legend,
2460
            histkwds=histkwds,
2461
            legkwds=legkwds,
2462
            **kwds,
2463
        )
2464

2465
    def save(
1✔
2466
        self,
2467
        faddr: str,
2468
        **kwds,
2469
    ):
2470
        """Saves the binned data to the provided path and filename.
2471

2472
        Args:
2473
            faddr (str): Path and name of the file to write. Its extension determines
2474
                the file type to write. Valid file types are:
2475

2476
                - "*.tiff", "*.tif": Saves a TIFF stack.
2477
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2478
                - "*.nxs", "*.nexus": Saves a NeXus file.
2479

2480
            **kwds: Keyword arguments, which are passed to the writer functions:
2481
                For TIFF writing:
2482

2483
                - **alias_dict**: Dictionary of dimension aliases to use.
2484

2485
                For HDF5 writing:
2486

2487
                - **mode**: hdf5 read/write mode. Defaults to "w".
2488

2489
                For NeXus:
2490

2491
                - **reader**: Name of the pynxtools reader to use.
2492
                  Defaults to config["nexus"]["reader"]
2493
                - **definition**: NeXus application definition to use for saving.
2494
                  Must be supported by the used ``reader``. Defaults to
2495
                  config["nexus"]["definition"]
2496
                - **input_files**: A list of input files to pass to the reader.
2497
                  Defaults to config["nexus"]["input_files"]
2498
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2499
                  to add to the list of files to pass to the reader.
2500
        """
2501
        if self._binned is None:
1✔
2502
            raise NameError("Need to bin data first!")
1✔
2503

2504
        if self._normalized is not None:
1✔
2505
            data = self._normalized
×
2506
        else:
2507
            data = self._binned
1✔
2508

2509
        extension = pathlib.Path(faddr).suffix
1✔
2510

2511
        if extension in (".tif", ".tiff"):
1✔
2512
            to_tiff(
1✔
2513
                data=data,
2514
                faddr=faddr,
2515
                **kwds,
2516
            )
2517
        elif extension in (".h5", ".hdf5"):
1✔
2518
            to_h5(
1✔
2519
                data=data,
2520
                faddr=faddr,
2521
                **kwds,
2522
            )
2523
        elif extension in (".nxs", ".nexus"):
1✔
2524
            try:
1✔
2525
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2526
                definition = kwds.pop(
1✔
2527
                    "definition",
2528
                    self._config["nexus"]["definition"],
2529
                )
2530
                input_files = kwds.pop(
1✔
2531
                    "input_files",
2532
                    self._config["nexus"]["input_files"],
2533
                )
2534
            except KeyError as exc:
×
2535
                raise ValueError(
×
2536
                    "The nexus reader, definition and input files need to be provide!",
2537
                ) from exc
2538

2539
            if isinstance(input_files, str):
1✔
2540
                input_files = [input_files]
1✔
2541

2542
            if "eln_data" in kwds:
1✔
2543
                input_files.append(kwds.pop("eln_data"))
1✔
2544

2545
            to_nexus(
1✔
2546
                data=data,
2547
                faddr=faddr,
2548
                reader=reader,
2549
                definition=definition,
2550
                input_files=input_files,
2551
                **kwds,
2552
            )
2553

2554
        else:
2555
            raise NotImplementedError(
1✔
2556
                f"Unrecognized file format: {extension}.",
2557
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc