• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 10948948523

19 Sep 2024 09:08PM UTC coverage: 92.522% (-0.2%) from 92.693%
10948948523

Pull #490

github

rettigl
Make verbose a private property, and add getter and setters for it, to propagate verbosity
Pull Request #490: Logging

283 of 341 new or added lines in 12 files covered. (82.99%)

1 existing line in 1 file now uncovered.

7238 of 7823 relevant lines covered (92.52%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.03
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from datetime import datetime
1✔
9
from typing import Any
1✔
10
from typing import cast
1✔
11

12
import dask.dataframe as ddf
1✔
13
import matplotlib.pyplot as plt
1✔
14
import numpy as np
1✔
15
import pandas as pd
1✔
16
import psutil
1✔
17
import xarray as xr
1✔
18

19
from sed.binning import bin_dataframe
1✔
20
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
22
from sed.calibrator import DelayCalibrator
1✔
23
from sed.calibrator import EnergyCalibrator
1✔
24
from sed.calibrator import MomentumCorrector
1✔
25
from sed.core.config import parse_config
1✔
26
from sed.core.config import save_config
1✔
27
from sed.core.dfops import add_time_stamped_data
1✔
28
from sed.core.dfops import apply_filter
1✔
29
from sed.core.dfops import apply_jitter
1✔
30
from sed.core.logging import call_logger
1✔
31
from sed.core.logging import set_verbosity
1✔
32
from sed.core.logging import setup_logging
1✔
33
from sed.core.metadata import MetaHandler
1✔
34
from sed.diagnostics import grid_histogram
1✔
35
from sed.io import to_h5
1✔
36
from sed.io import to_nexus
1✔
37
from sed.io import to_tiff
1✔
38
from sed.loader import CopyTool
1✔
39
from sed.loader import get_loader
1✔
40
from sed.loader.mpes.loader import get_archiver_data
1✔
41
from sed.loader.mpes.loader import MpesLoader
1✔
42

43
N_CPU = psutil.cpu_count()
1✔
44

45
# Configure logging
46
logger = setup_logging("processor")
1✔
47

48

49
class SedProcessor:
1✔
50
    """Processor class of sed. Contains wrapper functions defining a work flow for data
51
    correction, calibration and binning.
52

53
    Args:
54
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
55
        config (dict | str, optional): Config dictionary or config file name.
56
            Defaults to None.
57
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
58
            into the class. Defaults to None.
59
        files (list[str], optional): List of files to pass to the loader defined in
60
            the config. Defaults to None.
61
        folder (str, optional): Folder containing files to pass to the loader
62
            defined in the config. Defaults to None.
63
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
64
            defined in the config. Defaults to None.
65
        collect_metadata (bool): Option to collect metadata from files.
66
            Defaults to False.
67
        verbose (bool, optional): Option to print out diagnostic information.
68
            Defaults to config["core"]["verbose"] or True.
69
        **kwds: Keyword arguments passed to the reader.
70
    """
71

72
    @call_logger(logger)
1✔
73
    def __init__(
1✔
74
        self,
75
        metadata: dict = None,
76
        config: dict | str = None,
77
        dataframe: pd.DataFrame | ddf.DataFrame = None,
78
        files: list[str] = None,
79
        folder: str = None,
80
        runs: Sequence[str] = None,
81
        collect_metadata: bool = False,
82
        verbose: bool = None,
83
        **kwds,
84
    ):
85
        """Processor class of sed. Contains wrapper functions defining a work flow
86
        for data correction, calibration, and binning.
87

88
        Args:
89
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
90
            config (dict | str, optional): Config dictionary or config file name.
91
                Defaults to None.
92
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
93
                into the class. Defaults to None.
94
            files (list[str], optional): List of files to pass to the loader defined in
95
                the config. Defaults to None.
96
            folder (str, optional): Folder containing files to pass to the loader
97
                defined in the config. Defaults to None.
98
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
99
                defined in the config. Defaults to None.
100
            collect_metadata (bool, optional): Option to collect metadata from files.
101
                Defaults to False.
102
            verbose (bool, optional): Option to print out diagnostic information.
103
                Defaults to config["core"]["verbose"] or True.
104
            **kwds: Keyword arguments passed to parse_config and to the reader.
105
        """
106
        # split off config keywords
107
        config_kwds = {
1✔
108
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
109
        }
110
        for key in config_kwds.keys():
1✔
111
            del kwds[key]
1✔
112
        self._config = parse_config(config, **config_kwds)
1✔
113
        num_cores = self._config["core"].get("num_cores", N_CPU - 1)
1✔
114
        if num_cores >= N_CPU:
1✔
115
            num_cores = N_CPU - 1
1✔
116
        self._config["core"]["num_cores"] = num_cores
1✔
117
        logger.debug(f"Use {num_cores} cores for processing.")
1✔
118

119
        if verbose is None:
1✔
120
            self._verbose = self._config["core"].get("verbose", True)
1✔
121
        else:
122
            self._verbose = verbose
1✔
123
        set_verbosity(logger, self._verbose)
1✔
124

125
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
126
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
127
        self._files: list[str] = []
1✔
128

129
        self._binned: xr.DataArray = None
1✔
130
        self._pre_binned: xr.DataArray = None
1✔
131
        self._normalization_histogram: xr.DataArray = None
1✔
132
        self._normalized: xr.DataArray = None
1✔
133

134
        self._attributes = MetaHandler(meta=metadata)
1✔
135

136
        loader_name = self._config["core"]["loader"]
1✔
137
        self.loader = get_loader(
1✔
138
            loader_name=loader_name,
139
            config=self._config,
140
            verbose=verbose,
141
        )
142
        logger.debug(f"Use loader: {loader_name}")
1✔
143

144
        self.ec = EnergyCalibrator(
1✔
145
            loader=get_loader(
146
                loader_name=loader_name,
147
                config=self._config,
148
                verbose=verbose,
149
            ),
150
            config=self._config,
151
            verbose=self._verbose,
152
        )
153

154
        self.mc = MomentumCorrector(
1✔
155
            config=self._config,
156
            verbose=self._verbose,
157
        )
158

159
        self.dc = DelayCalibrator(
1✔
160
            config=self._config,
161
            verbose=self._verbose,
162
        )
163

164
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
165
            "use_copy_tool",
166
            False,
167
        )
168
        if self.use_copy_tool:
1✔
169
            try:
1✔
170
                self.ct = CopyTool(
1✔
171
                    source=self._config["core"]["copy_tool_source"],
172
                    dest=self._config["core"]["copy_tool_dest"],
173
                    num_cores=self._config["core"]["num_cores"],
174
                    **self._config["core"].get("copy_tool_kwds", {}),
175
                )
176
                logger.debug(
1✔
177
                    f"Initialized copy tool: Copy file from "
178
                    f"'{self._config['core']['copy_tool_source']}' "
179
                    f"to '{self._config['core']['copy_tool_dest']}'.",
180
                )
181
            except KeyError:
1✔
182
                self.use_copy_tool = False
1✔
183

184
        # Load data if provided:
185
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
186
            self.load(
1✔
187
                dataframe=dataframe,
188
                metadata=metadata,
189
                files=files,
190
                folder=folder,
191
                runs=runs,
192
                collect_metadata=collect_metadata,
193
                **kwds,
194
            )
195

196
    def __repr__(self):
1✔
197
        if self._dataframe is None:
1✔
198
            df_str = "Dataframe: No Data loaded"
1✔
199
        else:
200
            df_str = self._dataframe.__repr__()
1✔
201
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
202
        return pretty_str
1✔
203

204
    def _repr_html_(self):
1✔
205
        html = "<div>"
×
206

207
        if self._dataframe is None:
×
208
            df_html = "Dataframe: No Data loaded"
×
209
        else:
210
            df_html = self._dataframe._repr_html_()
×
211

212
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
213

214
        # Add expandable section for attributes
215
        html += "<details><summary>Metadata</summary>"
×
216
        html += "<div style='padding-left: 10px;'>"
×
217
        html += self._attributes._repr_html_()
×
218
        html += "</div></details>"
×
219

220
        html += "</div>"
×
221

222
        return html
×
223

224
    ## Suggestion:
225
    # @property
226
    # def overview_panel(self):
227
    #     """Provides an overview panel with plots of different data attributes."""
228
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
229

230
    @property
1✔
231
    def verbose(self) -> bool:
1✔
232
        """Accessor to the verbosity flag.
233

234
        Returns:
235
            bool: Verbosity flag.
236
        """
NEW
237
        return self._verbose
×
238

239
    @verbose.setter
1✔
240
    def verbose(self, verbose: bool):
1✔
241
        """Setter for the verbosity.
242

243
        Args:
244
            verbose (bool): Option to turn on verbose output. Sets loglevel to INFO.
245
        """
NEW
246
        self._verbose = verbose
×
NEW
247
        set_verbosity(logger, self._verbose)
×
NEW
248
        self.mc.verbose = verbose
×
NEW
249
        self.ec.verbose = verbose
×
NEW
250
        self.dc.verbose = verbose
×
NEW
251
        self.loader.verbose = verbose
×
252

253
    @property
1✔
254
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
255
        """Accessor to the underlying dataframe.
256

257
        Returns:
258
            pd.DataFrame | ddf.DataFrame: Dataframe object.
259
        """
260
        return self._dataframe
1✔
261

262
    @dataframe.setter
1✔
263
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
264
        """Setter for the underlying dataframe.
265

266
        Args:
267
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
268
        """
269
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
270
            dataframe,
271
            self._dataframe.__class__,
272
        ):
273
            raise ValueError(
1✔
274
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
275
                "as the dataframe loaded into the SedProcessor!.\n"
276
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
277
            )
278
        self._dataframe = dataframe
1✔
279

280
    @property
1✔
281
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
282
        """Accessor to the underlying timed_dataframe.
283

284
        Returns:
285
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
286
        """
287
        return self._timed_dataframe
1✔
288

289
    @timed_dataframe.setter
1✔
290
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
291
        """Setter for the underlying timed dataframe.
292

293
        Args:
294
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
295
        """
296
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
297
            timed_dataframe,
298
            self._timed_dataframe.__class__,
299
        ):
300
            raise ValueError(
×
301
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
302
                "kind as the dataframe loaded into the SedProcessor!.\n"
303
                f"Loaded type: {self._timed_dataframe.__class__}, "
304
                f"provided type: {timed_dataframe}.",
305
            )
306
        self._timed_dataframe = timed_dataframe
×
307

308
    @property
1✔
309
    def attributes(self) -> MetaHandler:
1✔
310
        """Accessor to the metadata dict.
311

312
        Returns:
313
            MetaHandler: The metadata object
314
        """
315
        return self._attributes
1✔
316

317
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
318
        """Function to add element to the attributes dict.
319

320
        Args:
321
            attributes (dict): The attributes dictionary object to add.
322
            name (str): Key under which to add the dictionary to the attributes.
323
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
324
        """
325
        self._attributes.add(
1✔
326
            entry=attributes,
327
            name=name,
328
            **kwds,
329
        )
330

331
    @property
1✔
332
    def config(self) -> dict[Any, Any]:
1✔
333
        """Getter attribute for the config dictionary
334

335
        Returns:
336
            dict: The config dictionary.
337
        """
338
        return self._config
1✔
339

340
    @property
1✔
341
    def files(self) -> list[str]:
1✔
342
        """Getter attribute for the list of files
343

344
        Returns:
345
            list[str]: The list of loaded files
346
        """
347
        return self._files
1✔
348

349
    @property
1✔
350
    def binned(self) -> xr.DataArray:
1✔
351
        """Getter attribute for the binned data array
352

353
        Returns:
354
            xr.DataArray: The binned data array
355
        """
356
        if self._binned is None:
1✔
357
            raise ValueError("No binned data available, need to compute histogram first!")
×
358
        return self._binned
1✔
359

360
    @property
1✔
361
    def normalized(self) -> xr.DataArray:
1✔
362
        """Getter attribute for the normalized data array
363

364
        Returns:
365
            xr.DataArray: The normalized data array
366
        """
367
        if self._normalized is None:
1✔
368
            raise ValueError(
×
369
                "No normalized data available, compute data with normalization enabled!",
370
            )
371
        return self._normalized
1✔
372

373
    @property
1✔
374
    def normalization_histogram(self) -> xr.DataArray:
1✔
375
        """Getter attribute for the normalization histogram
376

377
        Returns:
378
            xr.DataArray: The normalization histogram
379
        """
380
        if self._normalization_histogram is None:
1✔
381
            raise ValueError("No normalization histogram available, generate histogram first!")
×
382
        return self._normalization_histogram
1✔
383

384
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
385
        """Function to mirror a list of files or a folder from a network drive to a
386
        local storage. Returns either the original or the copied path to the given
387
        path. The option to use this functionality is set by
388
        config["core"]["use_copy_tool"].
389

390
        Args:
391
            path (str | list[str]): Source path or path list.
392

393
        Returns:
394
            str | list[str]: Source or destination path or path list.
395
        """
396
        if self.use_copy_tool:
1✔
397
            if isinstance(path, list):
1✔
398
                path_out = []
1✔
399
                for file in path:
1✔
400
                    path_out.append(self.ct.copy(file))
1✔
401
                return path_out
1✔
402

403
            return self.ct.copy(path)
×
404

405
        if isinstance(path, list):
1✔
406
            return path
1✔
407

408
        return path
1✔
409

410
    @call_logger(logger)
1✔
411
    def load(
1✔
412
        self,
413
        dataframe: pd.DataFrame | ddf.DataFrame = None,
414
        metadata: dict = None,
415
        files: list[str] = None,
416
        folder: str = None,
417
        runs: Sequence[str] = None,
418
        collect_metadata: bool = False,
419
        **kwds,
420
    ):
421
        """Load tabular data of single events into the dataframe object in the class.
422

423
        Args:
424
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
425
                format. Accepts anything which can be interpreted by pd.DataFrame as
426
                an input. Defaults to None.
427
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
428
            files (list[str], optional): List of file paths to pass to the loader.
429
                Defaults to None.
430
            runs (Sequence[str], optional): List of run identifiers to pass to the
431
                loader. Defaults to None.
432
            folder (str, optional): Folder path to pass to the loader.
433
                Defaults to None.
434
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
435
            **kwds:
436
                - *timed_dataframe*: timed dataframe if dataframe is provided.
437

438
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
439

440
        Raises:
441
            ValueError: Raised if no valid input is provided.
442
        """
443
        if metadata is None:
1✔
444
            metadata = {}
1✔
445
        if dataframe is not None:
1✔
446
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
447
        elif runs is not None:
1✔
448
            # If runs are provided, we only use the copy tool if also folder is provided.
449
            # In that case, we copy the whole provided base folder tree, and pass the copied
450
            # version to the loader as base folder to look for the runs.
451
            if folder is not None:
1✔
452
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
453
                    folders=cast(str, self.cpy(folder)),
454
                    runs=runs,
455
                    metadata=metadata,
456
                    collect_metadata=collect_metadata,
457
                    **kwds,
458
                )
459
            else:
460
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
461
                    runs=runs,
462
                    metadata=metadata,
463
                    collect_metadata=collect_metadata,
464
                    **kwds,
465
                )
466

467
        elif folder is not None:
1✔
468
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
469
                folders=cast(str, self.cpy(folder)),
470
                metadata=metadata,
471
                collect_metadata=collect_metadata,
472
                **kwds,
473
            )
474
        elif files is not None:
1✔
475
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
476
                files=cast(list[str], self.cpy(files)),
477
                metadata=metadata,
478
                collect_metadata=collect_metadata,
479
                **kwds,
480
            )
481
        else:
482
            raise ValueError(
1✔
483
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
484
            )
485

486
        self._dataframe = dataframe
1✔
487
        self._timed_dataframe = timed_dataframe
1✔
488
        self._files = self.loader.files
1✔
489

490
        for key in metadata:
1✔
491
            self._attributes.add(
1✔
492
                entry=metadata[key],
493
                name=key,
494
                duplicate_policy="merge",
495
            )
496

497
    @call_logger(logger)
1✔
498
    def filter_column(
1✔
499
        self,
500
        column: str,
501
        min_value: float = -np.inf,
502
        max_value: float = np.inf,
503
    ) -> None:
504
        """Filter values in a column which are outside of a given range
505

506
        Args:
507
            column (str): Name of the column to filter
508
            min_value (float, optional): Minimum value to keep. Defaults to None.
509
            max_value (float, optional): Maximum value to keep. Defaults to None.
510
        """
511
        if column != "index" and column not in self._dataframe.columns:
1✔
512
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
513
        if min_value >= max_value:
1✔
514
            raise ValueError("min_value has to be smaller than max_value!")
1✔
515
        if self._dataframe is not None:
1✔
516
            self._dataframe = apply_filter(
1✔
517
                self._dataframe,
518
                col=column,
519
                lower_bound=min_value,
520
                upper_bound=max_value,
521
            )
522
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
523
            self._timed_dataframe = apply_filter(
1✔
524
                self._timed_dataframe,
525
                column,
526
                lower_bound=min_value,
527
                upper_bound=max_value,
528
            )
529
        metadata = {
1✔
530
            "filter": {
531
                "column": column,
532
                "min_value": min_value,
533
                "max_value": max_value,
534
            },
535
        }
536
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
537

538
    # Momentum calibration workflow
539
    # 1. Bin raw detector data for distortion correction
540
    @call_logger(logger)
1✔
541
    def bin_and_load_momentum_calibration(
1✔
542
        self,
543
        df_partitions: int | Sequence[int] = 100,
544
        axes: list[str] = None,
545
        bins: list[int] = None,
546
        ranges: Sequence[tuple[float, float]] = None,
547
        plane: int = 0,
548
        width: int = 5,
549
        apply: bool = False,
550
        **kwds,
551
    ):
552
        """1st step of momentum correction work flow. Function to do an initial binning
553
        of the dataframe loaded to the class, slice a plane from it using an
554
        interactive view, and load it into the momentum corrector class.
555

556
        Args:
557
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
558
                to use for the initial binning. Defaults to 100.
559
            axes (list[str], optional): Axes to bin.
560
                Defaults to config["momentum"]["axes"].
561
            bins (list[int], optional): Bin numbers to use for binning.
562
                Defaults to config["momentum"]["bins"].
563
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
564
                Defaults to config["momentum"]["ranges"].
565
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
566
            width (int, optional): Initial value for the width slider. Defaults to 5.
567
            apply (bool, optional): Option to directly apply the values and select the
568
                slice. Defaults to False.
569
            **kwds: Keyword argument passed to the pre_binning function.
570
        """
571
        self._pre_binned = self.pre_binning(
1✔
572
            df_partitions=df_partitions,
573
            axes=axes,
574
            bins=bins,
575
            ranges=ranges,
576
            **kwds,
577
        )
578

579
        self.mc.load_data(data=self._pre_binned)
1✔
580
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
581

582
    # 2. Generate the spline warp correction from momentum features.
583
    # Either autoselect features, or input features from view above.
584
    @call_logger(logger)
1✔
585
    def define_features(
1✔
586
        self,
587
        features: np.ndarray = None,
588
        rotation_symmetry: int = 6,
589
        auto_detect: bool = False,
590
        include_center: bool = True,
591
        apply: bool = False,
592
        **kwds,
593
    ):
594
        """2. Step of the distortion correction workflow: Define feature points in
595
        momentum space. They can be either manually selected using a GUI tool, be
596
        provided as list of feature points, or auto-generated using a
597
        feature-detection algorithm.
598

599
        Args:
600
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
601
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
602
                Defaults to 6.
603
            auto_detect (bool, optional): Whether to auto-detect the features.
604
                Defaults to False.
605
            include_center (bool, optional): Option to include a point at the center
606
                in the feature list. Defaults to True.
607
            apply (bool, optional): Option to directly apply the values and select the
608
                slice. Defaults to False.
609
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
610
                ``MomentumCorrector.feature_select()``.
611
        """
612
        if auto_detect:  # automatic feature selection
1✔
613
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
614
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
615
            sigma_radius = kwds.pop(
×
616
                "sigma_radius",
617
                self._config["momentum"]["sigma_radius"],
618
            )
619
            self.mc.feature_extract(
×
620
                sigma=sigma,
621
                fwhm=fwhm,
622
                sigma_radius=sigma_radius,
623
                rotsym=rotation_symmetry,
624
                **kwds,
625
            )
626
            features = self.mc.peaks
×
627

628
        self.mc.feature_select(
1✔
629
            rotsym=rotation_symmetry,
630
            include_center=include_center,
631
            features=features,
632
            apply=apply,
633
            **kwds,
634
        )
635

636
    # 3. Generate the spline warp correction from momentum features.
637
    # If no features have been selected before, use class defaults.
638
    @call_logger(logger)
1✔
639
    def generate_splinewarp(
1✔
640
        self,
641
        use_center: bool = None,
642
        **kwds,
643
    ):
644
        """3. Step of the distortion correction workflow: Generate the correction
645
        function restoring the symmetry in the image using a splinewarp algorithm.
646

647
        Args:
648
            use_center (bool, optional): Option to use the position of the
649
                center point in the correction. Default is read from config, or set to True.
650
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
651
        """
652

653
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
654

655
        if self.mc.slice is not None and self._verbose:
1✔
656
            print("Original slice with reference features")
1✔
657
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
658

659
            print("Corrected slice with target features")
1✔
660
            self.mc.view(
1✔
661
                image=self.mc.slice_corrected,
662
                annotated=True,
663
                points={"feats": self.mc.ptargs},
664
                backend="bokeh",
665
                crosshair=True,
666
            )
667

668
            print("Original slice with target features")
1✔
669
            self.mc.view(
1✔
670
                image=self.mc.slice,
671
                points={"feats": self.mc.ptargs},
672
                annotated=True,
673
                backend="bokeh",
674
            )
675

676
    # 3a. Save spline-warp parameters to config file.
677
    def save_splinewarp(
1✔
678
        self,
679
        filename: str = None,
680
        overwrite: bool = False,
681
    ):
682
        """Save the generated spline-warp parameters to the folder config file.
683

684
        Args:
685
            filename (str, optional): Filename of the config dictionary to save to.
686
                Defaults to "sed_config.yaml" in the current folder.
687
            overwrite (bool, optional): Option to overwrite the present dictionary.
688
                Defaults to False.
689
        """
690
        if filename is None:
1✔
691
            filename = "sed_config.yaml"
×
692
        if len(self.mc.correction) == 0:
1✔
693
            raise ValueError("No momentum correction parameters to save!")
×
694
        correction = {}
1✔
695
        for key, value in self.mc.correction.items():
1✔
696
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
697
                continue
1✔
698
            if key in ["use_center", "rotation_symmetry"]:
1✔
699
                correction[key] = value
1✔
700
            elif key in ["center_point", "ascale"]:
1✔
701
                correction[key] = [float(i) for i in value]
1✔
702
            elif key in ["outer_points", "feature_points"]:
1✔
703
                correction[key] = []
1✔
704
                for point in value:
1✔
705
                    correction[key].append([float(i) for i in point])
1✔
706
            else:
707
                correction[key] = float(value)
1✔
708

709
        if "creation_date" not in correction:
1✔
710
            correction["creation_date"] = datetime.now().timestamp()
×
711

712
        config = {
1✔
713
            "momentum": {
714
                "correction": correction,
715
            },
716
        }
717
        save_config(config, filename, overwrite)
1✔
718
        logger.info(f'Saved momentum correction parameters to "{filename}".')
1✔
719

720
    # 4. Pose corrections. Provide interactive interface for correcting
721
    # scaling, shift and rotation
722
    @call_logger(logger)
1✔
723
    def pose_adjustment(
1✔
724
        self,
725
        transformations: dict[str, Any] = None,
726
        apply: bool = False,
727
        use_correction: bool = True,
728
        reset: bool = True,
729
        **kwds,
730
    ):
731
        """3. step of the distortion correction workflow: Generate an interactive panel
732
        to adjust affine transformations that are applied to the image. Applies first
733
        a scaling, next an x/y translation, and last a rotation around the center of
734
        the image.
735

736
        Args:
737
            transformations (dict[str, Any], optional): Dictionary with transformations.
738
                Defaults to self.transformations or config["momentum"]["transformations"].
739
            apply (bool, optional): Option to directly apply the provided
740
                transformations. Defaults to False.
741
            use_correction (bool, option): Whether to use the spline warp correction
742
                or not. Defaults to True.
743
            reset (bool, optional): Option to reset the correction before transformation.
744
                Defaults to True.
745
            **kwds: Keyword parameters defining defaults for the transformations:
746

747
                - **scale** (float): Initial value of the scaling slider.
748
                - **xtrans** (float): Initial value of the xtrans slider.
749
                - **ytrans** (float): Initial value of the ytrans slider.
750
                - **angle** (float): Initial value of the angle slider.
751
        """
752
        # Generate homography as default if no distortion correction has been applied
753
        if self.mc.slice_corrected is None:
1✔
754
            if self.mc.slice is None:
1✔
755
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
756
            self.mc.slice_corrected = self.mc.slice
1✔
757

758
        if not use_correction:
1✔
759
            self.mc.reset_deformation()
1✔
760

761
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
762
            # Generate distortion correction from config values
NEW
763
            self.mc.spline_warp_estimate()
×
764

765
        self.mc.pose_adjustment(
1✔
766
            transformations=transformations,
767
            apply=apply,
768
            reset=reset,
769
            **kwds,
770
        )
771

772
    # 4a. Save pose adjustment parameters to config file.
773
    @call_logger(logger)
1✔
774
    def save_transformations(
1✔
775
        self,
776
        filename: str = None,
777
        overwrite: bool = False,
778
    ):
779
        """Save the pose adjustment parameters to the folder config file.
780

781
        Args:
782
            filename (str, optional): Filename of the config dictionary to save to.
783
                Defaults to "sed_config.yaml" in the current folder.
784
            overwrite (bool, optional): Option to overwrite the present dictionary.
785
                Defaults to False.
786
        """
787
        if filename is None:
1✔
788
            filename = "sed_config.yaml"
×
789
        if len(self.mc.transformations) == 0:
1✔
790
            raise ValueError("No momentum transformation parameters to save!")
×
791
        transformations = {}
1✔
792
        for key, value in self.mc.transformations.items():
1✔
793
            transformations[key] = float(value)
1✔
794

795
        if "creation_date" not in transformations:
1✔
796
            transformations["creation_date"] = datetime.now().timestamp()
×
797

798
        config = {
1✔
799
            "momentum": {
800
                "transformations": transformations,
801
            },
802
        }
803
        save_config(config, filename, overwrite)
1✔
804
        logger.info(f'Saved momentum transformation parameters to "{filename}".')
1✔
805

806
    # 5. Apply the momentum correction to the dataframe
807
    @call_logger(logger)
1✔
808
    def apply_momentum_correction(
1✔
809
        self,
810
        preview: bool = False,
811
        **kwds,
812
    ):
813
        """Applies the distortion correction and pose adjustment (optional)
814
        to the dataframe.
815

816
        Args:
817
            preview (bool, optional): Option to preview the first elements of the data frame.
818
                Defaults to False.
819
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
820

821
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
822
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
823
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
824

825
        """
826
        x_column = self._config["dataframe"]["x_column"]
1✔
827
        y_column = self._config["dataframe"]["y_column"]
1✔
828

829
        if self._dataframe is not None:
1✔
830
            logger.info("Adding corrected X/Y columns to dataframe:")
1✔
831
            df, metadata = self.mc.apply_corrections(
1✔
832
                df=self._dataframe,
833
                **kwds,
834
            )
835
            if (
1✔
836
                self._timed_dataframe is not None
837
                and x_column in self._timed_dataframe.columns
838
                and y_column in self._timed_dataframe.columns
839
            ):
840
                tdf, _ = self.mc.apply_corrections(
1✔
841
                    self._timed_dataframe,
842
                    **kwds,
843
                )
844

845
            # Add Metadata
846
            self._attributes.add(
1✔
847
                metadata,
848
                "momentum_correction",
849
                duplicate_policy="merge",
850
            )
851
            self._dataframe = df
1✔
852
            if (
1✔
853
                self._timed_dataframe is not None
854
                and x_column in self._timed_dataframe.columns
855
                and y_column in self._timed_dataframe.columns
856
            ):
857
                self._timed_dataframe = tdf
1✔
858
        else:
859
            raise ValueError("No dataframe loaded!")
×
860
        if preview:
1✔
NEW
861
            logger.info(self._dataframe.head(10))
×
862
        else:
863
            logger.info(self._dataframe)
1✔
864

865
    # Momentum calibration work flow
866
    # 1. Calculate momentum calibration
867
    @call_logger(logger)
1✔
868
    def calibrate_momentum_axes(
1✔
869
        self,
870
        point_a: np.ndarray | list[int] = None,
871
        point_b: np.ndarray | list[int] = None,
872
        k_distance: float = None,
873
        k_coord_a: np.ndarray | list[float] = None,
874
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
875
        equiscale: bool = True,
876
        apply=False,
877
    ):
878
        """1. step of the momentum calibration workflow. Calibrate momentum
879
        axes using either provided pixel coordinates of a high-symmetry point and its
880
        distance to the BZ center, or the k-coordinates of two points in the BZ
881
        (depending on the equiscale option). Opens an interactive panel for selecting
882
        the points.
883

884
        Args:
885
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
886
                point used for momentum calibration.
887
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
888
                second point used for momentum calibration.
889
                Defaults to config["momentum"]["center_pixel"].
890
            k_distance (float, optional): Momentum distance between point a and b.
891
                Needs to be provided if no specific k-coordinates for the two points
892
                are given. Defaults to None.
893
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
894
                of the first point used for calibration. Used if equiscale is False.
895
                Defaults to None.
896
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
897
                of the second point used for calibration. Defaults to [0.0, 0.0].
898
            equiscale (bool, optional): Option to apply different scales to kx and ky.
899
                If True, the distance between points a and b, and the absolute
900
                position of point a are used for defining the scale. If False, the
901
                scale is calculated from the k-positions of both points a and b.
902
                Defaults to True.
903
            apply (bool, optional): Option to directly store the momentum calibration
904
                in the class. Defaults to False.
905
        """
906
        if point_b is None:
1✔
907
            point_b = self._config["momentum"]["center_pixel"]
1✔
908

909
        self.mc.select_k_range(
1✔
910
            point_a=point_a,
911
            point_b=point_b,
912
            k_distance=k_distance,
913
            k_coord_a=k_coord_a,
914
            k_coord_b=k_coord_b,
915
            equiscale=equiscale,
916
            apply=apply,
917
        )
918

919
    # 1a. Save momentum calibration parameters to config file.
920
    def save_momentum_calibration(
1✔
921
        self,
922
        filename: str = None,
923
        overwrite: bool = False,
924
    ):
925
        """Save the generated momentum calibration parameters to the folder config file.
926

927
        Args:
928
            filename (str, optional): Filename of the config dictionary to save to.
929
                Defaults to "sed_config.yaml" in the current folder.
930
            overwrite (bool, optional): Option to overwrite the present dictionary.
931
                Defaults to False.
932
        """
933
        if filename is None:
1✔
934
            filename = "sed_config.yaml"
×
935
        if len(self.mc.calibration) == 0:
1✔
936
            raise ValueError("No momentum calibration parameters to save!")
×
937
        calibration = {}
1✔
938
        for key, value in self.mc.calibration.items():
1✔
939
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
940
                continue
1✔
941

942
            calibration[key] = float(value)
1✔
943

944
        if "creation_date" not in calibration:
1✔
945
            calibration["creation_date"] = datetime.now().timestamp()
×
946

947
        config = {"momentum": {"calibration": calibration}}
1✔
948
        save_config(config, filename, overwrite)
1✔
949
        logger.info(f"Saved momentum calibration parameters to {filename}")
1✔
950

951
    # 2. Apply correction and calibration to the dataframe
952
    @call_logger(logger)
1✔
953
    def apply_momentum_calibration(
1✔
954
        self,
955
        calibration: dict = None,
956
        preview: bool = False,
957
        **kwds,
958
    ):
959
        """2. step of the momentum calibration work flow: Apply the momentum
960
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
961
        these are used.
962

963
        Args:
964
            calibration (dict, optional): Optional dictionary with calibration data to
965
                use. Defaults to None.
966
            preview (bool, optional): Option to preview the first elements of the data frame.
967
                Defaults to False.
968
            **kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
969
        """
970
        x_column = self._config["dataframe"]["x_column"]
1✔
971
        y_column = self._config["dataframe"]["y_column"]
1✔
972

973
        if self._dataframe is not None:
1✔
974
            logger.info("Adding kx/ky columns to dataframe:")
1✔
975
            df, metadata = self.mc.append_k_axis(
1✔
976
                df=self._dataframe,
977
                calibration=calibration,
978
                **kwds,
979
            )
980
            if (
1✔
981
                self._timed_dataframe is not None
982
                and x_column in self._timed_dataframe.columns
983
                and y_column in self._timed_dataframe.columns
984
            ):
985
                tdf, _ = self.mc.append_k_axis(
1✔
986
                    df=self._timed_dataframe,
987
                    calibration=calibration,
988
                    suppress_output=True,
989
                    **kwds,
990
                )
991

992
            # Add Metadata
993
            self._attributes.add(
1✔
994
                metadata,
995
                "momentum_calibration",
996
                duplicate_policy="merge",
997
            )
998
            self._dataframe = df
1✔
999
            if (
1✔
1000
                self._timed_dataframe is not None
1001
                and x_column in self._timed_dataframe.columns
1002
                and y_column in self._timed_dataframe.columns
1003
            ):
1004
                self._timed_dataframe = tdf
1✔
1005
        else:
1006
            raise ValueError("No dataframe loaded!")
×
1007
        if preview:
1✔
NEW
1008
            logger.info(self._dataframe.head(10))
×
1009
        else:
1010
            logger.info(self._dataframe)
1✔
1011

1012
    # Energy correction workflow
1013
    # 1. Adjust the energy correction parameters
1014
    @call_logger(logger)
1✔
1015
    def adjust_energy_correction(
1✔
1016
        self,
1017
        correction_type: str = None,
1018
        amplitude: float = None,
1019
        center: tuple[float, float] = None,
1020
        apply=False,
1021
        **kwds,
1022
    ):
1023
        """1. step of the energy correction workflow: Opens an interactive plot to
1024
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
1025
        they are not present yet.
1026

1027
        Args:
1028
            correction_type (str, optional): Type of correction to apply to the TOF
1029
                axis. Valid values are:
1030

1031
                - 'spherical'
1032
                - 'Lorentzian'
1033
                - 'Gaussian'
1034
                - 'Lorentzian_asymmetric'
1035

1036
                Defaults to config["energy"]["correction_type"].
1037
            amplitude (float, optional): Amplitude of the correction.
1038
                Defaults to config["energy"]["correction"]["amplitude"].
1039
            center (tuple[float, float], optional): Center X/Y coordinates for the
1040
                correction. Defaults to config["energy"]["correction"]["center"].
1041
            apply (bool, optional): Option to directly apply the provided or default
1042
                correction parameters. Defaults to False.
1043
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1044
        """
1045
        if self._pre_binned is None:
1✔
1046
            logger.warn("Pre-binned data not present, binning using defaults from config...")
1✔
1047
            self._pre_binned = self.pre_binning()
1✔
1048

1049
        self.ec.adjust_energy_correction(
1✔
1050
            self._pre_binned,
1051
            correction_type=correction_type,
1052
            amplitude=amplitude,
1053
            center=center,
1054
            apply=apply,
1055
            **kwds,
1056
        )
1057

1058
    # 1a. Save energy correction parameters to config file.
1059
    def save_energy_correction(
1✔
1060
        self,
1061
        filename: str = None,
1062
        overwrite: bool = False,
1063
    ):
1064
        """Save the generated energy correction parameters to the folder config file.
1065

1066
        Args:
1067
            filename (str, optional): Filename of the config dictionary to save to.
1068
                Defaults to "sed_config.yaml" in the current folder.
1069
            overwrite (bool, optional): Option to overwrite the present dictionary.
1070
                Defaults to False.
1071
        """
1072
        if filename is None:
1✔
1073
            filename = "sed_config.yaml"
1✔
1074
        if len(self.ec.correction) == 0:
1✔
1075
            raise ValueError("No energy correction parameters to save!")
×
1076
        correction = {}
1✔
1077
        for key, val in self.ec.correction.items():
1✔
1078
            if key == "correction_type":
1✔
1079
                correction[key] = val
1✔
1080
            elif key == "center":
1✔
1081
                correction[key] = [float(i) for i in val]
1✔
1082
            else:
1083
                correction[key] = float(val)
1✔
1084

1085
        if "creation_date" not in correction:
1✔
1086
            correction["creation_date"] = datetime.now().timestamp()
×
1087

1088
        config = {"energy": {"correction": correction}}
1✔
1089
        save_config(config, filename, overwrite)
1✔
1090
        logger.info(f"Saved energy correction parameters to {filename}")
1✔
1091

1092
    # 2. Apply energy correction to dataframe
1093
    @call_logger(logger)
1✔
1094
    def apply_energy_correction(
1✔
1095
        self,
1096
        correction: dict = None,
1097
        preview: bool = False,
1098
        **kwds,
1099
    ):
1100
        """2. step of the energy correction workflow: Apply the energy correction
1101
        parameters stored in the class to the dataframe.
1102

1103
        Args:
1104
            correction (dict, optional): Dictionary containing the correction
1105
                parameters. Defaults to config["energy"]["calibration"].
1106
            preview (bool, optional): Option to preview the first elements of the data frame.
1107
                Defaults to False.
1108
            **kwds:
1109
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1110
        """
1111
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1112

1113
        if self._dataframe is not None:
1✔
1114
            logger.info("Applying energy correction to dataframe...")
1✔
1115
            df, metadata = self.ec.apply_energy_correction(
1✔
1116
                df=self._dataframe,
1117
                correction=correction,
1118
                **kwds,
1119
            )
1120
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1121
                tdf, _ = self.ec.apply_energy_correction(
1✔
1122
                    df=self._timed_dataframe,
1123
                    correction=correction,
1124
                    suppress_output=True,
1125
                    **kwds,
1126
                )
1127

1128
            # Add Metadata
1129
            self._attributes.add(
1✔
1130
                metadata,
1131
                "energy_correction",
1132
            )
1133
            self._dataframe = df
1✔
1134
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1135
                self._timed_dataframe = tdf
1✔
1136
        else:
1137
            raise ValueError("No dataframe loaded!")
×
1138
        if preview:
1✔
NEW
1139
            logger.info(self._dataframe.head(10))
×
1140
        else:
1141
            logger.info(self._dataframe)
1✔
1142

1143
    # Energy calibrator workflow
1144
    # 1. Load and normalize data
1145
    @call_logger(logger)
1✔
1146
    def load_bias_series(
1✔
1147
        self,
1148
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1149
        data_files: list[str] = None,
1150
        axes: list[str] = None,
1151
        bins: list = None,
1152
        ranges: Sequence[tuple[float, float]] = None,
1153
        biases: np.ndarray = None,
1154
        bias_key: str = None,
1155
        normalize: bool = None,
1156
        span: int = None,
1157
        order: int = None,
1158
    ):
1159
        """1. step of the energy calibration workflow: Load and bin data from
1160
        single-event files, or load binned bias/TOF traces.
1161

1162
        Args:
1163
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1164
                Binned data If provided as DataArray, Needs to contain dimensions
1165
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1166
                provided as tuple, needs to contain elements tof, biases, traces.
1167
            data_files (list[str], optional): list of file paths to bin
1168
            axes (list[str], optional): bin axes.
1169
                Defaults to config["dataframe"]["tof_column"].
1170
            bins (list, optional): number of bins.
1171
                Defaults to config["energy"]["bins"].
1172
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1173
                Defaults to config["energy"]["ranges"].
1174
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1175
                voltages are extracted from the data files.
1176
            bias_key (str, optional): hdf5 path where bias values are stored.
1177
                Defaults to config["energy"]["bias_key"].
1178
            normalize (bool, optional): Option to normalize traces.
1179
                Defaults to config["energy"]["normalize"].
1180
            span (int, optional): span smoothing parameters of the LOESS method
1181
                (see ``scipy.signal.savgol_filter()``).
1182
                Defaults to config["energy"]["normalize_span"].
1183
            order (int, optional): order smoothing parameters of the LOESS method
1184
                (see ``scipy.signal.savgol_filter()``).
1185
                Defaults to config["energy"]["normalize_order"].
1186
        """
1187
        if binned_data is not None:
1✔
1188
            if isinstance(binned_data, xr.DataArray):
1✔
1189
                if (
1✔
1190
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
1191
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
1192
                ):
1193
                    raise ValueError(
1✔
1194
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1195
                        f"'{self._config['dataframe']['tof_column']}' and "
1196
                        f"'{self._config['dataframe']['bias_column']}'!.",
1197
                    )
1198
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
1199
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
1200
                traces = binned_data.values[:, :]
1✔
1201
            else:
1202
                try:
1✔
1203
                    (tof, biases, traces) = binned_data
1✔
1204
                except ValueError as exc:
1✔
1205
                    raise ValueError(
1✔
1206
                        "If binned_data is provided as tuple, it needs to contain "
1207
                        "(tof, biases, traces)!",
1208
                    ) from exc
1209
            logger.debug(f"Energy calibration data loaded from binned data. Bias values={biases}.")
1✔
1210
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1211

1212
        elif data_files is not None:
1✔
1213
            self.ec.bin_data(
1✔
1214
                data_files=cast(list[str], self.cpy(data_files)),
1215
                axes=axes,
1216
                bins=bins,
1217
                ranges=ranges,
1218
                biases=biases,
1219
                bias_key=bias_key,
1220
            )
1221
            logger.debug(
1✔
1222
                f"Energy calibration data binned from files {data_files} data. "
1223
                f"Bias values={biases}.",
1224
            )
1225

1226
        else:
1227
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1228

1229
        if (normalize is not None and normalize is True) or (
1✔
1230
            normalize is None and self._config["energy"]["normalize"]
1231
        ):
1232
            if span is None:
1✔
1233
                span = self._config["energy"]["normalize_span"]
1✔
1234
            if order is None:
1✔
1235
                order = self._config["energy"]["normalize_order"]
1✔
1236
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1237
        self.ec.view(
1✔
1238
            traces=self.ec.traces_normed,
1239
            xaxis=self.ec.tof,
1240
            backend="bokeh",
1241
        )
1242

1243
    # 2. extract ranges and get peak positions
1244
    @call_logger(logger)
1✔
1245
    def find_bias_peaks(
1✔
1246
        self,
1247
        ranges: list[tuple] | tuple,
1248
        ref_id: int = 0,
1249
        infer_others: bool = True,
1250
        mode: str = "replace",
1251
        radius: int = None,
1252
        peak_window: int = None,
1253
        apply: bool = False,
1254
    ):
1255
        """2. step of the energy calibration workflow: Find a peak within a given range
1256
        for the indicated reference trace, and tries to find the same peak for all
1257
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1258
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1259
        middle of the set, and don't choose the range too narrow around the peak.
1260
        Alternatively, a list of ranges for all traces can be provided.
1261

1262
        Args:
1263
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1264
                Alternatively, a list of ranges for all traces can be given.
1265
            ref_id (int, optional): The id of the trace the range refers to.
1266
                Defaults to 0.
1267
            infer_others (bool, optional): Whether to determine the range for the other
1268
                traces. Defaults to True.
1269
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1270
                Defaults to "replace".
1271
            radius (int, optional): Radius parameter for fast_dtw.
1272
                Defaults to config["energy"]["fastdtw_radius"].
1273
            peak_window (int, optional): Peak_window parameter for the peak detection
1274
                algorithm. amount of points that have to have to behave monotonously
1275
                around a peak. Defaults to config["energy"]["peak_window"].
1276
            apply (bool, optional): Option to directly apply the provided parameters.
1277
                Defaults to False.
1278
        """
1279
        if radius is None:
1✔
1280
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1281
        if peak_window is None:
1✔
1282
            peak_window = self._config["energy"]["peak_window"]
1✔
1283
        if not infer_others:
1✔
1284
            self.ec.add_ranges(
1✔
1285
                ranges=ranges,
1286
                ref_id=ref_id,
1287
                infer_others=infer_others,
1288
                mode=mode,
1289
                radius=radius,
1290
            )
1291
            logger.info(f"Use feature ranges: {self.ec.featranges}.")
1✔
1292
            try:
1✔
1293
                self.ec.feature_extract(peak_window=peak_window)
1✔
1294
                logger.info(f"Extracted energy features: {self.ec.peaks}.")
1✔
1295
                self.ec.view(
1✔
1296
                    traces=self.ec.traces_normed,
1297
                    segs=self.ec.featranges,
1298
                    xaxis=self.ec.tof,
1299
                    peaks=self.ec.peaks,
1300
                    backend="bokeh",
1301
                )
1302
            except IndexError:
×
NEW
1303
                logger.error("Could not determine all peaks!")
×
1304
                raise
×
1305
        else:
1306
            # New adjustment tool
1307
            assert isinstance(ranges, tuple)
1✔
1308
            self.ec.adjust_ranges(
1✔
1309
                ranges=ranges,
1310
                ref_id=ref_id,
1311
                traces=self.ec.traces_normed,
1312
                infer_others=infer_others,
1313
                radius=radius,
1314
                peak_window=peak_window,
1315
                apply=apply,
1316
            )
1317

1318
    # 3. Fit the energy calibration relation
1319
    @call_logger(logger)
1✔
1320
    def calibrate_energy_axis(
1✔
1321
        self,
1322
        ref_energy: float,
1323
        method: str = None,
1324
        energy_scale: str = None,
1325
        **kwds,
1326
    ):
1327
        """3. Step of the energy calibration workflow: Calculate the calibration
1328
        function for the energy axis, and apply it to the dataframe. Two
1329
        approximations are implemented, a (normally 3rd order) polynomial
1330
        approximation, and a d^2/(t-t0)^2 relation.
1331

1332
        Args:
1333
            ref_energy (float): Binding/kinetic energy of the detected feature.
1334
            method (str, optional): Method for determining the energy calibration.
1335

1336
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1337
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1338

1339
                Defaults to config["energy"]["calibration_method"]
1340
            energy_scale (str, optional): Direction of increasing energy scale.
1341

1342
                - **'kinetic'**: increasing energy with decreasing TOF.
1343
                - **'binding'**: increasing energy with increasing TOF.
1344

1345
                Defaults to config["energy"]["energy_scale"]
1346
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1347
        """
1348
        if method is None:
1✔
1349
            method = self._config["energy"]["calibration_method"]
1✔
1350

1351
        if energy_scale is None:
1✔
1352
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1353

1354
        self.ec.calibrate(
1✔
1355
            ref_energy=ref_energy,
1356
            method=method,
1357
            energy_scale=energy_scale,
1358
            **kwds,
1359
        )
1360
        if self._verbose:
1✔
1361
            self.ec.view(
1✔
1362
                traces=self.ec.traces_normed,
1363
                xaxis=self.ec.calibration["axis"],
1364
                align=True,
1365
                energy_scale=energy_scale,
1366
                backend="matplotlib",
1367
                title="Quality of Calibration",
1368
            )
1369
            plt.xlabel("Energy (eV)")
1✔
1370
            plt.ylabel("Intensity")
1✔
1371
            plt.tight_layout()
1✔
1372
            plt.show()
1✔
1373
            if energy_scale == "kinetic":
1✔
1374
                self.ec.view(
1✔
1375
                    traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1376
                    xaxis=self.ec.tof,
1377
                    backend="matplotlib",
1378
                    show_legend=False,
1379
                    title="E/TOF relationship",
1380
                )
1381
                plt.scatter(
1✔
1382
                    self.ec.peaks[:, 0],
1383
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1384
                    s=50,
1385
                    c="k",
1386
                )
1387
                plt.tight_layout()
1✔
1388
            elif energy_scale == "binding":
1✔
1389
                self.ec.view(
1✔
1390
                    traces=self.ec.calibration["axis"][None, :] - self.ec.biases[0],
1391
                    xaxis=self.ec.tof,
1392
                    backend="matplotlib",
1393
                    show_legend=False,
1394
                    title="E/TOF relationship",
1395
                )
1396
                plt.scatter(
1✔
1397
                    self.ec.peaks[:, 0],
1398
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1399
                    s=50,
1400
                    c="k",
1401
                )
1402
            else:
1403
                raise ValueError(
×
1404
                    'energy_scale needs to be either "binding" or "kinetic"',
1405
                    f", got {energy_scale}.",
1406
                )
1407
            plt.xlabel("Time-of-flight")
1✔
1408
            plt.ylabel("Energy (eV)")
1✔
1409
            plt.tight_layout()
1✔
1410
            plt.show()
1✔
1411

1412
    # 3a. Save energy calibration parameters to config file.
1413
    def save_energy_calibration(
1✔
1414
        self,
1415
        filename: str = None,
1416
        overwrite: bool = False,
1417
    ):
1418
        """Save the generated energy calibration parameters to the folder config file.
1419

1420
        Args:
1421
            filename (str, optional): Filename of the config dictionary to save to.
1422
                Defaults to "sed_config.yaml" in the current folder.
1423
            overwrite (bool, optional): Option to overwrite the present dictionary.
1424
                Defaults to False.
1425
        """
1426
        if filename is None:
1✔
1427
            filename = "sed_config.yaml"
×
1428
        if len(self.ec.calibration) == 0:
1✔
1429
            raise ValueError("No energy calibration parameters to save!")
×
1430
        calibration = {}
1✔
1431
        for key, value in self.ec.calibration.items():
1✔
1432
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1433
                continue
1✔
1434
            if key == "energy_scale":
1✔
1435
                calibration[key] = value
1✔
1436
            elif key == "coeffs":
1✔
1437
                calibration[key] = [float(i) for i in value]
1✔
1438
            else:
1439
                calibration[key] = float(value)
1✔
1440

1441
        if "creation_date" not in calibration:
1✔
1442
            calibration["creation_date"] = datetime.now().timestamp()
×
1443

1444
        config = {"energy": {"calibration": calibration}}
1✔
1445
        save_config(config, filename, overwrite)
1✔
1446
        logger.info(f'Saved energy calibration parameters to "{filename}".')
1✔
1447

1448
    # 4. Apply energy calibration to the dataframe
1449
    @call_logger(logger)
1✔
1450
    def append_energy_axis(
1✔
1451
        self,
1452
        calibration: dict = None,
1453
        bias_voltage: float = None,
1454
        preview: bool = False,
1455
        **kwds,
1456
    ):
1457
        """4. step of the energy calibration workflow: Apply the calibration function
1458
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1459
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1460
        can be provided.
1461

1462
        Args:
1463
            calibration (dict, optional): Calibration dict containing calibration
1464
                parameters. Overrides calibration from class or config.
1465
                Defaults to None.
1466
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1467
                the bias voltage is being read from the dataframe. If it is not found there,
1468
                a warning is printed and the calibrated data might have an offset.
1469
            preview (bool): Option to preview the first elements of the data frame.
1470
            **kwds:
1471
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1472
        """
1473
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1474

1475
        if self._dataframe is not None:
1✔
1476
            logger.info("Adding energy column to dataframe:")
1✔
1477
            df, metadata = self.ec.append_energy_axis(
1✔
1478
                df=self._dataframe,
1479
                calibration=calibration,
1480
                bias_voltage=bias_voltage,
1481
                **kwds,
1482
            )
1483
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1484
                tdf, _ = self.ec.append_energy_axis(
1✔
1485
                    df=self._timed_dataframe,
1486
                    calibration=calibration,
1487
                    bias_voltage=bias_voltage,
1488
                    suppress_output=True,
1489
                    **kwds,
1490
                )
1491

1492
            # Add Metadata
1493
            self._attributes.add(
1✔
1494
                metadata,
1495
                "energy_calibration",
1496
                duplicate_policy="merge",
1497
            )
1498
            self._dataframe = df
1✔
1499
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1500
                self._timed_dataframe = tdf
1✔
1501

1502
        else:
1503
            raise ValueError("No dataframe loaded!")
×
1504
        if preview:
1✔
NEW
1505
            logger.info(self._dataframe.head(10))
×
1506
        else:
1507
            logger.info(self._dataframe)
1✔
1508

1509
    @call_logger(logger)
1✔
1510
    def add_energy_offset(
1✔
1511
        self,
1512
        constant: float = None,
1513
        columns: str | Sequence[str] = None,
1514
        weights: float | Sequence[float] = None,
1515
        reductions: str | Sequence[str] = None,
1516
        preserve_mean: bool | Sequence[bool] = None,
1517
        preview: bool = False,
1518
    ) -> None:
1519
        """Shift the energy axis of the dataframe by a given amount.
1520

1521
        Args:
1522
            constant (float, optional): The constant to shift the energy axis by.
1523
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1524
            weights (float | Sequence[float], optional): weights to apply to the columns.
1525
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1526
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1527
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1528
                case the function is applied to the column to generate a single value for the whole
1529
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1530
                Currently only "mean" is supported.
1531
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1532
                column before applying the shift. Defaults to False.
1533
            preview (bool, optional): Option to preview the first elements of the data frame.
1534
                Defaults to False.
1535

1536
        Raises:
1537
            ValueError: If the energy column is not in the dataframe.
1538
        """
1539
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1540
        if energy_column not in self._dataframe.columns:
1✔
1541
            raise ValueError(
1✔
1542
                f"Energy column {energy_column} not found in dataframe! "
1543
                "Run `append_energy_axis()` first.",
1544
            )
1545
        if self.dataframe is not None:
1✔
1546
            logger.info("Adding energy offset to dataframe:")
1✔
1547
            df, metadata = self.ec.add_offsets(
1✔
1548
                df=self._dataframe,
1549
                constant=constant,
1550
                columns=columns,
1551
                energy_column=energy_column,
1552
                weights=weights,
1553
                reductions=reductions,
1554
                preserve_mean=preserve_mean,
1555
            )
1556
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1557
                tdf, _ = self.ec.add_offsets(
1✔
1558
                    df=self._timed_dataframe,
1559
                    constant=constant,
1560
                    columns=columns,
1561
                    energy_column=energy_column,
1562
                    weights=weights,
1563
                    reductions=reductions,
1564
                    preserve_mean=preserve_mean,
1565
                    suppress_output=True,
1566
                )
1567

1568
            self._attributes.add(
1✔
1569
                metadata,
1570
                "add_energy_offset",
1571
                # TODO: allow only appending when no offset along this column(s) was applied
1572
                # TODO: clear memory of modifications if the energy axis is recalculated
1573
                duplicate_policy="append",
1574
            )
1575
            self._dataframe = df
1✔
1576
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1577
                self._timed_dataframe = tdf
1✔
1578
        else:
1579
            raise ValueError("No dataframe loaded!")
×
1580
        if preview:
1✔
NEW
1581
            logger.info(self._dataframe.head(10))
×
1582
        else:
1583
            logger.info(self._dataframe)
1✔
1584

1585
    def save_energy_offset(
1✔
1586
        self,
1587
        filename: str = None,
1588
        overwrite: bool = False,
1589
    ):
1590
        """Save the generated energy calibration parameters to the folder config file.
1591

1592
        Args:
1593
            filename (str, optional): Filename of the config dictionary to save to.
1594
                Defaults to "sed_config.yaml" in the current folder.
1595
            overwrite (bool, optional): Option to overwrite the present dictionary.
1596
                Defaults to False.
1597
        """
1598
        if filename is None:
×
1599
            filename = "sed_config.yaml"
×
1600
        if len(self.ec.offsets) == 0:
×
1601
            raise ValueError("No energy offset parameters to save!")
×
1602

1603
        if "creation_date" not in self.ec.offsets.keys():
×
1604
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
×
1605

1606
        config = {"energy": {"offsets": self.ec.offsets}}
×
1607
        save_config(config, filename, overwrite)
×
NEW
1608
        logger.info(f'Saved energy offset parameters to "{filename}".')
×
1609

1610
    @call_logger(logger)
1✔
1611
    def append_tof_ns_axis(
1✔
1612
        self,
1613
        preview: bool = False,
1614
        **kwds,
1615
    ):
1616
        """Convert time-of-flight channel steps to nanoseconds.
1617

1618
        Args:
1619
            tof_ns_column (str, optional): Name of the generated column containing the
1620
                time-of-flight in nanosecond.
1621
                Defaults to config["dataframe"]["tof_ns_column"].
1622
            preview (bool, optional): Option to preview the first elements of the data frame.
1623
                Defaults to False.
1624
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1625

1626
        """
1627
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1628

1629
        if self._dataframe is not None:
1✔
1630
            logger.info("Adding time-of-flight column in nanoseconds to dataframe.")
1✔
1631
            # TODO assert order of execution through metadata
1632

1633
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1634
                df=self._dataframe,
1635
                **kwds,
1636
            )
1637
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1638
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1639
                    df=self._timed_dataframe,
1640
                    **kwds,
1641
                )
1642

1643
            self._attributes.add(
1✔
1644
                metadata,
1645
                "tof_ns_conversion",
1646
                duplicate_policy="overwrite",
1647
            )
1648
            self._dataframe = df
1✔
1649
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1650
                self._timed_dataframe = tdf
1✔
1651
        else:
1652
            raise ValueError("No dataframe loaded!")
×
1653
        if preview:
1✔
NEW
1654
            logger.info(self._dataframe.head(10))
×
1655
        else:
1656
            logger.info(self._dataframe)
1✔
1657

1658
    @call_logger(logger)
1✔
1659
    def align_dld_sectors(
1✔
1660
        self,
1661
        sector_delays: np.ndarray = None,
1662
        preview: bool = False,
1663
        **kwds,
1664
    ):
1665
        """Align the 8s sectors of the HEXTOF endstation.
1666

1667
        Args:
1668
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1669
                config["dataframe"]["sector_delays"].
1670
            preview (bool, optional): Option to preview the first elements of the data frame.
1671
                Defaults to False.
1672
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1673
        """
1674
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1675

1676
        if self._dataframe is not None:
1✔
1677
            logger.info("Aligning 8s sectors of dataframe")
1✔
1678
            # TODO assert order of execution through metadata
1679

1680
            df, metadata = self.ec.align_dld_sectors(
1✔
1681
                df=self._dataframe,
1682
                sector_delays=sector_delays,
1683
                **kwds,
1684
            )
1685
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1686
                tdf, _ = self.ec.align_dld_sectors(
×
1687
                    df=self._timed_dataframe,
1688
                    sector_delays=sector_delays,
1689
                    **kwds,
1690
                )
1691

1692
            self._attributes.add(
1✔
1693
                metadata,
1694
                "dld_sector_alignment",
1695
                duplicate_policy="raise",
1696
            )
1697
            self._dataframe = df
1✔
1698
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1699
                self._timed_dataframe = tdf
×
1700
        else:
1701
            raise ValueError("No dataframe loaded!")
×
1702
        if preview:
1✔
NEW
1703
            logger.info(self._dataframe.head(10))
×
1704
        else:
1705
            logger.info(self._dataframe)
1✔
1706

1707
    # Delay calibration function
1708
    @call_logger(logger)
1✔
1709
    def calibrate_delay_axis(
1✔
1710
        self,
1711
        delay_range: tuple[float, float] = None,
1712
        datafile: str = None,
1713
        preview: bool = False,
1714
        **kwds,
1715
    ):
1716
        """Append delay column to dataframe. Either provide delay ranges, or read
1717
        them from a file.
1718

1719
        Args:
1720
            delay_range (tuple[float, float], optional): The scanned delay range in
1721
                picoseconds. Defaults to None.
1722
            datafile (str, optional): The file from which to read the delay ranges.
1723
                Defaults to None.
1724
            preview (bool, optional): Option to preview the first elements of the data frame.
1725
                Defaults to False.
1726
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1727
        """
1728
        adc_column = self._config["dataframe"]["adc_column"]
1✔
1729
        if adc_column not in self._dataframe.columns:
1✔
1730
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1731

1732
        if self._dataframe is not None:
1✔
1733
            logger.info("Adding delay column to dataframe:")
1✔
1734

1735
            if delay_range is None and datafile is None:
1✔
1736
                if len(self.dc.calibration) == 0:
1✔
1737
                    try:
1✔
1738
                        datafile = self._files[0]
1✔
1739
                    except IndexError as exc:
×
1740
                        raise IndexError(
×
1741
                            "No datafile available, specify either 'datafile' or 'delay_range'",
1742
                        ) from exc
1743

1744
            df, metadata = self.dc.append_delay_axis(
1✔
1745
                self._dataframe,
1746
                delay_range=delay_range,
1747
                datafile=datafile,
1748
                **kwds,
1749
            )
1750
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1751
                tdf, _ = self.dc.append_delay_axis(
1✔
1752
                    self._timed_dataframe,
1753
                    delay_range=delay_range,
1754
                    datafile=datafile,
1755
                    suppress_output=True,
1756
                    **kwds,
1757
                )
1758

1759
            # Add Metadata
1760
            self._attributes.add(
1✔
1761
                metadata,
1762
                "delay_calibration",
1763
                duplicate_policy="overwrite",
1764
            )
1765
            self._dataframe = df
1✔
1766
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1767
                self._timed_dataframe = tdf
1✔
1768
        else:
1769
            raise ValueError("No dataframe loaded!")
×
1770
        if preview:
1✔
1771
            logger.info(self._dataframe.head(10))
1✔
1772
        else:
1773
            logger.debug(self._dataframe)
1✔
1774

1775
    def save_delay_calibration(
1✔
1776
        self,
1777
        filename: str = None,
1778
        overwrite: bool = False,
1779
    ) -> None:
1780
        """Save the generated delay calibration parameters to the folder config file.
1781

1782
        Args:
1783
            filename (str, optional): Filename of the config dictionary to save to.
1784
                Defaults to "sed_config.yaml" in the current folder.
1785
            overwrite (bool, optional): Option to overwrite the present dictionary.
1786
                Defaults to False.
1787
        """
1788
        if filename is None:
1✔
1789
            filename = "sed_config.yaml"
×
1790

1791
        if len(self.dc.calibration) == 0:
1✔
1792
            raise ValueError("No delay calibration parameters to save!")
×
1793
        calibration = {}
1✔
1794
        for key, value in self.dc.calibration.items():
1✔
1795
            if key == "datafile":
1✔
1796
                calibration[key] = value
1✔
1797
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1798
                calibration[key] = [float(i) for i in value]
1✔
1799
            else:
1800
                calibration[key] = float(value)
1✔
1801

1802
        if "creation_date" not in calibration:
1✔
1803
            calibration["creation_date"] = datetime.now().timestamp()
×
1804

1805
        config = {
1✔
1806
            "delay": {
1807
                "calibration": calibration,
1808
            },
1809
        }
1810
        save_config(config, filename, overwrite)
1✔
1811

1812
    @call_logger(logger)
1✔
1813
    def add_delay_offset(
1✔
1814
        self,
1815
        constant: float = None,
1816
        flip_delay_axis: bool = None,
1817
        columns: str | Sequence[str] = None,
1818
        weights: float | Sequence[float] = 1.0,
1819
        reductions: str | Sequence[str] = None,
1820
        preserve_mean: bool | Sequence[bool] = False,
1821
        preview: bool = False,
1822
    ) -> None:
1823
        """Shift the delay axis of the dataframe by a constant or other columns.
1824

1825
        Args:
1826
            constant (float, optional): The constant to shift the delay axis by.
1827
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1828
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1829
            weights (float | Sequence[float], optional): weights to apply to the columns.
1830
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1831
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1832
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1833
                case the function is applied to the column to generate a single value for the whole
1834
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1835
                Currently only "mean" is supported.
1836
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1837
                column before applying the shift. Defaults to False.
1838
            preview (bool, optional): Option to preview the first elements of the data frame.
1839
                Defaults to False.
1840

1841
        Raises:
1842
            ValueError: If the delay column is not in the dataframe.
1843
        """
1844
        delay_column = self._config["dataframe"]["delay_column"]
1✔
1845
        if delay_column not in self._dataframe.columns:
1✔
1846
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1847

1848
        if self.dataframe is not None:
1✔
1849
            logger.info("Adding delay offset to dataframe:")
1✔
1850
            df, metadata = self.dc.add_offsets(
1✔
1851
                df=self._dataframe,
1852
                constant=constant,
1853
                flip_delay_axis=flip_delay_axis,
1854
                columns=columns,
1855
                delay_column=delay_column,
1856
                weights=weights,
1857
                reductions=reductions,
1858
                preserve_mean=preserve_mean,
1859
            )
1860
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1861
                tdf, _ = self.dc.add_offsets(
1✔
1862
                    df=self._timed_dataframe,
1863
                    constant=constant,
1864
                    flip_delay_axis=flip_delay_axis,
1865
                    columns=columns,
1866
                    delay_column=delay_column,
1867
                    weights=weights,
1868
                    reductions=reductions,
1869
                    preserve_mean=preserve_mean,
1870
                    suppress_output=True,
1871
                )
1872

1873
            self._attributes.add(
1✔
1874
                metadata,
1875
                "delay_offset",
1876
                duplicate_policy="append",
1877
            )
1878
            self._dataframe = df
1✔
1879
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1880
                self._timed_dataframe = tdf
1✔
1881
        else:
1882
            raise ValueError("No dataframe loaded!")
×
1883
        if preview:
1✔
1884
            logger.info(self._dataframe.head(10))
1✔
1885
        else:
1886
            logger.info(self._dataframe)
1✔
1887

1888
    def save_delay_offsets(
1✔
1889
        self,
1890
        filename: str = None,
1891
        overwrite: bool = False,
1892
    ) -> None:
1893
        """Save the generated delay calibration parameters to the folder config file.
1894

1895
        Args:
1896
            filename (str, optional): Filename of the config dictionary to save to.
1897
                Defaults to "sed_config.yaml" in the current folder.
1898
            overwrite (bool, optional): Option to overwrite the present dictionary.
1899
                Defaults to False.
1900
        """
1901
        if filename is None:
1✔
1902
            filename = "sed_config.yaml"
×
1903
        if len(self.dc.offsets) == 0:
1✔
1904
            raise ValueError("No delay offset parameters to save!")
×
1905

1906
        if "creation_date" not in self.ec.offsets.keys():
1✔
1907
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
1✔
1908

1909
        config = {
1✔
1910
            "delay": {
1911
                "offsets": self.dc.offsets,
1912
            },
1913
        }
1914
        save_config(config, filename, overwrite)
1✔
1915
        logger.info(f'Saved delay offset parameters to "{filename}".')
1✔
1916

1917
    def save_workflow_params(
1✔
1918
        self,
1919
        filename: str = None,
1920
        overwrite: bool = False,
1921
    ) -> None:
1922
        """run all save calibration parameter methods
1923

1924
        Args:
1925
            filename (str, optional): Filename of the config dictionary to save to.
1926
                Defaults to "sed_config.yaml" in the current folder.
1927
            overwrite (bool, optional): Option to overwrite the present dictionary.
1928
                Defaults to False.
1929
        """
1930
        for method in [
×
1931
            self.save_splinewarp,
1932
            self.save_transformations,
1933
            self.save_momentum_calibration,
1934
            self.save_energy_correction,
1935
            self.save_energy_calibration,
1936
            self.save_energy_offset,
1937
            self.save_delay_calibration,
1938
            self.save_delay_offsets,
1939
        ]:
1940
            try:
×
1941
                method(filename, overwrite)
×
1942
            except (ValueError, AttributeError, KeyError):
×
1943
                pass
×
1944

1945
    @call_logger(logger)
1✔
1946
    def add_jitter(
1✔
1947
        self,
1948
        cols: list[str] = None,
1949
        amps: float | Sequence[float] = None,
1950
        **kwds,
1951
    ):
1952
        """Add jitter to the selected dataframe columns.
1953

1954
        Args:
1955
            cols (list[str], optional): The columns onto which to apply jitter.
1956
                Defaults to config["dataframe"]["jitter_cols"].
1957
            amps (float | Sequence[float], optional): Amplitude scalings for the
1958
                jittering noise. If one number is given, the same is used for all axes.
1959
                For uniform noise (default) it will cover the interval [-amp, +amp].
1960
                Defaults to config["dataframe"]["jitter_amps"].
1961
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1962
        """
1963
        if cols is None:
1✔
1964
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1965
        for loc, col in enumerate(cols):
1✔
1966
            if col.startswith("@"):
1✔
1967
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1968

1969
        if amps is None:
1✔
1970
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1971

1972
        self._dataframe = self._dataframe.map_partitions(
1✔
1973
            apply_jitter,
1974
            cols=cols,
1975
            cols_jittered=cols,
1976
            amps=amps,
1977
            **kwds,
1978
        )
1979
        if self._timed_dataframe is not None:
1✔
1980
            cols_timed = cols.copy()
1✔
1981
            for col in cols:
1✔
1982
                if col not in self._timed_dataframe.columns:
1✔
1983
                    cols_timed.remove(col)
×
1984

1985
            if cols_timed:
1✔
1986
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1987
                    apply_jitter,
1988
                    cols=cols_timed,
1989
                    cols_jittered=cols_timed,
1990
                )
1991
        metadata = []
1✔
1992
        for col in cols:
1✔
1993
            metadata.append(col)
1✔
1994
        # TODO: allow only appending if columns are not jittered yet
1995
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1996
        logger.info(f"add_jitter: Added jitter to columns {cols}.")
1✔
1997

1998
    @call_logger(logger)
1✔
1999
    def add_time_stamped_data(
1✔
2000
        self,
2001
        dest_column: str,
2002
        time_stamps: np.ndarray = None,
2003
        data: np.ndarray = None,
2004
        archiver_channel: str = None,
2005
        **kwds,
2006
    ):
2007
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
2008
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
2009
        an EPICS archiver instance.
2010

2011
        Args:
2012
            dest_column (str): destination column name
2013
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
2014
                time stamps are retrieved from the epics archiver
2015
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
2016
                If omitted, data are retrieved from the epics archiver.
2017
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
2018
                Either this or data and time_stamps have to be present.
2019
            **kwds:
2020

2021
                - **time_stamp_column**: Dataframe column containing time-stamp data
2022

2023
                Additional keyword arguments passed to ``add_time_stamped_data``.
2024
        """
2025
        time_stamp_column = kwds.pop(
1✔
2026
            "time_stamp_column",
2027
            self._config["dataframe"].get("time_stamp_alias", ""),
2028
        )
2029

2030
        if time_stamps is None and data is None:
1✔
2031
            if archiver_channel is None:
×
2032
                raise ValueError(
×
2033
                    "Either archiver_channel or both time_stamps and data have to be present!",
2034
                )
2035
            if self.loader.__name__ != "mpes":
×
2036
                raise NotImplementedError(
×
2037
                    "This function is currently only implemented for the mpes loader!",
2038
                )
2039
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2040
            # get channel data with +-5 seconds safety margin
2041
            time_stamps, data = get_archiver_data(
×
2042
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2043
                archiver_channel=archiver_channel,
2044
                ts_from=ts_from - 5,
2045
                ts_to=ts_to + 5,
2046
            )
2047

2048
        self._dataframe = add_time_stamped_data(
1✔
2049
            self._dataframe,
2050
            time_stamps=time_stamps,
2051
            data=data,
2052
            dest_column=dest_column,
2053
            time_stamp_column=time_stamp_column,
2054
            **kwds,
2055
        )
2056
        if self._timed_dataframe is not None:
1✔
2057
            if time_stamp_column in self._timed_dataframe:
1✔
2058
                self._timed_dataframe = add_time_stamped_data(
1✔
2059
                    self._timed_dataframe,
2060
                    time_stamps=time_stamps,
2061
                    data=data,
2062
                    dest_column=dest_column,
2063
                    time_stamp_column=time_stamp_column,
2064
                    **kwds,
2065
                )
2066
        metadata: list[Any] = []
1✔
2067
        metadata.append(dest_column)
1✔
2068
        metadata.append(time_stamps)
1✔
2069
        metadata.append(data)
1✔
2070
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2071
        logger.info(f"add_time_stamped_data: Added time-stamped data as column {dest_column}.")
1✔
2072

2073
    @call_logger(logger)
1✔
2074
    def pre_binning(
1✔
2075
        self,
2076
        df_partitions: int | Sequence[int] = 100,
2077
        axes: list[str] = None,
2078
        bins: list[int] = None,
2079
        ranges: Sequence[tuple[float, float]] = None,
2080
        **kwds,
2081
    ) -> xr.DataArray:
2082
        """Function to do an initial binning of the dataframe loaded to the class.
2083

2084
        Args:
2085
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2086
                use for the initial binning. Defaults to 100.
2087
            axes (list[str], optional): Axes to bin.
2088
                Defaults to config["momentum"]["axes"].
2089
            bins (list[int], optional): Bin numbers to use for binning.
2090
                Defaults to config["momentum"]["bins"].
2091
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2092
                Defaults to config["momentum"]["ranges"].
2093
            **kwds: Keyword argument passed to ``compute``.
2094

2095
        Returns:
2096
            xr.DataArray: pre-binned data-array.
2097
        """
2098
        if axes is None:
1✔
2099
            axes = self._config["momentum"]["axes"]
1✔
2100
        for loc, axis in enumerate(axes):
1✔
2101
            if axis.startswith("@"):
1✔
2102
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2103

2104
        if bins is None:
1✔
2105
            bins = self._config["momentum"]["bins"]
1✔
2106
        if ranges is None:
1✔
2107
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2108
            ranges_[2] = np.asarray(ranges_[2]) / self._config["dataframe"]["tof_binning"]
1✔
2109
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2110

2111
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2112

2113
        return self.compute(
1✔
2114
            bins=bins,
2115
            axes=axes,
2116
            ranges=ranges,
2117
            df_partitions=df_partitions,
2118
            **kwds,
2119
        )
2120

2121
    @call_logger(logger)
1✔
2122
    def compute(
1✔
2123
        self,
2124
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2125
        axes: str | Sequence[str] = None,
2126
        ranges: Sequence[tuple[float, float]] = None,
2127
        normalize_to_acquisition_time: bool | str = False,
2128
        **kwds,
2129
    ) -> xr.DataArray:
2130
        """Compute the histogram along the given dimensions.
2131

2132
        Args:
2133
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2134
                Definition of the bins. Can be any of the following cases:
2135

2136
                - an integer describing the number of bins in on all dimensions
2137
                - a tuple of 3 numbers describing start, end and step of the binning
2138
                  range
2139
                - a np.arrays defining the binning edges
2140
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2141
                - a dictionary made of the axes as keys and any of the above as values.
2142

2143
                This takes priority over the axes and range arguments. Defaults to 100.
2144
            axes (str | Sequence[str], optional): The names of the axes (columns)
2145
                on which to calculate the histogram. The order will be the order of the
2146
                dimensions in the resulting array. Defaults to None.
2147
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2148
                the start and end point of the binning range. Defaults to None.
2149
            normalize_to_acquisition_time (bool | str): Option to normalize the
2150
                result to the acquisition time. If a "slow" axis was scanned, providing
2151
                the name of the scanned axis will compute and apply the corresponding
2152
                normalization histogram. Defaults to False.
2153
            **kwds: Keyword arguments:
2154

2155
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2156
                  ``bin_dataframe`` for details. Defaults to
2157
                  config["binning"]["hist_mode"].
2158
                - **mode**: Defines how the results from each partition are combined.
2159
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2160
                  Defaults to config["binning"]["mode"].
2161
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2162
                  config["binning"]["pbar"].
2163
                - **n_cores**: Number of CPU cores to use for parallelization.
2164
                  Defaults to config["core"]["num_cores"] or N_CPU-1.
2165
                - **threads_per_worker**: Limit the number of threads that
2166
                  multiprocessing can spawn per binning thread. Defaults to
2167
                  config["binning"]["threads_per_worker"].
2168
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2169
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2170
                  config["binning"]["threadpool_API"].
2171
                - **df_partitions**: A sequence of dataframe partitions, or the
2172
                  number of the dataframe partitions to use. Defaults to all partitions.
2173
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2174
                  "upper_bound" to apply as filter to the dataframe before binning. The
2175
                  dataframe in the class remains unmodified by this.
2176

2177
                Additional kwds are passed to ``bin_dataframe``.
2178

2179
        Raises:
2180
            AssertError: Rises when no dataframe has been loaded.
2181

2182
        Returns:
2183
            xr.DataArray: The result of the n-dimensional binning represented in an
2184
            xarray object, combining the data with the axes.
2185
        """
2186
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2187

2188
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2189
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2190
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2191
        num_cores = kwds.pop("num_cores", self._config["core"]["num_cores"])
1✔
2192
        threads_per_worker = kwds.pop(
1✔
2193
            "threads_per_worker",
2194
            self._config["binning"]["threads_per_worker"],
2195
        )
2196
        threadpool_api = kwds.pop(
1✔
2197
            "threadpool_API",
2198
            self._config["binning"]["threadpool_API"],
2199
        )
2200
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2201
        if isinstance(df_partitions, int):
1✔
2202
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2203
        if df_partitions is not None:
1✔
2204
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2205
        else:
2206
            dataframe = self._dataframe
1✔
2207

2208
        filter_params = kwds.pop("filter", None)
1✔
2209
        if filter_params is not None:
1✔
2210
            try:
1✔
2211
                for param in filter_params:
1✔
2212
                    if "col" not in param:
1✔
2213
                        raise ValueError(
1✔
2214
                            "'col' needs to be defined for each filter entry! ",
2215
                            f"Not present in {param}.",
2216
                        )
2217
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2218
                    dataframe = apply_filter(dataframe, **param)
1✔
2219
            except AssertionError as exc:
1✔
2220
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2221
                raise ValueError(
1✔
2222
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2223
                    f"Parameters {invalid_keys} not valid in {param}.",
2224
                ) from exc
2225

2226
        self._binned = bin_dataframe(
1✔
2227
            df=dataframe,
2228
            bins=bins,
2229
            axes=axes,
2230
            ranges=ranges,
2231
            hist_mode=hist_mode,
2232
            mode=mode,
2233
            pbar=pbar,
2234
            n_cores=num_cores,
2235
            threads_per_worker=threads_per_worker,
2236
            threadpool_api=threadpool_api,
2237
            **kwds,
2238
        )
2239

2240
        for dim in self._binned.dims:
1✔
2241
            try:
1✔
2242
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2243
            except KeyError:
1✔
2244
                pass
1✔
2245

2246
        self._binned.attrs["units"] = "counts"
1✔
2247
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2248
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2249

2250
        if normalize_to_acquisition_time:
1✔
2251
            if isinstance(normalize_to_acquisition_time, str):
1✔
2252
                axis = normalize_to_acquisition_time
1✔
2253
                logger.info(f"Calculate normalization histogram for axis '{axis}'...")
1✔
2254
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2255
                    axis=axis,
2256
                    df_partitions=df_partitions,
2257
                )
2258
                # if the axes are named correctly, xarray figures out the normalization correctly
2259
                self._normalized = self._binned / self._normalization_histogram
1✔
2260
                self._attributes.add(
1✔
2261
                    self._normalization_histogram.values,
2262
                    name="normalization_histogram",
2263
                    duplicate_policy="overwrite",
2264
                )
2265
            else:
2266
                acquisition_time = self.loader.get_elapsed_time(
×
2267
                    fids=df_partitions,
2268
                )
2269
                if acquisition_time > 0:
×
2270
                    self._normalized = self._binned / acquisition_time
×
2271
                self._attributes.add(
×
2272
                    acquisition_time,
2273
                    name="normalization_histogram",
2274
                    duplicate_policy="overwrite",
2275
                )
2276

2277
            self._normalized.attrs["units"] = "counts/second"
1✔
2278
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2279
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2280

2281
            return self._normalized
1✔
2282

2283
        return self._binned
1✔
2284

2285
    @call_logger(logger)
1✔
2286
    def get_normalization_histogram(
1✔
2287
        self,
2288
        axis: str = "delay",
2289
        use_time_stamps: bool = False,
2290
        **kwds,
2291
    ) -> xr.DataArray:
2292
        """Generates a normalization histogram from the timed dataframe. Optionally,
2293
        use the TimeStamps column instead.
2294

2295
        Args:
2296
            axis (str, optional): The axis for which to compute histogram.
2297
                Defaults to "delay".
2298
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2299
                dataframe, rather than the timed dataframe. Defaults to False.
2300
            **kwds: Keyword arguments:
2301

2302
                - **df_partitions**: A sequence of dataframe partitions, or the
2303
                  number of the dataframe partitions to use. Defaults to all partitions.
2304

2305
        Raises:
2306
            ValueError: Raised if no data are binned.
2307
            ValueError: Raised if 'axis' not in binned coordinates.
2308
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2309
                in Dataframe.
2310

2311
        Returns:
2312
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2313
            per bin).
2314
        """
2315

2316
        if self._binned is None:
1✔
2317
            raise ValueError("Need to bin data first!")
1✔
2318
        if axis not in self._binned.coords:
1✔
2319
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2320

2321
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2322

2323
        if len(kwds) > 0:
1✔
2324
            raise TypeError(
1✔
2325
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2326
            )
2327

2328
        if isinstance(df_partitions, int):
1✔
2329
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2330
        if use_time_stamps or self._timed_dataframe is None:
1✔
2331
            if df_partitions is not None:
1✔
2332
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2333
                    self._dataframe.partitions[df_partitions],
2334
                    axis,
2335
                    self._binned.coords[axis].values,
2336
                    self._config["dataframe"]["time_stamp_alias"],
2337
                )
2338
            else:
2339
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2340
                    self._dataframe,
2341
                    axis,
2342
                    self._binned.coords[axis].values,
2343
                    self._config["dataframe"]["time_stamp_alias"],
2344
                )
2345
        else:
2346
            if df_partitions is not None:
1✔
2347
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2348
                    self._timed_dataframe.partitions[df_partitions],
2349
                    axis,
2350
                    self._binned.coords[axis].values,
2351
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2352
                )
2353
            else:
2354
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2355
                    self._timed_dataframe,
2356
                    axis,
2357
                    self._binned.coords[axis].values,
2358
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2359
                )
2360

2361
        return self._normalization_histogram
1✔
2362

2363
    def view_event_histogram(
1✔
2364
        self,
2365
        dfpid: int,
2366
        ncol: int = 2,
2367
        bins: Sequence[int] = None,
2368
        axes: Sequence[str] = None,
2369
        ranges: Sequence[tuple[float, float]] = None,
2370
        backend: str = "bokeh",
2371
        legend: bool = True,
2372
        histkwds: dict = None,
2373
        legkwds: dict = None,
2374
        **kwds,
2375
    ):
2376
        """Plot individual histograms of specified dimensions (axes) from a substituent
2377
        dataframe partition.
2378

2379
        Args:
2380
            dfpid (int): Number of the data frame partition to look at.
2381
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2382
            bins (Sequence[int], optional): Number of bins to use for the specified
2383
                axes. Defaults to config["histogram"]["bins"].
2384
            axes (Sequence[str], optional): Names of the axes to display.
2385
                Defaults to config["histogram"]["axes"].
2386
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2387
                specified axes. Defaults to config["histogram"]["ranges"].
2388
            backend (str, optional): Backend of the plotting library
2389
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2390
            legend (bool, optional): Option to include a legend in the histogram plots.
2391
                Defaults to True.
2392
            histkwds (dict, optional): Keyword arguments for histograms
2393
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2394
            legkwds (dict, optional): Keyword arguments for legend
2395
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2396
            **kwds: Extra keyword arguments passed to
2397
                ``sed.diagnostics.grid_histogram()``.
2398

2399
        Raises:
2400
            TypeError: Raises when the input values are not of the correct type.
2401
        """
2402
        if bins is None:
1✔
2403
            bins = self._config["histogram"]["bins"]
1✔
2404
        if axes is None:
1✔
2405
            axes = self._config["histogram"]["axes"]
1✔
2406
        axes = list(axes)
1✔
2407
        for loc, axis in enumerate(axes):
1✔
2408
            if axis.startswith("@"):
1✔
2409
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2410
        if ranges is None:
1✔
2411
            ranges = list(self._config["histogram"]["ranges"])
1✔
2412
            for loc, axis in enumerate(axes):
1✔
2413
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2414
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
1✔
2415
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
2416
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]
×
2417

2418
        input_types = map(type, [axes, bins, ranges])
1✔
2419
        allowed_types = [list, tuple]
1✔
2420

2421
        df = self._dataframe
1✔
2422

2423
        if not set(input_types).issubset(allowed_types):
1✔
2424
            raise TypeError(
×
2425
                "Inputs of axes, bins, ranges need to be list or tuple!",
2426
            )
2427

2428
        # Read out the values for the specified groups
2429
        group_dict_dd = {}
1✔
2430
        dfpart = df.get_partition(dfpid)
1✔
2431
        cols = dfpart.columns
1✔
2432
        for ax in axes:
1✔
2433
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2434
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2435

2436
        # Plot multiple histograms in a grid
2437
        grid_histogram(
1✔
2438
            group_dict,
2439
            ncol=ncol,
2440
            rvs=axes,
2441
            rvbins=bins,
2442
            rvranges=ranges,
2443
            backend=backend,
2444
            legend=legend,
2445
            histkwds=histkwds,
2446
            legkwds=legkwds,
2447
            **kwds,
2448
        )
2449

2450
    @call_logger(logger)
1✔
2451
    def save(
1✔
2452
        self,
2453
        faddr: str,
2454
        **kwds,
2455
    ):
2456
        """Saves the binned data to the provided path and filename.
2457

2458
        Args:
2459
            faddr (str): Path and name of the file to write. Its extension determines
2460
                the file type to write. Valid file types are:
2461

2462
                - "*.tiff", "*.tif": Saves a TIFF stack.
2463
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2464
                - "*.nxs", "*.nexus": Saves a NeXus file.
2465

2466
            **kwds: Keyword arguments, which are passed to the writer functions:
2467
                For TIFF writing:
2468

2469
                - **alias_dict**: Dictionary of dimension aliases to use.
2470

2471
                For HDF5 writing:
2472

2473
                - **mode**: hdf5 read/write mode. Defaults to "w".
2474

2475
                For NeXus:
2476

2477
                - **reader**: Name of the pynxtools reader to use.
2478
                  Defaults to config["nexus"]["reader"]
2479
                - **definition**: NeXus application definition to use for saving.
2480
                  Must be supported by the used ``reader``. Defaults to
2481
                  config["nexus"]["definition"]
2482
                - **input_files**: A list of input files to pass to the reader.
2483
                  Defaults to config["nexus"]["input_files"]
2484
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2485
                  to add to the list of files to pass to the reader.
2486
        """
2487
        if self._binned is None:
1✔
2488
            raise NameError("Need to bin data first!")
1✔
2489

2490
        if self._normalized is not None:
1✔
2491
            data = self._normalized
×
2492
        else:
2493
            data = self._binned
1✔
2494

2495
        extension = pathlib.Path(faddr).suffix
1✔
2496

2497
        if extension in (".tif", ".tiff"):
1✔
2498
            to_tiff(
1✔
2499
                data=data,
2500
                faddr=faddr,
2501
                **kwds,
2502
            )
2503
        elif extension in (".h5", ".hdf5"):
1✔
2504
            to_h5(
1✔
2505
                data=data,
2506
                faddr=faddr,
2507
                **kwds,
2508
            )
2509
        elif extension in (".nxs", ".nexus"):
1✔
2510
            try:
1✔
2511
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2512
                definition = kwds.pop(
1✔
2513
                    "definition",
2514
                    self._config["nexus"]["definition"],
2515
                )
2516
                input_files = kwds.pop(
1✔
2517
                    "input_files",
2518
                    self._config["nexus"]["input_files"],
2519
                )
2520
            except KeyError as exc:
×
2521
                raise ValueError(
×
2522
                    "The nexus reader, definition and input files need to be provide!",
2523
                ) from exc
2524

2525
            if isinstance(input_files, str):
1✔
2526
                input_files = [input_files]
1✔
2527

2528
            if "eln_data" in kwds:
1✔
2529
                input_files.append(kwds.pop("eln_data"))
1✔
2530

2531
            to_nexus(
1✔
2532
                data=data,
2533
                faddr=faddr,
2534
                reader=reader,
2535
                definition=definition,
2536
                input_files=input_files,
2537
                **kwds,
2538
            )
2539

2540
        else:
2541
            raise NotImplementedError(
1✔
2542
                f"Unrecognized file format: {extension}.",
2543
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc