• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 10383513665

14 Aug 2024 07:44AM UTC coverage: 92.668% (-0.02%) from 92.688%
10383513665

Pull #490

github

rettigl
some further fixes
Pull Request #490: Logging

262 of 299 new or added lines in 12 files covered. (87.63%)

4 existing lines in 2 files now uncovered.

7154 of 7720 relevant lines covered (92.67%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.67
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from datetime import datetime
1✔
9
from logging import INFO
1✔
10
from logging import WARNING
1✔
11
from typing import Any
1✔
12
from typing import cast
1✔
13

14
import dask.dataframe as ddf
1✔
15
import matplotlib.pyplot as plt
1✔
16
import numpy as np
1✔
17
import pandas as pd
1✔
18
import psutil
1✔
19
import xarray as xr
1✔
20

21
from sed.binning import bin_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
23
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
24
from sed.calibrator import DelayCalibrator
1✔
25
from sed.calibrator import EnergyCalibrator
1✔
26
from sed.calibrator import MomentumCorrector
1✔
27
from sed.core.config import parse_config
1✔
28
from sed.core.config import save_config
1✔
29
from sed.core.dfops import add_time_stamped_data
1✔
30
from sed.core.dfops import apply_filter
1✔
31
from sed.core.dfops import apply_jitter
1✔
32
from sed.core.logging import call_logger
1✔
33
from sed.core.logging import setup_logging
1✔
34
from sed.core.metadata import MetaHandler
1✔
35
from sed.diagnostics import grid_histogram
1✔
36
from sed.io import to_h5
1✔
37
from sed.io import to_nexus
1✔
38
from sed.io import to_tiff
1✔
39
from sed.loader import CopyTool
1✔
40
from sed.loader import get_loader
1✔
41
from sed.loader.mpes.loader import get_archiver_data
1✔
42
from sed.loader.mpes.loader import MpesLoader
1✔
43

44
N_CPU = psutil.cpu_count()
1✔
45

46
# Configure logging
47
logger = setup_logging("processor")
1✔
48

49

50
class SedProcessor:
1✔
51
    """Processor class of sed. Contains wrapper functions defining a work flow for data
52
    correction, calibration and binning.
53

54
    Args:
55
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
56
        config (dict | str, optional): Config dictionary or config file name.
57
            Defaults to None.
58
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
59
            into the class. Defaults to None.
60
        files (list[str], optional): List of files to pass to the loader defined in
61
            the config. Defaults to None.
62
        folder (str, optional): Folder containing files to pass to the loader
63
            defined in the config. Defaults to None.
64
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
65
            defined in the config. Defaults to None.
66
        collect_metadata (bool): Option to collect metadata from files.
67
            Defaults to False.
68
        verbose (bool, optional): Option to print out diagnostic information.
69
            Defaults to config["core"]["verbose"] or True.
70
        **kwds: Keyword arguments passed to the reader.
71
    """
72

73
    @call_logger(logger)
1✔
74
    def __init__(
1✔
75
        self,
76
        metadata: dict = None,
77
        config: dict | str = None,
78
        dataframe: pd.DataFrame | ddf.DataFrame = None,
79
        files: list[str] = None,
80
        folder: str = None,
81
        runs: Sequence[str] = None,
82
        collect_metadata: bool = False,
83
        verbose: bool = None,
84
        **kwds,
85
    ):
86
        """Processor class of sed. Contains wrapper functions defining a work flow
87
        for data correction, calibration, and binning.
88

89
        Args:
90
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
91
            config (dict | str, optional): Config dictionary or config file name.
92
                Defaults to None.
93
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
94
                into the class. Defaults to None.
95
            files (list[str], optional): List of files to pass to the loader defined in
96
                the config. Defaults to None.
97
            folder (str, optional): Folder containing files to pass to the loader
98
                defined in the config. Defaults to None.
99
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
100
                defined in the config. Defaults to None.
101
            collect_metadata (bool, optional): Option to collect metadata from files.
102
                Defaults to False.
103
            verbose (bool, optional): Option to print out diagnostic information.
104
                Defaults to config["core"]["verbose"] or True.
105
            **kwds: Keyword arguments passed to parse_config and to the reader.
106
        """
107
        # split off config keywords
108
        config_kwds = {
1✔
109
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
110
        }
111
        for key in config_kwds.keys():
1✔
112
            del kwds[key]
1✔
113
        self._config = parse_config(config, **config_kwds)
1✔
114
        num_cores = self._config["core"].get("num_cores", N_CPU - 1)
1✔
115
        if num_cores >= N_CPU:
1✔
116
            num_cores = N_CPU - 1
1✔
117
        self._config["core"]["num_cores"] = num_cores
1✔
118
        logger.debug(f"Use {num_cores} cores for processing.")
1✔
119

120
        if verbose is None:
1✔
121
            self.verbose = self._config["core"].get("verbose", True)
1✔
122
        else:
123
            self.verbose = verbose
1✔
124
        if self.verbose:
1✔
125
            logger.handlers[0].setLevel(INFO)
1✔
126
        else:
NEW
127
            logger.handlers[0].setLevel(WARNING)
×
128

129
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
130
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
131
        self._files: list[str] = []
1✔
132

133
        self._binned: xr.DataArray = None
1✔
134
        self._pre_binned: xr.DataArray = None
1✔
135
        self._normalization_histogram: xr.DataArray = None
1✔
136
        self._normalized: xr.DataArray = None
1✔
137

138
        self._attributes = MetaHandler(meta=metadata)
1✔
139

140
        loader_name = self._config["core"]["loader"]
1✔
141
        self.loader = get_loader(
1✔
142
            loader_name=loader_name,
143
            config=self._config,
144
            verbose=verbose,
145
        )
146
        logger.debug(f"Use loader: {loader_name}")
1✔
147

148
        self.ec = EnergyCalibrator(
1✔
149
            loader=get_loader(
150
                loader_name=loader_name,
151
                config=self._config,
152
                verbose=verbose,
153
            ),
154
            config=self._config,
155
            verbose=self.verbose,
156
        )
157

158
        self.mc = MomentumCorrector(
1✔
159
            config=self._config,
160
            verbose=self.verbose,
161
        )
162

163
        self.dc = DelayCalibrator(
1✔
164
            config=self._config,
165
            verbose=self.verbose,
166
        )
167

168
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
169
            "use_copy_tool",
170
            False,
171
        )
172
        if self.use_copy_tool:
1✔
173
            try:
1✔
174
                self.ct = CopyTool(
1✔
175
                    source=self._config["core"]["copy_tool_source"],
176
                    dest=self._config["core"]["copy_tool_dest"],
177
                    num_cores=self._config["core"]["num_cores"],
178
                    **self._config["core"].get("copy_tool_kwds", {}),
179
                )
180
                logger.debug(
1✔
181
                    f"Initialized copy tool: Copy file from "
182
                    f"'{self._config['core']['copy_tool_source']}' "
183
                    f"to '{self._config['core']['copy_tool_dest']}'.",
184
                )
185
            except KeyError:
1✔
186
                self.use_copy_tool = False
1✔
187

188
        # Load data if provided:
189
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
190
            self.load(
1✔
191
                dataframe=dataframe,
192
                metadata=metadata,
193
                files=files,
194
                folder=folder,
195
                runs=runs,
196
                collect_metadata=collect_metadata,
197
                **kwds,
198
            )
199

200
    def __repr__(self):
1✔
201
        if self._dataframe is None:
1✔
202
            df_str = "Dataframe: No Data loaded"
1✔
203
        else:
204
            df_str = self._dataframe.__repr__()
1✔
205
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
206
        return pretty_str
1✔
207

208
    def _repr_html_(self):
1✔
209
        html = "<div>"
×
210

211
        if self._dataframe is None:
×
212
            df_html = "Dataframe: No Data loaded"
×
213
        else:
214
            df_html = self._dataframe._repr_html_()
×
215

216
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
217

218
        # Add expandable section for attributes
219
        html += "<details><summary>Metadata</summary>"
×
220
        html += "<div style='padding-left: 10px;'>"
×
221
        html += self._attributes._repr_html_()
×
222
        html += "</div></details>"
×
223

224
        html += "</div>"
×
225

226
        return html
×
227

228
    ## Suggestion:
229
    # @property
230
    # def overview_panel(self):
231
    #     """Provides an overview panel with plots of different data attributes."""
232
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
233

234
    @property
1✔
235
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
236
        """Accessor to the underlying dataframe.
237

238
        Returns:
239
            pd.DataFrame | ddf.DataFrame: Dataframe object.
240
        """
241
        return self._dataframe
1✔
242

243
    @dataframe.setter
1✔
244
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
245
        """Setter for the underlying dataframe.
246

247
        Args:
248
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
249
        """
250
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
251
            dataframe,
252
            self._dataframe.__class__,
253
        ):
254
            raise ValueError(
1✔
255
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
256
                "as the dataframe loaded into the SedProcessor!.\n"
257
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
258
            )
259
        self._dataframe = dataframe
1✔
260

261
    @property
1✔
262
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
263
        """Accessor to the underlying timed_dataframe.
264

265
        Returns:
266
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
267
        """
268
        return self._timed_dataframe
1✔
269

270
    @timed_dataframe.setter
1✔
271
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
272
        """Setter for the underlying timed dataframe.
273

274
        Args:
275
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
276
        """
277
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
278
            timed_dataframe,
279
            self._timed_dataframe.__class__,
280
        ):
281
            raise ValueError(
×
282
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
283
                "kind as the dataframe loaded into the SedProcessor!.\n"
284
                f"Loaded type: {self._timed_dataframe.__class__}, "
285
                f"provided type: {timed_dataframe}.",
286
            )
287
        self._timed_dataframe = timed_dataframe
×
288

289
    @property
1✔
290
    def attributes(self) -> MetaHandler:
1✔
291
        """Accessor to the metadata dict.
292

293
        Returns:
294
            MetaHandler: The metadata object
295
        """
296
        return self._attributes
1✔
297

298
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
299
        """Function to add element to the attributes dict.
300

301
        Args:
302
            attributes (dict): The attributes dictionary object to add.
303
            name (str): Key under which to add the dictionary to the attributes.
304
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
305
        """
306
        self._attributes.add(
1✔
307
            entry=attributes,
308
            name=name,
309
            **kwds,
310
        )
311

312
    @property
1✔
313
    def config(self) -> dict[Any, Any]:
1✔
314
        """Getter attribute for the config dictionary
315

316
        Returns:
317
            dict: The config dictionary.
318
        """
319
        return self._config
1✔
320

321
    @property
1✔
322
    def files(self) -> list[str]:
1✔
323
        """Getter attribute for the list of files
324

325
        Returns:
326
            list[str]: The list of loaded files
327
        """
328
        return self._files
1✔
329

330
    @property
1✔
331
    def binned(self) -> xr.DataArray:
1✔
332
        """Getter attribute for the binned data array
333

334
        Returns:
335
            xr.DataArray: The binned data array
336
        """
337
        if self._binned is None:
1✔
338
            raise ValueError("No binned data available, need to compute histogram first!")
×
339
        return self._binned
1✔
340

341
    @property
1✔
342
    def normalized(self) -> xr.DataArray:
1✔
343
        """Getter attribute for the normalized data array
344

345
        Returns:
346
            xr.DataArray: The normalized data array
347
        """
348
        if self._normalized is None:
1✔
349
            raise ValueError(
×
350
                "No normalized data available, compute data with normalization enabled!",
351
            )
352
        return self._normalized
1✔
353

354
    @property
1✔
355
    def normalization_histogram(self) -> xr.DataArray:
1✔
356
        """Getter attribute for the normalization histogram
357

358
        Returns:
359
            xr.DataArray: The normalization histogram
360
        """
361
        if self._normalization_histogram is None:
1✔
362
            raise ValueError("No normalization histogram available, generate histogram first!")
×
363
        return self._normalization_histogram
1✔
364

365
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
366
        """Function to mirror a list of files or a folder from a network drive to a
367
        local storage. Returns either the original or the copied path to the given
368
        path. The option to use this functionality is set by
369
        config["core"]["use_copy_tool"].
370

371
        Args:
372
            path (str | list[str]): Source path or path list.
373

374
        Returns:
375
            str | list[str]: Source or destination path or path list.
376
        """
377
        if self.use_copy_tool:
1✔
378
            if isinstance(path, list):
1✔
379
                path_out = []
1✔
380
                for file in path:
1✔
381
                    path_out.append(self.ct.copy(file))
1✔
382
                return path_out
1✔
383

384
            return self.ct.copy(path)
×
385

386
        if isinstance(path, list):
1✔
387
            return path
1✔
388

389
        return path
1✔
390

391
    @call_logger(logger)
1✔
392
    def load(
1✔
393
        self,
394
        dataframe: pd.DataFrame | ddf.DataFrame = None,
395
        metadata: dict = None,
396
        files: list[str] = None,
397
        folder: str = None,
398
        runs: Sequence[str] = None,
399
        collect_metadata: bool = False,
400
        **kwds,
401
    ):
402
        """Load tabular data of single events into the dataframe object in the class.
403

404
        Args:
405
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
406
                format. Accepts anything which can be interpreted by pd.DataFrame as
407
                an input. Defaults to None.
408
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
409
            files (list[str], optional): List of file paths to pass to the loader.
410
                Defaults to None.
411
            runs (Sequence[str], optional): List of run identifiers to pass to the
412
                loader. Defaults to None.
413
            folder (str, optional): Folder path to pass to the loader.
414
                Defaults to None.
415
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
416
            **kwds:
417
                - *timed_dataframe*: timed dataframe if dataframe is provided.
418

419
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
420

421
        Raises:
422
            ValueError: Raised if no valid input is provided.
423
        """
424
        if metadata is None:
1✔
425
            metadata = {}
1✔
426
        if dataframe is not None:
1✔
427
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
428
        elif runs is not None:
1✔
429
            # If runs are provided, we only use the copy tool if also folder is provided.
430
            # In that case, we copy the whole provided base folder tree, and pass the copied
431
            # version to the loader as base folder to look for the runs.
432
            if folder is not None:
1✔
433
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
434
                    folders=cast(str, self.cpy(folder)),
435
                    runs=runs,
436
                    metadata=metadata,
437
                    collect_metadata=collect_metadata,
438
                    **kwds,
439
                )
440
            else:
441
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
442
                    runs=runs,
443
                    metadata=metadata,
444
                    collect_metadata=collect_metadata,
445
                    **kwds,
446
                )
447

448
        elif folder is not None:
1✔
449
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
450
                folders=cast(str, self.cpy(folder)),
451
                metadata=metadata,
452
                collect_metadata=collect_metadata,
453
                **kwds,
454
            )
455
        elif files is not None:
1✔
456
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
457
                files=cast(list[str], self.cpy(files)),
458
                metadata=metadata,
459
                collect_metadata=collect_metadata,
460
                **kwds,
461
            )
462
        else:
463
            raise ValueError(
1✔
464
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
465
            )
466

467
        self._dataframe = dataframe
1✔
468
        self._timed_dataframe = timed_dataframe
1✔
469
        self._files = self.loader.files
1✔
470

471
        for key in metadata:
1✔
472
            self._attributes.add(
1✔
473
                entry=metadata[key],
474
                name=key,
475
                duplicate_policy="merge",
476
            )
477

478
    @call_logger(logger)
1✔
479
    def filter_column(
1✔
480
        self,
481
        column: str,
482
        min_value: float = -np.inf,
483
        max_value: float = np.inf,
484
    ) -> None:
485
        """Filter values in a column which are outside of a given range
486

487
        Args:
488
            column (str): Name of the column to filter
489
            min_value (float, optional): Minimum value to keep. Defaults to None.
490
            max_value (float, optional): Maximum value to keep. Defaults to None.
491
        """
492
        if column != "index" and column not in self._dataframe.columns:
1✔
493
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
494
        if min_value >= max_value:
1✔
495
            raise ValueError("min_value has to be smaller than max_value!")
1✔
496
        if self._dataframe is not None:
1✔
497
            self._dataframe = apply_filter(
1✔
498
                self._dataframe,
499
                col=column,
500
                lower_bound=min_value,
501
                upper_bound=max_value,
502
            )
503
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
504
            self._timed_dataframe = apply_filter(
1✔
505
                self._timed_dataframe,
506
                column,
507
                lower_bound=min_value,
508
                upper_bound=max_value,
509
            )
510
        metadata = {
1✔
511
            "filter": {
512
                "column": column,
513
                "min_value": min_value,
514
                "max_value": max_value,
515
            },
516
        }
517
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
518

519
    # Momentum calibration workflow
520
    # 1. Bin raw detector data for distortion correction
521
    @call_logger(logger)
1✔
522
    def bin_and_load_momentum_calibration(
1✔
523
        self,
524
        df_partitions: int | Sequence[int] = 100,
525
        axes: list[str] = None,
526
        bins: list[int] = None,
527
        ranges: Sequence[tuple[float, float]] = None,
528
        plane: int = 0,
529
        width: int = 5,
530
        apply: bool = False,
531
        **kwds,
532
    ):
533
        """1st step of momentum correction work flow. Function to do an initial binning
534
        of the dataframe loaded to the class, slice a plane from it using an
535
        interactive view, and load it into the momentum corrector class.
536

537
        Args:
538
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
539
                to use for the initial binning. Defaults to 100.
540
            axes (list[str], optional): Axes to bin.
541
                Defaults to config["momentum"]["axes"].
542
            bins (list[int], optional): Bin numbers to use for binning.
543
                Defaults to config["momentum"]["bins"].
544
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
545
                Defaults to config["momentum"]["ranges"].
546
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
547
            width (int, optional): Initial value for the width slider. Defaults to 5.
548
            apply (bool, optional): Option to directly apply the values and select the
549
                slice. Defaults to False.
550
            **kwds: Keyword argument passed to the pre_binning function.
551
        """
552
        self._pre_binned = self.pre_binning(
1✔
553
            df_partitions=df_partitions,
554
            axes=axes,
555
            bins=bins,
556
            ranges=ranges,
557
            **kwds,
558
        )
559

560
        self.mc.load_data(data=self._pre_binned)
1✔
561
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
562

563
    # 2. Generate the spline warp correction from momentum features.
564
    # Either autoselect features, or input features from view above.
565
    @call_logger(logger)
1✔
566
    def define_features(
1✔
567
        self,
568
        features: np.ndarray = None,
569
        rotation_symmetry: int = 6,
570
        auto_detect: bool = False,
571
        include_center: bool = True,
572
        apply: bool = False,
573
        **kwds,
574
    ):
575
        """2. Step of the distortion correction workflow: Define feature points in
576
        momentum space. They can be either manually selected using a GUI tool, be
577
        provided as list of feature points, or auto-generated using a
578
        feature-detection algorithm.
579

580
        Args:
581
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
582
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
583
                Defaults to 6.
584
            auto_detect (bool, optional): Whether to auto-detect the features.
585
                Defaults to False.
586
            include_center (bool, optional): Option to include a point at the center
587
                in the feature list. Defaults to True.
588
            apply (bool, optional): Option to directly apply the values and select the
589
                slice. Defaults to False.
590
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
591
                ``MomentumCorrector.feature_select()``.
592
        """
593
        if auto_detect:  # automatic feature selection
1✔
594
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
595
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
596
            sigma_radius = kwds.pop(
×
597
                "sigma_radius",
598
                self._config["momentum"]["sigma_radius"],
599
            )
600
            self.mc.feature_extract(
×
601
                sigma=sigma,
602
                fwhm=fwhm,
603
                sigma_radius=sigma_radius,
604
                rotsym=rotation_symmetry,
605
                **kwds,
606
            )
607
            features = self.mc.peaks
×
608

609
        self.mc.feature_select(
1✔
610
            rotsym=rotation_symmetry,
611
            include_center=include_center,
612
            features=features,
613
            apply=apply,
614
            **kwds,
615
        )
616

617
    # 3. Generate the spline warp correction from momentum features.
618
    # If no features have been selected before, use class defaults.
619
    @call_logger(logger)
1✔
620
    def generate_splinewarp(
1✔
621
        self,
622
        use_center: bool = None,
623
        **kwds,
624
    ):
625
        """3. Step of the distortion correction workflow: Generate the correction
626
        function restoring the symmetry in the image using a splinewarp algorithm.
627

628
        Args:
629
            use_center (bool, optional): Option to use the position of the
630
                center point in the correction. Default is read from config, or set to True.
631
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
632
        """
633

634
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
635

636
        if self.mc.slice is not None and self.verbose:
1✔
637
            print("Original slice with reference features")
1✔
638
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
639

640
            print("Corrected slice with target features")
1✔
641
            self.mc.view(
1✔
642
                image=self.mc.slice_corrected,
643
                annotated=True,
644
                points={"feats": self.mc.ptargs},
645
                backend="bokeh",
646
                crosshair=True,
647
            )
648

649
            print("Original slice with target features")
1✔
650
            self.mc.view(
1✔
651
                image=self.mc.slice,
652
                points={"feats": self.mc.ptargs},
653
                annotated=True,
654
                backend="bokeh",
655
            )
656

657
    # 3a. Save spline-warp parameters to config file.
658
    def save_splinewarp(
1✔
659
        self,
660
        filename: str = None,
661
        overwrite: bool = False,
662
    ):
663
        """Save the generated spline-warp parameters to the folder config file.
664

665
        Args:
666
            filename (str, optional): Filename of the config dictionary to save to.
667
                Defaults to "sed_config.yaml" in the current folder.
668
            overwrite (bool, optional): Option to overwrite the present dictionary.
669
                Defaults to False.
670
        """
671
        if filename is None:
1✔
672
            filename = "sed_config.yaml"
×
673
        if len(self.mc.correction) == 0:
1✔
674
            raise ValueError("No momentum correction parameters to save!")
×
675
        correction = {}
1✔
676
        for key, value in self.mc.correction.items():
1✔
677
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
678
                continue
1✔
679
            if key in ["use_center", "rotation_symmetry"]:
1✔
680
                correction[key] = value
1✔
681
            elif key in ["center_point", "ascale"]:
1✔
682
                correction[key] = [float(i) for i in value]
1✔
683
            elif key in ["outer_points", "feature_points"]:
1✔
684
                correction[key] = []
1✔
685
                for point in value:
1✔
686
                    correction[key].append([float(i) for i in point])
1✔
687
            else:
688
                correction[key] = float(value)
1✔
689

690
        if "creation_date" not in correction:
1✔
691
            correction["creation_date"] = datetime.now().timestamp()
×
692

693
        config = {
1✔
694
            "momentum": {
695
                "correction": correction,
696
            },
697
        }
698
        save_config(config, filename, overwrite)
1✔
699
        logger.info(f'Saved momentum correction parameters to "{filename}".')
1✔
700

701
    # 4. Pose corrections. Provide interactive interface for correcting
702
    # scaling, shift and rotation
703
    @call_logger(logger)
1✔
704
    def pose_adjustment(
1✔
705
        self,
706
        transformations: dict[str, Any] = None,
707
        apply: bool = False,
708
        use_correction: bool = True,
709
        reset: bool = True,
710
        **kwds,
711
    ):
712
        """3. step of the distortion correction workflow: Generate an interactive panel
713
        to adjust affine transformations that are applied to the image. Applies first
714
        a scaling, next an x/y translation, and last a rotation around the center of
715
        the image.
716

717
        Args:
718
            transformations (dict[str, Any], optional): Dictionary with transformations.
719
                Defaults to self.transformations or config["momentum"]["transformations"].
720
            apply (bool, optional): Option to directly apply the provided
721
                transformations. Defaults to False.
722
            use_correction (bool, option): Whether to use the spline warp correction
723
                or not. Defaults to True.
724
            reset (bool, optional): Option to reset the correction before transformation.
725
                Defaults to True.
726
            **kwds: Keyword parameters defining defaults for the transformations:
727

728
                - **scale** (float): Initial value of the scaling slider.
729
                - **xtrans** (float): Initial value of the xtrans slider.
730
                - **ytrans** (float): Initial value of the ytrans slider.
731
                - **angle** (float): Initial value of the angle slider.
732
        """
733
        # Generate homography as default if no distortion correction has been applied
734
        if self.mc.slice_corrected is None:
1✔
735
            if self.mc.slice is None:
1✔
736
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
737
            self.mc.slice_corrected = self.mc.slice
1✔
738

739
        if not use_correction:
1✔
740
            self.mc.reset_deformation()
1✔
741

742
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
743
            # Generate distortion correction from config values
NEW
744
            self.mc.spline_warp_estimate()
×
745

746
        self.mc.pose_adjustment(
1✔
747
            transformations=transformations,
748
            apply=apply,
749
            reset=reset,
750
            **kwds,
751
        )
752

753
    # 4a. Save pose adjustment parameters to config file.
754
    @call_logger(logger)
1✔
755
    def save_transformations(
1✔
756
        self,
757
        filename: str = None,
758
        overwrite: bool = False,
759
    ):
760
        """Save the pose adjustment parameters to the folder config file.
761

762
        Args:
763
            filename (str, optional): Filename of the config dictionary to save to.
764
                Defaults to "sed_config.yaml" in the current folder.
765
            overwrite (bool, optional): Option to overwrite the present dictionary.
766
                Defaults to False.
767
        """
768
        if filename is None:
1✔
769
            filename = "sed_config.yaml"
×
770
        if len(self.mc.transformations) == 0:
1✔
771
            raise ValueError("No momentum transformation parameters to save!")
×
772
        transformations = {}
1✔
773
        for key, value in self.mc.transformations.items():
1✔
774
            transformations[key] = float(value)
1✔
775

776
        if "creation_date" not in transformations:
1✔
777
            transformations["creation_date"] = datetime.now().timestamp()
×
778

779
        config = {
1✔
780
            "momentum": {
781
                "transformations": transformations,
782
            },
783
        }
784
        save_config(config, filename, overwrite)
1✔
785
        logger.info(f'Saved momentum transformation parameters to "{filename}".')
1✔
786

787
    # 5. Apply the momentum correction to the dataframe
788
    @call_logger(logger)
1✔
789
    def apply_momentum_correction(
1✔
790
        self,
791
        preview: bool = False,
792
        **kwds,
793
    ):
794
        """Applies the distortion correction and pose adjustment (optional)
795
        to the dataframe.
796

797
        Args:
798
            preview (bool, optional): Option to preview the first elements of the data frame.
799
                Defaults to False.
800
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
801

802
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
803
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
804
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
805

806
        """
807
        x_column = self._config["dataframe"]["x_column"]
1✔
808
        y_column = self._config["dataframe"]["y_column"]
1✔
809

810
        if self._dataframe is not None:
1✔
811
            logger.info("Adding corrected X/Y columns to dataframe:")
1✔
812
            df, metadata = self.mc.apply_corrections(
1✔
813
                df=self._dataframe,
814
                **kwds,
815
            )
816
            if (
1✔
817
                self._timed_dataframe is not None
818
                and x_column in self._timed_dataframe.columns
819
                and y_column in self._timed_dataframe.columns
820
            ):
821
                tdf, _ = self.mc.apply_corrections(
1✔
822
                    self._timed_dataframe,
823
                    **kwds,
824
                )
825

826
            # Add Metadata
827
            self._attributes.add(
1✔
828
                metadata,
829
                "momentum_correction",
830
                duplicate_policy="merge",
831
            )
832
            self._dataframe = df
1✔
833
            if (
1✔
834
                self._timed_dataframe is not None
835
                and x_column in self._timed_dataframe.columns
836
                and y_column in self._timed_dataframe.columns
837
            ):
838
                self._timed_dataframe = tdf
1✔
839
        else:
840
            raise ValueError("No dataframe loaded!")
×
841
        if preview:
1✔
NEW
842
            logger.info(self._dataframe.head(10))
×
843
        else:
844
            logger.info(self._dataframe)
1✔
845

846
    # Momentum calibration work flow
847
    # 1. Calculate momentum calibration
848
    @call_logger(logger)
1✔
849
    def calibrate_momentum_axes(
1✔
850
        self,
851
        point_a: np.ndarray | list[int] = None,
852
        point_b: np.ndarray | list[int] = None,
853
        k_distance: float = None,
854
        k_coord_a: np.ndarray | list[float] = None,
855
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
856
        equiscale: bool = True,
857
        apply=False,
858
    ):
859
        """1. step of the momentum calibration workflow. Calibrate momentum
860
        axes using either provided pixel coordinates of a high-symmetry point and its
861
        distance to the BZ center, or the k-coordinates of two points in the BZ
862
        (depending on the equiscale option). Opens an interactive panel for selecting
863
        the points.
864

865
        Args:
866
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
867
                point used for momentum calibration.
868
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
869
                second point used for momentum calibration.
870
                Defaults to config["momentum"]["center_pixel"].
871
            k_distance (float, optional): Momentum distance between point a and b.
872
                Needs to be provided if no specific k-coordinates for the two points
873
                are given. Defaults to None.
874
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
875
                of the first point used for calibration. Used if equiscale is False.
876
                Defaults to None.
877
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
878
                of the second point used for calibration. Defaults to [0.0, 0.0].
879
            equiscale (bool, optional): Option to apply different scales to kx and ky.
880
                If True, the distance between points a and b, and the absolute
881
                position of point a are used for defining the scale. If False, the
882
                scale is calculated from the k-positions of both points a and b.
883
                Defaults to True.
884
            apply (bool, optional): Option to directly store the momentum calibration
885
                in the class. Defaults to False.
886
        """
887
        if point_b is None:
1✔
888
            point_b = self._config["momentum"]["center_pixel"]
1✔
889

890
        self.mc.select_k_range(
1✔
891
            point_a=point_a,
892
            point_b=point_b,
893
            k_distance=k_distance,
894
            k_coord_a=k_coord_a,
895
            k_coord_b=k_coord_b,
896
            equiscale=equiscale,
897
            apply=apply,
898
        )
899

900
    # 1a. Save momentum calibration parameters to config file.
901
    def save_momentum_calibration(
1✔
902
        self,
903
        filename: str = None,
904
        overwrite: bool = False,
905
    ):
906
        """Save the generated momentum calibration parameters to the folder config file.
907

908
        Args:
909
            filename (str, optional): Filename of the config dictionary to save to.
910
                Defaults to "sed_config.yaml" in the current folder.
911
            overwrite (bool, optional): Option to overwrite the present dictionary.
912
                Defaults to False.
913
        """
914
        if filename is None:
1✔
915
            filename = "sed_config.yaml"
×
916
        if len(self.mc.calibration) == 0:
1✔
917
            raise ValueError("No momentum calibration parameters to save!")
×
918
        calibration = {}
1✔
919
        for key, value in self.mc.calibration.items():
1✔
920
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
921
                continue
1✔
922

923
            calibration[key] = float(value)
1✔
924

925
        if "creation_date" not in calibration:
1✔
926
            calibration["creation_date"] = datetime.now().timestamp()
×
927

928
        config = {"momentum": {"calibration": calibration}}
1✔
929
        save_config(config, filename, overwrite)
1✔
930
        logger.info(f"Saved momentum calibration parameters to {filename}")
1✔
931

932
    # 2. Apply correction and calibration to the dataframe
933
    @call_logger(logger)
1✔
934
    def apply_momentum_calibration(
1✔
935
        self,
936
        calibration: dict = None,
937
        preview: bool = False,
938
        **kwds,
939
    ):
940
        """2. step of the momentum calibration work flow: Apply the momentum
941
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
942
        these are used.
943

944
        Args:
945
            calibration (dict, optional): Optional dictionary with calibration data to
946
                use. Defaults to None.
947
            preview (bool, optional): Option to preview the first elements of the data frame.
948
                Defaults to False.
949
            **kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
950
        """
951
        x_column = self._config["dataframe"]["x_column"]
1✔
952
        y_column = self._config["dataframe"]["y_column"]
1✔
953

954
        if self._dataframe is not None:
1✔
955
            logger.info("Adding kx/ky columns to dataframe:")
1✔
956
            df, metadata = self.mc.append_k_axis(
1✔
957
                df=self._dataframe,
958
                calibration=calibration,
959
                **kwds,
960
            )
961
            if (
1✔
962
                self._timed_dataframe is not None
963
                and x_column in self._timed_dataframe.columns
964
                and y_column in self._timed_dataframe.columns
965
            ):
966
                tdf, _ = self.mc.append_k_axis(
1✔
967
                    df=self._timed_dataframe,
968
                    calibration=calibration,
969
                    suppress_output=True,
970
                    **kwds,
971
                )
972

973
            # Add Metadata
974
            self._attributes.add(
1✔
975
                metadata,
976
                "momentum_calibration",
977
                duplicate_policy="merge",
978
            )
979
            self._dataframe = df
1✔
980
            if (
1✔
981
                self._timed_dataframe is not None
982
                and x_column in self._timed_dataframe.columns
983
                and y_column in self._timed_dataframe.columns
984
            ):
985
                self._timed_dataframe = tdf
1✔
986
        else:
987
            raise ValueError("No dataframe loaded!")
×
988
        if preview:
1✔
NEW
989
            logger.info(self._dataframe.head(10))
×
990
        else:
991
            logger.info(self._dataframe)
1✔
992

993
    # Energy correction workflow
994
    # 1. Adjust the energy correction parameters
995
    @call_logger(logger)
1✔
996
    def adjust_energy_correction(
1✔
997
        self,
998
        correction_type: str = None,
999
        amplitude: float = None,
1000
        center: tuple[float, float] = None,
1001
        apply=False,
1002
        **kwds,
1003
    ):
1004
        """1. step of the energy correction workflow: Opens an interactive plot to
1005
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
1006
        they are not present yet.
1007

1008
        Args:
1009
            correction_type (str, optional): Type of correction to apply to the TOF
1010
                axis. Valid values are:
1011

1012
                - 'spherical'
1013
                - 'Lorentzian'
1014
                - 'Gaussian'
1015
                - 'Lorentzian_asymmetric'
1016

1017
                Defaults to config["energy"]["correction_type"].
1018
            amplitude (float, optional): Amplitude of the correction.
1019
                Defaults to config["energy"]["correction"]["amplitude"].
1020
            center (tuple[float, float], optional): Center X/Y coordinates for the
1021
                correction. Defaults to config["energy"]["correction"]["center"].
1022
            apply (bool, optional): Option to directly apply the provided or default
1023
                correction parameters. Defaults to False.
1024
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1025
        """
1026
        if self._pre_binned is None:
1✔
1027
            logger.warn("Pre-binned data not present, binning using defaults from config...")
1✔
1028
            self._pre_binned = self.pre_binning()
1✔
1029

1030
        self.ec.adjust_energy_correction(
1✔
1031
            self._pre_binned,
1032
            correction_type=correction_type,
1033
            amplitude=amplitude,
1034
            center=center,
1035
            apply=apply,
1036
            **kwds,
1037
        )
1038

1039
    # 1a. Save energy correction parameters to config file.
1040
    def save_energy_correction(
1✔
1041
        self,
1042
        filename: str = None,
1043
        overwrite: bool = False,
1044
    ):
1045
        """Save the generated energy correction parameters to the folder config file.
1046

1047
        Args:
1048
            filename (str, optional): Filename of the config dictionary to save to.
1049
                Defaults to "sed_config.yaml" in the current folder.
1050
            overwrite (bool, optional): Option to overwrite the present dictionary.
1051
                Defaults to False.
1052
        """
1053
        if filename is None:
1✔
1054
            filename = "sed_config.yaml"
1✔
1055
        if len(self.ec.correction) == 0:
1✔
1056
            raise ValueError("No energy correction parameters to save!")
×
1057
        correction = {}
1✔
1058
        for key, val in self.ec.correction.items():
1✔
1059
            if key == "correction_type":
1✔
1060
                correction[key] = val
1✔
1061
            elif key == "center":
1✔
1062
                correction[key] = [float(i) for i in val]
1✔
1063
            else:
1064
                correction[key] = float(val)
1✔
1065

1066
        if "creation_date" not in correction:
1✔
1067
            correction["creation_date"] = datetime.now().timestamp()
×
1068

1069
        config = {"energy": {"correction": correction}}
1✔
1070
        save_config(config, filename, overwrite)
1✔
1071
        logger.info(f"Saved energy correction parameters to {filename}")
1✔
1072

1073
    # 2. Apply energy correction to dataframe
1074
    @call_logger(logger)
1✔
1075
    def apply_energy_correction(
1✔
1076
        self,
1077
        correction: dict = None,
1078
        preview: bool = False,
1079
        **kwds,
1080
    ):
1081
        """2. step of the energy correction workflow: Apply the energy correction
1082
        parameters stored in the class to the dataframe.
1083

1084
        Args:
1085
            correction (dict, optional): Dictionary containing the correction
1086
                parameters. Defaults to config["energy"]["calibration"].
1087
            preview (bool, optional): Option to preview the first elements of the data frame.
1088
                Defaults to False.
1089
            **kwds:
1090
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1091
        """
1092
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1093

1094
        if self._dataframe is not None:
1✔
1095
            logger.info("Applying energy correction to dataframe...")
1✔
1096
            df, metadata = self.ec.apply_energy_correction(
1✔
1097
                df=self._dataframe,
1098
                correction=correction,
1099
                **kwds,
1100
            )
1101
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1102
                tdf, _ = self.ec.apply_energy_correction(
1✔
1103
                    df=self._timed_dataframe,
1104
                    correction=correction,
1105
                    suppress_output=True,
1106
                    **kwds,
1107
                )
1108

1109
            # Add Metadata
1110
            self._attributes.add(
1✔
1111
                metadata,
1112
                "energy_correction",
1113
            )
1114
            self._dataframe = df
1✔
1115
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1116
                self._timed_dataframe = tdf
1✔
1117
        else:
1118
            raise ValueError("No dataframe loaded!")
×
1119
        if preview:
1✔
NEW
1120
            logger.info(self._dataframe.head(10))
×
1121
        else:
1122
            logger.info(self._dataframe)
1✔
1123

1124
    # Energy calibrator workflow
1125
    # 1. Load and normalize data
1126
    @call_logger(logger)
1✔
1127
    def load_bias_series(
1✔
1128
        self,
1129
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1130
        data_files: list[str] = None,
1131
        axes: list[str] = None,
1132
        bins: list = None,
1133
        ranges: Sequence[tuple[float, float]] = None,
1134
        biases: np.ndarray = None,
1135
        bias_key: str = None,
1136
        normalize: bool = None,
1137
        span: int = None,
1138
        order: int = None,
1139
    ):
1140
        """1. step of the energy calibration workflow: Load and bin data from
1141
        single-event files, or load binned bias/TOF traces.
1142

1143
        Args:
1144
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1145
                Binned data If provided as DataArray, Needs to contain dimensions
1146
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1147
                provided as tuple, needs to contain elements tof, biases, traces.
1148
            data_files (list[str], optional): list of file paths to bin
1149
            axes (list[str], optional): bin axes.
1150
                Defaults to config["dataframe"]["tof_column"].
1151
            bins (list, optional): number of bins.
1152
                Defaults to config["energy"]["bins"].
1153
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1154
                Defaults to config["energy"]["ranges"].
1155
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1156
                voltages are extracted from the data files.
1157
            bias_key (str, optional): hdf5 path where bias values are stored.
1158
                Defaults to config["energy"]["bias_key"].
1159
            normalize (bool, optional): Option to normalize traces.
1160
                Defaults to config["energy"]["normalize"].
1161
            span (int, optional): span smoothing parameters of the LOESS method
1162
                (see ``scipy.signal.savgol_filter()``).
1163
                Defaults to config["energy"]["normalize_span"].
1164
            order (int, optional): order smoothing parameters of the LOESS method
1165
                (see ``scipy.signal.savgol_filter()``).
1166
                Defaults to config["energy"]["normalize_order"].
1167
        """
1168
        if binned_data is not None:
1✔
1169
            if isinstance(binned_data, xr.DataArray):
1✔
1170
                if (
1✔
1171
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
1172
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
1173
                ):
1174
                    raise ValueError(
1✔
1175
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1176
                        f"'{self._config['dataframe']['tof_column']}' and "
1177
                        f"'{self._config['dataframe']['bias_column']}'!.",
1178
                    )
1179
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
1180
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
1181
                traces = binned_data.values[:, :]
1✔
1182
            else:
1183
                try:
1✔
1184
                    (tof, biases, traces) = binned_data
1✔
1185
                except ValueError as exc:
1✔
1186
                    raise ValueError(
1✔
1187
                        "If binned_data is provided as tuple, it needs to contain "
1188
                        "(tof, biases, traces)!",
1189
                    ) from exc
1190
            logger.debug(f"Energy calibration data loaded from binned data. Bias values={biases}.")
1✔
1191
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1192

1193
        elif data_files is not None:
1✔
1194
            self.ec.bin_data(
1✔
1195
                data_files=cast(list[str], self.cpy(data_files)),
1196
                axes=axes,
1197
                bins=bins,
1198
                ranges=ranges,
1199
                biases=biases,
1200
                bias_key=bias_key,
1201
            )
1202
            logger.debug(
1✔
1203
                f"Energy calibration data binned from files {data_files} data. "
1204
                f"Bias values={biases}.",
1205
            )
1206

1207
        else:
1208
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1209

1210
        if (normalize is not None and normalize is True) or (
1✔
1211
            normalize is None and self._config["energy"]["normalize"]
1212
        ):
1213
            if span is None:
1✔
1214
                span = self._config["energy"]["normalize_span"]
1✔
1215
            if order is None:
1✔
1216
                order = self._config["energy"]["normalize_order"]
1✔
1217
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1218
        self.ec.view(
1✔
1219
            traces=self.ec.traces_normed,
1220
            xaxis=self.ec.tof,
1221
            backend="bokeh",
1222
        )
1223

1224
    # 2. extract ranges and get peak positions
1225
    @call_logger(logger)
1✔
1226
    def find_bias_peaks(
1✔
1227
        self,
1228
        ranges: list[tuple] | tuple,
1229
        ref_id: int = 0,
1230
        infer_others: bool = True,
1231
        mode: str = "replace",
1232
        radius: int = None,
1233
        peak_window: int = None,
1234
        apply: bool = False,
1235
    ):
1236
        """2. step of the energy calibration workflow: Find a peak within a given range
1237
        for the indicated reference trace, and tries to find the same peak for all
1238
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1239
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1240
        middle of the set, and don't choose the range too narrow around the peak.
1241
        Alternatively, a list of ranges for all traces can be provided.
1242

1243
        Args:
1244
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1245
                Alternatively, a list of ranges for all traces can be given.
1246
            ref_id (int, optional): The id of the trace the range refers to.
1247
                Defaults to 0.
1248
            infer_others (bool, optional): Whether to determine the range for the other
1249
                traces. Defaults to True.
1250
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1251
                Defaults to "replace".
1252
            radius (int, optional): Radius parameter for fast_dtw.
1253
                Defaults to config["energy"]["fastdtw_radius"].
1254
            peak_window (int, optional): Peak_window parameter for the peak detection
1255
                algorithm. amount of points that have to have to behave monotonously
1256
                around a peak. Defaults to config["energy"]["peak_window"].
1257
            apply (bool, optional): Option to directly apply the provided parameters.
1258
                Defaults to False.
1259
        """
1260
        if radius is None:
1✔
1261
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1262
        if peak_window is None:
1✔
1263
            peak_window = self._config["energy"]["peak_window"]
1✔
1264
        if not infer_others:
1✔
1265
            self.ec.add_ranges(
1✔
1266
                ranges=ranges,
1267
                ref_id=ref_id,
1268
                infer_others=infer_others,
1269
                mode=mode,
1270
                radius=radius,
1271
            )
1272
            logger.info(f"Use feature ranges: {self.ec.featranges}.")
1✔
1273
            try:
1✔
1274
                self.ec.feature_extract(peak_window=peak_window)
1✔
1275
                logger.info(f"Extracted energy features: {self.ec.peaks}.")
1✔
1276
                self.ec.view(
1✔
1277
                    traces=self.ec.traces_normed,
1278
                    segs=self.ec.featranges,
1279
                    xaxis=self.ec.tof,
1280
                    peaks=self.ec.peaks,
1281
                    backend="bokeh",
1282
                )
1283
            except IndexError:
×
NEW
1284
                logger.error("Could not determine all peaks!")
×
1285
                raise
×
1286
        else:
1287
            # New adjustment tool
1288
            assert isinstance(ranges, tuple)
1✔
1289
            self.ec.adjust_ranges(
1✔
1290
                ranges=ranges,
1291
                ref_id=ref_id,
1292
                traces=self.ec.traces_normed,
1293
                infer_others=infer_others,
1294
                radius=radius,
1295
                peak_window=peak_window,
1296
                apply=apply,
1297
            )
1298

1299
    # 3. Fit the energy calibration relation
1300
    @call_logger(logger)
1✔
1301
    def calibrate_energy_axis(
1✔
1302
        self,
1303
        ref_energy: float,
1304
        method: str = None,
1305
        energy_scale: str = None,
1306
        **kwds,
1307
    ):
1308
        """3. Step of the energy calibration workflow: Calculate the calibration
1309
        function for the energy axis, and apply it to the dataframe. Two
1310
        approximations are implemented, a (normally 3rd order) polynomial
1311
        approximation, and a d^2/(t-t0)^2 relation.
1312

1313
        Args:
1314
            ref_energy (float): Binding/kinetic energy of the detected feature.
1315
            method (str, optional): Method for determining the energy calibration.
1316

1317
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1318
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1319

1320
                Defaults to config["energy"]["calibration_method"]
1321
            energy_scale (str, optional): Direction of increasing energy scale.
1322

1323
                - **'kinetic'**: increasing energy with decreasing TOF.
1324
                - **'binding'**: increasing energy with increasing TOF.
1325

1326
                Defaults to config["energy"]["energy_scale"]
1327
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1328
        """
1329
        if method is None:
1✔
1330
            method = self._config["energy"]["calibration_method"]
1✔
1331

1332
        if energy_scale is None:
1✔
1333
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1334

1335
        self.ec.calibrate(
1✔
1336
            ref_energy=ref_energy,
1337
            method=method,
1338
            energy_scale=energy_scale,
1339
            **kwds,
1340
        )
1341
        if self.verbose:
1✔
1342
            self.ec.view(
1✔
1343
                traces=self.ec.traces_normed,
1344
                xaxis=self.ec.calibration["axis"],
1345
                align=True,
1346
                energy_scale=energy_scale,
1347
                backend="matplotlib",
1348
                title="Quality of Calibration",
1349
            )
1350
            plt.xlabel("Energy (eV)")
1✔
1351
            plt.ylabel("Intensity")
1✔
1352
            plt.tight_layout()
1✔
1353
            plt.show()
1✔
1354
            if energy_scale == "kinetic":
1✔
1355
                self.ec.view(
1✔
1356
                    traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1357
                    xaxis=self.ec.tof,
1358
                    backend="matplotlib",
1359
                    show_legend=False,
1360
                    title="E/TOF relationship",
1361
                )
1362
                plt.scatter(
1✔
1363
                    self.ec.peaks[:, 0],
1364
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1365
                    s=50,
1366
                    c="k",
1367
                )
1368
                plt.tight_layout()
1✔
1369
            elif energy_scale == "binding":
1✔
1370
                self.ec.view(
1✔
1371
                    traces=self.ec.calibration["axis"][None, :] - self.ec.biases[0],
1372
                    xaxis=self.ec.tof,
1373
                    backend="matplotlib",
1374
                    show_legend=False,
1375
                    title="E/TOF relationship",
1376
                )
1377
                plt.scatter(
1✔
1378
                    self.ec.peaks[:, 0],
1379
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1380
                    s=50,
1381
                    c="k",
1382
                )
1383
            else:
1384
                raise ValueError(
×
1385
                    'energy_scale needs to be either "binding" or "kinetic"',
1386
                    f", got {energy_scale}.",
1387
                )
1388
            plt.xlabel("Time-of-flight")
1✔
1389
            plt.ylabel("Energy (eV)")
1✔
1390
            plt.tight_layout()
1✔
1391
            plt.show()
1✔
1392

1393
    # 3a. Save energy calibration parameters to config file.
1394
    def save_energy_calibration(
1✔
1395
        self,
1396
        filename: str = None,
1397
        overwrite: bool = False,
1398
    ):
1399
        """Save the generated energy calibration parameters to the folder config file.
1400

1401
        Args:
1402
            filename (str, optional): Filename of the config dictionary to save to.
1403
                Defaults to "sed_config.yaml" in the current folder.
1404
            overwrite (bool, optional): Option to overwrite the present dictionary.
1405
                Defaults to False.
1406
        """
1407
        if filename is None:
1✔
1408
            filename = "sed_config.yaml"
×
1409
        if len(self.ec.calibration) == 0:
1✔
1410
            raise ValueError("No energy calibration parameters to save!")
×
1411
        calibration = {}
1✔
1412
        for key, value in self.ec.calibration.items():
1✔
1413
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1414
                continue
1✔
1415
            if key == "energy_scale":
1✔
1416
                calibration[key] = value
1✔
1417
            elif key == "coeffs":
1✔
1418
                calibration[key] = [float(i) for i in value]
1✔
1419
            else:
1420
                calibration[key] = float(value)
1✔
1421

1422
        if "creation_date" not in calibration:
1✔
1423
            calibration["creation_date"] = datetime.now().timestamp()
×
1424

1425
        config = {"energy": {"calibration": calibration}}
1✔
1426
        save_config(config, filename, overwrite)
1✔
1427
        logger.info(f'Saved energy calibration parameters to "{filename}".')
1✔
1428

1429
    # 4. Apply energy calibration to the dataframe
1430
    @call_logger(logger)
1✔
1431
    def append_energy_axis(
1✔
1432
        self,
1433
        calibration: dict = None,
1434
        bias_voltage: float = None,
1435
        preview: bool = False,
1436
        **kwds,
1437
    ):
1438
        """4. step of the energy calibration workflow: Apply the calibration function
1439
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1440
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1441
        can be provided.
1442

1443
        Args:
1444
            calibration (dict, optional): Calibration dict containing calibration
1445
                parameters. Overrides calibration from class or config.
1446
                Defaults to None.
1447
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1448
                the bias voltage is being read from the dataframe. If it is not found there,
1449
                a warning is printed and the calibrated data might have an offset.
1450
            preview (bool): Option to preview the first elements of the data frame.
1451
            **kwds:
1452
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1453
        """
1454
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1455

1456
        if self._dataframe is not None:
1✔
1457
            logger.info("Adding energy column to dataframe:")
1✔
1458
            df, metadata = self.ec.append_energy_axis(
1✔
1459
                df=self._dataframe,
1460
                calibration=calibration,
1461
                bias_voltage=bias_voltage,
1462
                **kwds,
1463
            )
1464
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1465
                tdf, _ = self.ec.append_energy_axis(
1✔
1466
                    df=self._timed_dataframe,
1467
                    calibration=calibration,
1468
                    bias_voltage=bias_voltage,
1469
                    suppress_output=True,
1470
                    **kwds,
1471
                )
1472

1473
            # Add Metadata
1474
            self._attributes.add(
1✔
1475
                metadata,
1476
                "energy_calibration",
1477
                duplicate_policy="merge",
1478
            )
1479
            self._dataframe = df
1✔
1480
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1481
                self._timed_dataframe = tdf
1✔
1482

1483
        else:
1484
            raise ValueError("No dataframe loaded!")
×
1485
        if preview:
1✔
NEW
1486
            logger.info(self._dataframe.head(10))
×
1487
        else:
1488
            logger.info(self._dataframe)
1✔
1489

1490
    @call_logger(logger)
1✔
1491
    def add_energy_offset(
1✔
1492
        self,
1493
        constant: float = None,
1494
        columns: str | Sequence[str] = None,
1495
        weights: float | Sequence[float] = None,
1496
        reductions: str | Sequence[str] = None,
1497
        preserve_mean: bool | Sequence[bool] = None,
1498
        preview: bool = False,
1499
    ) -> None:
1500
        """Shift the energy axis of the dataframe by a given amount.
1501

1502
        Args:
1503
            constant (float, optional): The constant to shift the energy axis by.
1504
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1505
            weights (float | Sequence[float], optional): weights to apply to the columns.
1506
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1507
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1508
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1509
                case the function is applied to the column to generate a single value for the whole
1510
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1511
                Currently only "mean" is supported.
1512
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1513
                column before applying the shift. Defaults to False.
1514
            preview (bool, optional): Option to preview the first elements of the data frame.
1515
                Defaults to False.
1516

1517
        Raises:
1518
            ValueError: If the energy column is not in the dataframe.
1519
        """
1520
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1521
        if energy_column not in self._dataframe.columns:
1✔
1522
            raise ValueError(
1✔
1523
                f"Energy column {energy_column} not found in dataframe! "
1524
                "Run `append_energy_axis()` first.",
1525
            )
1526
        if self.dataframe is not None:
1✔
1527
            logger.info("Adding energy offset to dataframe:")
1✔
1528
            df, metadata = self.ec.add_offsets(
1✔
1529
                df=self._dataframe,
1530
                constant=constant,
1531
                columns=columns,
1532
                energy_column=energy_column,
1533
                weights=weights,
1534
                reductions=reductions,
1535
                preserve_mean=preserve_mean,
1536
            )
1537
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1538
                tdf, _ = self.ec.add_offsets(
1✔
1539
                    df=self._timed_dataframe,
1540
                    constant=constant,
1541
                    columns=columns,
1542
                    energy_column=energy_column,
1543
                    weights=weights,
1544
                    reductions=reductions,
1545
                    preserve_mean=preserve_mean,
1546
                    suppress_output=True,
1547
                )
1548

1549
            self._attributes.add(
1✔
1550
                metadata,
1551
                "add_energy_offset",
1552
                # TODO: allow only appending when no offset along this column(s) was applied
1553
                # TODO: clear memory of modifications if the energy axis is recalculated
1554
                duplicate_policy="append",
1555
            )
1556
            self._dataframe = df
1✔
1557
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1558
                self._timed_dataframe = tdf
1✔
1559
        else:
1560
            raise ValueError("No dataframe loaded!")
×
1561
        if preview:
1✔
NEW
1562
            logger.info(self._dataframe.head(10))
×
1563
        else:
1564
            logger.info(self._dataframe)
1✔
1565

1566
    def save_energy_offset(
1✔
1567
        self,
1568
        filename: str = None,
1569
        overwrite: bool = False,
1570
    ):
1571
        """Save the generated energy calibration parameters to the folder config file.
1572

1573
        Args:
1574
            filename (str, optional): Filename of the config dictionary to save to.
1575
                Defaults to "sed_config.yaml" in the current folder.
1576
            overwrite (bool, optional): Option to overwrite the present dictionary.
1577
                Defaults to False.
1578
        """
1579
        if filename is None:
×
1580
            filename = "sed_config.yaml"
×
1581
        if len(self.ec.offsets) == 0:
×
1582
            raise ValueError("No energy offset parameters to save!")
×
1583

1584
        if "creation_date" not in self.ec.offsets.keys():
×
1585
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
×
1586

1587
        config = {"energy": {"offsets": self.ec.offsets}}
×
1588
        save_config(config, filename, overwrite)
×
NEW
1589
        logger.info(f'Saved energy offset parameters to "{filename}".')
×
1590

1591
    @call_logger(logger)
1✔
1592
    def append_tof_ns_axis(
1✔
1593
        self,
1594
        preview: bool = False,
1595
        **kwds,
1596
    ):
1597
        """Convert time-of-flight channel steps to nanoseconds.
1598

1599
        Args:
1600
            tof_ns_column (str, optional): Name of the generated column containing the
1601
                time-of-flight in nanosecond.
1602
                Defaults to config["dataframe"]["tof_ns_column"].
1603
            preview (bool, optional): Option to preview the first elements of the data frame.
1604
                Defaults to False.
1605
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1606

1607
        """
1608
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1609

1610
        if self._dataframe is not None:
1✔
1611
            logger.info("Adding time-of-flight column in nanoseconds to dataframe.")
1✔
1612
            # TODO assert order of execution through metadata
1613

1614
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1615
                df=self._dataframe,
1616
                **kwds,
1617
            )
1618
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1619
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1620
                    df=self._timed_dataframe,
1621
                    **kwds,
1622
                )
1623

1624
            self._attributes.add(
1✔
1625
                metadata,
1626
                "tof_ns_conversion",
1627
                duplicate_policy="overwrite",
1628
            )
1629
            self._dataframe = df
1✔
1630
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1631
                self._timed_dataframe = tdf
1✔
1632
        else:
1633
            raise ValueError("No dataframe loaded!")
×
1634
        if preview:
1✔
NEW
1635
            logger.info(self._dataframe.head(10))
×
1636
        else:
1637
            logger.info(self._dataframe)
1✔
1638

1639
    @call_logger(logger)
1✔
1640
    def align_dld_sectors(
1✔
1641
        self,
1642
        sector_delays: np.ndarray = None,
1643
        preview: bool = False,
1644
        **kwds,
1645
    ):
1646
        """Align the 8s sectors of the HEXTOF endstation.
1647

1648
        Args:
1649
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1650
                config["dataframe"]["sector_delays"].
1651
            preview (bool, optional): Option to preview the first elements of the data frame.
1652
                Defaults to False.
1653
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1654
        """
1655
        tof_column = self._config["dataframe"]["tof_column"]
1✔
1656

1657
        if self._dataframe is not None:
1✔
1658
            logger.info("Aligning 8s sectors of dataframe")
1✔
1659
            # TODO assert order of execution through metadata
1660

1661
            df, metadata = self.ec.align_dld_sectors(
1✔
1662
                df=self._dataframe,
1663
                sector_delays=sector_delays,
1664
                **kwds,
1665
            )
1666
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1667
                tdf, _ = self.ec.align_dld_sectors(
×
1668
                    df=self._timed_dataframe,
1669
                    sector_delays=sector_delays,
1670
                    **kwds,
1671
                )
1672

1673
            self._attributes.add(
1✔
1674
                metadata,
1675
                "dld_sector_alignment",
1676
                duplicate_policy="raise",
1677
            )
1678
            self._dataframe = df
1✔
1679
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1680
                self._timed_dataframe = tdf
×
1681
        else:
1682
            raise ValueError("No dataframe loaded!")
×
1683
        if preview:
1✔
NEW
1684
            logger.info(self._dataframe.head(10))
×
1685
        else:
1686
            logger.info(self._dataframe)
1✔
1687

1688
    # Delay calibration function
1689
    @call_logger(logger)
1✔
1690
    def calibrate_delay_axis(
1✔
1691
        self,
1692
        delay_range: tuple[float, float] = None,
1693
        datafile: str = None,
1694
        preview: bool = False,
1695
        **kwds,
1696
    ):
1697
        """Append delay column to dataframe. Either provide delay ranges, or read
1698
        them from a file.
1699

1700
        Args:
1701
            delay_range (tuple[float, float], optional): The scanned delay range in
1702
                picoseconds. Defaults to None.
1703
            datafile (str, optional): The file from which to read the delay ranges.
1704
                Defaults to None.
1705
            preview (bool, optional): Option to preview the first elements of the data frame.
1706
                Defaults to False.
1707
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1708
        """
1709
        adc_column = self._config["dataframe"]["adc_column"]
1✔
1710
        if adc_column not in self._dataframe.columns:
1✔
1711
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1712

1713
        if self._dataframe is not None:
1✔
1714
            logger.info("Adding delay column to dataframe:")
1✔
1715

1716
            if delay_range is None and datafile is None:
1✔
1717
                if len(self.dc.calibration) == 0:
1✔
1718
                    try:
1✔
1719
                        datafile = self._files[0]
1✔
1720
                    except IndexError as exc:
×
1721
                        raise IndexError(
×
1722
                            "No datafile available, specify either 'datafile' or 'delay_range'",
1723
                        ) from exc
1724

1725
            df, metadata = self.dc.append_delay_axis(
1✔
1726
                self._dataframe,
1727
                delay_range=delay_range,
1728
                datafile=datafile,
1729
                **kwds,
1730
            )
1731
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1732
                tdf, _ = self.dc.append_delay_axis(
1✔
1733
                    self._timed_dataframe,
1734
                    delay_range=delay_range,
1735
                    datafile=datafile,
1736
                    suppress_output=True,
1737
                    **kwds,
1738
                )
1739

1740
            # Add Metadata
1741
            self._attributes.add(
1✔
1742
                metadata,
1743
                "delay_calibration",
1744
                duplicate_policy="overwrite",
1745
            )
1746
            self._dataframe = df
1✔
1747
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1748
                self._timed_dataframe = tdf
1✔
1749
        else:
1750
            raise ValueError("No dataframe loaded!")
×
1751
        if preview:
1✔
1752
            logger.info(self._dataframe.head(10))
1✔
1753
        else:
1754
            logger.debug(self._dataframe)
1✔
1755

1756
    def save_delay_calibration(
1✔
1757
        self,
1758
        filename: str = None,
1759
        overwrite: bool = False,
1760
    ) -> None:
1761
        """Save the generated delay calibration parameters to the folder config file.
1762

1763
        Args:
1764
            filename (str, optional): Filename of the config dictionary to save to.
1765
                Defaults to "sed_config.yaml" in the current folder.
1766
            overwrite (bool, optional): Option to overwrite the present dictionary.
1767
                Defaults to False.
1768
        """
1769
        if filename is None:
1✔
1770
            filename = "sed_config.yaml"
×
1771

1772
        if len(self.dc.calibration) == 0:
1✔
1773
            raise ValueError("No delay calibration parameters to save!")
×
1774
        calibration = {}
1✔
1775
        for key, value in self.dc.calibration.items():
1✔
1776
            if key == "datafile":
1✔
1777
                calibration[key] = value
1✔
1778
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1779
                calibration[key] = [float(i) for i in value]
1✔
1780
            else:
1781
                calibration[key] = float(value)
1✔
1782

1783
        if "creation_date" not in calibration:
1✔
1784
            calibration["creation_date"] = datetime.now().timestamp()
×
1785

1786
        config = {
1✔
1787
            "delay": {
1788
                "calibration": calibration,
1789
            },
1790
        }
1791
        save_config(config, filename, overwrite)
1✔
1792

1793
    @call_logger(logger)
1✔
1794
    def add_delay_offset(
1✔
1795
        self,
1796
        constant: float = None,
1797
        flip_delay_axis: bool = None,
1798
        columns: str | Sequence[str] = None,
1799
        weights: float | Sequence[float] = 1.0,
1800
        reductions: str | Sequence[str] = None,
1801
        preserve_mean: bool | Sequence[bool] = False,
1802
        preview: bool = False,
1803
    ) -> None:
1804
        """Shift the delay axis of the dataframe by a constant or other columns.
1805

1806
        Args:
1807
            constant (float, optional): The constant to shift the delay axis by.
1808
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1809
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1810
            weights (float | Sequence[float], optional): weights to apply to the columns.
1811
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1812
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1813
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1814
                case the function is applied to the column to generate a single value for the whole
1815
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1816
                Currently only "mean" is supported.
1817
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1818
                column before applying the shift. Defaults to False.
1819
            preview (bool, optional): Option to preview the first elements of the data frame.
1820
                Defaults to False.
1821

1822
        Raises:
1823
            ValueError: If the delay column is not in the dataframe.
1824
        """
1825
        delay_column = self._config["dataframe"]["delay_column"]
1✔
1826
        if delay_column not in self._dataframe.columns:
1✔
1827
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1828

1829
        if self.dataframe is not None:
1✔
1830
            logger.info("Adding delay offset to dataframe:")
1✔
1831
            df, metadata = self.dc.add_offsets(
1✔
1832
                df=self._dataframe,
1833
                constant=constant,
1834
                flip_delay_axis=flip_delay_axis,
1835
                columns=columns,
1836
                delay_column=delay_column,
1837
                weights=weights,
1838
                reductions=reductions,
1839
                preserve_mean=preserve_mean,
1840
            )
1841
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1842
                tdf, _ = self.dc.add_offsets(
1✔
1843
                    df=self._timed_dataframe,
1844
                    constant=constant,
1845
                    flip_delay_axis=flip_delay_axis,
1846
                    columns=columns,
1847
                    delay_column=delay_column,
1848
                    weights=weights,
1849
                    reductions=reductions,
1850
                    preserve_mean=preserve_mean,
1851
                    suppress_output=True,
1852
                )
1853

1854
            self._attributes.add(
1✔
1855
                metadata,
1856
                "delay_offset",
1857
                duplicate_policy="append",
1858
            )
1859
            self._dataframe = df
1✔
1860
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1861
                self._timed_dataframe = tdf
1✔
1862
        else:
1863
            raise ValueError("No dataframe loaded!")
×
1864
        if preview:
1✔
1865
            logger.info(self._dataframe.head(10))
1✔
1866
        else:
1867
            logger.info(self._dataframe)
1✔
1868

1869
    def save_delay_offsets(
1✔
1870
        self,
1871
        filename: str = None,
1872
        overwrite: bool = False,
1873
    ) -> None:
1874
        """Save the generated delay calibration parameters to the folder config file.
1875

1876
        Args:
1877
            filename (str, optional): Filename of the config dictionary to save to.
1878
                Defaults to "sed_config.yaml" in the current folder.
1879
            overwrite (bool, optional): Option to overwrite the present dictionary.
1880
                Defaults to False.
1881
        """
1882
        if filename is None:
1✔
1883
            filename = "sed_config.yaml"
×
1884
        if len(self.dc.offsets) == 0:
1✔
1885
            raise ValueError("No delay offset parameters to save!")
×
1886

1887
        if "creation_date" not in self.ec.offsets.keys():
1✔
1888
            self.ec.offsets["creation_date"] = datetime.now().timestamp()
1✔
1889

1890
        config = {
1✔
1891
            "delay": {
1892
                "offsets": self.dc.offsets,
1893
            },
1894
        }
1895
        save_config(config, filename, overwrite)
1✔
1896
        logger.info(f'Saved delay offset parameters to "{filename}".')
1✔
1897

1898
    def save_workflow_params(
1✔
1899
        self,
1900
        filename: str = None,
1901
        overwrite: bool = False,
1902
    ) -> None:
1903
        """run all save calibration parameter methods
1904

1905
        Args:
1906
            filename (str, optional): Filename of the config dictionary to save to.
1907
                Defaults to "sed_config.yaml" in the current folder.
1908
            overwrite (bool, optional): Option to overwrite the present dictionary.
1909
                Defaults to False.
1910
        """
1911
        for method in [
×
1912
            self.save_splinewarp,
1913
            self.save_transformations,
1914
            self.save_momentum_calibration,
1915
            self.save_energy_correction,
1916
            self.save_energy_calibration,
1917
            self.save_energy_offset,
1918
            self.save_delay_calibration,
1919
            self.save_delay_offsets,
1920
        ]:
1921
            try:
×
1922
                method(filename, overwrite)
×
1923
            except (ValueError, AttributeError, KeyError):
×
1924
                pass
×
1925

1926
    @call_logger(logger)
1✔
1927
    def add_jitter(
1✔
1928
        self,
1929
        cols: list[str] = None,
1930
        amps: float | Sequence[float] = None,
1931
        **kwds,
1932
    ):
1933
        """Add jitter to the selected dataframe columns.
1934

1935
        Args:
1936
            cols (list[str], optional): The columns onto which to apply jitter.
1937
                Defaults to config["dataframe"]["jitter_cols"].
1938
            amps (float | Sequence[float], optional): Amplitude scalings for the
1939
                jittering noise. If one number is given, the same is used for all axes.
1940
                For uniform noise (default) it will cover the interval [-amp, +amp].
1941
                Defaults to config["dataframe"]["jitter_amps"].
1942
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1943
        """
1944
        if cols is None:
1✔
1945
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1946
        for loc, col in enumerate(cols):
1✔
1947
            if col.startswith("@"):
1✔
1948
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1949

1950
        if amps is None:
1✔
1951
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1952

1953
        self._dataframe = self._dataframe.map_partitions(
1✔
1954
            apply_jitter,
1955
            cols=cols,
1956
            cols_jittered=cols,
1957
            amps=amps,
1958
            **kwds,
1959
        )
1960
        if self._timed_dataframe is not None:
1✔
1961
            cols_timed = cols.copy()
1✔
1962
            for col in cols:
1✔
1963
                if col not in self._timed_dataframe.columns:
1✔
1964
                    cols_timed.remove(col)
×
1965

1966
            if cols_timed:
1✔
1967
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1968
                    apply_jitter,
1969
                    cols=cols_timed,
1970
                    cols_jittered=cols_timed,
1971
                )
1972
        metadata = []
1✔
1973
        for col in cols:
1✔
1974
            metadata.append(col)
1✔
1975
        # TODO: allow only appending if columns are not jittered yet
1976
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1977
        logger.info(f"add_jitter: Added jitter to columns {cols}.")
1✔
1978

1979
    @call_logger(logger)
1✔
1980
    def add_time_stamped_data(
1✔
1981
        self,
1982
        dest_column: str,
1983
        time_stamps: np.ndarray = None,
1984
        data: np.ndarray = None,
1985
        archiver_channel: str = None,
1986
        **kwds,
1987
    ):
1988
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
1989
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
1990
        an EPICS archiver instance.
1991

1992
        Args:
1993
            dest_column (str): destination column name
1994
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
1995
                time stamps are retrieved from the epics archiver
1996
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
1997
                If omitted, data are retrieved from the epics archiver.
1998
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
1999
                Either this or data and time_stamps have to be present.
2000
            **kwds:
2001

2002
                - **time_stamp_column**: Dataframe column containing time-stamp data
2003

2004
                Additional keyword arguments passed to ``add_time_stamped_data``.
2005
        """
2006
        time_stamp_column = kwds.pop(
1✔
2007
            "time_stamp_column",
2008
            self._config["dataframe"].get("time_stamp_alias", ""),
2009
        )
2010

2011
        if time_stamps is None and data is None:
1✔
2012
            if archiver_channel is None:
×
2013
                raise ValueError(
×
2014
                    "Either archiver_channel or both time_stamps and data have to be present!",
2015
                )
2016
            if self.loader.__name__ != "mpes":
×
2017
                raise NotImplementedError(
×
2018
                    "This function is currently only implemented for the mpes loader!",
2019
                )
2020
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2021
            # get channel data with +-5 seconds safety margin
2022
            time_stamps, data = get_archiver_data(
×
2023
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2024
                archiver_channel=archiver_channel,
2025
                ts_from=ts_from - 5,
2026
                ts_to=ts_to + 5,
2027
            )
2028

2029
        self._dataframe = add_time_stamped_data(
1✔
2030
            self._dataframe,
2031
            time_stamps=time_stamps,
2032
            data=data,
2033
            dest_column=dest_column,
2034
            time_stamp_column=time_stamp_column,
2035
            **kwds,
2036
        )
2037
        if self._timed_dataframe is not None:
1✔
2038
            if time_stamp_column in self._timed_dataframe:
1✔
2039
                self._timed_dataframe = add_time_stamped_data(
1✔
2040
                    self._timed_dataframe,
2041
                    time_stamps=time_stamps,
2042
                    data=data,
2043
                    dest_column=dest_column,
2044
                    time_stamp_column=time_stamp_column,
2045
                    **kwds,
2046
                )
2047
        metadata: list[Any] = []
1✔
2048
        metadata.append(dest_column)
1✔
2049
        metadata.append(time_stamps)
1✔
2050
        metadata.append(data)
1✔
2051
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2052
        logger.info(f"add_time_stamped_data: Added time-stamped data as column {dest_column}.")
1✔
2053

2054
    @call_logger(logger)
1✔
2055
    def pre_binning(
1✔
2056
        self,
2057
        df_partitions: int | Sequence[int] = 100,
2058
        axes: list[str] = None,
2059
        bins: list[int] = None,
2060
        ranges: Sequence[tuple[float, float]] = None,
2061
        **kwds,
2062
    ) -> xr.DataArray:
2063
        """Function to do an initial binning of the dataframe loaded to the class.
2064

2065
        Args:
2066
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2067
                use for the initial binning. Defaults to 100.
2068
            axes (list[str], optional): Axes to bin.
2069
                Defaults to config["momentum"]["axes"].
2070
            bins (list[int], optional): Bin numbers to use for binning.
2071
                Defaults to config["momentum"]["bins"].
2072
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2073
                Defaults to config["momentum"]["ranges"].
2074
            **kwds: Keyword argument passed to ``compute``.
2075

2076
        Returns:
2077
            xr.DataArray: pre-binned data-array.
2078
        """
2079
        if axes is None:
1✔
2080
            axes = self._config["momentum"]["axes"]
1✔
2081
        for loc, axis in enumerate(axes):
1✔
2082
            if axis.startswith("@"):
1✔
2083
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2084

2085
        if bins is None:
1✔
2086
            bins = self._config["momentum"]["bins"]
1✔
2087
        if ranges is None:
1✔
2088
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2089
            ranges_[2] = np.asarray(ranges_[2]) / self._config["dataframe"]["tof_binning"]
1✔
2090
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2091

2092
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2093

2094
        return self.compute(
1✔
2095
            bins=bins,
2096
            axes=axes,
2097
            ranges=ranges,
2098
            df_partitions=df_partitions,
2099
            **kwds,
2100
        )
2101

2102
    @call_logger(logger)
1✔
2103
    def compute(
1✔
2104
        self,
2105
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2106
        axes: str | Sequence[str] = None,
2107
        ranges: Sequence[tuple[float, float]] = None,
2108
        normalize_to_acquisition_time: bool | str = False,
2109
        **kwds,
2110
    ) -> xr.DataArray:
2111
        """Compute the histogram along the given dimensions.
2112

2113
        Args:
2114
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2115
                Definition of the bins. Can be any of the following cases:
2116

2117
                - an integer describing the number of bins in on all dimensions
2118
                - a tuple of 3 numbers describing start, end and step of the binning
2119
                  range
2120
                - a np.arrays defining the binning edges
2121
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2122
                - a dictionary made of the axes as keys and any of the above as values.
2123

2124
                This takes priority over the axes and range arguments. Defaults to 100.
2125
            axes (str | Sequence[str], optional): The names of the axes (columns)
2126
                on which to calculate the histogram. The order will be the order of the
2127
                dimensions in the resulting array. Defaults to None.
2128
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2129
                the start and end point of the binning range. Defaults to None.
2130
            normalize_to_acquisition_time (bool | str): Option to normalize the
2131
                result to the acquisition time. If a "slow" axis was scanned, providing
2132
                the name of the scanned axis will compute and apply the corresponding
2133
                normalization histogram. Defaults to False.
2134
            **kwds: Keyword arguments:
2135

2136
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2137
                  ``bin_dataframe`` for details. Defaults to
2138
                  config["binning"]["hist_mode"].
2139
                - **mode**: Defines how the results from each partition are combined.
2140
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2141
                  Defaults to config["binning"]["mode"].
2142
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2143
                  config["binning"]["pbar"].
2144
                - **n_cores**: Number of CPU cores to use for parallelization.
2145
                  Defaults to config["core"]["num_cores"] or N_CPU-1.
2146
                - **threads_per_worker**: Limit the number of threads that
2147
                  multiprocessing can spawn per binning thread. Defaults to
2148
                  config["binning"]["threads_per_worker"].
2149
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2150
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2151
                  config["binning"]["threadpool_API"].
2152
                - **df_partitions**: A sequence of dataframe partitions, or the
2153
                  number of the dataframe partitions to use. Defaults to all partitions.
2154
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2155
                  "upper_bound" to apply as filter to the dataframe before binning. The
2156
                  dataframe in the class remains unmodified by this.
2157

2158
                Additional kwds are passed to ``bin_dataframe``.
2159

2160
        Raises:
2161
            AssertError: Rises when no dataframe has been loaded.
2162

2163
        Returns:
2164
            xr.DataArray: The result of the n-dimensional binning represented in an
2165
            xarray object, combining the data with the axes.
2166
        """
2167
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2168

2169
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2170
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2171
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2172
        num_cores = kwds.pop("num_cores", self._config["core"]["num_cores"])
1✔
2173
        threads_per_worker = kwds.pop(
1✔
2174
            "threads_per_worker",
2175
            self._config["binning"]["threads_per_worker"],
2176
        )
2177
        threadpool_api = kwds.pop(
1✔
2178
            "threadpool_API",
2179
            self._config["binning"]["threadpool_API"],
2180
        )
2181
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2182
        if isinstance(df_partitions, int):
1✔
2183
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2184
        if df_partitions is not None:
1✔
2185
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2186
        else:
2187
            dataframe = self._dataframe
1✔
2188

2189
        filter_params = kwds.pop("filter", None)
1✔
2190
        if filter_params is not None:
1✔
2191
            try:
1✔
2192
                for param in filter_params:
1✔
2193
                    if "col" not in param:
1✔
2194
                        raise ValueError(
1✔
2195
                            "'col' needs to be defined for each filter entry! ",
2196
                            f"Not present in {param}.",
2197
                        )
2198
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2199
                    dataframe = apply_filter(dataframe, **param)
1✔
2200
            except AssertionError as exc:
1✔
2201
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2202
                raise ValueError(
1✔
2203
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2204
                    f"Parameters {invalid_keys} not valid in {param}.",
2205
                ) from exc
2206

2207
        self._binned = bin_dataframe(
1✔
2208
            df=dataframe,
2209
            bins=bins,
2210
            axes=axes,
2211
            ranges=ranges,
2212
            hist_mode=hist_mode,
2213
            mode=mode,
2214
            pbar=pbar,
2215
            n_cores=num_cores,
2216
            threads_per_worker=threads_per_worker,
2217
            threadpool_api=threadpool_api,
2218
            **kwds,
2219
        )
2220

2221
        for dim in self._binned.dims:
1✔
2222
            try:
1✔
2223
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2224
            except KeyError:
1✔
2225
                pass
1✔
2226

2227
        self._binned.attrs["units"] = "counts"
1✔
2228
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2229
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2230

2231
        if normalize_to_acquisition_time:
1✔
2232
            if isinstance(normalize_to_acquisition_time, str):
1✔
2233
                axis = normalize_to_acquisition_time
1✔
2234
                logger.info(f"Calculate normalization histogram for axis '{axis}'...")
1✔
2235
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2236
                    axis=axis,
2237
                    df_partitions=df_partitions,
2238
                )
2239
                # if the axes are named correctly, xarray figures out the normalization correctly
2240
                self._normalized = self._binned / self._normalization_histogram
1✔
2241
                self._attributes.add(
1✔
2242
                    self._normalization_histogram.values,
2243
                    name="normalization_histogram",
2244
                    duplicate_policy="overwrite",
2245
                )
2246
            else:
2247
                acquisition_time = self.loader.get_elapsed_time(
×
2248
                    fids=df_partitions,
2249
                )
2250
                if acquisition_time > 0:
×
2251
                    self._normalized = self._binned / acquisition_time
×
2252
                self._attributes.add(
×
2253
                    acquisition_time,
2254
                    name="normalization_histogram",
2255
                    duplicate_policy="overwrite",
2256
                )
2257

2258
            self._normalized.attrs["units"] = "counts/second"
1✔
2259
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2260
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2261

2262
            return self._normalized
1✔
2263

2264
        return self._binned
1✔
2265

2266
    @call_logger(logger)
1✔
2267
    def get_normalization_histogram(
1✔
2268
        self,
2269
        axis: str = "delay",
2270
        use_time_stamps: bool = False,
2271
        **kwds,
2272
    ) -> xr.DataArray:
2273
        """Generates a normalization histogram from the timed dataframe. Optionally,
2274
        use the TimeStamps column instead.
2275

2276
        Args:
2277
            axis (str, optional): The axis for which to compute histogram.
2278
                Defaults to "delay".
2279
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2280
                dataframe, rather than the timed dataframe. Defaults to False.
2281
            **kwds: Keyword arguments:
2282

2283
                - **df_partitions**: A sequence of dataframe partitions, or the
2284
                  number of the dataframe partitions to use. Defaults to all partitions.
2285

2286
        Raises:
2287
            ValueError: Raised if no data are binned.
2288
            ValueError: Raised if 'axis' not in binned coordinates.
2289
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2290
                in Dataframe.
2291

2292
        Returns:
2293
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2294
            per bin).
2295
        """
2296

2297
        if self._binned is None:
1✔
2298
            raise ValueError("Need to bin data first!")
1✔
2299
        if axis not in self._binned.coords:
1✔
2300
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2301

2302
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2303

2304
        if len(kwds) > 0:
1✔
2305
            raise TypeError(
1✔
2306
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2307
            )
2308

2309
        if isinstance(df_partitions, int):
1✔
2310
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2311
        if use_time_stamps or self._timed_dataframe is None:
1✔
2312
            if df_partitions is not None:
1✔
2313
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2314
                    self._dataframe.partitions[df_partitions],
2315
                    axis,
2316
                    self._binned.coords[axis].values,
2317
                    self._config["dataframe"]["time_stamp_alias"],
2318
                )
2319
            else:
2320
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2321
                    self._dataframe,
2322
                    axis,
2323
                    self._binned.coords[axis].values,
2324
                    self._config["dataframe"]["time_stamp_alias"],
2325
                )
2326
        else:
2327
            if df_partitions is not None:
1✔
2328
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2329
                    self._timed_dataframe.partitions[df_partitions],
2330
                    axis,
2331
                    self._binned.coords[axis].values,
2332
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2333
                )
2334
            else:
2335
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2336
                    self._timed_dataframe,
2337
                    axis,
2338
                    self._binned.coords[axis].values,
2339
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2340
                )
2341

2342
        return self._normalization_histogram
1✔
2343

2344
    def view_event_histogram(
1✔
2345
        self,
2346
        dfpid: int,
2347
        ncol: int = 2,
2348
        bins: Sequence[int] = None,
2349
        axes: Sequence[str] = None,
2350
        ranges: Sequence[tuple[float, float]] = None,
2351
        backend: str = "bokeh",
2352
        legend: bool = True,
2353
        histkwds: dict = None,
2354
        legkwds: dict = None,
2355
        **kwds,
2356
    ):
2357
        """Plot individual histograms of specified dimensions (axes) from a substituent
2358
        dataframe partition.
2359

2360
        Args:
2361
            dfpid (int): Number of the data frame partition to look at.
2362
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2363
            bins (Sequence[int], optional): Number of bins to use for the specified
2364
                axes. Defaults to config["histogram"]["bins"].
2365
            axes (Sequence[str], optional): Names of the axes to display.
2366
                Defaults to config["histogram"]["axes"].
2367
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2368
                specified axes. Defaults to config["histogram"]["ranges"].
2369
            backend (str, optional): Backend of the plotting library
2370
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2371
            legend (bool, optional): Option to include a legend in the histogram plots.
2372
                Defaults to True.
2373
            histkwds (dict, optional): Keyword arguments for histograms
2374
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2375
            legkwds (dict, optional): Keyword arguments for legend
2376
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2377
            **kwds: Extra keyword arguments passed to
2378
                ``sed.diagnostics.grid_histogram()``.
2379

2380
        Raises:
2381
            TypeError: Raises when the input values are not of the correct type.
2382
        """
2383
        if bins is None:
1✔
2384
            bins = self._config["histogram"]["bins"]
1✔
2385
        if axes is None:
1✔
2386
            axes = self._config["histogram"]["axes"]
1✔
2387
        axes = list(axes)
1✔
2388
        for loc, axis in enumerate(axes):
1✔
2389
            if axis.startswith("@"):
1✔
2390
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2391
        if ranges is None:
1✔
2392
            ranges = list(self._config["histogram"]["ranges"])
1✔
2393
            for loc, axis in enumerate(axes):
1✔
2394
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2395
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
1✔
2396
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
2397
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]
×
2398

2399
        input_types = map(type, [axes, bins, ranges])
1✔
2400
        allowed_types = [list, tuple]
1✔
2401

2402
        df = self._dataframe
1✔
2403

2404
        if not set(input_types).issubset(allowed_types):
1✔
2405
            raise TypeError(
×
2406
                "Inputs of axes, bins, ranges need to be list or tuple!",
2407
            )
2408

2409
        # Read out the values for the specified groups
2410
        group_dict_dd = {}
1✔
2411
        dfpart = df.get_partition(dfpid)
1✔
2412
        cols = dfpart.columns
1✔
2413
        for ax in axes:
1✔
2414
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2415
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2416

2417
        # Plot multiple histograms in a grid
2418
        grid_histogram(
1✔
2419
            group_dict,
2420
            ncol=ncol,
2421
            rvs=axes,
2422
            rvbins=bins,
2423
            rvranges=ranges,
2424
            backend=backend,
2425
            legend=legend,
2426
            histkwds=histkwds,
2427
            legkwds=legkwds,
2428
            **kwds,
2429
        )
2430

2431
    @call_logger(logger)
1✔
2432
    def save(
1✔
2433
        self,
2434
        faddr: str,
2435
        **kwds,
2436
    ):
2437
        """Saves the binned data to the provided path and filename.
2438

2439
        Args:
2440
            faddr (str): Path and name of the file to write. Its extension determines
2441
                the file type to write. Valid file types are:
2442

2443
                - "*.tiff", "*.tif": Saves a TIFF stack.
2444
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2445
                - "*.nxs", "*.nexus": Saves a NeXus file.
2446

2447
            **kwds: Keyword arguments, which are passed to the writer functions:
2448
                For TIFF writing:
2449

2450
                - **alias_dict**: Dictionary of dimension aliases to use.
2451

2452
                For HDF5 writing:
2453

2454
                - **mode**: hdf5 read/write mode. Defaults to "w".
2455

2456
                For NeXus:
2457

2458
                - **reader**: Name of the pynxtools reader to use.
2459
                  Defaults to config["nexus"]["reader"]
2460
                - **definition**: NeXus application definition to use for saving.
2461
                  Must be supported by the used ``reader``. Defaults to
2462
                  config["nexus"]["definition"]
2463
                - **input_files**: A list of input files to pass to the reader.
2464
                  Defaults to config["nexus"]["input_files"]
2465
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2466
                  to add to the list of files to pass to the reader.
2467
        """
2468
        if self._binned is None:
1✔
2469
            raise NameError("Need to bin data first!")
1✔
2470

2471
        if self._normalized is not None:
1✔
2472
            data = self._normalized
×
2473
        else:
2474
            data = self._binned
1✔
2475

2476
        extension = pathlib.Path(faddr).suffix
1✔
2477

2478
        if extension in (".tif", ".tiff"):
1✔
2479
            to_tiff(
1✔
2480
                data=data,
2481
                faddr=faddr,
2482
                **kwds,
2483
            )
2484
        elif extension in (".h5", ".hdf5"):
1✔
2485
            to_h5(
1✔
2486
                data=data,
2487
                faddr=faddr,
2488
                **kwds,
2489
            )
2490
        elif extension in (".nxs", ".nexus"):
1✔
2491
            try:
1✔
2492
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2493
                definition = kwds.pop(
1✔
2494
                    "definition",
2495
                    self._config["nexus"]["definition"],
2496
                )
2497
                input_files = kwds.pop(
1✔
2498
                    "input_files",
2499
                    self._config["nexus"]["input_files"],
2500
                )
2501
            except KeyError as exc:
×
2502
                raise ValueError(
×
2503
                    "The nexus reader, definition and input files need to be provide!",
2504
                ) from exc
2505

2506
            if isinstance(input_files, str):
1✔
2507
                input_files = [input_files]
1✔
2508

2509
            if "eln_data" in kwds:
1✔
2510
                input_files.append(kwds.pop("eln_data"))
1✔
2511

2512
            to_nexus(
1✔
2513
                data=data,
2514
                faddr=faddr,
2515
                reader=reader,
2516
                definition=definition,
2517
                input_files=input_files,
2518
                **kwds,
2519
            )
2520

2521
        else:
2522
            raise NotImplementedError(
1✔
2523
                f"Unrecognized file format: {extension}.",
2524
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc