• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 11344168648

15 Oct 2024 10:26AM UTC coverage: 92.633% (+0.1%) from 92.524%
11344168648

Pull #487

github

rettigl
add review suggestions
Pull Request #487: Pydantic model

446 of 465 new or added lines in 19 files covered. (95.91%)

9 existing lines in 4 files now uncovered.

7532 of 8131 relevant lines covered (92.63%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.7
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from copy import deepcopy
1✔
9
from datetime import datetime
1✔
10
from typing import Any
1✔
11
from typing import cast
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
23
from sed.calibrator import DelayCalibrator
1✔
24
from sed.calibrator import EnergyCalibrator
1✔
25
from sed.calibrator import MomentumCorrector
1✔
26
from sed.core.config import parse_config
1✔
27
from sed.core.config import save_config
1✔
28
from sed.core.dfops import add_time_stamped_data
1✔
29
from sed.core.dfops import apply_filter
1✔
30
from sed.core.dfops import apply_jitter
1✔
31
from sed.core.logging import call_logger
1✔
32
from sed.core.logging import set_verbosity
1✔
33
from sed.core.logging import setup_logging
1✔
34
from sed.core.metadata import MetaHandler
1✔
35
from sed.diagnostics import grid_histogram
1✔
36
from sed.io import to_h5
1✔
37
from sed.io import to_nexus
1✔
38
from sed.io import to_tiff
1✔
39
from sed.loader import CopyTool
1✔
40
from sed.loader import get_loader
1✔
41
from sed.loader.mpes.loader import get_archiver_data
1✔
42
from sed.loader.mpes.loader import MpesLoader
1✔
43

44
N_CPU = psutil.cpu_count()
1✔
45

46
# Configure logging
47
logger = setup_logging("processor")
1✔
48

49

50
class SedProcessor:
1✔
51
    """Processor class of sed. Contains wrapper functions defining a work flow for data
52
    correction, calibration and binning.
53

54
    Args:
55
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
56
        config (dict | str, optional): Config dictionary or config file name.
57
            Defaults to None.
58
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
59
            into the class. Defaults to None.
60
        files (list[str], optional): List of files to pass to the loader defined in
61
            the config. Defaults to None.
62
        folder (str, optional): Folder containing files to pass to the loader
63
            defined in the config. Defaults to None.
64
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
65
            defined in the config. Defaults to None.
66
        collect_metadata (bool): Option to collect metadata from files.
67
            Defaults to False.
68
        verbose (bool, optional): Option to print out diagnostic information.
69
            Defaults to config["core"]["verbose"] or True.
70
        **kwds: Keyword arguments passed to the reader.
71
    """
72

73
    @call_logger(logger)
1✔
74
    def __init__(
1✔
75
        self,
76
        metadata: dict = None,
77
        config: dict | str = None,
78
        dataframe: pd.DataFrame | ddf.DataFrame = None,
79
        files: list[str] = None,
80
        folder: str = None,
81
        runs: Sequence[str] = None,
82
        collect_metadata: bool = False,
83
        verbose: bool = None,
84
        **kwds,
85
    ):
86
        """Processor class of sed. Contains wrapper functions defining a work flow
87
        for data correction, calibration, and binning.
88

89
        Args:
90
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
91
            config (dict | str, optional): Config dictionary or config file name.
92
                Defaults to None.
93
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
94
                into the class. Defaults to None.
95
            files (list[str], optional): List of files to pass to the loader defined in
96
                the config. Defaults to None.
97
            folder (str, optional): Folder containing files to pass to the loader
98
                defined in the config. Defaults to None.
99
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
100
                defined in the config. Defaults to None.
101
            collect_metadata (bool, optional): Option to collect metadata from files.
102
                Defaults to False.
103
            verbose (bool, optional): Option to print out diagnostic information.
104
                Defaults to config["core"]["verbose"] or True.
105
            **kwds: Keyword arguments passed to parse_config and to the reader.
106
        """
107
        # split off config keywords
108
        config_kwds = {
1✔
109
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
110
        }
111
        for key in config_kwds.keys():
1✔
112
            del kwds[key]
1✔
113
        self._config = parse_config(config, **config_kwds)
1✔
114
        num_cores = self._config["core"].get("num_cores", N_CPU - 1)
1✔
115
        if num_cores >= N_CPU:
1✔
116
            num_cores = N_CPU - 1
1✔
117
        self._config["core"]["num_cores"] = num_cores
1✔
118
        logger.debug(f"Use {num_cores} cores for processing.")
1✔
119

120
        if verbose is None:
1✔
121
            self._verbose = self._config["core"].get("verbose", True)
1✔
122
        else:
123
            self._verbose = verbose
1✔
124
        set_verbosity(logger, self._verbose)
1✔
125

126
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
127
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
128
        self._files: list[str] = []
1✔
129

130
        self._binned: xr.DataArray = None
1✔
131
        self._pre_binned: xr.DataArray = None
1✔
132
        self._normalization_histogram: xr.DataArray = None
1✔
133
        self._normalized: xr.DataArray = None
1✔
134

135
        self._attributes = MetaHandler(meta=metadata)
1✔
136

137
        loader_name = self._config["core"]["loader"]
1✔
138
        self.loader = get_loader(
1✔
139
            loader_name=loader_name,
140
            config=self._config,
141
            verbose=verbose,
142
        )
143
        logger.debug(f"Use loader: {loader_name}")
1✔
144

145
        self.ec = EnergyCalibrator(
1✔
146
            loader=get_loader(
147
                loader_name=loader_name,
148
                config=self._config,
149
                verbose=verbose,
150
            ),
151
            config=self._config,
152
            verbose=self._verbose,
153
        )
154

155
        self.mc = MomentumCorrector(
1✔
156
            config=self._config,
157
            verbose=self._verbose,
158
        )
159

160
        self.dc = DelayCalibrator(
1✔
161
            config=self._config,
162
            verbose=self._verbose,
163
        )
164

165
        self.use_copy_tool = "copy_tool" in self._config["core"]
1✔
166
        if self.use_copy_tool:
1✔
167
            try:
1✔
168
                self.ct = CopyTool(
1✔
169
                    num_cores=self._config["core"]["num_cores"],
170
                    **self._config["core"]["copy_tool"],
171
                )
172
                logger.debug(
1✔
173
                    f"Initialized copy tool: Copy files from "
174
                    f"'{self._config['core']['copy_tool']['source']}' "
175
                    f"to '{self._config['core']['copy_tool']['dest']}'.",
176
                )
UNCOV
177
            except KeyError:
×
UNCOV
178
                self.use_copy_tool = False
×
179

180
        # Load data if provided:
181
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
182
            self.load(
1✔
183
                dataframe=dataframe,
184
                metadata=metadata,
185
                files=files,
186
                folder=folder,
187
                runs=runs,
188
                collect_metadata=collect_metadata,
189
                **kwds,
190
            )
191

192
    def __repr__(self):
1✔
193
        if self._dataframe is None:
1✔
194
            df_str = "Dataframe: No Data loaded"
1✔
195
        else:
196
            df_str = self._dataframe.__repr__()
1✔
197
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
198
        return pretty_str
1✔
199

200
    def _repr_html_(self):
1✔
201
        html = "<div>"
×
202

203
        if self._dataframe is None:
×
204
            df_html = "Dataframe: No Data loaded"
×
205
        else:
206
            df_html = self._dataframe._repr_html_()
×
207

208
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
209

210
        # Add expandable section for attributes
211
        html += "<details><summary>Metadata</summary>"
×
212
        html += "<div style='padding-left: 10px;'>"
×
213
        html += self._attributes._repr_html_()
×
214
        html += "</div></details>"
×
215

216
        html += "</div>"
×
217

218
        return html
×
219

220
    ## Suggestion:
221
    # @property
222
    # def overview_panel(self):
223
    #     """Provides an overview panel with plots of different data attributes."""
224
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
225

226
    @property
1✔
227
    def verbose(self) -> bool:
1✔
228
        """Accessor to the verbosity flag.
229

230
        Returns:
231
            bool: Verbosity flag.
232
        """
233
        return self._verbose
×
234

235
    @verbose.setter
1✔
236
    def verbose(self, verbose: bool):
1✔
237
        """Setter for the verbosity.
238

239
        Args:
240
            verbose (bool): Option to turn on verbose output. Sets loglevel to INFO.
241
        """
242
        self._verbose = verbose
×
243
        set_verbosity(logger, self._verbose)
×
244
        self.mc.verbose = verbose
×
245
        self.ec.verbose = verbose
×
246
        self.dc.verbose = verbose
×
247
        self.loader.verbose = verbose
×
248

249
    @property
1✔
250
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
251
        """Accessor to the underlying dataframe.
252

253
        Returns:
254
            pd.DataFrame | ddf.DataFrame: Dataframe object.
255
        """
256
        return self._dataframe
1✔
257

258
    @dataframe.setter
1✔
259
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
260
        """Setter for the underlying dataframe.
261

262
        Args:
263
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
264
        """
265
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
266
            dataframe,
267
            self._dataframe.__class__,
268
        ):
269
            raise ValueError(
1✔
270
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
271
                "as the dataframe loaded into the SedProcessor!.\n"
272
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
273
            )
274
        self._dataframe = dataframe
1✔
275

276
    @property
1✔
277
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
278
        """Accessor to the underlying timed_dataframe.
279

280
        Returns:
281
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
282
        """
283
        return self._timed_dataframe
1✔
284

285
    @timed_dataframe.setter
1✔
286
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
287
        """Setter for the underlying timed dataframe.
288

289
        Args:
290
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
291
        """
292
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
293
            timed_dataframe,
294
            self._timed_dataframe.__class__,
295
        ):
296
            raise ValueError(
×
297
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
298
                "kind as the dataframe loaded into the SedProcessor!.\n"
299
                f"Loaded type: {self._timed_dataframe.__class__}, "
300
                f"provided type: {timed_dataframe}.",
301
            )
302
        self._timed_dataframe = timed_dataframe
×
303

304
    @property
1✔
305
    def attributes(self) -> MetaHandler:
1✔
306
        """Accessor to the metadata dict.
307

308
        Returns:
309
            MetaHandler: The metadata object
310
        """
311
        return self._attributes
1✔
312

313
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
314
        """Function to add element to the attributes dict.
315

316
        Args:
317
            attributes (dict): The attributes dictionary object to add.
318
            name (str): Key under which to add the dictionary to the attributes.
319
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
320
        """
321
        self._attributes.add(
1✔
322
            entry=attributes,
323
            name=name,
324
            **kwds,
325
        )
326

327
    @property
1✔
328
    def config(self) -> dict[Any, Any]:
1✔
329
        """Getter attribute for the config dictionary
330

331
        Returns:
332
            dict: The config dictionary.
333
        """
334
        return self._config
1✔
335

336
    @property
1✔
337
    def files(self) -> list[str]:
1✔
338
        """Getter attribute for the list of files
339

340
        Returns:
341
            list[str]: The list of loaded files
342
        """
343
        return self._files
1✔
344

345
    @property
1✔
346
    def binned(self) -> xr.DataArray:
1✔
347
        """Getter attribute for the binned data array
348

349
        Returns:
350
            xr.DataArray: The binned data array
351
        """
352
        if self._binned is None:
1✔
353
            raise ValueError("No binned data available, need to compute histogram first!")
×
354
        return self._binned
1✔
355

356
    @property
1✔
357
    def normalized(self) -> xr.DataArray:
1✔
358
        """Getter attribute for the normalized data array
359

360
        Returns:
361
            xr.DataArray: The normalized data array
362
        """
363
        if self._normalized is None:
1✔
364
            raise ValueError(
×
365
                "No normalized data available, compute data with normalization enabled!",
366
            )
367
        return self._normalized
1✔
368

369
    @property
1✔
370
    def normalization_histogram(self) -> xr.DataArray:
1✔
371
        """Getter attribute for the normalization histogram
372

373
        Returns:
374
            xr.DataArray: The normalization histogram
375
        """
376
        if self._normalization_histogram is None:
1✔
377
            raise ValueError("No normalization histogram available, generate histogram first!")
×
378
        return self._normalization_histogram
1✔
379

380
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
381
        """Function to mirror a list of files or a folder from a network drive to a
382
        local storage. Returns either the original or the copied path to the given
383
        path. The option to use this functionality is set by
384
        config["core"]["use_copy_tool"].
385

386
        Args:
387
            path (str | list[str]): Source path or path list.
388

389
        Returns:
390
            str | list[str]: Source or destination path or path list.
391
        """
392
        if self.use_copy_tool:
1✔
393
            if isinstance(path, list):
1✔
394
                path_out = []
1✔
395
                for file in path:
1✔
396
                    path_out.append(self.ct.copy(file))
1✔
397
                return path_out
1✔
398

399
            return self.ct.copy(path)
×
400

401
        if isinstance(path, list):
1✔
402
            return path
1✔
403

404
        return path
1✔
405

406
    @call_logger(logger)
1✔
407
    def load(
1✔
408
        self,
409
        dataframe: pd.DataFrame | ddf.DataFrame = None,
410
        metadata: dict = None,
411
        files: list[str] = None,
412
        folder: str = None,
413
        runs: Sequence[str] = None,
414
        collect_metadata: bool = False,
415
        **kwds,
416
    ):
417
        """Load tabular data of single events into the dataframe object in the class.
418

419
        Args:
420
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
421
                format. Accepts anything which can be interpreted by pd.DataFrame as
422
                an input. Defaults to None.
423
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
424
            files (list[str], optional): List of file paths to pass to the loader.
425
                Defaults to None.
426
            runs (Sequence[str], optional): List of run identifiers to pass to the
427
                loader. Defaults to None.
428
            folder (str, optional): Folder path to pass to the loader.
429
                Defaults to None.
430
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
431
            **kwds:
432
                - *timed_dataframe*: timed dataframe if dataframe is provided.
433

434
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
435

436
        Raises:
437
            ValueError: Raised if no valid input is provided.
438
        """
439
        if metadata is None:
1✔
440
            metadata = {}
1✔
441
        if dataframe is not None:
1✔
442
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
443
        elif runs is not None:
1✔
444
            # If runs are provided, we only use the copy tool if also folder is provided.
445
            # In that case, we copy the whole provided base folder tree, and pass the copied
446
            # version to the loader as base folder to look for the runs.
447
            if folder is not None:
1✔
448
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
449
                    folders=cast(str, self.cpy(folder)),
450
                    runs=runs,
451
                    metadata=metadata,
452
                    collect_metadata=collect_metadata,
453
                    **kwds,
454
                )
455
            else:
456
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
457
                    runs=runs,
458
                    metadata=metadata,
459
                    collect_metadata=collect_metadata,
460
                    **kwds,
461
                )
462

463
        elif folder is not None:
1✔
464
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
465
                folders=cast(str, self.cpy(folder)),
466
                metadata=metadata,
467
                collect_metadata=collect_metadata,
468
                **kwds,
469
            )
470
        elif files is not None:
1✔
471
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
472
                files=cast(list[str], self.cpy(files)),
473
                metadata=metadata,
474
                collect_metadata=collect_metadata,
475
                **kwds,
476
            )
477
        else:
478
            raise ValueError(
1✔
479
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
480
            )
481

482
        self._dataframe = dataframe
1✔
483
        self._timed_dataframe = timed_dataframe
1✔
484
        self._files = self.loader.files
1✔
485

486
        for key in metadata:
1✔
487
            self._attributes.add(
1✔
488
                entry=metadata[key],
489
                name=key,
490
                duplicate_policy="merge",
491
            )
492

493
    @call_logger(logger)
1✔
494
    def filter_column(
1✔
495
        self,
496
        column: str,
497
        min_value: float = -np.inf,
498
        max_value: float = np.inf,
499
    ) -> None:
500
        """Filter values in a column which are outside of a given range
501

502
        Args:
503
            column (str): Name of the column to filter
504
            min_value (float, optional): Minimum value to keep. Defaults to None.
505
            max_value (float, optional): Maximum value to keep. Defaults to None.
506
        """
507
        if column != "index" and column not in self._dataframe.columns:
1✔
508
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
509
        if min_value >= max_value:
1✔
510
            raise ValueError("min_value has to be smaller than max_value!")
1✔
511
        if self._dataframe is not None:
1✔
512
            self._dataframe = apply_filter(
1✔
513
                self._dataframe,
514
                col=column,
515
                lower_bound=min_value,
516
                upper_bound=max_value,
517
            )
518
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
519
            self._timed_dataframe = apply_filter(
1✔
520
                self._timed_dataframe,
521
                column,
522
                lower_bound=min_value,
523
                upper_bound=max_value,
524
            )
525
        metadata = {
1✔
526
            "filter": {
527
                "column": column,
528
                "min_value": min_value,
529
                "max_value": max_value,
530
            },
531
        }
532
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
533

534
    # Momentum calibration workflow
535
    # 1. Bin raw detector data for distortion correction
536
    @call_logger(logger)
1✔
537
    def bin_and_load_momentum_calibration(
1✔
538
        self,
539
        df_partitions: int | Sequence[int] = 100,
540
        axes: list[str] = None,
541
        bins: list[int] = None,
542
        ranges: Sequence[tuple[float, float]] = None,
543
        plane: int = 0,
544
        width: int = 5,
545
        apply: bool = False,
546
        **kwds,
547
    ):
548
        """1st step of momentum correction work flow. Function to do an initial binning
549
        of the dataframe loaded to the class, slice a plane from it using an
550
        interactive view, and load it into the momentum corrector class.
551

552
        Args:
553
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
554
                to use for the initial binning. Defaults to 100.
555
            axes (list[str], optional): Axes to bin.
556
                Defaults to config["momentum"]["axes"].
557
            bins (list[int], optional): Bin numbers to use for binning.
558
                Defaults to config["momentum"]["bins"].
559
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
560
                Defaults to config["momentum"]["ranges"].
561
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
562
            width (int, optional): Initial value for the width slider. Defaults to 5.
563
            apply (bool, optional): Option to directly apply the values and select the
564
                slice. Defaults to False.
565
            **kwds: Keyword argument passed to the pre_binning function.
566
        """
567
        self._pre_binned = self.pre_binning(
1✔
568
            df_partitions=df_partitions,
569
            axes=axes,
570
            bins=bins,
571
            ranges=ranges,
572
            **kwds,
573
        )
574

575
        self.mc.load_data(data=self._pre_binned)
1✔
576
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
577

578
    # 2. Generate the spline warp correction from momentum features.
579
    # Either autoselect features, or input features from view above.
580
    @call_logger(logger)
1✔
581
    def define_features(
1✔
582
        self,
583
        features: np.ndarray = None,
584
        rotation_symmetry: int = 6,
585
        auto_detect: bool = False,
586
        include_center: bool = True,
587
        apply: bool = False,
588
        **kwds,
589
    ):
590
        """2. Step of the distortion correction workflow: Define feature points in
591
        momentum space. They can be either manually selected using a GUI tool, be
592
        provided as list of feature points, or auto-generated using a
593
        feature-detection algorithm.
594

595
        Args:
596
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
597
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
598
                Defaults to 6.
599
            auto_detect (bool, optional): Whether to auto-detect the features.
600
                Defaults to False.
601
            include_center (bool, optional): Option to include a point at the center
602
                in the feature list. Defaults to True.
603
            apply (bool, optional): Option to directly apply the values and select the
604
                slice. Defaults to False.
605
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
606
                ``MomentumCorrector.feature_select()``.
607
        """
608
        if auto_detect:  # automatic feature selection
1✔
609
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
610
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
611
            sigma_radius = kwds.pop(
×
612
                "sigma_radius",
613
                self._config["momentum"]["sigma_radius"],
614
            )
615
            self.mc.feature_extract(
×
616
                sigma=sigma,
617
                fwhm=fwhm,
618
                sigma_radius=sigma_radius,
619
                rotsym=rotation_symmetry,
620
                **kwds,
621
            )
622
            features = self.mc.peaks
×
623

624
        self.mc.feature_select(
1✔
625
            rotsym=rotation_symmetry,
626
            include_center=include_center,
627
            features=features,
628
            apply=apply,
629
            **kwds,
630
        )
631

632
    # 3. Generate the spline warp correction from momentum features.
633
    # If no features have been selected before, use class defaults.
634
    @call_logger(logger)
1✔
635
    def generate_splinewarp(
1✔
636
        self,
637
        use_center: bool = None,
638
        **kwds,
639
    ):
640
        """3. Step of the distortion correction workflow: Generate the correction
641
        function restoring the symmetry in the image using a splinewarp algorithm.
642

643
        Args:
644
            use_center (bool, optional): Option to use the position of the
645
                center point in the correction. Default is read from config, or set to True.
646
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
647
        """
648

649
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
650

651
        if self.mc.slice is not None and self._verbose:
1✔
652
            print("Original slice with reference features")
1✔
653
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
654

655
            print("Corrected slice with target features")
1✔
656
            self.mc.view(
1✔
657
                image=self.mc.slice_corrected,
658
                annotated=True,
659
                points={"feats": self.mc.ptargs},
660
                backend="bokeh",
661
                crosshair=True,
662
            )
663

664
            print("Original slice with target features")
1✔
665
            self.mc.view(
1✔
666
                image=self.mc.slice,
667
                points={"feats": self.mc.ptargs},
668
                annotated=True,
669
                backend="bokeh",
670
            )
671

672
    # 3a. Save spline-warp parameters to config file.
673
    def save_splinewarp(
1✔
674
        self,
675
        filename: str = None,
676
        overwrite: bool = False,
677
    ):
678
        """Save the generated spline-warp parameters to the folder config file.
679

680
        Args:
681
            filename (str, optional): Filename of the config dictionary to save to.
682
                Defaults to "sed_config.yaml" in the current folder.
683
            overwrite (bool, optional): Option to overwrite the present dictionary.
684
                Defaults to False.
685
        """
686
        if filename is None:
1✔
687
            filename = "sed_config.yaml"
×
688
        if len(self.mc.correction) == 0:
1✔
689
            raise ValueError("No momentum correction parameters to save!")
×
690
        correction = {}
1✔
691
        for key, value in self.mc.correction.items():
1✔
692
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
693
                continue
1✔
694
            if key in ["use_center", "rotation_symmetry"]:
1✔
695
                correction[key] = value
1✔
696
            elif key in ["center_point", "ascale"]:
1✔
697
                correction[key] = [float(i) for i in value]
1✔
698
            elif key in ["outer_points", "feature_points"]:
1✔
699
                correction[key] = []
1✔
700
                for point in value:
1✔
701
                    correction[key].append([float(i) for i in point])
1✔
702
            elif key == "creation_date":
1✔
703
                correction[key] = value.isoformat()
1✔
704
            else:
705
                correction[key] = float(value)
1✔
706

707
        if "creation_date" not in correction:
1✔
NEW
708
            correction["creation_date"] = datetime.now().isoformat()
×
709

710
        config = {
1✔
711
            "momentum": {
712
                "correction": correction,
713
            },
714
        }
715
        save_config(config, filename, overwrite)
1✔
716
        logger.info(f'Saved momentum correction parameters to "{filename}".')
1✔
717

718
    # 4. Pose corrections. Provide interactive interface for correcting
719
    # scaling, shift and rotation
720
    @call_logger(logger)
1✔
721
    def pose_adjustment(
1✔
722
        self,
723
        transformations: dict[str, Any] = None,
724
        apply: bool = False,
725
        use_correction: bool = True,
726
        reset: bool = True,
727
        **kwds,
728
    ):
729
        """3. step of the distortion correction workflow: Generate an interactive panel
730
        to adjust affine transformations that are applied to the image. Applies first
731
        a scaling, next an x/y translation, and last a rotation around the center of
732
        the image.
733

734
        Args:
735
            transformations (dict[str, Any], optional): Dictionary with transformations.
736
                Defaults to self.transformations or config["momentum"]["transformations"].
737
            apply (bool, optional): Option to directly apply the provided
738
                transformations. Defaults to False.
739
            use_correction (bool, option): Whether to use the spline warp correction
740
                or not. Defaults to True.
741
            reset (bool, optional): Option to reset the correction before transformation.
742
                Defaults to True.
743
            **kwds: Keyword parameters defining defaults for the transformations:
744

745
                - **scale** (float): Initial value of the scaling slider.
746
                - **xtrans** (float): Initial value of the xtrans slider.
747
                - **ytrans** (float): Initial value of the ytrans slider.
748
                - **angle** (float): Initial value of the angle slider.
749
        """
750
        # Generate homography as default if no distortion correction has been applied
751
        if self.mc.slice_corrected is None:
1✔
752
            if self.mc.slice is None:
1✔
753
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
754
            self.mc.slice_corrected = self.mc.slice
1✔
755

756
        if not use_correction:
1✔
757
            self.mc.reset_deformation()
1✔
758

759
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
760
            # Generate distortion correction from config values
761
            self.mc.spline_warp_estimate()
×
762

763
        self.mc.pose_adjustment(
1✔
764
            transformations=transformations,
765
            apply=apply,
766
            reset=reset,
767
            **kwds,
768
        )
769

770
    # 4a. Save pose adjustment parameters to config file.
771
    @call_logger(logger)
1✔
772
    def save_transformations(
1✔
773
        self,
774
        filename: str = None,
775
        overwrite: bool = False,
776
    ):
777
        """Save the pose adjustment parameters to the folder config file.
778

779
        Args:
780
            filename (str, optional): Filename of the config dictionary to save to.
781
                Defaults to "sed_config.yaml" in the current folder.
782
            overwrite (bool, optional): Option to overwrite the present dictionary.
783
                Defaults to False.
784
        """
785
        if filename is None:
1✔
786
            filename = "sed_config.yaml"
×
787
        if len(self.mc.transformations) == 0:
1✔
788
            raise ValueError("No momentum transformation parameters to save!")
×
789
        transformations = {}
1✔
790
        for key, value in self.mc.transformations.items():
1✔
791
            if key == "creation_date":
1✔
792
                transformations[key] = value.isoformat()
1✔
793
            else:
794
                transformations[key] = float(value)
1✔
795

796
        if "creation_date" not in transformations:
1✔
NEW
797
            transformations["creation_date"] = datetime.now().isoformat()
×
798

799
        config = {
1✔
800
            "momentum": {
801
                "transformations": transformations,
802
            },
803
        }
804
        save_config(config, filename, overwrite)
1✔
805
        logger.info(f'Saved momentum transformation parameters to "{filename}".')
1✔
806

807
    # 5. Apply the momentum correction to the dataframe
808
    @call_logger(logger)
1✔
809
    def apply_momentum_correction(
1✔
810
        self,
811
        preview: bool = False,
812
        **kwds,
813
    ):
814
        """Applies the distortion correction and pose adjustment (optional)
815
        to the dataframe.
816

817
        Args:
818
            preview (bool, optional): Option to preview the first elements of the data frame.
819
                Defaults to False.
820
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
821

822
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
823
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
824
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
825

826
        """
827
        x_column = self._config["dataframe"]["columns"]["x"]
1✔
828
        y_column = self._config["dataframe"]["columns"]["y"]
1✔
829

830
        if self._dataframe is not None:
1✔
831
            logger.info("Adding corrected X/Y columns to dataframe:")
1✔
832
            df, metadata = self.mc.apply_corrections(
1✔
833
                df=self._dataframe,
834
                **kwds,
835
            )
836
            if (
1✔
837
                self._timed_dataframe is not None
838
                and x_column in self._timed_dataframe.columns
839
                and y_column in self._timed_dataframe.columns
840
            ):
841
                tdf, _ = self.mc.apply_corrections(
1✔
842
                    self._timed_dataframe,
843
                    **kwds,
844
                )
845

846
            # Add Metadata
847
            self._attributes.add(
1✔
848
                metadata,
849
                "momentum_correction",
850
                duplicate_policy="merge",
851
            )
852
            self._dataframe = df
1✔
853
            if (
1✔
854
                self._timed_dataframe is not None
855
                and x_column in self._timed_dataframe.columns
856
                and y_column in self._timed_dataframe.columns
857
            ):
858
                self._timed_dataframe = tdf
1✔
859
        else:
860
            raise ValueError("No dataframe loaded!")
×
861
        if preview:
1✔
862
            logger.info(self._dataframe.head(10))
×
863
        else:
864
            logger.info(self._dataframe)
1✔
865

866
    # Momentum calibration work flow
867
    # 1. Calculate momentum calibration
868
    @call_logger(logger)
1✔
869
    def calibrate_momentum_axes(
1✔
870
        self,
871
        point_a: np.ndarray | list[int] = None,
872
        point_b: np.ndarray | list[int] = None,
873
        k_distance: float = None,
874
        k_coord_a: np.ndarray | list[float] = None,
875
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
876
        equiscale: bool = True,
877
        apply=False,
878
    ):
879
        """1. step of the momentum calibration workflow. Calibrate momentum
880
        axes using either provided pixel coordinates of a high-symmetry point and its
881
        distance to the BZ center, or the k-coordinates of two points in the BZ
882
        (depending on the equiscale option). Opens an interactive panel for selecting
883
        the points.
884

885
        Args:
886
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
887
                point used for momentum calibration.
888
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
889
                second point used for momentum calibration.
890
                Defaults to config["momentum"]["center_pixel"].
891
            k_distance (float, optional): Momentum distance between point a and b.
892
                Needs to be provided if no specific k-coordinates for the two points
893
                are given. Defaults to None.
894
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
895
                of the first point used for calibration. Used if equiscale is False.
896
                Defaults to None.
897
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
898
                of the second point used for calibration. Defaults to [0.0, 0.0].
899
            equiscale (bool, optional): Option to apply different scales to kx and ky.
900
                If True, the distance between points a and b, and the absolute
901
                position of point a are used for defining the scale. If False, the
902
                scale is calculated from the k-positions of both points a and b.
903
                Defaults to True.
904
            apply (bool, optional): Option to directly store the momentum calibration
905
                in the class. Defaults to False.
906
        """
907
        if point_b is None:
1✔
908
            point_b = self._config["momentum"]["center_pixel"]
1✔
909

910
        self.mc.select_k_range(
1✔
911
            point_a=point_a,
912
            point_b=point_b,
913
            k_distance=k_distance,
914
            k_coord_a=k_coord_a,
915
            k_coord_b=k_coord_b,
916
            equiscale=equiscale,
917
            apply=apply,
918
        )
919

920
    # 1a. Save momentum calibration parameters to config file.
921
    def save_momentum_calibration(
1✔
922
        self,
923
        filename: str = None,
924
        overwrite: bool = False,
925
    ):
926
        """Save the generated momentum calibration parameters to the folder config file.
927

928
        Args:
929
            filename (str, optional): Filename of the config dictionary to save to.
930
                Defaults to "sed_config.yaml" in the current folder.
931
            overwrite (bool, optional): Option to overwrite the present dictionary.
932
                Defaults to False.
933
        """
934
        if filename is None:
1✔
935
            filename = "sed_config.yaml"
×
936
        if len(self.mc.calibration) == 0:
1✔
937
            raise ValueError("No momentum calibration parameters to save!")
×
938
        calibration = {}
1✔
939
        for key, value in self.mc.calibration.items():
1✔
940
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
941
                continue
1✔
942
            elif key == "creation_date":
1✔
943
                calibration[key] = value.isoformat()
1✔
944
            else:
945
                calibration[key] = float(value)
1✔
946

947
        if "creation_date" not in calibration:
1✔
NEW
948
            calibration["creation_date"] = datetime.now().isoformat()
×
949

950
        config = {"momentum": {"calibration": calibration}}
1✔
951
        save_config(config, filename, overwrite)
1✔
952
        logger.info(f"Saved momentum calibration parameters to {filename}")
1✔
953

954
    # 2. Apply correction and calibration to the dataframe
955
    @call_logger(logger)
1✔
956
    def apply_momentum_calibration(
1✔
957
        self,
958
        calibration: dict = None,
959
        preview: bool = False,
960
        **kwds,
961
    ):
962
        """2. step of the momentum calibration work flow: Apply the momentum
963
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
964
        these are used.
965

966
        Args:
967
            calibration (dict, optional): Optional dictionary with calibration data to
968
                use. Defaults to None.
969
            preview (bool, optional): Option to preview the first elements of the data frame.
970
                Defaults to False.
971
            **kwds: Keyword args passed to ``MomentumCalibrator.append_k_axis``.
972
        """
973
        x_column = self._config["dataframe"]["columns"]["x"]
1✔
974
        y_column = self._config["dataframe"]["columns"]["y"]
1✔
975

976
        if self._dataframe is not None:
1✔
977
            logger.info("Adding kx/ky columns to dataframe:")
1✔
978
            df, metadata = self.mc.append_k_axis(
1✔
979
                df=self._dataframe,
980
                calibration=calibration,
981
                **kwds,
982
            )
983
            if (
1✔
984
                self._timed_dataframe is not None
985
                and x_column in self._timed_dataframe.columns
986
                and y_column in self._timed_dataframe.columns
987
            ):
988
                tdf, _ = self.mc.append_k_axis(
1✔
989
                    df=self._timed_dataframe,
990
                    calibration=calibration,
991
                    suppress_output=True,
992
                    **kwds,
993
                )
994

995
            # Add Metadata
996
            self._attributes.add(
1✔
997
                metadata,
998
                "momentum_calibration",
999
                duplicate_policy="merge",
1000
            )
1001
            self._dataframe = df
1✔
1002
            if (
1✔
1003
                self._timed_dataframe is not None
1004
                and x_column in self._timed_dataframe.columns
1005
                and y_column in self._timed_dataframe.columns
1006
            ):
1007
                self._timed_dataframe = tdf
1✔
1008
        else:
1009
            raise ValueError("No dataframe loaded!")
×
1010
        if preview:
1✔
1011
            logger.info(self._dataframe.head(10))
×
1012
        else:
1013
            logger.info(self._dataframe)
1✔
1014

1015
    # Energy correction workflow
1016
    # 1. Adjust the energy correction parameters
1017
    @call_logger(logger)
1✔
1018
    def adjust_energy_correction(
1✔
1019
        self,
1020
        correction_type: str = None,
1021
        amplitude: float = None,
1022
        center: tuple[float, float] = None,
1023
        apply=False,
1024
        **kwds,
1025
    ):
1026
        """1. step of the energy correction workflow: Opens an interactive plot to
1027
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
1028
        they are not present yet.
1029

1030
        Args:
1031
            correction_type (str, optional): Type of correction to apply to the TOF
1032
                axis. Valid values are:
1033

1034
                - 'spherical'
1035
                - 'Lorentzian'
1036
                - 'Gaussian'
1037
                - 'Lorentzian_asymmetric'
1038

1039
                Defaults to config["energy"]["correction_type"].
1040
            amplitude (float, optional): Amplitude of the correction.
1041
                Defaults to config["energy"]["correction"]["amplitude"].
1042
            center (tuple[float, float], optional): Center X/Y coordinates for the
1043
                correction. Defaults to config["energy"]["correction"]["center"].
1044
            apply (bool, optional): Option to directly apply the provided or default
1045
                correction parameters. Defaults to False.
1046
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1047
        """
1048
        if self._pre_binned is None:
1✔
1049
            logger.warn("Pre-binned data not present, binning using defaults from config...")
1✔
1050
            self._pre_binned = self.pre_binning()
1✔
1051

1052
        self.ec.adjust_energy_correction(
1✔
1053
            self._pre_binned,
1054
            correction_type=correction_type,
1055
            amplitude=amplitude,
1056
            center=center,
1057
            apply=apply,
1058
            **kwds,
1059
        )
1060

1061
    # 1a. Save energy correction parameters to config file.
1062
    def save_energy_correction(
1✔
1063
        self,
1064
        filename: str = None,
1065
        overwrite: bool = False,
1066
    ):
1067
        """Save the generated energy correction parameters to the folder config file.
1068

1069
        Args:
1070
            filename (str, optional): Filename of the config dictionary to save to.
1071
                Defaults to "sed_config.yaml" in the current folder.
1072
            overwrite (bool, optional): Option to overwrite the present dictionary.
1073
                Defaults to False.
1074
        """
1075
        if filename is None:
1✔
1076
            filename = "sed_config.yaml"
1✔
1077
        if len(self.ec.correction) == 0:
1✔
1078
            raise ValueError("No energy correction parameters to save!")
×
1079
        correction = {}
1✔
1080
        for key, value in self.ec.correction.items():
1✔
1081
            if key == "correction_type":
1✔
1082
                correction[key] = value
1✔
1083
            elif key == "center":
1✔
1084
                correction[key] = [float(i) for i in value]
1✔
1085
            elif key == "creation_date":
1✔
1086
                correction[key] = value.isoformat()
1✔
1087
            else:
1088
                correction[key] = float(value)
1✔
1089

1090
        if "creation_date" not in correction:
1✔
NEW
1091
            correction["creation_date"] = datetime.now().isoformat()
×
1092

1093
        config = {"energy": {"correction": correction}}
1✔
1094
        save_config(config, filename, overwrite)
1✔
1095
        logger.info(f"Saved energy correction parameters to {filename}")
1✔
1096

1097
    # 2. Apply energy correction to dataframe
1098
    @call_logger(logger)
1✔
1099
    def apply_energy_correction(
1✔
1100
        self,
1101
        correction: dict = None,
1102
        preview: bool = False,
1103
        **kwds,
1104
    ):
1105
        """2. step of the energy correction workflow: Apply the energy correction
1106
        parameters stored in the class to the dataframe.
1107

1108
        Args:
1109
            correction (dict, optional): Dictionary containing the correction
1110
                parameters. Defaults to config["energy"]["calibration"].
1111
            preview (bool, optional): Option to preview the first elements of the data frame.
1112
                Defaults to False.
1113
            **kwds:
1114
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1115
        """
1116
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1117

1118
        if self._dataframe is not None:
1✔
1119
            logger.info("Applying energy correction to dataframe...")
1✔
1120
            df, metadata = self.ec.apply_energy_correction(
1✔
1121
                df=self._dataframe,
1122
                correction=correction,
1123
                **kwds,
1124
            )
1125
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1126
                tdf, _ = self.ec.apply_energy_correction(
1✔
1127
                    df=self._timed_dataframe,
1128
                    correction=correction,
1129
                    suppress_output=True,
1130
                    **kwds,
1131
                )
1132

1133
            # Add Metadata
1134
            self._attributes.add(
1✔
1135
                metadata,
1136
                "energy_correction",
1137
            )
1138
            self._dataframe = df
1✔
1139
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1140
                self._timed_dataframe = tdf
1✔
1141
        else:
1142
            raise ValueError("No dataframe loaded!")
×
1143
        if preview:
1✔
1144
            logger.info(self._dataframe.head(10))
×
1145
        else:
1146
            logger.info(self._dataframe)
1✔
1147

1148
    # Energy calibrator workflow
1149
    # 1. Load and normalize data
1150
    @call_logger(logger)
1✔
1151
    def load_bias_series(
1✔
1152
        self,
1153
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1154
        data_files: list[str] = None,
1155
        axes: list[str] = None,
1156
        bins: list = None,
1157
        ranges: Sequence[tuple[float, float]] = None,
1158
        biases: np.ndarray = None,
1159
        bias_key: str = None,
1160
        normalize: bool = None,
1161
        span: int = None,
1162
        order: int = None,
1163
    ):
1164
        """1. step of the energy calibration workflow: Load and bin data from
1165
        single-event files, or load binned bias/TOF traces.
1166

1167
        Args:
1168
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1169
                Binned data If provided as DataArray, Needs to contain dimensions
1170
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1171
                provided as tuple, needs to contain elements tof, biases, traces.
1172
            data_files (list[str], optional): list of file paths to bin
1173
            axes (list[str], optional): bin axes.
1174
                Defaults to config["dataframe"]["tof_column"].
1175
            bins (list, optional): number of bins.
1176
                Defaults to config["energy"]["bins"].
1177
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1178
                Defaults to config["energy"]["ranges"].
1179
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1180
                voltages are extracted from the data files.
1181
            bias_key (str, optional): hdf5 path where bias values are stored.
1182
                Defaults to config["energy"]["bias_key"].
1183
            normalize (bool, optional): Option to normalize traces.
1184
                Defaults to config["energy"]["normalize"].
1185
            span (int, optional): span smoothing parameters of the LOESS method
1186
                (see ``scipy.signal.savgol_filter()``).
1187
                Defaults to config["energy"]["normalize_span"].
1188
            order (int, optional): order smoothing parameters of the LOESS method
1189
                (see ``scipy.signal.savgol_filter()``).
1190
                Defaults to config["energy"]["normalize_order"].
1191
        """
1192
        if binned_data is not None:
1✔
1193
            if isinstance(binned_data, xr.DataArray):
1✔
1194
                if (
1✔
1195
                    self._config["dataframe"]["columns"]["tof"] not in binned_data.dims
1196
                    or self._config["dataframe"]["columns"]["bias"] not in binned_data.dims
1197
                ):
1198
                    raise ValueError(
1✔
1199
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1200
                        f"'{self._config['dataframe']['columns']['tof']}' and "
1201
                        f"'{self._config['dataframe']['columns']['bias']}'!.",
1202
                    )
1203
                tof = binned_data.coords[self._config["dataframe"]["columns"]["tof"]].values
1✔
1204
                biases = binned_data.coords[self._config["dataframe"]["columns"]["bias"]].values
1✔
1205
                traces = binned_data.values[:, :]
1✔
1206
            else:
1207
                try:
1✔
1208
                    (tof, biases, traces) = binned_data
1✔
1209
                except ValueError as exc:
1✔
1210
                    raise ValueError(
1✔
1211
                        "If binned_data is provided as tuple, it needs to contain "
1212
                        "(tof, biases, traces)!",
1213
                    ) from exc
1214
            logger.debug(f"Energy calibration data loaded from binned data. Bias values={biases}.")
1✔
1215
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1216

1217
        elif data_files is not None:
1✔
1218
            self.ec.bin_data(
1✔
1219
                data_files=cast(list[str], self.cpy(data_files)),
1220
                axes=axes,
1221
                bins=bins,
1222
                ranges=ranges,
1223
                biases=biases,
1224
                bias_key=bias_key,
1225
            )
1226
            logger.debug(
1✔
1227
                f"Energy calibration data binned from files {data_files} data. "
1228
                f"Bias values={biases}.",
1229
            )
1230

1231
        else:
1232
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1233

1234
        if (normalize is not None and normalize is True) or (
1✔
1235
            normalize is None and self._config["energy"]["normalize"]
1236
        ):
1237
            if span is None:
1✔
1238
                span = self._config["energy"]["normalize_span"]
1✔
1239
            if order is None:
1✔
1240
                order = self._config["energy"]["normalize_order"]
1✔
1241
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1242
        self.ec.view(
1✔
1243
            traces=self.ec.traces_normed,
1244
            xaxis=self.ec.tof,
1245
            backend="bokeh",
1246
        )
1247

1248
    # 2. extract ranges and get peak positions
1249
    @call_logger(logger)
1✔
1250
    def find_bias_peaks(
1✔
1251
        self,
1252
        ranges: list[tuple] | tuple,
1253
        ref_id: int = 0,
1254
        infer_others: bool = True,
1255
        mode: str = "replace",
1256
        radius: int = None,
1257
        peak_window: int = None,
1258
        apply: bool = False,
1259
    ):
1260
        """2. step of the energy calibration workflow: Find a peak within a given range
1261
        for the indicated reference trace, and tries to find the same peak for all
1262
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1263
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1264
        middle of the set, and don't choose the range too narrow around the peak.
1265
        Alternatively, a list of ranges for all traces can be provided.
1266

1267
        Args:
1268
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1269
                Alternatively, a list of ranges for all traces can be given.
1270
            ref_id (int, optional): The id of the trace the range refers to.
1271
                Defaults to 0.
1272
            infer_others (bool, optional): Whether to determine the range for the other
1273
                traces. Defaults to True.
1274
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1275
                Defaults to "replace".
1276
            radius (int, optional): Radius parameter for fast_dtw.
1277
                Defaults to config["energy"]["fastdtw_radius"].
1278
            peak_window (int, optional): Peak_window parameter for the peak detection
1279
                algorithm. amount of points that have to have to behave monotonously
1280
                around a peak. Defaults to config["energy"]["peak_window"].
1281
            apply (bool, optional): Option to directly apply the provided parameters.
1282
                Defaults to False.
1283
        """
1284
        if radius is None:
1✔
1285
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1286
        if peak_window is None:
1✔
1287
            peak_window = self._config["energy"]["peak_window"]
1✔
1288
        if not infer_others:
1✔
1289
            self.ec.add_ranges(
1✔
1290
                ranges=ranges,
1291
                ref_id=ref_id,
1292
                infer_others=infer_others,
1293
                mode=mode,
1294
                radius=radius,
1295
            )
1296
            logger.info(f"Use feature ranges: {self.ec.featranges}.")
1✔
1297
            try:
1✔
1298
                self.ec.feature_extract(peak_window=peak_window)
1✔
1299
                logger.info(f"Extracted energy features: {self.ec.peaks}.")
1✔
1300
                self.ec.view(
1✔
1301
                    traces=self.ec.traces_normed,
1302
                    segs=self.ec.featranges,
1303
                    xaxis=self.ec.tof,
1304
                    peaks=self.ec.peaks,
1305
                    backend="bokeh",
1306
                )
1307
            except IndexError:
×
1308
                logger.error("Could not determine all peaks!")
×
1309
                raise
×
1310
        else:
1311
            # New adjustment tool
1312
            assert isinstance(ranges, tuple)
1✔
1313
            self.ec.adjust_ranges(
1✔
1314
                ranges=ranges,
1315
                ref_id=ref_id,
1316
                traces=self.ec.traces_normed,
1317
                infer_others=infer_others,
1318
                radius=radius,
1319
                peak_window=peak_window,
1320
                apply=apply,
1321
            )
1322

1323
    # 3. Fit the energy calibration relation
1324
    @call_logger(logger)
1✔
1325
    def calibrate_energy_axis(
1✔
1326
        self,
1327
        ref_energy: float,
1328
        method: str = None,
1329
        energy_scale: str = None,
1330
        **kwds,
1331
    ):
1332
        """3. Step of the energy calibration workflow: Calculate the calibration
1333
        function for the energy axis, and apply it to the dataframe. Two
1334
        approximations are implemented, a (normally 3rd order) polynomial
1335
        approximation, and a d^2/(t-t0)^2 relation.
1336

1337
        Args:
1338
            ref_energy (float): Binding/kinetic energy of the detected feature.
1339
            method (str, optional): Method for determining the energy calibration.
1340

1341
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1342
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1343

1344
                Defaults to config["energy"]["calibration_method"]
1345
            energy_scale (str, optional): Direction of increasing energy scale.
1346

1347
                - **'kinetic'**: increasing energy with decreasing TOF.
1348
                - **'binding'**: increasing energy with increasing TOF.
1349

1350
                Defaults to config["energy"]["energy_scale"]
1351
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1352
        """
1353
        if method is None:
1✔
1354
            method = self._config["energy"]["calibration_method"]
1✔
1355

1356
        if energy_scale is None:
1✔
1357
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1358

1359
        self.ec.calibrate(
1✔
1360
            ref_energy=ref_energy,
1361
            method=method,
1362
            energy_scale=energy_scale,
1363
            **kwds,
1364
        )
1365
        if self._verbose:
1✔
1366
            self.ec.view(
1✔
1367
                traces=self.ec.traces_normed,
1368
                xaxis=self.ec.calibration["axis"],
1369
                align=True,
1370
                energy_scale=energy_scale,
1371
                backend="matplotlib",
1372
                title="Quality of Calibration",
1373
            )
1374
            plt.xlabel("Energy (eV)")
1✔
1375
            plt.ylabel("Intensity")
1✔
1376
            plt.tight_layout()
1✔
1377
            plt.show()
1✔
1378
            if energy_scale == "kinetic":
1✔
1379
                self.ec.view(
1✔
1380
                    traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1381
                    xaxis=self.ec.tof,
1382
                    backend="matplotlib",
1383
                    show_legend=False,
1384
                    title="E/TOF relationship",
1385
                )
1386
                plt.scatter(
1✔
1387
                    self.ec.peaks[:, 0],
1388
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1389
                    s=50,
1390
                    c="k",
1391
                )
1392
                plt.tight_layout()
1✔
1393
            elif energy_scale == "binding":
1✔
1394
                self.ec.view(
1✔
1395
                    traces=self.ec.calibration["axis"][None, :] - self.ec.biases[0],
1396
                    xaxis=self.ec.tof,
1397
                    backend="matplotlib",
1398
                    show_legend=False,
1399
                    title="E/TOF relationship",
1400
                )
1401
                plt.scatter(
1✔
1402
                    self.ec.peaks[:, 0],
1403
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1404
                    s=50,
1405
                    c="k",
1406
                )
1407
            else:
1408
                raise ValueError(
×
1409
                    'energy_scale needs to be either "binding" or "kinetic"',
1410
                    f", got {energy_scale}.",
1411
                )
1412
            plt.xlabel("Time-of-flight")
1✔
1413
            plt.ylabel("Energy (eV)")
1✔
1414
            plt.tight_layout()
1✔
1415
            plt.show()
1✔
1416

1417
    # 3a. Save energy calibration parameters to config file.
1418
    def save_energy_calibration(
1✔
1419
        self,
1420
        filename: str = None,
1421
        overwrite: bool = False,
1422
    ):
1423
        """Save the generated energy calibration parameters to the folder config file.
1424

1425
        Args:
1426
            filename (str, optional): Filename of the config dictionary to save to.
1427
                Defaults to "sed_config.yaml" in the current folder.
1428
            overwrite (bool, optional): Option to overwrite the present dictionary.
1429
                Defaults to False.
1430
        """
1431
        if filename is None:
1✔
1432
            filename = "sed_config.yaml"
×
1433
        if len(self.ec.calibration) == 0:
1✔
1434
            raise ValueError("No energy calibration parameters to save!")
×
1435
        calibration = {}
1✔
1436
        for key, value in self.ec.calibration.items():
1✔
1437
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1438
                continue
1✔
1439
            if key == "energy_scale":
1✔
1440
                calibration[key] = value
1✔
1441
            elif key == "coeffs":
1✔
1442
                calibration[key] = [float(i) for i in value]
1✔
1443
            elif key == "creation_date":
1✔
1444
                calibration[key] = value.isoformat()
1✔
1445
            else:
1446
                calibration[key] = float(value)
1✔
1447

1448
        if "creation_date" not in calibration:
1✔
NEW
1449
            calibration["creation_date"] = datetime.now().isoformat()
×
1450

1451
        config = {"energy": {"calibration": calibration}}
1✔
1452
        save_config(config, filename, overwrite)
1✔
1453
        logger.info(f'Saved energy calibration parameters to "{filename}".')
1✔
1454

1455
    # 4. Apply energy calibration to the dataframe
1456
    @call_logger(logger)
1✔
1457
    def append_energy_axis(
1✔
1458
        self,
1459
        calibration: dict = None,
1460
        bias_voltage: float = None,
1461
        preview: bool = False,
1462
        **kwds,
1463
    ):
1464
        """4. step of the energy calibration workflow: Apply the calibration function
1465
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1466
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1467
        can be provided.
1468

1469
        Args:
1470
            calibration (dict, optional): Calibration dict containing calibration
1471
                parameters. Overrides calibration from class or config.
1472
                Defaults to None.
1473
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1474
                the bias voltage is being read from the dataframe. If it is not found there,
1475
                a warning is printed and the calibrated data might have an offset.
1476
            preview (bool): Option to preview the first elements of the data frame.
1477
            **kwds:
1478
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1479
        """
1480
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1481

1482
        if self._dataframe is not None:
1✔
1483
            logger.info("Adding energy column to dataframe:")
1✔
1484
            df, metadata = self.ec.append_energy_axis(
1✔
1485
                df=self._dataframe,
1486
                calibration=calibration,
1487
                bias_voltage=bias_voltage,
1488
                **kwds,
1489
            )
1490
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1491
                tdf, _ = self.ec.append_energy_axis(
1✔
1492
                    df=self._timed_dataframe,
1493
                    calibration=calibration,
1494
                    bias_voltage=bias_voltage,
1495
                    suppress_output=True,
1496
                    **kwds,
1497
                )
1498

1499
            # Add Metadata
1500
            self._attributes.add(
1✔
1501
                metadata,
1502
                "energy_calibration",
1503
                duplicate_policy="merge",
1504
            )
1505
            self._dataframe = df
1✔
1506
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1507
                self._timed_dataframe = tdf
1✔
1508

1509
        else:
1510
            raise ValueError("No dataframe loaded!")
×
1511
        if preview:
1✔
1512
            logger.info(self._dataframe.head(10))
×
1513
        else:
1514
            logger.info(self._dataframe)
1✔
1515

1516
    @call_logger(logger)
1✔
1517
    def add_energy_offset(
1✔
1518
        self,
1519
        constant: float = None,
1520
        columns: str | Sequence[str] = None,
1521
        weights: float | Sequence[float] = None,
1522
        reductions: str | Sequence[str] = None,
1523
        preserve_mean: bool | Sequence[bool] = None,
1524
        preview: bool = False,
1525
    ) -> None:
1526
        """Shift the energy axis of the dataframe by a given amount.
1527

1528
        Args:
1529
            constant (float, optional): The constant to shift the energy axis by.
1530
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1531
            weights (float | Sequence[float], optional): weights to apply to the columns.
1532
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1533
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1534
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1535
                case the function is applied to the column to generate a single value for the whole
1536
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1537
                Currently only "mean" is supported.
1538
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1539
                column before applying the shift. Defaults to False.
1540
            preview (bool, optional): Option to preview the first elements of the data frame.
1541
                Defaults to False.
1542

1543
        Raises:
1544
            ValueError: If the energy column is not in the dataframe.
1545
        """
1546
        energy_column = self._config["dataframe"]["columns"]["energy"]
1✔
1547
        if energy_column not in self._dataframe.columns:
1✔
1548
            raise ValueError(
1✔
1549
                f"Energy column {energy_column} not found in dataframe! "
1550
                "Run `append_energy_axis()` first.",
1551
            )
1552
        if self.dataframe is not None:
1✔
1553
            logger.info("Adding energy offset to dataframe:")
1✔
1554
            df, metadata = self.ec.add_offsets(
1✔
1555
                df=self._dataframe,
1556
                constant=constant,
1557
                columns=columns,
1558
                energy_column=energy_column,
1559
                weights=weights,
1560
                reductions=reductions,
1561
                preserve_mean=preserve_mean,
1562
            )
1563
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1564
                tdf, _ = self.ec.add_offsets(
1✔
1565
                    df=self._timed_dataframe,
1566
                    constant=constant,
1567
                    columns=columns,
1568
                    energy_column=energy_column,
1569
                    weights=weights,
1570
                    reductions=reductions,
1571
                    preserve_mean=preserve_mean,
1572
                    suppress_output=True,
1573
                )
1574

1575
            self._attributes.add(
1✔
1576
                metadata,
1577
                "add_energy_offset",
1578
                # TODO: allow only appending when no offset along this column(s) was applied
1579
                # TODO: clear memory of modifications if the energy axis is recalculated
1580
                duplicate_policy="append",
1581
            )
1582
            self._dataframe = df
1✔
1583
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1584
                self._timed_dataframe = tdf
1✔
1585
        else:
1586
            raise ValueError("No dataframe loaded!")
×
1587
        if preview:
1✔
1588
            logger.info(self._dataframe.head(10))
×
1589
        else:
1590
            logger.info(self._dataframe)
1✔
1591

1592
    def save_energy_offset(
1✔
1593
        self,
1594
        filename: str = None,
1595
        overwrite: bool = False,
1596
    ):
1597
        """Save the generated energy calibration parameters to the folder config file.
1598

1599
        Args:
1600
            filename (str, optional): Filename of the config dictionary to save to.
1601
                Defaults to "sed_config.yaml" in the current folder.
1602
            overwrite (bool, optional): Option to overwrite the present dictionary.
1603
                Defaults to False.
1604
        """
1605
        if filename is None:
×
1606
            filename = "sed_config.yaml"
×
1607
        if len(self.ec.offsets) == 0:
×
1608
            raise ValueError("No energy offset parameters to save!")
×
1609

NEW
1610
        offsets = deepcopy(self.ec.offsets)
×
1611

NEW
1612
        if "creation_date" not in offsets.keys():
×
NEW
1613
            offsets["creation_date"] = datetime.now()
×
1614

NEW
1615
        offsets["creation_date"] = offsets["creation_date"].isoformat()
×
1616

NEW
1617
        config = {"energy": {"offsets": offsets}}
×
1618
        save_config(config, filename, overwrite)
×
1619
        logger.info(f'Saved energy offset parameters to "{filename}".')
×
1620

1621
    @call_logger(logger)
1✔
1622
    def append_tof_ns_axis(
1✔
1623
        self,
1624
        preview: bool = False,
1625
        **kwds,
1626
    ):
1627
        """Convert time-of-flight channel steps to nanoseconds.
1628

1629
        Args:
1630
            tof_ns_column (str, optional): Name of the generated column containing the
1631
                time-of-flight in nanosecond.
1632
                Defaults to config["dataframe"]["tof_ns_column"].
1633
            preview (bool, optional): Option to preview the first elements of the data frame.
1634
                Defaults to False.
1635
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1636

1637
        """
1638
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1639

1640
        if self._dataframe is not None:
1✔
1641
            logger.info("Adding time-of-flight column in nanoseconds to dataframe.")
1✔
1642
            # TODO assert order of execution through metadata
1643

1644
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1645
                df=self._dataframe,
1646
                **kwds,
1647
            )
1648
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1649
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1650
                    df=self._timed_dataframe,
1651
                    **kwds,
1652
                )
1653

1654
            self._attributes.add(
1✔
1655
                metadata,
1656
                "tof_ns_conversion",
1657
                duplicate_policy="overwrite",
1658
            )
1659
            self._dataframe = df
1✔
1660
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1661
                self._timed_dataframe = tdf
1✔
1662
        else:
1663
            raise ValueError("No dataframe loaded!")
×
1664
        if preview:
1✔
1665
            logger.info(self._dataframe.head(10))
×
1666
        else:
1667
            logger.info(self._dataframe)
1✔
1668

1669
    @call_logger(logger)
1✔
1670
    def align_dld_sectors(
1✔
1671
        self,
1672
        sector_delays: np.ndarray = None,
1673
        preview: bool = False,
1674
        **kwds,
1675
    ):
1676
        """Align the 8s sectors of the HEXTOF endstation.
1677

1678
        Args:
1679
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1680
                config["dataframe"]["sector_delays"].
1681
            preview (bool, optional): Option to preview the first elements of the data frame.
1682
                Defaults to False.
1683
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1684
        """
1685
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1686

1687
        if self._dataframe is not None:
1✔
1688
            logger.info("Aligning 8s sectors of dataframe")
1✔
1689
            # TODO assert order of execution through metadata
1690

1691
            df, metadata = self.ec.align_dld_sectors(
1✔
1692
                df=self._dataframe,
1693
                sector_delays=sector_delays,
1694
                **kwds,
1695
            )
1696
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1697
                tdf, _ = self.ec.align_dld_sectors(
×
1698
                    df=self._timed_dataframe,
1699
                    sector_delays=sector_delays,
1700
                    **kwds,
1701
                )
1702

1703
            self._attributes.add(
1✔
1704
                metadata,
1705
                "dld_sector_alignment",
1706
                duplicate_policy="raise",
1707
            )
1708
            self._dataframe = df
1✔
1709
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1710
                self._timed_dataframe = tdf
×
1711
        else:
1712
            raise ValueError("No dataframe loaded!")
×
1713
        if preview:
1✔
1714
            logger.info(self._dataframe.head(10))
×
1715
        else:
1716
            logger.info(self._dataframe)
1✔
1717

1718
    # Delay calibration function
1719
    @call_logger(logger)
1✔
1720
    def calibrate_delay_axis(
1✔
1721
        self,
1722
        delay_range: tuple[float, float] = None,
1723
        datafile: str = None,
1724
        preview: bool = False,
1725
        **kwds,
1726
    ):
1727
        """Append delay column to dataframe. Either provide delay ranges, or read
1728
        them from a file.
1729

1730
        Args:
1731
            delay_range (tuple[float, float], optional): The scanned delay range in
1732
                picoseconds. Defaults to None.
1733
            datafile (str, optional): The file from which to read the delay ranges.
1734
                Defaults to None.
1735
            preview (bool, optional): Option to preview the first elements of the data frame.
1736
                Defaults to False.
1737
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1738
        """
1739
        adc_column = self._config["dataframe"]["columns"]["adc"]
1✔
1740
        if adc_column not in self._dataframe.columns:
1✔
1741
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1742

1743
        if self._dataframe is not None:
1✔
1744
            logger.info("Adding delay column to dataframe:")
1✔
1745

1746
            if delay_range is None and datafile is None:
1✔
1747
                if len(self.dc.calibration) == 0:
1✔
1748
                    try:
1✔
1749
                        datafile = self._files[0]
1✔
1750
                    except IndexError as exc:
×
1751
                        raise IndexError(
×
1752
                            "No datafile available, specify either 'datafile' or 'delay_range'",
1753
                        ) from exc
1754

1755
            df, metadata = self.dc.append_delay_axis(
1✔
1756
                self._dataframe,
1757
                delay_range=delay_range,
1758
                datafile=datafile,
1759
                **kwds,
1760
            )
1761
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1762
                tdf, _ = self.dc.append_delay_axis(
1✔
1763
                    self._timed_dataframe,
1764
                    delay_range=delay_range,
1765
                    datafile=datafile,
1766
                    suppress_output=True,
1767
                    **kwds,
1768
                )
1769

1770
            # Add Metadata
1771
            self._attributes.add(
1✔
1772
                metadata,
1773
                "delay_calibration",
1774
                duplicate_policy="overwrite",
1775
            )
1776
            self._dataframe = df
1✔
1777
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1778
                self._timed_dataframe = tdf
1✔
1779
        else:
1780
            raise ValueError("No dataframe loaded!")
×
1781
        if preview:
1✔
1782
            logger.info(self._dataframe.head(10))
1✔
1783
        else:
1784
            logger.debug(self._dataframe)
1✔
1785

1786
    def save_delay_calibration(
1✔
1787
        self,
1788
        filename: str = None,
1789
        overwrite: bool = False,
1790
    ) -> None:
1791
        """Save the generated delay calibration parameters to the folder config file.
1792

1793
        Args:
1794
            filename (str, optional): Filename of the config dictionary to save to.
1795
                Defaults to "sed_config.yaml" in the current folder.
1796
            overwrite (bool, optional): Option to overwrite the present dictionary.
1797
                Defaults to False.
1798
        """
1799
        if filename is None:
1✔
1800
            filename = "sed_config.yaml"
×
1801

1802
        if len(self.dc.calibration) == 0:
1✔
1803
            raise ValueError("No delay calibration parameters to save!")
×
1804
        calibration = {}
1✔
1805
        for key, value in self.dc.calibration.items():
1✔
1806
            if key == "datafile":
1✔
1807
                calibration[key] = value
1✔
1808
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1809
                calibration[key] = [float(i) for i in value]
1✔
1810
            elif key == "creation_date":
1✔
1811
                calibration[key] = value.isoformat()
1✔
1812
            else:
1813
                calibration[key] = float(value)
1✔
1814

1815
        if "creation_date" not in calibration:
1✔
NEW
1816
            calibration["creation_date"] = datetime.now().isoformat()
×
1817

1818
        config = {
1✔
1819
            "delay": {
1820
                "calibration": calibration,
1821
            },
1822
        }
1823
        save_config(config, filename, overwrite)
1✔
1824

1825
    @call_logger(logger)
1✔
1826
    def add_delay_offset(
1✔
1827
        self,
1828
        constant: float = None,
1829
        flip_delay_axis: bool = None,
1830
        columns: str | Sequence[str] = None,
1831
        weights: float | Sequence[float] = 1.0,
1832
        reductions: str | Sequence[str] = None,
1833
        preserve_mean: bool | Sequence[bool] = False,
1834
        preview: bool = False,
1835
    ) -> None:
1836
        """Shift the delay axis of the dataframe by a constant or other columns.
1837

1838
        Args:
1839
            constant (float, optional): The constant to shift the delay axis by.
1840
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1841
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1842
            weights (float | Sequence[float], optional): weights to apply to the columns.
1843
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1844
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1845
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1846
                case the function is applied to the column to generate a single value for the whole
1847
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1848
                Currently only "mean" is supported.
1849
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1850
                column before applying the shift. Defaults to False.
1851
            preview (bool, optional): Option to preview the first elements of the data frame.
1852
                Defaults to False.
1853

1854
        Raises:
1855
            ValueError: If the delay column is not in the dataframe.
1856
        """
1857
        delay_column = self._config["dataframe"]["columns"]["delay"]
1✔
1858
        if delay_column not in self._dataframe.columns:
1✔
1859
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1860

1861
        if self.dataframe is not None:
1✔
1862
            logger.info("Adding delay offset to dataframe:")
1✔
1863
            df, metadata = self.dc.add_offsets(
1✔
1864
                df=self._dataframe,
1865
                constant=constant,
1866
                flip_delay_axis=flip_delay_axis,
1867
                columns=columns,
1868
                delay_column=delay_column,
1869
                weights=weights,
1870
                reductions=reductions,
1871
                preserve_mean=preserve_mean,
1872
            )
1873
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1874
                tdf, _ = self.dc.add_offsets(
1✔
1875
                    df=self._timed_dataframe,
1876
                    constant=constant,
1877
                    flip_delay_axis=flip_delay_axis,
1878
                    columns=columns,
1879
                    delay_column=delay_column,
1880
                    weights=weights,
1881
                    reductions=reductions,
1882
                    preserve_mean=preserve_mean,
1883
                    suppress_output=True,
1884
                )
1885

1886
            self._attributes.add(
1✔
1887
                metadata,
1888
                "delay_offset",
1889
                duplicate_policy="append",
1890
            )
1891
            self._dataframe = df
1✔
1892
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1893
                self._timed_dataframe = tdf
1✔
1894
        else:
1895
            raise ValueError("No dataframe loaded!")
×
1896
        if preview:
1✔
1897
            logger.info(self._dataframe.head(10))
1✔
1898
        else:
1899
            logger.info(self._dataframe)
1✔
1900

1901
    def save_delay_offsets(
1✔
1902
        self,
1903
        filename: str = None,
1904
        overwrite: bool = False,
1905
    ) -> None:
1906
        """Save the generated delay calibration parameters to the folder config file.
1907

1908
        Args:
1909
            filename (str, optional): Filename of the config dictionary to save to.
1910
                Defaults to "sed_config.yaml" in the current folder.
1911
            overwrite (bool, optional): Option to overwrite the present dictionary.
1912
                Defaults to False.
1913
        """
1914
        if filename is None:
1✔
1915
            filename = "sed_config.yaml"
×
1916
        if len(self.dc.offsets) == 0:
1✔
1917
            raise ValueError("No delay offset parameters to save!")
×
1918

1919
        offsets = deepcopy(self.dc.offsets)
1✔
1920

1921
        if "creation_date" not in offsets.keys():
1✔
NEW
1922
            offsets["creation_date"] = datetime.now()
×
1923

1924
        offsets["creation_date"] = offsets["creation_date"].isoformat()
1✔
1925

1926
        config = {"delay": {"offsets": offsets}}
1✔
1927
        save_config(config, filename, overwrite)
1✔
1928
        logger.info(f'Saved delay offset parameters to "{filename}".')
1✔
1929

1930
    def save_workflow_params(
1✔
1931
        self,
1932
        filename: str = None,
1933
        overwrite: bool = False,
1934
    ) -> None:
1935
        """run all save calibration parameter methods
1936

1937
        Args:
1938
            filename (str, optional): Filename of the config dictionary to save to.
1939
                Defaults to "sed_config.yaml" in the current folder.
1940
            overwrite (bool, optional): Option to overwrite the present dictionary.
1941
                Defaults to False.
1942
        """
1943
        for method in [
×
1944
            self.save_splinewarp,
1945
            self.save_transformations,
1946
            self.save_momentum_calibration,
1947
            self.save_energy_correction,
1948
            self.save_energy_calibration,
1949
            self.save_energy_offset,
1950
            self.save_delay_calibration,
1951
            self.save_delay_offsets,
1952
        ]:
1953
            try:
×
1954
                method(filename, overwrite)
×
1955
            except (ValueError, AttributeError, KeyError):
×
1956
                pass
×
1957

1958
    @call_logger(logger)
1✔
1959
    def add_jitter(
1✔
1960
        self,
1961
        cols: list[str] = None,
1962
        amps: float | Sequence[float] = None,
1963
        **kwds,
1964
    ):
1965
        """Add jitter to the selected dataframe columns.
1966

1967
        Args:
1968
            cols (list[str], optional): The columns onto which to apply jitter.
1969
                Defaults to config["dataframe"]["jitter_cols"].
1970
            amps (float | Sequence[float], optional): Amplitude scalings for the
1971
                jittering noise. If one number is given, the same is used for all axes.
1972
                For uniform noise (default) it will cover the interval [-amp, +amp].
1973
                Defaults to config["dataframe"]["jitter_amps"].
1974
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1975
        """
1976
        if cols is None:
1✔
1977
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1978
        for loc, col in enumerate(cols):
1✔
1979
            if col.startswith("@"):
1✔
1980
                cols[loc] = self._config["dataframe"]["columns"].get(col.strip("@"))
1✔
1981

1982
        if amps is None:
1✔
1983
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1984

1985
        self._dataframe = self._dataframe.map_partitions(
1✔
1986
            apply_jitter,
1987
            cols=cols,
1988
            cols_jittered=cols,
1989
            amps=amps,
1990
            **kwds,
1991
        )
1992
        if self._timed_dataframe is not None:
1✔
1993
            cols_timed = cols.copy()
1✔
1994
            for col in cols:
1✔
1995
                if col not in self._timed_dataframe.columns:
1✔
1996
                    cols_timed.remove(col)
×
1997

1998
            if cols_timed:
1✔
1999
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
2000
                    apply_jitter,
2001
                    cols=cols_timed,
2002
                    cols_jittered=cols_timed,
2003
                )
2004
        metadata = []
1✔
2005
        for col in cols:
1✔
2006
            metadata.append(col)
1✔
2007
        # TODO: allow only appending if columns are not jittered yet
2008
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
2009
        logger.info(f"add_jitter: Added jitter to columns {cols}.")
1✔
2010

2011
    @call_logger(logger)
1✔
2012
    def add_time_stamped_data(
1✔
2013
        self,
2014
        dest_column: str,
2015
        time_stamps: np.ndarray = None,
2016
        data: np.ndarray = None,
2017
        archiver_channel: str = None,
2018
        **kwds,
2019
    ):
2020
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
2021
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
2022
        an EPICS archiver instance.
2023

2024
        Args:
2025
            dest_column (str): destination column name
2026
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
2027
                time stamps are retrieved from the epics archiver
2028
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
2029
                If omitted, data are retrieved from the epics archiver.
2030
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
2031
                Either this or data and time_stamps have to be present.
2032
            **kwds:
2033

2034
                - **time_stamp_column**: Dataframe column containing time-stamp data
2035

2036
                Additional keyword arguments passed to ``add_time_stamped_data``.
2037
        """
2038
        time_stamp_column = kwds.pop(
1✔
2039
            "time_stamp_column",
2040
            self._config["dataframe"]["columns"].get("timestamp", ""),
2041
        )
2042

2043
        if time_stamps is None and data is None:
1✔
2044
            if archiver_channel is None:
×
2045
                raise ValueError(
×
2046
                    "Either archiver_channel or both time_stamps and data have to be present!",
2047
                )
2048
            if self.loader.__name__ != "mpes":
×
2049
                raise NotImplementedError(
×
2050
                    "This function is currently only implemented for the mpes loader!",
2051
                )
2052
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2053
            # get channel data with +-5 seconds safety margin
2054
            time_stamps, data = get_archiver_data(
×
2055
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2056
                archiver_channel=archiver_channel,
2057
                ts_from=ts_from - 5,
2058
                ts_to=ts_to + 5,
2059
            )
2060

2061
        self._dataframe = add_time_stamped_data(
1✔
2062
            self._dataframe,
2063
            time_stamps=time_stamps,
2064
            data=data,
2065
            dest_column=dest_column,
2066
            time_stamp_column=time_stamp_column,
2067
            **kwds,
2068
        )
2069
        if self._timed_dataframe is not None:
1✔
2070
            if time_stamp_column in self._timed_dataframe:
1✔
2071
                self._timed_dataframe = add_time_stamped_data(
1✔
2072
                    self._timed_dataframe,
2073
                    time_stamps=time_stamps,
2074
                    data=data,
2075
                    dest_column=dest_column,
2076
                    time_stamp_column=time_stamp_column,
2077
                    **kwds,
2078
                )
2079
        metadata: list[Any] = []
1✔
2080
        metadata.append(dest_column)
1✔
2081
        metadata.append(time_stamps)
1✔
2082
        metadata.append(data)
1✔
2083
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2084
        logger.info(f"add_time_stamped_data: Added time-stamped data as column {dest_column}.")
1✔
2085

2086
    @call_logger(logger)
1✔
2087
    def pre_binning(
1✔
2088
        self,
2089
        df_partitions: int | Sequence[int] = 100,
2090
        axes: list[str] = None,
2091
        bins: list[int] = None,
2092
        ranges: Sequence[tuple[float, float]] = None,
2093
        **kwds,
2094
    ) -> xr.DataArray:
2095
        """Function to do an initial binning of the dataframe loaded to the class.
2096

2097
        Args:
2098
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2099
                use for the initial binning. Defaults to 100.
2100
            axes (list[str], optional): Axes to bin.
2101
                Defaults to config["momentum"]["axes"].
2102
            bins (list[int], optional): Bin numbers to use for binning.
2103
                Defaults to config["momentum"]["bins"].
2104
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2105
                Defaults to config["momentum"]["ranges"].
2106
            **kwds: Keyword argument passed to ``compute``.
2107

2108
        Returns:
2109
            xr.DataArray: pre-binned data-array.
2110
        """
2111
        if axes is None:
1✔
2112
            axes = self._config["momentum"]["axes"]
1✔
2113
        for loc, axis in enumerate(axes):
1✔
2114
            if axis.startswith("@"):
1✔
2115
                axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))
1✔
2116

2117
        if bins is None:
1✔
2118
            bins = self._config["momentum"]["bins"]
1✔
2119
        if ranges is None:
1✔
2120
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2121
            ranges_[2] = np.asarray(ranges_[2]) / self._config["dataframe"]["tof_binning"]
1✔
2122
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2123

2124
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2125

2126
        return self.compute(
1✔
2127
            bins=bins,
2128
            axes=axes,
2129
            ranges=ranges,
2130
            df_partitions=df_partitions,
2131
            **kwds,
2132
        )
2133

2134
    @call_logger(logger)
1✔
2135
    def compute(
1✔
2136
        self,
2137
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2138
        axes: str | Sequence[str] = None,
2139
        ranges: Sequence[tuple[float, float]] = None,
2140
        normalize_to_acquisition_time: bool | str = False,
2141
        **kwds,
2142
    ) -> xr.DataArray:
2143
        """Compute the histogram along the given dimensions.
2144

2145
        Args:
2146
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2147
                Definition of the bins. Can be any of the following cases:
2148

2149
                - an integer describing the number of bins in on all dimensions
2150
                - a tuple of 3 numbers describing start, end and step of the binning
2151
                  range
2152
                - a np.arrays defining the binning edges
2153
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2154
                - a dictionary made of the axes as keys and any of the above as values.
2155

2156
                This takes priority over the axes and range arguments. Defaults to 100.
2157
            axes (str | Sequence[str], optional): The names of the axes (columns)
2158
                on which to calculate the histogram. The order will be the order of the
2159
                dimensions in the resulting array. Defaults to None.
2160
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2161
                the start and end point of the binning range. Defaults to None.
2162
            normalize_to_acquisition_time (bool | str): Option to normalize the
2163
                result to the acquisition time. If a "slow" axis was scanned, providing
2164
                the name of the scanned axis will compute and apply the corresponding
2165
                normalization histogram. Defaults to False.
2166
            **kwds: Keyword arguments:
2167

2168
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2169
                  ``bin_dataframe`` for details. Defaults to
2170
                  config["binning"]["hist_mode"].
2171
                - **mode**: Defines how the results from each partition are combined.
2172
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2173
                  Defaults to config["binning"]["mode"].
2174
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2175
                  config["binning"]["pbar"].
2176
                - **n_cores**: Number of CPU cores to use for parallelization.
2177
                  Defaults to config["core"]["num_cores"] or N_CPU-1.
2178
                - **threads_per_worker**: Limit the number of threads that
2179
                  multiprocessing can spawn per binning thread. Defaults to
2180
                  config["binning"]["threads_per_worker"].
2181
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2182
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2183
                  config["binning"]["threadpool_API"].
2184
                - **df_partitions**: A sequence of dataframe partitions, or the
2185
                  number of the dataframe partitions to use. Defaults to all partitions.
2186
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2187
                  "upper_bound" to apply as filter to the dataframe before binning. The
2188
                  dataframe in the class remains unmodified by this.
2189

2190
                Additional kwds are passed to ``bin_dataframe``.
2191

2192
        Raises:
2193
            AssertError: Rises when no dataframe has been loaded.
2194

2195
        Returns:
2196
            xr.DataArray: The result of the n-dimensional binning represented in an
2197
            xarray object, combining the data with the axes.
2198
        """
2199
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2200

2201
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2202
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2203
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2204
        num_cores = kwds.pop("num_cores", self._config["core"]["num_cores"])
1✔
2205
        threads_per_worker = kwds.pop(
1✔
2206
            "threads_per_worker",
2207
            self._config["binning"]["threads_per_worker"],
2208
        )
2209
        threadpool_api = kwds.pop(
1✔
2210
            "threadpool_API",
2211
            self._config["binning"]["threadpool_API"],
2212
        )
2213
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2214
        if isinstance(df_partitions, int):
1✔
2215
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2216
        if df_partitions is not None:
1✔
2217
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2218
        else:
2219
            dataframe = self._dataframe
1✔
2220

2221
        filter_params = kwds.pop("filter", None)
1✔
2222
        if filter_params is not None:
1✔
2223
            try:
1✔
2224
                for param in filter_params:
1✔
2225
                    if "col" not in param:
1✔
2226
                        raise ValueError(
1✔
2227
                            "'col' needs to be defined for each filter entry! ",
2228
                            f"Not present in {param}.",
2229
                        )
2230
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2231
                    dataframe = apply_filter(dataframe, **param)
1✔
2232
            except AssertionError as exc:
1✔
2233
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2234
                raise ValueError(
1✔
2235
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2236
                    f"Parameters {invalid_keys} not valid in {param}.",
2237
                ) from exc
2238

2239
        self._binned = bin_dataframe(
1✔
2240
            df=dataframe,
2241
            bins=bins,
2242
            axes=axes,
2243
            ranges=ranges,
2244
            hist_mode=hist_mode,
2245
            mode=mode,
2246
            pbar=pbar,
2247
            n_cores=num_cores,
2248
            threads_per_worker=threads_per_worker,
2249
            threadpool_api=threadpool_api,
2250
            **kwds,
2251
        )
2252

2253
        for dim in self._binned.dims:
1✔
2254
            try:
1✔
2255
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2256
            except KeyError:
1✔
2257
                pass
1✔
2258

2259
        self._binned.attrs["units"] = "counts"
1✔
2260
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2261
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2262

2263
        if normalize_to_acquisition_time:
1✔
2264
            if isinstance(normalize_to_acquisition_time, str):
1✔
2265
                axis = normalize_to_acquisition_time
1✔
2266
                logger.info(f"Calculate normalization histogram for axis '{axis}'...")
1✔
2267
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2268
                    axis=axis,
2269
                    df_partitions=df_partitions,
2270
                )
2271
                # if the axes are named correctly, xarray figures out the normalization correctly
2272
                self._normalized = self._binned / self._normalization_histogram
1✔
2273
                self._attributes.add(
1✔
2274
                    self._normalization_histogram.values,
2275
                    name="normalization_histogram",
2276
                    duplicate_policy="overwrite",
2277
                )
2278
            else:
2279
                acquisition_time = self.loader.get_elapsed_time(
×
2280
                    fids=df_partitions,
2281
                )
2282
                if acquisition_time > 0:
×
2283
                    self._normalized = self._binned / acquisition_time
×
2284
                self._attributes.add(
×
2285
                    acquisition_time,
2286
                    name="normalization_histogram",
2287
                    duplicate_policy="overwrite",
2288
                )
2289

2290
            self._normalized.attrs["units"] = "counts/second"
1✔
2291
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2292
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2293

2294
            return self._normalized
1✔
2295

2296
        return self._binned
1✔
2297

2298
    @call_logger(logger)
1✔
2299
    def get_normalization_histogram(
1✔
2300
        self,
2301
        axis: str = "delay",
2302
        use_time_stamps: bool = False,
2303
        **kwds,
2304
    ) -> xr.DataArray:
2305
        """Generates a normalization histogram from the timed dataframe. Optionally,
2306
        use the TimeStamps column instead.
2307

2308
        Args:
2309
            axis (str, optional): The axis for which to compute histogram.
2310
                Defaults to "delay".
2311
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2312
                dataframe, rather than the timed dataframe. Defaults to False.
2313
            **kwds: Keyword arguments:
2314

2315
                - **df_partitions**: A sequence of dataframe partitions, or the
2316
                  number of the dataframe partitions to use. Defaults to all partitions.
2317

2318
        Raises:
2319
            ValueError: Raised if no data are binned.
2320
            ValueError: Raised if 'axis' not in binned coordinates.
2321
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2322
                in Dataframe.
2323

2324
        Returns:
2325
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2326
            per bin).
2327
        """
2328

2329
        if self._binned is None:
1✔
2330
            raise ValueError("Need to bin data first!")
1✔
2331
        if axis not in self._binned.coords:
1✔
2332
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2333

2334
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2335

2336
        if len(kwds) > 0:
1✔
2337
            raise TypeError(
1✔
2338
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2339
            )
2340

2341
        if isinstance(df_partitions, int):
1✔
2342
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2343
        if use_time_stamps or self._timed_dataframe is None:
1✔
2344
            if df_partitions is not None:
1✔
2345
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2346
                    self._dataframe.partitions[df_partitions],
2347
                    axis,
2348
                    self._binned.coords[axis].values,
2349
                    self._config["dataframe"]["columns"]["timestamp"],
2350
                )
2351
            else:
2352
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2353
                    self._dataframe,
2354
                    axis,
2355
                    self._binned.coords[axis].values,
2356
                    self._config["dataframe"]["columns"]["timestamp"],
2357
                )
2358
        else:
2359
            if df_partitions is not None:
1✔
2360
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2361
                    self._timed_dataframe.partitions[df_partitions],
2362
                    axis,
2363
                    self._binned.coords[axis].values,
2364
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2365
                )
2366
            else:
2367
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2368
                    self._timed_dataframe,
2369
                    axis,
2370
                    self._binned.coords[axis].values,
2371
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2372
                )
2373

2374
        return self._normalization_histogram
1✔
2375

2376
    def view_event_histogram(
1✔
2377
        self,
2378
        dfpid: int,
2379
        ncol: int = 2,
2380
        bins: Sequence[int] = None,
2381
        axes: Sequence[str] = None,
2382
        ranges: Sequence[tuple[float, float]] = None,
2383
        backend: str = "bokeh",
2384
        legend: bool = True,
2385
        histkwds: dict = None,
2386
        legkwds: dict = None,
2387
        **kwds,
2388
    ):
2389
        """Plot individual histograms of specified dimensions (axes) from a substituent
2390
        dataframe partition.
2391

2392
        Args:
2393
            dfpid (int): Number of the data frame partition to look at.
2394
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2395
            bins (Sequence[int], optional): Number of bins to use for the specified
2396
                axes. Defaults to config["histogram"]["bins"].
2397
            axes (Sequence[str], optional): Names of the axes to display.
2398
                Defaults to config["histogram"]["axes"].
2399
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2400
                specified axes. Defaults to config["histogram"]["ranges"].
2401
            backend (str, optional): Backend of the plotting library
2402
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2403
            legend (bool, optional): Option to include a legend in the histogram plots.
2404
                Defaults to True.
2405
            histkwds (dict, optional): Keyword arguments for histograms
2406
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2407
            legkwds (dict, optional): Keyword arguments for legend
2408
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2409
            **kwds: Extra keyword arguments passed to
2410
                ``sed.diagnostics.grid_histogram()``.
2411

2412
        Raises:
2413
            TypeError: Raises when the input values are not of the correct type.
2414
        """
2415
        if bins is None:
1✔
2416
            bins = self._config["histogram"]["bins"]
1✔
2417
        if axes is None:
1✔
2418
            axes = self._config["histogram"]["axes"]
1✔
2419
        axes = list(axes)
1✔
2420
        for loc, axis in enumerate(axes):
1✔
2421
            if axis.startswith("@"):
1✔
2422
                axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))
1✔
2423
        if ranges is None:
1✔
2424
            ranges = list(self._config["histogram"]["ranges"])
1✔
2425
            for loc, axis in enumerate(axes):
1✔
2426
                if axis == self._config["dataframe"]["columns"]["tof"]:
1✔
2427
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
1✔
2428
                elif axis == self._config["dataframe"]["columns"]["adc"]:
1✔
2429
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]
×
2430

2431
        input_types = map(type, [axes, bins, ranges])
1✔
2432
        allowed_types = [list, tuple]
1✔
2433

2434
        df = self._dataframe
1✔
2435

2436
        if not set(input_types).issubset(allowed_types):
1✔
2437
            raise TypeError(
×
2438
                "Inputs of axes, bins, ranges need to be list or tuple!",
2439
            )
2440

2441
        # Read out the values for the specified groups
2442
        group_dict_dd = {}
1✔
2443
        dfpart = df.get_partition(dfpid)
1✔
2444
        cols = dfpart.columns
1✔
2445
        for ax in axes:
1✔
2446
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2447
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2448

2449
        # Plot multiple histograms in a grid
2450
        grid_histogram(
1✔
2451
            group_dict,
2452
            ncol=ncol,
2453
            rvs=axes,
2454
            rvbins=bins,
2455
            rvranges=ranges,
2456
            backend=backend,
2457
            legend=legend,
2458
            histkwds=histkwds,
2459
            legkwds=legkwds,
2460
            **kwds,
2461
        )
2462

2463
    @call_logger(logger)
1✔
2464
    def save(
1✔
2465
        self,
2466
        faddr: str,
2467
        **kwds,
2468
    ):
2469
        """Saves the binned data to the provided path and filename.
2470

2471
        Args:
2472
            faddr (str): Path and name of the file to write. Its extension determines
2473
                the file type to write. Valid file types are:
2474

2475
                - "*.tiff", "*.tif": Saves a TIFF stack.
2476
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2477
                - "*.nxs", "*.nexus": Saves a NeXus file.
2478

2479
            **kwds: Keyword arguments, which are passed to the writer functions:
2480
                For TIFF writing:
2481

2482
                - **alias_dict**: Dictionary of dimension aliases to use.
2483

2484
                For HDF5 writing:
2485

2486
                - **mode**: hdf5 read/write mode. Defaults to "w".
2487

2488
                For NeXus:
2489

2490
                - **reader**: Name of the pynxtools reader to use.
2491
                  Defaults to config["nexus"]["reader"]
2492
                - **definition**: NeXus application definition to use for saving.
2493
                  Must be supported by the used ``reader``. Defaults to
2494
                  config["nexus"]["definition"]
2495
                - **input_files**: A list of input files to pass to the reader.
2496
                  Defaults to config["nexus"]["input_files"]
2497
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2498
                  to add to the list of files to pass to the reader.
2499
        """
2500
        if self._binned is None:
1✔
2501
            raise NameError("Need to bin data first!")
1✔
2502

2503
        if self._normalized is not None:
1✔
2504
            data = self._normalized
×
2505
        else:
2506
            data = self._binned
1✔
2507

2508
        extension = pathlib.Path(faddr).suffix
1✔
2509

2510
        if extension in (".tif", ".tiff"):
1✔
2511
            to_tiff(
1✔
2512
                data=data,
2513
                faddr=faddr,
2514
                **kwds,
2515
            )
2516
        elif extension in (".h5", ".hdf5"):
1✔
2517
            to_h5(
1✔
2518
                data=data,
2519
                faddr=faddr,
2520
                **kwds,
2521
            )
2522
        elif extension in (".nxs", ".nexus"):
1✔
2523
            try:
1✔
2524
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2525
                definition = kwds.pop(
1✔
2526
                    "definition",
2527
                    self._config["nexus"]["definition"],
2528
                )
2529
                input_files = kwds.pop(
1✔
2530
                    "input_files",
2531
                    self._config["nexus"]["input_files"],
2532
                )
2533
            except KeyError as exc:
×
2534
                raise ValueError(
×
2535
                    "The nexus reader, definition and input files need to be provide!",
2536
                ) from exc
2537

2538
            if isinstance(input_files, str):
1✔
2539
                input_files = [input_files]
1✔
2540

2541
            if "eln_data" in kwds:
1✔
2542
                input_files.append(kwds.pop("eln_data"))
1✔
2543

2544
            to_nexus(
1✔
2545
                data=data,
2546
                faddr=faddr,
2547
                reader=reader,
2548
                definition=definition,
2549
                input_files=input_files,
2550
                **kwds,
2551
            )
2552

2553
        else:
2554
            raise NotImplementedError(
1✔
2555
                f"Unrecognized file format: {extension}.",
2556
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc