• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 13641008901

03 Mar 2025 09:42PM CUT coverage: 92.444% (+0.3%) from 92.174%
13641008901

Pull #551

github

web-flow
Merge 037dfb355 into 541d4c8fe
Pull Request #551: Mpes elab metadata

280 of 309 new or added lines in 5 files covered. (90.61%)

10 existing lines in 1 file now uncovered.

8161 of 8828 relevant lines covered (92.44%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.64
/src/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
from __future__ import annotations
1✔
5

6
import pathlib
1✔
7
from collections.abc import Sequence
1✔
8
from copy import deepcopy
1✔
9
from datetime import datetime
1✔
10
from typing import Any
1✔
11
from typing import cast
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
23
from sed.calibrator import DelayCalibrator
1✔
24
from sed.calibrator import EnergyCalibrator
1✔
25
from sed.calibrator import MomentumCorrector
1✔
26
from sed.core.config import parse_config
1✔
27
from sed.core.config import save_config
1✔
28
from sed.core.dfops import add_time_stamped_data
1✔
29
from sed.core.dfops import apply_filter
1✔
30
from sed.core.dfops import apply_jitter
1✔
31
from sed.core.logging import call_logger
1✔
32
from sed.core.logging import set_verbosity
1✔
33
from sed.core.logging import setup_logging
1✔
34
from sed.core.metadata import MetaHandler
1✔
35
from sed.diagnostics import grid_histogram
1✔
36
from sed.io import to_h5
1✔
37
from sed.io import to_nexus
1✔
38
from sed.io import to_tiff
1✔
39
from sed.loader import CopyTool
1✔
40
from sed.loader import get_loader
1✔
41
from sed.loader.mpes.loader import MpesLoader
1✔
42
from sed.loader.mpes.metadata import get_archiver_data
1✔
43

44
N_CPU = psutil.cpu_count()
1✔
45

46
# Configure logging
47
logger = setup_logging("processor")
1✔
48

49

50
class SedProcessor:
1✔
51
    """Processor class of sed. Contains wrapper functions defining a work flow for data
52
    correction, calibration and binning.
53

54
    Args:
55
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
56
        config (dict | str, optional): Config dictionary or config file name.
57
            Defaults to None.
58
        dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
59
            into the class. Defaults to None.
60
        files (list[str], optional): List of files to pass to the loader defined in
61
            the config. Defaults to None.
62
        folder (str, optional): Folder containing files to pass to the loader
63
            defined in the config. Defaults to None.
64
        runs (Sequence[str], optional): List of run identifiers to pass to the loader
65
            defined in the config. Defaults to None.
66
        collect_metadata (bool): Option to collect metadata from files.
67
            Defaults to False.
68
        verbose (bool, optional): Option to print out diagnostic information.
69
            Defaults to config["core"]["verbose"] or True.
70
        **kwds: Keyword arguments passed to the reader.
71
    """
72

73
    @call_logger(logger)
1✔
74
    def __init__(
1✔
75
        self,
76
        metadata: dict = None,
77
        config: dict | str = None,
78
        dataframe: pd.DataFrame | ddf.DataFrame = None,
79
        files: list[str] = None,
80
        folder: str = None,
81
        runs: Sequence[str] = None,
82
        collect_metadata: bool = False,
83
        verbose: bool = None,
84
        **kwds,
85
    ):
86
        """Processor class of sed. Contains wrapper functions defining a work flow
87
        for data correction, calibration, and binning.
88

89
        Args:
90
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
91
            config (dict | str, optional): Config dictionary or config file name.
92
                Defaults to None.
93
            dataframe (pd.DataFrame | ddf.DataFrame, optional): dataframe to load
94
                into the class. Defaults to None.
95
            files (list[str], optional): List of files to pass to the loader defined in
96
                the config. Defaults to None.
97
            folder (str, optional): Folder containing files to pass to the loader
98
                defined in the config. Defaults to None.
99
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
100
                defined in the config. Defaults to None.
101
            collect_metadata (bool, optional): Option to collect metadata from files.
102
                Defaults to False.
103
            verbose (bool, optional): Option to print out diagnostic information.
104
                Defaults to config["core"]["verbose"] or True.
105
            **kwds: Keyword arguments passed to parse_config and to the reader.
106
        """
107
        # split off config keywords
108
        config_kwds = {
1✔
109
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
110
        }
111
        for key in config_kwds.keys():
1✔
112
            del kwds[key]
1✔
113
        self._config = parse_config(config, **config_kwds)
1✔
114
        num_cores = self._config["core"].get("num_cores", N_CPU - 1)
1✔
115
        if num_cores >= N_CPU:
1✔
116
            num_cores = N_CPU - 1
1✔
117
        self._config["core"]["num_cores"] = num_cores
1✔
118
        logger.debug(f"Use {num_cores} cores for processing.")
1✔
119

120
        if verbose is None:
1✔
121
            self._verbose = self._config["core"].get("verbose", True)
1✔
122
        else:
123
            self._verbose = verbose
1✔
124
        set_verbosity(logger, self._verbose)
1✔
125

126
        self._dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
127
        self._timed_dataframe: pd.DataFrame | ddf.DataFrame = None
1✔
128
        self._files: list[str] = []
1✔
129

130
        self._binned: xr.DataArray = None
1✔
131
        self._pre_binned: xr.DataArray = None
1✔
132
        self._normalization_histogram: xr.DataArray = None
1✔
133
        self._normalized: xr.DataArray = None
1✔
134

135
        self._attributes = MetaHandler(meta=metadata)
1✔
136

137
        loader_name = self._config["core"]["loader"]
1✔
138
        self.loader = get_loader(
1✔
139
            loader_name=loader_name,
140
            config=self._config,
141
            verbose=verbose,
142
        )
143
        logger.debug(f"Use loader: {loader_name}")
1✔
144

145
        self.ec = EnergyCalibrator(
1✔
146
            loader=get_loader(
147
                loader_name=loader_name,
148
                config=self._config,
149
                verbose=verbose,
150
            ),
151
            config=self._config,
152
            verbose=self._verbose,
153
        )
154

155
        self.mc = MomentumCorrector(
1✔
156
            config=self._config,
157
            verbose=self._verbose,
158
        )
159

160
        self.dc = DelayCalibrator(
1✔
161
            config=self._config,
162
            verbose=self._verbose,
163
        )
164

165
        self.use_copy_tool = "copy_tool" in self._config["core"]
1✔
166
        if self.use_copy_tool:
1✔
167
            try:
1✔
168
                self.ct = CopyTool(
1✔
169
                    num_cores=self._config["core"]["num_cores"],
170
                    **self._config["core"]["copy_tool"],
171
                )
172
                logger.debug(
1✔
173
                    f"Initialized copy tool: Copy files from "
174
                    f"'{self._config['core']['copy_tool']['source']}' "
175
                    f"to '{self._config['core']['copy_tool']['dest']}'.",
176
                )
177
            except KeyError:
×
178
                self.use_copy_tool = False
×
179

180
        # Load data if provided:
181
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
182
            self.load(
1✔
183
                dataframe=dataframe,
184
                metadata=metadata,
185
                files=files,
186
                folder=folder,
187
                runs=runs,
188
                collect_metadata=collect_metadata,
189
                **kwds,
190
            )
191

192
    def __repr__(self):
1✔
193
        if self._dataframe is None:
1✔
194
            df_str = "Dataframe: No Data loaded"
1✔
195
        else:
196
            df_str = self._dataframe.__repr__()
1✔
197
        pretty_str = df_str + "\n" + "Metadata: " + "\n" + self._attributes.__repr__()
1✔
198
        return pretty_str
1✔
199

200
    def _repr_html_(self):
1✔
201
        html = "<div>"
×
202

203
        if self._dataframe is None:
×
204
            df_html = "Dataframe: No Data loaded"
×
205
        else:
206
            df_html = self._dataframe._repr_html_()
×
207

208
        html += f"<details><summary>Dataframe</summary>{df_html}</details>"
×
209

210
        # Add expandable section for attributes
211
        html += "<details><summary>Metadata</summary>"
×
212
        html += "<div style='padding-left: 10px;'>"
×
213
        html += self._attributes._repr_html_()
×
214
        html += "</div></details>"
×
215

216
        html += "</div>"
×
217

218
        return html
×
219

220
    ## Suggestion:
221
    # @property
222
    # def overview_panel(self):
223
    #     """Provides an overview panel with plots of different data attributes."""
224
    #     self.view_event_histogram(dfpid=2, backend="matplotlib")
225

226
    @property
1✔
227
    def verbose(self) -> bool:
1✔
228
        """Accessor to the verbosity flag.
229

230
        Returns:
231
            bool: Verbosity flag.
232
        """
233
        return self._verbose
×
234

235
    @verbose.setter
1✔
236
    def verbose(self, verbose: bool):
1✔
237
        """Setter for the verbosity.
238

239
        Args:
240
            verbose (bool): Option to turn on verbose output. Sets loglevel to INFO.
241
        """
242
        self._verbose = verbose
×
243
        set_verbosity(logger, self._verbose)
×
244
        self.mc.verbose = verbose
×
245
        self.ec.verbose = verbose
×
246
        self.dc.verbose = verbose
×
247
        self.loader.verbose = verbose
×
248

249
    @property
1✔
250
    def dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
251
        """Accessor to the underlying dataframe.
252

253
        Returns:
254
            pd.DataFrame | ddf.DataFrame: Dataframe object.
255
        """
256
        return self._dataframe
1✔
257

258
    @dataframe.setter
1✔
259
    def dataframe(self, dataframe: pd.DataFrame | ddf.DataFrame):
1✔
260
        """Setter for the underlying dataframe.
261

262
        Args:
263
            dataframe (pd.DataFrame | ddf.DataFrame): The dataframe object to set.
264
        """
265
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
266
            dataframe,
267
            self._dataframe.__class__,
268
        ):
269
            raise ValueError(
1✔
270
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
271
                "as the dataframe loaded into the SedProcessor!.\n"
272
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
273
            )
274
        self._dataframe = dataframe
1✔
275

276
    @property
1✔
277
    def timed_dataframe(self) -> pd.DataFrame | ddf.DataFrame:
1✔
278
        """Accessor to the underlying timed_dataframe.
279

280
        Returns:
281
            pd.DataFrame | ddf.DataFrame: Timed Dataframe object.
282
        """
283
        return self._timed_dataframe
1✔
284

285
    @timed_dataframe.setter
1✔
286
    def timed_dataframe(self, timed_dataframe: pd.DataFrame | ddf.DataFrame):
1✔
287
        """Setter for the underlying timed dataframe.
288

289
        Args:
290
            timed_dataframe (pd.DataFrame | ddf.DataFrame): The timed dataframe object to set
291
        """
292
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
293
            timed_dataframe,
294
            self._timed_dataframe.__class__,
295
        ):
296
            raise ValueError(
×
297
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
298
                "kind as the dataframe loaded into the SedProcessor!.\n"
299
                f"Loaded type: {self._timed_dataframe.__class__}, "
300
                f"provided type: {timed_dataframe}.",
301
            )
302
        self._timed_dataframe = timed_dataframe
×
303

304
    @property
1✔
305
    def attributes(self) -> MetaHandler:
1✔
306
        """Accessor to the metadata dict.
307

308
        Returns:
309
            MetaHandler: The metadata object
310
        """
311
        return self._attributes
1✔
312

313
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
314
        """Function to add element to the attributes dict.
315

316
        Args:
317
            attributes (dict): The attributes dictionary object to add.
318
            name (str): Key under which to add the dictionary to the attributes.
319
            **kwds: Additional keywords are passed to the ``MetaHandler.add()`` function.
320
        """
321
        self._attributes.add(
1✔
322
            entry=attributes,
323
            name=name,
324
            **kwds,
325
        )
326

327
    @property
1✔
328
    def config(self) -> dict[Any, Any]:
1✔
329
        """Getter attribute for the config dictionary
330

331
        Returns:
332
            dict: The config dictionary.
333
        """
334
        return self._config
1✔
335

336
    @property
1✔
337
    def files(self) -> list[str]:
1✔
338
        """Getter attribute for the list of files
339

340
        Returns:
341
            list[str]: The list of loaded files
342
        """
343
        return self._files
1✔
344

345
    @property
1✔
346
    def binned(self) -> xr.DataArray:
1✔
347
        """Getter attribute for the binned data array
348

349
        Returns:
350
            xr.DataArray: The binned data array
351
        """
352
        if self._binned is None:
1✔
353
            raise ValueError("No binned data available, need to compute histogram first!")
×
354
        return self._binned
1✔
355

356
    @property
1✔
357
    def normalized(self) -> xr.DataArray:
1✔
358
        """Getter attribute for the normalized data array
359

360
        Returns:
361
            xr.DataArray: The normalized data array
362
        """
363
        if self._normalized is None:
1✔
364
            raise ValueError(
×
365
                "No normalized data available, compute data with normalization enabled!",
366
            )
367
        return self._normalized
1✔
368

369
    @property
1✔
370
    def normalization_histogram(self) -> xr.DataArray:
1✔
371
        """Getter attribute for the normalization histogram
372

373
        Returns:
374
            xr.DataArray: The normalization histogram
375
        """
376
        if self._normalization_histogram is None:
1✔
377
            raise ValueError("No normalization histogram available, generate histogram first!")
×
378
        return self._normalization_histogram
1✔
379

380
    def cpy(self, path: str | list[str]) -> str | list[str]:
1✔
381
        """Function to mirror a list of files or a folder from a network drive to a
382
        local storage. Returns either the original or the copied path to the given
383
        path. The option to use this functionality is set by
384
        config["core"]["use_copy_tool"].
385

386
        Args:
387
            path (str | list[str]): Source path or path list.
388

389
        Returns:
390
            str | list[str]: Source or destination path or path list.
391
        """
392
        if self.use_copy_tool:
1✔
393
            if isinstance(path, list):
1✔
394
                path_out = []
1✔
395
                for file in path:
1✔
396
                    path_out.append(self.ct.copy(file))
1✔
397
                return path_out
1✔
398

399
            return self.ct.copy(path)
×
400

401
        if isinstance(path, list):
1✔
402
            return path
1✔
403

404
        return path
1✔
405

406
    @call_logger(logger)
1✔
407
    def load(
1✔
408
        self,
409
        dataframe: pd.DataFrame | ddf.DataFrame = None,
410
        metadata: dict = None,
411
        files: list[str] = None,
412
        folder: str = None,
413
        runs: Sequence[str] = None,
414
        collect_metadata: bool = False,
415
        **kwds,
416
    ):
417
        """Load tabular data of single events into the dataframe object in the class.
418

419
        Args:
420
            dataframe (pd.DataFrame | ddf.DataFrame, optional): data in tabular
421
                format. Accepts anything which can be interpreted by pd.DataFrame as
422
                an input. Defaults to None.
423
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
424
            files (list[str], optional): List of file paths to pass to the loader.
425
                Defaults to None.
426
            runs (Sequence[str], optional): List of run identifiers to pass to the
427
                loader. Defaults to None.
428
            folder (str, optional): Folder path to pass to the loader.
429
                Defaults to None.
430
            collect_metadata (bool, optional): Option for collecting metadata in the reader.
431
            **kwds:
432
                - *timed_dataframe*: timed dataframe if dataframe is provided.
433

434
                Additional keyword parameters are passed to ``loader.read_dataframe()``.
435

436
        Raises:
437
            ValueError: Raised if no valid input is provided.
438
        """
439
        if metadata is None:
1✔
440
            metadata = {}
1✔
441
        if dataframe is not None:
1✔
442
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
443
        elif runs is not None:
1✔
444
            # If runs are provided, we only use the copy tool if also folder is provided.
445
            # In that case, we copy the whole provided base folder tree, and pass the copied
446
            # version to the loader as base folder to look for the runs.
447
            if folder is not None:
1✔
448
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
449
                    folders=cast(str, self.cpy(folder)),
450
                    runs=runs,
451
                    metadata=metadata,
452
                    collect_metadata=collect_metadata,
453
                    **kwds,
454
                )
455
            else:
456
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
457
                    runs=runs,
458
                    metadata=metadata,
459
                    collect_metadata=collect_metadata,
460
                    **kwds,
461
                )
462

463
        elif folder is not None:
1✔
464
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
465
                folders=cast(str, self.cpy(folder)),
466
                metadata=metadata,
467
                collect_metadata=collect_metadata,
468
                **kwds,
469
            )
470
        elif files is not None:
1✔
471
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
472
                files=cast(list[str], self.cpy(files)),
473
                metadata=metadata,
474
                collect_metadata=collect_metadata,
475
                **kwds,
476
            )
477
        else:
478
            raise ValueError(
1✔
479
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
480
            )
481

482
        self._dataframe = dataframe
1✔
483
        self._timed_dataframe = timed_dataframe
1✔
484
        self._files = self.loader.files
1✔
485

486
        for key in metadata:
1✔
487
            self._attributes.add(
1✔
488
                entry=metadata[key],
489
                name=key,
490
                duplicate_policy="merge",
491
            )
492

493
    @call_logger(logger)
1✔
494
    def filter_column(
1✔
495
        self,
496
        column: str,
497
        min_value: float = -np.inf,
498
        max_value: float = np.inf,
499
    ) -> None:
500
        """Filter values in a column which are outside of a given range
501

502
        Args:
503
            column (str): Name of the column to filter
504
            min_value (float, optional): Minimum value to keep. Defaults to None.
505
            max_value (float, optional): Maximum value to keep. Defaults to None.
506
        """
507
        if column != "index" and column not in self._dataframe.columns:
1✔
508
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
509
        if min_value >= max_value:
1✔
510
            raise ValueError("min_value has to be smaller than max_value!")
1✔
511
        if self._dataframe is not None:
1✔
512
            self._dataframe = apply_filter(
1✔
513
                self._dataframe,
514
                col=column,
515
                lower_bound=min_value,
516
                upper_bound=max_value,
517
            )
518
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
519
            self._timed_dataframe = apply_filter(
1✔
520
                self._timed_dataframe,
521
                column,
522
                lower_bound=min_value,
523
                upper_bound=max_value,
524
            )
525
        metadata = {
1✔
526
            "filter": {
527
                "column": column,
528
                "min_value": min_value,
529
                "max_value": max_value,
530
            },
531
        }
532
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
533

534
    # Momentum calibration workflow
535
    # 1. Bin raw detector data for distortion correction
536
    @call_logger(logger)
1✔
537
    def bin_and_load_momentum_calibration(
1✔
538
        self,
539
        df_partitions: int | Sequence[int] = 100,
540
        axes: list[str] = None,
541
        bins: list[int] = None,
542
        ranges: Sequence[tuple[float, float]] = None,
543
        plane: int = 0,
544
        width: int = 5,
545
        apply: bool = False,
546
        **kwds,
547
    ):
548
        """1st step of momentum correction work flow. Function to do an initial binning
549
        of the dataframe loaded to the class, slice a plane from it using an
550
        interactive view, and load it into the momentum corrector class.
551

552
        Args:
553
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions
554
                to use for the initial binning. Defaults to 100.
555
            axes (list[str], optional): Axes to bin.
556
                Defaults to config["momentum"]["axes"].
557
            bins (list[int], optional): Bin numbers to use for binning.
558
                Defaults to config["momentum"]["bins"].
559
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
560
                Defaults to config["momentum"]["ranges"].
561
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
562
            width (int, optional): Initial value for the width slider. Defaults to 5.
563
            apply (bool, optional): Option to directly apply the values and select the
564
                slice. Defaults to False.
565
            **kwds: Keyword argument passed to the pre_binning function.
566
        """
567
        self._pre_binned = self.pre_binning(
1✔
568
            df_partitions=df_partitions,
569
            axes=axes,
570
            bins=bins,
571
            ranges=ranges,
572
            **kwds,
573
        )
574

575
        self.mc.load_data(data=self._pre_binned)
1✔
576
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
577

578
    # 2. Generate the spline warp correction from momentum features.
579
    # Either autoselect features, or input features from view above.
580
    @call_logger(logger)
1✔
581
    def define_features(
1✔
582
        self,
583
        features: np.ndarray = None,
584
        rotation_symmetry: int = 6,
585
        auto_detect: bool = False,
586
        include_center: bool = True,
587
        apply: bool = False,
588
        **kwds,
589
    ):
590
        """2. Step of the distortion correction workflow: Define feature points in
591
        momentum space. They can be either manually selected using a GUI tool, be
592
        provided as list of feature points, or auto-generated using a
593
        feature-detection algorithm.
594

595
        Args:
596
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
597
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
598
                Defaults to 6.
599
            auto_detect (bool, optional): Whether to auto-detect the features.
600
                Defaults to False.
601
            include_center (bool, optional): Option to include a point at the center
602
                in the feature list. Defaults to True.
603
            apply (bool, optional): Option to directly apply the values and select the
604
                slice. Defaults to False.
605
            **kwds: Keyword arguments for ``MomentumCorrector.feature_extract()`` and
606
                ``MomentumCorrector.feature_select()``.
607
        """
608
        if auto_detect:  # automatic feature selection
1✔
609
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
610
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
611
            sigma_radius = kwds.pop(
×
612
                "sigma_radius",
613
                self._config["momentum"]["sigma_radius"],
614
            )
615
            self.mc.feature_extract(
×
616
                sigma=sigma,
617
                fwhm=fwhm,
618
                sigma_radius=sigma_radius,
619
                rotsym=rotation_symmetry,
620
                **kwds,
621
            )
622
            features = self.mc.peaks
×
623

624
        self.mc.feature_select(
1✔
625
            rotsym=rotation_symmetry,
626
            include_center=include_center,
627
            features=features,
628
            apply=apply,
629
            **kwds,
630
        )
631

632
    # 3. Generate the spline warp correction from momentum features.
633
    # If no features have been selected before, use class defaults.
634
    @call_logger(logger)
1✔
635
    def generate_splinewarp(
1✔
636
        self,
637
        use_center: bool = None,
638
        **kwds,
639
    ):
640
        """3. Step of the distortion correction workflow: Generate the correction
641
        function restoring the symmetry in the image using a splinewarp algorithm.
642

643
        Args:
644
            use_center (bool, optional): Option to use the position of the
645
                center point in the correction. Default is read from config, or set to True.
646
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
647
        """
648

649
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
650

651
        if self.mc.slice is not None and self._verbose:
1✔
652
            self.mc.view(
1✔
653
                annotated=True,
654
                backend="matplotlib",
655
                crosshair=True,
656
                title="Original slice with reference features",
657
            )
658

659
            self.mc.view(
1✔
660
                image=self.mc.slice_corrected,
661
                annotated=True,
662
                points={"feats": self.mc.ptargs},
663
                backend="matplotlib",
664
                crosshair=True,
665
                title="Corrected slice with target features",
666
            )
667

668
            self.mc.view(
1✔
669
                image=self.mc.slice,
670
                points={"feats": self.mc.ptargs},
671
                annotated=True,
672
                backend="matplotlib",
673
                title="Original slice with target features",
674
            )
675

676
    # 3a. Save spline-warp parameters to config file.
677
    def save_splinewarp(
1✔
678
        self,
679
        filename: str = None,
680
        overwrite: bool = False,
681
    ):
682
        """Save the generated spline-warp parameters to the folder config file.
683

684
        Args:
685
            filename (str, optional): Filename of the config dictionary to save to.
686
                Defaults to "sed_config.yaml" in the current folder.
687
            overwrite (bool, optional): Option to overwrite the present dictionary.
688
                Defaults to False.
689
        """
690
        if filename is None:
1✔
691
            filename = "sed_config.yaml"
×
692
        if len(self.mc.correction) == 0:
1✔
693
            raise ValueError("No momentum correction parameters to save!")
×
694
        correction = {}
1✔
695
        for key, value in self.mc.correction.items():
1✔
696
            if key in ["reference_points", "target_points", "cdeform_field", "rdeform_field"]:
1✔
697
                continue
1✔
698
            if key in ["use_center", "rotation_symmetry"]:
1✔
699
                correction[key] = value
1✔
700
            elif key in ["center_point", "ascale"]:
1✔
701
                correction[key] = [float(i) for i in value]
1✔
702
            elif key in ["outer_points", "feature_points"]:
1✔
703
                correction[key] = []
1✔
704
                for point in value:
1✔
705
                    correction[key].append([float(i) for i in point])
1✔
706
            elif key == "creation_date":
1✔
707
                correction[key] = value.isoformat()
1✔
708
            else:
709
                correction[key] = float(value)
1✔
710

711
        if "creation_date" not in correction:
1✔
712
            correction["creation_date"] = datetime.now().isoformat()
×
713

714
        config = {
1✔
715
            "momentum": {
716
                "correction": correction,
717
            },
718
        }
719
        save_config(config, filename, overwrite)
1✔
720
        logger.info(f'Saved momentum correction parameters to "{filename}".')
1✔
721

722
    # 4. Pose corrections. Provide interactive interface for correcting
723
    # scaling, shift and rotation
724
    @call_logger(logger)
1✔
725
    def pose_adjustment(
1✔
726
        self,
727
        transformations: dict[str, Any] = None,
728
        apply: bool = False,
729
        use_correction: bool = True,
730
        reset: bool = True,
731
        **kwds,
732
    ):
733
        """3. step of the distortion correction workflow: Generate an interactive panel
734
        to adjust affine transformations that are applied to the image. Applies first
735
        a scaling, next an x/y translation, and last a rotation around the center of
736
        the image.
737

738
        Args:
739
            transformations (dict[str, Any], optional): Dictionary with transformations.
740
                Defaults to self.transformations or config["momentum"]["transformations"].
741
            apply (bool, optional): Option to directly apply the provided
742
                transformations. Defaults to False.
743
            use_correction (bool, option): Whether to use the spline warp correction
744
                or not. Defaults to True.
745
            reset (bool, optional): Option to reset the correction before transformation.
746
                Defaults to True.
747
            **kwds: Keyword parameters defining defaults for the transformations:
748

749
                - **scale** (float): Initial value of the scaling slider.
750
                - **xtrans** (float): Initial value of the xtrans slider.
751
                - **ytrans** (float): Initial value of the ytrans slider.
752
                - **angle** (float): Initial value of the angle slider.
753
        """
754
        # Generate homography as default if no distortion correction has been applied
755
        if self.mc.slice_corrected is None:
1✔
756
            if self.mc.slice is None:
1✔
757
                self.mc.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
758
                self.mc.bin_ranges = self._config["momentum"]["ranges"]
1✔
759
            self.mc.slice_corrected = self.mc.slice
1✔
760

761
        if not use_correction:
1✔
762
            self.mc.reset_deformation()
1✔
763

764
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
765
            # Generate distortion correction from config values
766
            self.mc.spline_warp_estimate()
1✔
767

768
        self.mc.pose_adjustment(
1✔
769
            transformations=transformations,
770
            apply=apply,
771
            reset=reset,
772
            **kwds,
773
        )
774

775
    # 4a. Save pose adjustment parameters to config file.
776
    @call_logger(logger)
1✔
777
    def save_transformations(
1✔
778
        self,
779
        filename: str = None,
780
        overwrite: bool = False,
781
    ):
782
        """Save the pose adjustment parameters to the folder config file.
783

784
        Args:
785
            filename (str, optional): Filename of the config dictionary to save to.
786
                Defaults to "sed_config.yaml" in the current folder.
787
            overwrite (bool, optional): Option to overwrite the present dictionary.
788
                Defaults to False.
789
        """
790
        if filename is None:
1✔
791
            filename = "sed_config.yaml"
×
792
        if len(self.mc.transformations) == 0:
1✔
793
            raise ValueError("No momentum transformation parameters to save!")
×
794
        transformations = {}
1✔
795
        for key, value in self.mc.transformations.items():
1✔
796
            if key == "creation_date":
1✔
797
                transformations[key] = value.isoformat()
1✔
798
            else:
799
                transformations[key] = float(value)
1✔
800

801
        if "creation_date" not in transformations:
1✔
802
            transformations["creation_date"] = datetime.now().isoformat()
×
803

804
        config = {
1✔
805
            "momentum": {
806
                "transformations": transformations,
807
            },
808
        }
809
        save_config(config, filename, overwrite)
1✔
810
        logger.info(f'Saved momentum transformation parameters to "{filename}".')
1✔
811

812
    # 5. Apply the momentum correction to the dataframe
813
    @call_logger(logger)
1✔
814
    def apply_momentum_correction(
1✔
815
        self,
816
        preview: bool = False,
817
        **kwds,
818
    ):
819
        """Applies the distortion correction and pose adjustment (optional)
820
        to the dataframe.
821

822
        Args:
823
            preview (bool, optional): Option to preview the first elements of the data frame.
824
                Defaults to False.
825
            **kwds: Keyword parameters for ``MomentumCorrector.apply_correction``:
826

827
                - **rdeform_field** (np.ndarray, optional): Row deformation field.
828
                - **cdeform_field** (np.ndarray, optional): Column deformation field.
829
                - **inv_dfield** (np.ndarray, optional): Inverse deformation field.
830

831
        """
832
        x_column = self._config["dataframe"]["columns"]["x"]
1✔
833
        y_column = self._config["dataframe"]["columns"]["y"]
1✔
834

835
        if self._dataframe is not None:
1✔
836
            logger.info("Adding corrected X/Y columns to dataframe:")
1✔
837
            df, metadata = self.mc.apply_corrections(
1✔
838
                df=self._dataframe,
839
                **kwds,
840
            )
841
            if (
1✔
842
                self._timed_dataframe is not None
843
                and x_column in self._timed_dataframe.columns
844
                and y_column in self._timed_dataframe.columns
845
            ):
846
                tdf, _ = self.mc.apply_corrections(
1✔
847
                    self._timed_dataframe,
848
                    **kwds,
849
                )
850

851
            # Add Metadata
852
            self._attributes.add(
1✔
853
                metadata,
854
                "momentum_correction",
855
                duplicate_policy="merge",
856
            )
857
            self._dataframe = df
1✔
858
            if (
1✔
859
                self._timed_dataframe is not None
860
                and x_column in self._timed_dataframe.columns
861
                and y_column in self._timed_dataframe.columns
862
            ):
863
                self._timed_dataframe = tdf
1✔
864
        else:
865
            raise ValueError("No dataframe loaded!")
×
866
        if preview:
1✔
867
            logger.info(self._dataframe.head(10))
×
868
        else:
869
            logger.info(self._dataframe)
1✔
870

871
    # Momentum calibration work flow
872
    # 1. Calculate momentum calibration
873
    @call_logger(logger)
1✔
874
    def calibrate_momentum_axes(
1✔
875
        self,
876
        point_a: np.ndarray | list[int] = None,
877
        point_b: np.ndarray | list[int] = None,
878
        k_distance: float = None,
879
        k_coord_a: np.ndarray | list[float] = None,
880
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
881
        equiscale: bool = True,
882
        apply=False,
883
    ):
884
        """1. step of the momentum calibration workflow. Calibrate momentum
885
        axes using either provided pixel coordinates of a high-symmetry point and its
886
        distance to the BZ center, or the k-coordinates of two points in the BZ
887
        (depending on the equiscale option). Opens an interactive panel for selecting
888
        the points.
889

890
        Args:
891
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the first
892
                point used for momentum calibration.
893
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
894
                second point used for momentum calibration.
895
                Defaults to config["momentum"]["center_pixel"].
896
            k_distance (float, optional): Momentum distance between point a and b.
897
                Needs to be provided if no specific k-coordinates for the two points
898
                are given. Defaults to None.
899
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
900
                of the first point used for calibration. Used if equiscale is False.
901
                Defaults to None.
902
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
903
                of the second point used for calibration. Defaults to [0.0, 0.0].
904
            equiscale (bool, optional): Option to apply different scales to kx and ky.
905
                If True, the distance between points a and b, and the absolute
906
                position of point a are used for defining the scale. If False, the
907
                scale is calculated from the k-positions of both points a and b.
908
                Defaults to True.
909
            apply (bool, optional): Option to directly store the momentum calibration
910
                in the class. Defaults to False.
911
        """
912
        if point_b is None:
1✔
913
            point_b = self._config["momentum"]["center_pixel"]
1✔
914

915
        self.mc.select_k_range(
1✔
916
            point_a=point_a,
917
            point_b=point_b,
918
            k_distance=k_distance,
919
            k_coord_a=k_coord_a,
920
            k_coord_b=k_coord_b,
921
            equiscale=equiscale,
922
            apply=apply,
923
        )
924

925
    # 1a. Save momentum calibration parameters to config file.
926
    def save_momentum_calibration(
1✔
927
        self,
928
        filename: str = None,
929
        overwrite: bool = False,
930
    ):
931
        """Save the generated momentum calibration parameters to the folder config file.
932

933
        Args:
934
            filename (str, optional): Filename of the config dictionary to save to.
935
                Defaults to "sed_config.yaml" in the current folder.
936
            overwrite (bool, optional): Option to overwrite the present dictionary.
937
                Defaults to False.
938
        """
939
        if filename is None:
1✔
940
            filename = "sed_config.yaml"
×
941
        if len(self.mc.calibration) == 0:
1✔
942
            raise ValueError("No momentum calibration parameters to save!")
×
943
        calibration = {}
1✔
944
        for key, value in self.mc.calibration.items():
1✔
945
            if key in ["kx_axis", "ky_axis", "grid", "extent"]:
1✔
946
                continue
1✔
947
            elif key == "creation_date":
1✔
948
                calibration[key] = value.isoformat()
1✔
949
            else:
950
                calibration[key] = float(value)
1✔
951

952
        if "creation_date" not in calibration:
1✔
953
            calibration["creation_date"] = datetime.now().isoformat()
×
954

955
        config = {"momentum": {"calibration": calibration}}
1✔
956
        save_config(config, filename, overwrite)
1✔
957
        logger.info(f"Saved momentum calibration parameters to {filename}")
1✔
958

959
    # 2. Apply correction and calibration to the dataframe
960
    @call_logger(logger)
1✔
961
    def apply_momentum_calibration(
1✔
962
        self,
963
        calibration: dict = None,
964
        preview: bool = False,
965
        **kwds,
966
    ):
967
        """2. step of the momentum calibration work flow: Apply the momentum
968
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
969
        these are used.
970

971
        Args:
972
            calibration (dict, optional): Optional dictionary with calibration data to
973
                use. Defaults to None.
974
            preview (bool, optional): Option to preview the first elements of the data frame.
975
                Defaults to False.
976
            verbose (bool, optional): Option to print out diagnostic information.
977
                Defaults to config["core"]["verbose"].
978
            **kwds: Keyword args passed to ``MomentumCorrector.append_k_axis``.
979
        """
980
        x_column = self._config["dataframe"]["columns"]["x"]
1✔
981
        y_column = self._config["dataframe"]["columns"]["y"]
1✔
982

983
        if self._dataframe is not None:
1✔
984
            logger.info("Adding kx/ky columns to dataframe:")
1✔
985
            df, metadata = self.mc.append_k_axis(
1✔
986
                df=self._dataframe,
987
                calibration=calibration,
988
                **kwds,
989
            )
990
            if (
1✔
991
                self._timed_dataframe is not None
992
                and x_column in self._timed_dataframe.columns
993
                and y_column in self._timed_dataframe.columns
994
            ):
995
                tdf, _ = self.mc.append_k_axis(
1✔
996
                    df=self._timed_dataframe,
997
                    calibration=calibration,
998
                    suppress_output=True,
999
                    **kwds,
1000
                )
1001

1002
            # Add Metadata
1003
            self._attributes.add(
1✔
1004
                metadata,
1005
                "momentum_calibration",
1006
                duplicate_policy="merge",
1007
            )
1008
            self._dataframe = df
1✔
1009
            if (
1✔
1010
                self._timed_dataframe is not None
1011
                and x_column in self._timed_dataframe.columns
1012
                and y_column in self._timed_dataframe.columns
1013
            ):
1014
                self._timed_dataframe = tdf
1✔
1015
        else:
1016
            raise ValueError("No dataframe loaded!")
×
1017
        if preview:
1✔
1018
            logger.info(self._dataframe.head(10))
×
1019
        else:
1020
            logger.info(self._dataframe)
1✔
1021

1022
    # Energy correction workflow
1023
    # 1. Adjust the energy correction parameters
1024
    @call_logger(logger)
1✔
1025
    def adjust_energy_correction(
1✔
1026
        self,
1027
        correction_type: str = None,
1028
        amplitude: float = None,
1029
        center: tuple[float, float] = None,
1030
        apply=False,
1031
        **kwds,
1032
    ):
1033
        """1. step of the energy correction workflow: Opens an interactive plot to
1034
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
1035
        they are not present yet.
1036

1037
        Args:
1038
            correction_type (str, optional): Type of correction to apply to the TOF
1039
                axis. Valid values are:
1040

1041
                - 'spherical'
1042
                - 'Lorentzian'
1043
                - 'Gaussian'
1044
                - 'Lorentzian_asymmetric'
1045

1046
                Defaults to config["energy"]["correction_type"].
1047
            amplitude (float, optional): Amplitude of the correction.
1048
                Defaults to config["energy"]["correction"]["amplitude"].
1049
            center (tuple[float, float], optional): Center X/Y coordinates for the
1050
                correction. Defaults to config["energy"]["correction"]["center"].
1051
            apply (bool, optional): Option to directly apply the provided or default
1052
                correction parameters. Defaults to False.
1053
            **kwds: Keyword parameters passed to ``EnergyCalibrator.adjust_energy_correction()``.
1054
        """
1055
        if self._pre_binned is None:
1✔
1056
            logger.warn("Pre-binned data not present, binning using defaults from config...")
1✔
1057
            self._pre_binned = self.pre_binning()
1✔
1058

1059
        self.ec.adjust_energy_correction(
1✔
1060
            self._pre_binned,
1061
            correction_type=correction_type,
1062
            amplitude=amplitude,
1063
            center=center,
1064
            apply=apply,
1065
            **kwds,
1066
        )
1067

1068
    # 1a. Save energy correction parameters to config file.
1069
    def save_energy_correction(
1✔
1070
        self,
1071
        filename: str = None,
1072
        overwrite: bool = False,
1073
    ):
1074
        """Save the generated energy correction parameters to the folder config file.
1075

1076
        Args:
1077
            filename (str, optional): Filename of the config dictionary to save to.
1078
                Defaults to "sed_config.yaml" in the current folder.
1079
            overwrite (bool, optional): Option to overwrite the present dictionary.
1080
                Defaults to False.
1081
        """
1082
        if filename is None:
1✔
1083
            filename = "sed_config.yaml"
1✔
1084
        if len(self.ec.correction) == 0:
1✔
1085
            raise ValueError("No energy correction parameters to save!")
×
1086
        correction = {}
1✔
1087
        for key, value in self.ec.correction.items():
1✔
1088
            if key == "correction_type":
1✔
1089
                correction[key] = value
1✔
1090
            elif key == "center":
1✔
1091
                correction[key] = [float(i) for i in value]
1✔
1092
            elif key == "creation_date":
1✔
1093
                correction[key] = value.isoformat()
1✔
1094
            else:
1095
                correction[key] = float(value)
1✔
1096

1097
        if "creation_date" not in correction:
1✔
1098
            correction["creation_date"] = datetime.now().isoformat()
×
1099

1100
        config = {"energy": {"correction": correction}}
1✔
1101
        save_config(config, filename, overwrite)
1✔
1102
        logger.info(f"Saved energy correction parameters to {filename}")
1✔
1103

1104
    # 2. Apply energy correction to dataframe
1105
    @call_logger(logger)
1✔
1106
    def apply_energy_correction(
1✔
1107
        self,
1108
        correction: dict = None,
1109
        preview: bool = False,
1110
        **kwds,
1111
    ):
1112
        """2. step of the energy correction workflow: Apply the energy correction
1113
        parameters stored in the class to the dataframe.
1114

1115
        Args:
1116
            correction (dict, optional): Dictionary containing the correction
1117
                parameters. Defaults to config["energy"]["calibration"].
1118
            preview (bool, optional): Option to preview the first elements of the data frame.
1119
                Defaults to False.
1120
            **kwds:
1121
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction()``.
1122
        """
1123
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1124

1125
        if self._dataframe is not None:
1✔
1126
            logger.info("Applying energy correction to dataframe...")
1✔
1127
            df, metadata = self.ec.apply_energy_correction(
1✔
1128
                df=self._dataframe,
1129
                correction=correction,
1130
                **kwds,
1131
            )
1132
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1133
                tdf, _ = self.ec.apply_energy_correction(
1✔
1134
                    df=self._timed_dataframe,
1135
                    correction=correction,
1136
                    suppress_output=True,
1137
                    **kwds,
1138
                )
1139

1140
            # Add Metadata
1141
            self._attributes.add(
1✔
1142
                metadata,
1143
                "energy_correction",
1144
            )
1145
            self._dataframe = df
1✔
1146
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1147
                self._timed_dataframe = tdf
1✔
1148
        else:
1149
            raise ValueError("No dataframe loaded!")
×
1150
        if preview:
1✔
1151
            logger.info(self._dataframe.head(10))
×
1152
        else:
1153
            logger.info(self._dataframe)
1✔
1154

1155
    # Energy calibrator workflow
1156
    # 1. Load and normalize data
1157
    @call_logger(logger)
1✔
1158
    def load_bias_series(
1✔
1159
        self,
1160
        binned_data: xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray] = None,
1161
        data_files: list[str] = None,
1162
        axes: list[str] = None,
1163
        bins: list = None,
1164
        ranges: Sequence[tuple[float, float]] = None,
1165
        biases: np.ndarray = None,
1166
        bias_key: str = None,
1167
        normalize: bool = None,
1168
        span: int = None,
1169
        order: int = None,
1170
    ):
1171
        """1. step of the energy calibration workflow: Load and bin data from
1172
        single-event files, or load binned bias/TOF traces.
1173

1174
        Args:
1175
            binned_data (xr.DataArray | tuple[np.ndarray, np.ndarray, np.ndarray], optional):
1176
                Binned data If provided as DataArray, Needs to contain dimensions
1177
                config["dataframe"]["columns"]["tof"] and config["dataframe"]["columns"]["bias"].
1178
                If provided as tuple, needs to contain elements tof, biases, traces.
1179
            data_files (list[str], optional): list of file paths to bin
1180
            axes (list[str], optional): bin axes.
1181
                Defaults to config["dataframe"]["columns"]["tof"].
1182
            bins (list, optional): number of bins.
1183
                Defaults to config["energy"]["bins"].
1184
            ranges (Sequence[tuple[float, float]], optional): bin ranges.
1185
                Defaults to config["energy"]["ranges"].
1186
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1187
                voltages are extracted from the data files.
1188
            bias_key (str, optional): hdf5 path where bias values are stored.
1189
                Defaults to config["energy"]["bias_key"].
1190
            normalize (bool, optional): Option to normalize traces.
1191
                Defaults to config["energy"]["normalize"].
1192
            span (int, optional): span smoothing parameters of the LOESS method
1193
                (see ``scipy.signal.savgol_filter()``).
1194
                Defaults to config["energy"]["normalize_span"].
1195
            order (int, optional): order smoothing parameters of the LOESS method
1196
                (see ``scipy.signal.savgol_filter()``).
1197
                Defaults to config["energy"]["normalize_order"].
1198
        """
1199
        if binned_data is not None:
1✔
1200
            if isinstance(binned_data, xr.DataArray):
1✔
1201
                if (
1✔
1202
                    self._config["dataframe"]["columns"]["tof"] not in binned_data.dims
1203
                    or self._config["dataframe"]["columns"]["bias"] not in binned_data.dims
1204
                ):
1205
                    raise ValueError(
1✔
1206
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1207
                        f"'{self._config['dataframe']['columns']['tof']}' and "
1208
                        f"'{self._config['dataframe']['columns']['bias']}'!.",
1209
                    )
1210
                tof = binned_data.coords[self._config["dataframe"]["columns"]["tof"]].values
1✔
1211
                biases = binned_data.coords[self._config["dataframe"]["columns"]["bias"]].values
1✔
1212
                traces = binned_data.values[:, :]
1✔
1213
            else:
1214
                try:
1✔
1215
                    (tof, biases, traces) = binned_data
1✔
1216
                except ValueError as exc:
1✔
1217
                    raise ValueError(
1✔
1218
                        "If binned_data is provided as tuple, it needs to contain "
1219
                        "(tof, biases, traces)!",
1220
                    ) from exc
1221
            logger.debug(f"Energy calibration data loaded from binned data. Bias values={biases}.")
1✔
1222
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1223

1224
        elif data_files is not None:
1✔
1225
            self.ec.bin_data(
1✔
1226
                data_files=cast(list[str], self.cpy(data_files)),
1227
                axes=axes,
1228
                bins=bins,
1229
                ranges=ranges,
1230
                biases=biases,
1231
                bias_key=bias_key,
1232
            )
1233
            logger.debug(
1✔
1234
                f"Energy calibration data binned from files {data_files} data. "
1235
                f"Bias values={biases}.",
1236
            )
1237

1238
        else:
1239
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1240

1241
        if (normalize is not None and normalize is True) or (
1✔
1242
            normalize is None and self._config["energy"]["normalize"]
1243
        ):
1244
            if span is None:
1✔
1245
                span = self._config["energy"]["normalize_span"]
1✔
1246
            if order is None:
1✔
1247
                order = self._config["energy"]["normalize_order"]
1✔
1248
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1249
        self.ec.view(
1✔
1250
            traces=self.ec.traces_normed,
1251
            xaxis=self.ec.tof,
1252
            backend="matplotlib",
1253
        )
1254
        plt.xlabel("Time-of-flight")
1✔
1255
        plt.ylabel("Intensity")
1✔
1256
        plt.tight_layout()
1✔
1257

1258
    # 2. extract ranges and get peak positions
1259
    @call_logger(logger)
1✔
1260
    def find_bias_peaks(
1✔
1261
        self,
1262
        ranges: list[tuple] | tuple,
1263
        ref_id: int = 0,
1264
        infer_others: bool = True,
1265
        mode: str = "replace",
1266
        radius: int = None,
1267
        peak_window: int = None,
1268
        apply: bool = False,
1269
    ):
1270
        """2. step of the energy calibration workflow: Find a peak within a given range
1271
        for the indicated reference trace, and tries to find the same peak for all
1272
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1273
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1274
        middle of the set, and don't choose the range too narrow around the peak.
1275
        Alternatively, a list of ranges for all traces can be provided.
1276

1277
        Args:
1278
            ranges (list[tuple] | tuple): Tuple of TOF values indicating a range.
1279
                Alternatively, a list of ranges for all traces can be given.
1280
            ref_id (int, optional): The id of the trace the range refers to.
1281
                Defaults to 0.
1282
            infer_others (bool, optional): Whether to determine the range for the other
1283
                traces. Defaults to True.
1284
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1285
                Defaults to "replace".
1286
            radius (int, optional): Radius parameter for fast_dtw.
1287
                Defaults to config["energy"]["fastdtw_radius"].
1288
            peak_window (int, optional): Peak_window parameter for the peak detection
1289
                algorithm. amount of points that have to have to behave monotonously
1290
                around a peak. Defaults to config["energy"]["peak_window"].
1291
            apply (bool, optional): Option to directly apply the provided parameters.
1292
                Defaults to False.
1293
        """
1294
        if radius is None:
1✔
1295
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1296
        if peak_window is None:
1✔
1297
            peak_window = self._config["energy"]["peak_window"]
1✔
1298
        if not infer_others:
1✔
1299
            self.ec.add_ranges(
1✔
1300
                ranges=ranges,
1301
                ref_id=ref_id,
1302
                infer_others=infer_others,
1303
                mode=mode,
1304
                radius=radius,
1305
            )
1306
            logger.info(f"Use feature ranges: {self.ec.featranges}.")
1✔
1307
            try:
1✔
1308
                self.ec.feature_extract(peak_window=peak_window)
1✔
1309
                logger.info(f"Extracted energy features: {self.ec.peaks}.")
1✔
1310
                self.ec.view(
1✔
1311
                    traces=self.ec.traces_normed,
1312
                    segs=self.ec.featranges,
1313
                    xaxis=self.ec.tof,
1314
                    peaks=self.ec.peaks,
1315
                    backend="matplotlib",
1316
                )
1317
            except IndexError:
×
1318
                logger.error("Could not determine all peaks!")
×
1319
                raise
×
1320
        else:
1321
            # New adjustment tool
1322
            assert isinstance(ranges, tuple)
1✔
1323
            self.ec.adjust_ranges(
1✔
1324
                ranges=ranges,
1325
                ref_id=ref_id,
1326
                traces=self.ec.traces_normed,
1327
                infer_others=infer_others,
1328
                radius=radius,
1329
                peak_window=peak_window,
1330
                apply=apply,
1331
            )
1332

1333
    # 3. Fit the energy calibration relation
1334
    @call_logger(logger)
1✔
1335
    def calibrate_energy_axis(
1✔
1336
        self,
1337
        ref_energy: float,
1338
        method: str = None,
1339
        energy_scale: str = None,
1340
        **kwds,
1341
    ):
1342
        """3. Step of the energy calibration workflow: Calculate the calibration
1343
        function for the energy axis, and apply it to the dataframe. Two
1344
        approximations are implemented, a (normally 3rd order) polynomial
1345
        approximation, and a d^2/(t-t0)^2 relation.
1346

1347
        Args:
1348
            ref_energy (float): Binding/kinetic energy of the detected feature.
1349
            method (str, optional): Method for determining the energy calibration.
1350

1351
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1352
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1353

1354
                Defaults to config["energy"]["calibration_method"]
1355
            energy_scale (str, optional): Direction of increasing energy scale.
1356

1357
                - **'kinetic'**: increasing energy with decreasing TOF.
1358
                - **'binding'**: increasing energy with increasing TOF.
1359

1360
                Defaults to config["energy"]["energy_scale"]
1361
            **kwds**: Keyword parameters passed to ``EnergyCalibrator.calibrate()``.
1362
        """
1363
        if method is None:
1✔
1364
            method = self._config["energy"]["calibration_method"]
1✔
1365

1366
        if energy_scale is None:
1✔
1367
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1368

1369
        self.ec.calibrate(
1✔
1370
            ref_energy=ref_energy,
1371
            method=method,
1372
            energy_scale=energy_scale,
1373
            **kwds,
1374
        )
1375
        if self._verbose:
1✔
1376
            if self.ec.traces_normed is not None:
1✔
1377
                self.ec.view(
1✔
1378
                    traces=self.ec.traces_normed,
1379
                    xaxis=self.ec.calibration["axis"],
1380
                    align=True,
1381
                    energy_scale=energy_scale,
1382
                    backend="matplotlib",
1383
                    title="Quality of Calibration",
1384
                )
1385
                plt.xlabel("Energy (eV)")
1✔
1386
                plt.ylabel("Intensity")
1✔
1387
                plt.tight_layout()
1✔
1388
                plt.show()
1✔
1389
            if energy_scale == "kinetic":
1✔
1390
                self.ec.view(
1✔
1391
                    traces=self.ec.calibration["axis"][None, :] + self.ec.biases[0],
1392
                    xaxis=self.ec.tof,
1393
                    backend="matplotlib",
1394
                    show_legend=False,
1395
                    title="E/TOF relationship",
1396
                )
1397
                plt.scatter(
1✔
1398
                    self.ec.peaks[:, 0],
1399
                    -(self.ec.biases - self.ec.biases[0]) + ref_energy,
1400
                    s=50,
1401
                    c="k",
1402
                )
1403
                plt.tight_layout()
1✔
1404
            elif energy_scale == "binding":
1✔
1405
                self.ec.view(
1✔
1406
                    traces=self.ec.calibration["axis"][None, :] - self.ec.biases[0],
1407
                    xaxis=self.ec.tof,
1408
                    backend="matplotlib",
1409
                    show_legend=False,
1410
                    title="E/TOF relationship",
1411
                )
1412
                plt.scatter(
1✔
1413
                    self.ec.peaks[:, 0],
1414
                    self.ec.biases - self.ec.biases[0] + ref_energy,
1415
                    s=50,
1416
                    c="k",
1417
                )
1418
            else:
1419
                raise ValueError(
×
1420
                    'energy_scale needs to be either "binding" or "kinetic"',
1421
                    f", got {energy_scale}.",
1422
                )
1423
            plt.xlabel("Time-of-flight")
1✔
1424
            plt.ylabel("Energy (eV)")
1✔
1425
            plt.tight_layout()
1✔
1426
            plt.show()
1✔
1427

1428
    # 3a. Save energy calibration parameters to config file.
1429
    def save_energy_calibration(
1✔
1430
        self,
1431
        filename: str = None,
1432
        overwrite: bool = False,
1433
    ):
1434
        """Save the generated energy calibration parameters to the folder config file.
1435

1436
        Args:
1437
            filename (str, optional): Filename of the config dictionary to save to.
1438
                Defaults to "sed_config.yaml" in the current folder.
1439
            overwrite (bool, optional): Option to overwrite the present dictionary.
1440
                Defaults to False.
1441
        """
1442
        if filename is None:
1✔
1443
            filename = "sed_config.yaml"
×
1444
        if len(self.ec.calibration) == 0:
1✔
1445
            raise ValueError("No energy calibration parameters to save!")
×
1446
        calibration = {}
1✔
1447
        for key, value in self.ec.calibration.items():
1✔
1448
            if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1449
                continue
1✔
1450
            if key == "energy_scale":
1✔
1451
                calibration[key] = value
1✔
1452
            elif key == "coeffs":
1✔
1453
                calibration[key] = [float(i) for i in value]
1✔
1454
            elif key == "creation_date":
1✔
1455
                calibration[key] = value.isoformat()
1✔
1456
            else:
1457
                calibration[key] = float(value)
1✔
1458

1459
        if "creation_date" not in calibration:
1✔
1460
            calibration["creation_date"] = datetime.now().isoformat()
×
1461

1462
        config = {"energy": {"calibration": calibration}}
1✔
1463
        save_config(config, filename, overwrite)
1✔
1464
        logger.info(f'Saved energy calibration parameters to "{filename}".')
1✔
1465

1466
    # 4. Apply energy calibration to the dataframe
1467
    @call_logger(logger)
1✔
1468
    def append_energy_axis(
1✔
1469
        self,
1470
        calibration: dict = None,
1471
        bias_voltage: float = None,
1472
        preview: bool = False,
1473
        **kwds,
1474
    ):
1475
        """4. step of the energy calibration workflow: Apply the calibration function
1476
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1477
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1478
        can be provided.
1479

1480
        Args:
1481
            calibration (dict, optional): Calibration dict containing calibration
1482
                parameters. Overrides calibration from class or config.
1483
                Defaults to None.
1484
            bias_voltage (float, optional): Sample bias voltage of the scan data. If omitted,
1485
                the bias voltage is being read from the dataframe. If it is not found there,
1486
                a warning is printed and the calibrated data might have an offset.
1487
            preview (bool): Option to preview the first elements of the data frame.
1488
            **kwds:
1489
                Keyword args passed to ``EnergyCalibrator.append_energy_axis()``.
1490
        """
1491
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1492

1493
        if self._dataframe is not None:
1✔
1494
            logger.info("Adding energy column to dataframe:")
1✔
1495
            df, metadata = self.ec.append_energy_axis(
1✔
1496
                df=self._dataframe,
1497
                calibration=calibration,
1498
                bias_voltage=bias_voltage,
1499
                **kwds,
1500
            )
1501
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1502
                tdf, _ = self.ec.append_energy_axis(
1✔
1503
                    df=self._timed_dataframe,
1504
                    calibration=calibration,
1505
                    bias_voltage=bias_voltage,
1506
                    suppress_output=True,
1507
                    **kwds,
1508
                )
1509

1510
            # Add Metadata
1511
            self._attributes.add(
1✔
1512
                metadata,
1513
                "energy_calibration",
1514
                duplicate_policy="merge",
1515
            )
1516
            self._dataframe = df
1✔
1517
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1518
                self._timed_dataframe = tdf
1✔
1519

1520
        else:
1521
            raise ValueError("No dataframe loaded!")
×
1522
        if preview:
1✔
1523
            logger.info(self._dataframe.head(10))
×
1524
        else:
1525
            logger.info(self._dataframe)
1✔
1526

1527
    @call_logger(logger)
1✔
1528
    def add_energy_offset(
1✔
1529
        self,
1530
        constant: float = None,
1531
        columns: str | Sequence[str] = None,
1532
        weights: float | Sequence[float] = None,
1533
        reductions: str | Sequence[str] = None,
1534
        preserve_mean: bool | Sequence[bool] = None,
1535
        preview: bool = False,
1536
    ) -> None:
1537
        """Shift the energy axis of the dataframe by a given amount.
1538

1539
        Args:
1540
            constant (float, optional): The constant to shift the energy axis by.
1541
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1542
            weights (float | Sequence[float], optional): weights to apply to the columns.
1543
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1544
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1545
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1546
                case the function is applied to the column to generate a single value for the whole
1547
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1548
                Currently only "mean" is supported.
1549
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1550
                column before applying the shift. Defaults to False.
1551
            preview (bool, optional): Option to preview the first elements of the data frame.
1552
                Defaults to False.
1553

1554
        Raises:
1555
            ValueError: If the energy column is not in the dataframe.
1556
        """
1557
        energy_column = self._config["dataframe"]["columns"]["energy"]
1✔
1558
        if energy_column not in self._dataframe.columns:
1✔
1559
            raise ValueError(
1✔
1560
                f"Energy column {energy_column} not found in dataframe! "
1561
                "Run `append_energy_axis()` first.",
1562
            )
1563
        if self.dataframe is not None:
1✔
1564
            logger.info("Adding energy offset to dataframe:")
1✔
1565
            df, metadata = self.ec.add_offsets(
1✔
1566
                df=self._dataframe,
1567
                constant=constant,
1568
                columns=columns,
1569
                energy_column=energy_column,
1570
                weights=weights,
1571
                reductions=reductions,
1572
                preserve_mean=preserve_mean,
1573
            )
1574
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1575
                tdf, _ = self.ec.add_offsets(
1✔
1576
                    df=self._timed_dataframe,
1577
                    constant=constant,
1578
                    columns=columns,
1579
                    energy_column=energy_column,
1580
                    weights=weights,
1581
                    reductions=reductions,
1582
                    preserve_mean=preserve_mean,
1583
                    suppress_output=True,
1584
                )
1585

1586
            self._attributes.add(
1✔
1587
                metadata,
1588
                "add_energy_offset",
1589
                # TODO: allow only appending when no offset along this column(s) was applied
1590
                # TODO: clear memory of modifications if the energy axis is recalculated
1591
                duplicate_policy="append",
1592
            )
1593
            self._dataframe = df
1✔
1594
            if self._timed_dataframe is not None and energy_column in self._timed_dataframe.columns:
1✔
1595
                self._timed_dataframe = tdf
1✔
1596
        else:
1597
            raise ValueError("No dataframe loaded!")
×
1598
        if preview:
1✔
1599
            logger.info(self._dataframe.head(10))
×
1600
        else:
1601
            logger.info(self._dataframe)
1✔
1602

1603
    def save_energy_offset(
1✔
1604
        self,
1605
        filename: str = None,
1606
        overwrite: bool = False,
1607
    ):
1608
        """Save the generated energy calibration parameters to the folder config file.
1609

1610
        Args:
1611
            filename (str, optional): Filename of the config dictionary to save to.
1612
                Defaults to "sed_config.yaml" in the current folder.
1613
            overwrite (bool, optional): Option to overwrite the present dictionary.
1614
                Defaults to False.
1615
        """
1616
        if filename is None:
×
1617
            filename = "sed_config.yaml"
×
1618
        if len(self.ec.offsets) == 0:
×
1619
            raise ValueError("No energy offset parameters to save!")
×
1620

1621
        offsets = deepcopy(self.ec.offsets)
×
1622

1623
        if "creation_date" not in offsets.keys():
×
1624
            offsets["creation_date"] = datetime.now()
×
1625

1626
        offsets["creation_date"] = offsets["creation_date"].isoformat()
×
1627

1628
        config = {"energy": {"offsets": offsets}}
×
1629
        save_config(config, filename, overwrite)
×
1630
        logger.info(f'Saved energy offset parameters to "{filename}".')
×
1631

1632
    @call_logger(logger)
1✔
1633
    def append_tof_ns_axis(
1✔
1634
        self,
1635
        preview: bool = False,
1636
        **kwds,
1637
    ):
1638
        """Convert time-of-flight channel steps to nanoseconds.
1639

1640
        Args:
1641
            tof_ns_column (str, optional): Name of the generated column containing the
1642
                time-of-flight in nanosecond.
1643
                Defaults to config["dataframe"]["columns"]["tof_ns"].
1644
            preview (bool, optional): Option to preview the first elements of the data frame.
1645
                Defaults to False.
1646
            **kwds: additional arguments are passed to ``EnergyCalibrator.append_tof_ns_axis()``.
1647

1648
        """
1649
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1650

1651
        if self._dataframe is not None:
1✔
1652
            logger.info("Adding time-of-flight column in nanoseconds to dataframe.")
1✔
1653
            # TODO assert order of execution through metadata
1654

1655
            df, metadata = self.ec.append_tof_ns_axis(
1✔
1656
                df=self._dataframe,
1657
                **kwds,
1658
            )
1659
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1660
                tdf, _ = self.ec.append_tof_ns_axis(
1✔
1661
                    df=self._timed_dataframe,
1662
                    **kwds,
1663
                )
1664

1665
            self._attributes.add(
1✔
1666
                metadata,
1667
                "tof_ns_conversion",
1668
                duplicate_policy="overwrite",
1669
            )
1670
            self._dataframe = df
1✔
1671
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1672
                self._timed_dataframe = tdf
1✔
1673
        else:
1674
            raise ValueError("No dataframe loaded!")
×
1675
        if preview:
1✔
1676
            logger.info(self._dataframe.head(10))
×
1677
        else:
1678
            logger.info(self._dataframe)
1✔
1679

1680
    @call_logger(logger)
1✔
1681
    def align_dld_sectors(
1✔
1682
        self,
1683
        sector_delays: np.ndarray = None,
1684
        preview: bool = False,
1685
        **kwds,
1686
    ):
1687
        """Align the 8s sectors of the HEXTOF endstation.
1688

1689
        Args:
1690
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1691
                config["dataframe"]["sector_delays"].
1692
            preview (bool, optional): Option to preview the first elements of the data frame.
1693
                Defaults to False.
1694
            **kwds: additional arguments are passed to ``EnergyCalibrator.align_dld_sectors()``.
1695
        """
1696
        tof_column = self._config["dataframe"]["columns"]["tof"]
1✔
1697

1698
        if self._dataframe is not None:
1✔
1699
            logger.info("Aligning 8s sectors of dataframe")
1✔
1700
            # TODO assert order of execution through metadata
1701

1702
            df, metadata = self.ec.align_dld_sectors(
1✔
1703
                df=self._dataframe,
1704
                sector_delays=sector_delays,
1705
                **kwds,
1706
            )
1707
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1708
                tdf, _ = self.ec.align_dld_sectors(
×
1709
                    df=self._timed_dataframe,
1710
                    sector_delays=sector_delays,
1711
                    **kwds,
1712
                )
1713

1714
            self._attributes.add(
1✔
1715
                metadata,
1716
                "dld_sector_alignment",
1717
                duplicate_policy="raise",
1718
            )
1719
            self._dataframe = df
1✔
1720
            if self._timed_dataframe is not None and tof_column in self._timed_dataframe.columns:
1✔
1721
                self._timed_dataframe = tdf
×
1722
        else:
1723
            raise ValueError("No dataframe loaded!")
×
1724
        if preview:
1✔
1725
            logger.info(self._dataframe.head(10))
×
1726
        else:
1727
            logger.info(self._dataframe)
1✔
1728

1729
    # Delay calibration function
1730
    @call_logger(logger)
1✔
1731
    def calibrate_delay_axis(
1✔
1732
        self,
1733
        delay_range: tuple[float, float] = None,
1734
        read_delay_ranges: bool = True,
1735
        datafile: str = None,
1736
        preview: bool = False,
1737
        **kwds,
1738
    ):
1739
        """Append delay column to dataframe. Either provide delay ranges, or read
1740
        them from a file.
1741

1742
        Args:
1743
            delay_range (tuple[float, float], optional): The scanned delay range in
1744
                picoseconds. Defaults to None.
1745
            read_delay_ranges (bool, optional): Option whether to read the delay ranges from the
1746
                data. Defaults to True. If false, parameters in the config will be used.
1747
            datafile (str, optional): The file from which to read the delay ranges.
1748
                Defaults to the first file of the dataset.
1749
            preview (bool, optional): Option to preview the first elements of the data frame.
1750
                Defaults to False.
1751
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1752
        """
1753
        adc_column = self._config["dataframe"]["columns"]["adc"]
1✔
1754
        if adc_column not in self._dataframe.columns:
1✔
1755
            raise ValueError(f"ADC column {adc_column} not found in dataframe, cannot calibrate!")
×
1756

1757
        if self._dataframe is not None:
1✔
1758
            logger.info("Adding delay column to dataframe:")
1✔
1759

1760
            if read_delay_ranges and delay_range is None and datafile is None:
1✔
1761
                try:
1✔
1762
                    datafile = self._files[0]
1✔
1763
                except IndexError as exc:
×
1764
                    raise ValueError(
×
1765
                        "No datafile available, specify either 'datafile' or 'delay_range'.",
1766
                    ) from exc
1767

1768
            df, metadata = self.dc.append_delay_axis(
1✔
1769
                self._dataframe,
1770
                delay_range=delay_range,
1771
                datafile=datafile,
1772
                **kwds,
1773
            )
1774
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1775
                tdf, _ = self.dc.append_delay_axis(
1✔
1776
                    self._timed_dataframe,
1777
                    delay_range=delay_range,
1778
                    datafile=datafile,
1779
                    suppress_output=True,
1780
                    **kwds,
1781
                )
1782

1783
            # Add Metadata
1784
            self._attributes.add(
1✔
1785
                metadata,
1786
                "delay_calibration",
1787
                duplicate_policy="overwrite",
1788
            )
1789
            self._dataframe = df
1✔
1790
            if self._timed_dataframe is not None and adc_column in self._timed_dataframe.columns:
1✔
1791
                self._timed_dataframe = tdf
1✔
1792
        else:
1793
            raise ValueError("No dataframe loaded!")
×
1794
        if preview:
1✔
1795
            logger.info(self._dataframe.head(10))
×
1796
        else:
1797
            logger.debug(self._dataframe)
1✔
1798

1799
    def save_delay_calibration(
1✔
1800
        self,
1801
        filename: str = None,
1802
        overwrite: bool = False,
1803
    ) -> None:
1804
        """Save the generated delay calibration parameters to the folder config file.
1805

1806
        Args:
1807
            filename (str, optional): Filename of the config dictionary to save to.
1808
                Defaults to "sed_config.yaml" in the current folder.
1809
            overwrite (bool, optional): Option to overwrite the present dictionary.
1810
                Defaults to False.
1811
        """
1812
        if filename is None:
1✔
1813
            filename = "sed_config.yaml"
×
1814

1815
        if len(self.dc.calibration) == 0:
1✔
1816
            raise ValueError("No delay calibration parameters to save!")
×
1817
        calibration = {}
1✔
1818
        for key, value in self.dc.calibration.items():
1✔
1819
            if key == "datafile":
1✔
1820
                calibration[key] = value
1✔
1821
            elif key in ["adc_range", "delay_range", "delay_range_mm"]:
1✔
1822
                calibration[key] = [float(i) for i in value]
1✔
1823
            elif key == "creation_date":
1✔
1824
                calibration[key] = value.isoformat()
1✔
1825
            else:
1826
                calibration[key] = float(value)
1✔
1827

1828
        if "creation_date" not in calibration:
1✔
1829
            calibration["creation_date"] = datetime.now().isoformat()
×
1830

1831
        config = {
1✔
1832
            "delay": {
1833
                "calibration": calibration,
1834
            },
1835
        }
1836
        save_config(config, filename, overwrite)
1✔
1837

1838
    @call_logger(logger)
1✔
1839
    def add_delay_offset(
1✔
1840
        self,
1841
        constant: float = None,
1842
        flip_delay_axis: bool = None,
1843
        columns: str | Sequence[str] = None,
1844
        weights: float | Sequence[float] = 1.0,
1845
        reductions: str | Sequence[str] = None,
1846
        preserve_mean: bool | Sequence[bool] = False,
1847
        preview: bool = False,
1848
    ) -> None:
1849
        """Shift the delay axis of the dataframe by a constant or other columns.
1850

1851
        Args:
1852
            constant (float, optional): The constant to shift the delay axis by.
1853
            flip_delay_axis (bool, optional): Option to reverse the direction of the delay axis.
1854
            columns (str | Sequence[str], optional): Name of the column(s) to apply the shift from.
1855
            weights (float | Sequence[float], optional): weights to apply to the columns.
1856
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1857
            reductions (str | Sequence[str], optional): The reduction to apply to the column.
1858
                Should be an available method of dask.dataframe.Series. For example "mean". In this
1859
                case the function is applied to the column to generate a single value for the whole
1860
                dataset. If None, the shift is applied per-dataframe-row. Defaults to None.
1861
                Currently only "mean" is supported.
1862
            preserve_mean (bool | Sequence[bool], optional): Whether to subtract the mean of the
1863
                column before applying the shift. Defaults to False.
1864
            preview (bool, optional): Option to preview the first elements of the data frame.
1865
                Defaults to False.
1866

1867
        Raises:
1868
            ValueError: If the delay column is not in the dataframe.
1869
        """
1870
        delay_column = self._config["dataframe"]["columns"]["delay"]
1✔
1871
        if delay_column not in self._dataframe.columns:
1✔
1872
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
1✔
1873

1874
        if self.dataframe is not None:
1✔
1875
            logger.info("Adding delay offset to dataframe:")
1✔
1876
            df, metadata = self.dc.add_offsets(
1✔
1877
                df=self._dataframe,
1878
                constant=constant,
1879
                flip_delay_axis=flip_delay_axis,
1880
                columns=columns,
1881
                delay_column=delay_column,
1882
                weights=weights,
1883
                reductions=reductions,
1884
                preserve_mean=preserve_mean,
1885
            )
1886
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1887
                tdf, _ = self.dc.add_offsets(
1✔
1888
                    df=self._timed_dataframe,
1889
                    constant=constant,
1890
                    flip_delay_axis=flip_delay_axis,
1891
                    columns=columns,
1892
                    delay_column=delay_column,
1893
                    weights=weights,
1894
                    reductions=reductions,
1895
                    preserve_mean=preserve_mean,
1896
                    suppress_output=True,
1897
                )
1898

1899
            self._attributes.add(
1✔
1900
                metadata,
1901
                "delay_offset",
1902
                duplicate_policy="append",
1903
            )
1904
            self._dataframe = df
1✔
1905
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
1✔
1906
                self._timed_dataframe = tdf
1✔
1907
        else:
1908
            raise ValueError("No dataframe loaded!")
×
1909
        if preview:
1✔
1910
            logger.info(self._dataframe.head(10))
1✔
1911
        else:
1912
            logger.info(self._dataframe)
1✔
1913

1914
    def save_delay_offsets(
1✔
1915
        self,
1916
        filename: str = None,
1917
        overwrite: bool = False,
1918
    ) -> None:
1919
        """Save the generated delay calibration parameters to the folder config file.
1920

1921
        Args:
1922
            filename (str, optional): Filename of the config dictionary to save to.
1923
                Defaults to "sed_config.yaml" in the current folder.
1924
            overwrite (bool, optional): Option to overwrite the present dictionary.
1925
                Defaults to False.
1926
        """
1927
        if filename is None:
1✔
1928
            filename = "sed_config.yaml"
×
1929
        if len(self.dc.offsets) == 0:
1✔
1930
            raise ValueError("No delay offset parameters to save!")
×
1931

1932
        offsets = deepcopy(self.dc.offsets)
1✔
1933

1934
        if "creation_date" not in offsets.keys():
1✔
1935
            offsets["creation_date"] = datetime.now()
×
1936

1937
        offsets["creation_date"] = offsets["creation_date"].isoformat()
1✔
1938

1939
        config = {"delay": {"offsets": offsets}}
1✔
1940
        save_config(config, filename, overwrite)
1✔
1941
        logger.info(f'Saved delay offset parameters to "{filename}".')
1✔
1942

1943
    def save_workflow_params(
1✔
1944
        self,
1945
        filename: str = None,
1946
        overwrite: bool = False,
1947
    ) -> None:
1948
        """run all save calibration parameter methods
1949

1950
        Args:
1951
            filename (str, optional): Filename of the config dictionary to save to.
1952
                Defaults to "sed_config.yaml" in the current folder.
1953
            overwrite (bool, optional): Option to overwrite the present dictionary.
1954
                Defaults to False.
1955
        """
1956
        for method in [
×
1957
            self.save_splinewarp,
1958
            self.save_transformations,
1959
            self.save_momentum_calibration,
1960
            self.save_energy_correction,
1961
            self.save_energy_calibration,
1962
            self.save_energy_offset,
1963
            self.save_delay_calibration,
1964
            self.save_delay_offsets,
1965
        ]:
1966
            try:
×
1967
                method(filename, overwrite)
×
1968
            except (ValueError, AttributeError, KeyError):
×
1969
                pass
×
1970

1971
    @call_logger(logger)
1✔
1972
    def add_jitter(
1✔
1973
        self,
1974
        cols: list[str] = None,
1975
        amps: float | Sequence[float] = None,
1976
        **kwds,
1977
    ):
1978
        """Add jitter to the selected dataframe columns.
1979

1980
        Args:
1981
            cols (list[str], optional): The columns onto which to apply jitter.
1982
                Defaults to config["dataframe"]["jitter_cols"].
1983
            amps (float | Sequence[float], optional): Amplitude scalings for the
1984
                jittering noise. If one number is given, the same is used for all axes.
1985
                For uniform noise (default) it will cover the interval [-amp, +amp].
1986
                Defaults to config["dataframe"]["jitter_amps"].
1987
            **kwds: additional keyword arguments passed to ``apply_jitter``.
1988
        """
1989
        if cols is None:
1✔
1990
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1991
        for loc, col in enumerate(cols):
1✔
1992
            if col.startswith("@"):
1✔
1993
                cols[loc] = self._config["dataframe"]["columns"].get(col.strip("@"))
1✔
1994

1995
        if amps is None:
1✔
1996
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1997

1998
        self._dataframe = self._dataframe.map_partitions(
1✔
1999
            apply_jitter,
2000
            cols=cols,
2001
            cols_jittered=cols,
2002
            amps=amps,
2003
            **kwds,
2004
        )
2005
        if self._timed_dataframe is not None:
1✔
2006
            cols_timed = cols.copy()
1✔
2007
            for col in cols:
1✔
2008
                if col not in self._timed_dataframe.columns:
1✔
2009
                    cols_timed.remove(col)
×
2010

2011
            if cols_timed:
1✔
2012
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
2013
                    apply_jitter,
2014
                    cols=cols_timed,
2015
                    cols_jittered=cols_timed,
2016
                )
2017
        metadata = []
1✔
2018
        for col in cols:
1✔
2019
            metadata.append(col)
1✔
2020
        # TODO: allow only appending if columns are not jittered yet
2021
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
2022
        logger.info(f"add_jitter: Added jitter to columns {cols}.")
1✔
2023

2024
    @call_logger(logger)
1✔
2025
    def add_time_stamped_data(
1✔
2026
        self,
2027
        dest_column: str,
2028
        time_stamps: np.ndarray = None,
2029
        data: np.ndarray = None,
2030
        archiver_channel: str = None,
2031
        **kwds,
2032
    ):
2033
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
2034
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
2035
        an EPICS archiver instance.
2036

2037
        Args:
2038
            dest_column (str): destination column name
2039
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
2040
                time stamps are retrieved from the epics archiver
2041
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
2042
                If omitted, data are retrieved from the epics archiver.
2043
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
2044
                Either this or data and time_stamps have to be present.
2045
            **kwds:
2046

2047
                - **time_stamp_column**: Dataframe column containing time-stamp data
2048

2049
                Additional keyword arguments passed to ``add_time_stamped_data``.
2050
        """
2051
        time_stamp_column = kwds.pop(
1✔
2052
            "time_stamp_column",
2053
            self._config["dataframe"]["columns"].get("timestamp", ""),
2054
        )
2055

2056
        if time_stamps is None and data is None:
1✔
2057
            if archiver_channel is None:
×
2058
                raise ValueError(
×
2059
                    "Either archiver_channel or both time_stamps and data have to be present!",
2060
                )
2061
            if self.loader.__name__ != "mpes":
×
2062
                raise NotImplementedError(
×
2063
                    "This function is currently only implemented for the mpes loader!",
2064
                )
2065
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
2066
            # get channel data with +-5 seconds safety margin
2067
            time_stamps, data = get_archiver_data(
×
2068
                archiver_url=self._config["metadata"].get("archiver_url", ""),
2069
                archiver_channel=archiver_channel,
2070
                ts_from=ts_from - 5,
2071
                ts_to=ts_to + 5,
2072
            )
2073

2074
        self._dataframe = add_time_stamped_data(
1✔
2075
            self._dataframe,
2076
            time_stamps=time_stamps,
2077
            data=data,
2078
            dest_column=dest_column,
2079
            time_stamp_column=time_stamp_column,
2080
            **kwds,
2081
        )
2082
        if self._timed_dataframe is not None:
1✔
2083
            if time_stamp_column in self._timed_dataframe:
1✔
2084
                self._timed_dataframe = add_time_stamped_data(
1✔
2085
                    self._timed_dataframe,
2086
                    time_stamps=time_stamps,
2087
                    data=data,
2088
                    dest_column=dest_column,
2089
                    time_stamp_column=time_stamp_column,
2090
                    **kwds,
2091
                )
2092
        metadata: list[Any] = []
1✔
2093
        metadata.append(dest_column)
1✔
2094
        metadata.append(time_stamps)
1✔
2095
        metadata.append(data)
1✔
2096
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
2097
        logger.info(f"add_time_stamped_data: Added time-stamped data as column {dest_column}.")
1✔
2098

2099
    @call_logger(logger)
1✔
2100
    def pre_binning(
1✔
2101
        self,
2102
        df_partitions: int | Sequence[int] = 100,
2103
        axes: list[str] = None,
2104
        bins: list[int] = None,
2105
        ranges: Sequence[tuple[float, float]] = None,
2106
        **kwds,
2107
    ) -> xr.DataArray:
2108
        """Function to do an initial binning of the dataframe loaded to the class.
2109

2110
        Args:
2111
            df_partitions (int | Sequence[int], optional): Number of dataframe partitions to
2112
                use for the initial binning. Defaults to 100.
2113
            axes (list[str], optional): Axes to bin.
2114
                Defaults to config["momentum"]["axes"].
2115
            bins (list[int], optional): Bin numbers to use for binning.
2116
                Defaults to config["momentum"]["bins"].
2117
            ranges (Sequence[tuple[float, float]], optional): Ranges to use for binning.
2118
                Defaults to config["momentum"]["ranges"].
2119
            **kwds: Keyword argument passed to ``compute``.
2120

2121
        Returns:
2122
            xr.DataArray: pre-binned data-array.
2123
        """
2124
        if axes is None:
1✔
2125
            axes = self._config["momentum"]["axes"]
1✔
2126
        for loc, axis in enumerate(axes):
1✔
2127
            if axis.startswith("@"):
1✔
2128
                axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))
1✔
2129

2130
        if bins is None:
1✔
2131
            bins = self._config["momentum"]["bins"]
1✔
2132
        if ranges is None:
1✔
2133
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
2134
            ranges_[2] = np.asarray(ranges_[2]) / self._config["dataframe"]["tof_binning"]
1✔
2135
            ranges = [cast(tuple[float, float], tuple(v)) for v in ranges_]
1✔
2136

2137
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2138

2139
        return self.compute(
1✔
2140
            bins=bins,
2141
            axes=axes,
2142
            ranges=ranges,
2143
            df_partitions=df_partitions,
2144
            **kwds,
2145
        )
2146

2147
    @call_logger(logger)
1✔
2148
    def compute(
1✔
2149
        self,
2150
        bins: int | dict | tuple | list[int] | list[np.ndarray] | list[tuple] = 100,
2151
        axes: str | Sequence[str] = None,
2152
        ranges: Sequence[tuple[float, float]] = None,
2153
        normalize_to_acquisition_time: bool | str = False,
2154
        **kwds,
2155
    ) -> xr.DataArray:
2156
        """Compute the histogram along the given dimensions.
2157

2158
        Args:
2159
            bins (int | dict | tuple | list[int] | list[np.ndarray] | list[tuple], optional):
2160
                Definition of the bins. Can be any of the following cases:
2161

2162
                - an integer describing the number of bins in on all dimensions
2163
                - a tuple of 3 numbers describing start, end and step of the binning
2164
                  range
2165
                - a np.arrays defining the binning edges
2166
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
2167
                - a dictionary made of the axes as keys and any of the above as values.
2168

2169
                This takes priority over the axes and range arguments. Defaults to 100.
2170
            axes (str | Sequence[str], optional): The names of the axes (columns)
2171
                on which to calculate the histogram. The order will be the order of the
2172
                dimensions in the resulting array. Defaults to None.
2173
            ranges (Sequence[tuple[float, float]], optional): list of tuples containing
2174
                the start and end point of the binning range. Defaults to None.
2175
            normalize_to_acquisition_time (bool | str): Option to normalize the
2176
                result to the acquisition time. If a "slow" axis was scanned, providing
2177
                the name of the scanned axis will compute and apply the corresponding
2178
                normalization histogram. Defaults to False.
2179
            **kwds: Keyword arguments:
2180

2181
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
2182
                  ``bin_dataframe`` for details. Defaults to
2183
                  config["binning"]["hist_mode"].
2184
                - **mode**: Defines how the results from each partition are combined.
2185
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
2186
                  Defaults to config["binning"]["mode"].
2187
                - **pbar**: Option to show the tqdm progress bar. Defaults to
2188
                  config["binning"]["pbar"].
2189
                - **n_cores**: Number of CPU cores to use for parallelization.
2190
                  Defaults to config["core"]["num_cores"] or N_CPU-1.
2191
                - **threads_per_worker**: Limit the number of threads that
2192
                  multiprocessing can spawn per binning thread. Defaults to
2193
                  config["binning"]["threads_per_worker"].
2194
                - **threadpool_api**: The API to use for multiprocessing. "blas",
2195
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
2196
                  config["binning"]["threadpool_API"].
2197
                - **df_partitions**: A sequence of dataframe partitions, or the
2198
                  number of the dataframe partitions to use. Defaults to all partitions.
2199
                - **filter**: A Sequence of Dictionaries with entries "col", "lower_bound",
2200
                  "upper_bound" to apply as filter to the dataframe before binning. The
2201
                  dataframe in the class remains unmodified by this.
2202

2203
                Additional kwds are passed to ``bin_dataframe``.
2204

2205
        Raises:
2206
            AssertError: Rises when no dataframe has been loaded.
2207

2208
        Returns:
2209
            xr.DataArray: The result of the n-dimensional binning represented in an
2210
            xarray object, combining the data with the axes.
2211
        """
2212
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
2213

2214
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
2215
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
2216
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
2217
        num_cores = kwds.pop("num_cores", self._config["core"]["num_cores"])
1✔
2218
        threads_per_worker = kwds.pop(
1✔
2219
            "threads_per_worker",
2220
            self._config["binning"]["threads_per_worker"],
2221
        )
2222
        threadpool_api = kwds.pop(
1✔
2223
            "threadpool_API",
2224
            self._config["binning"]["threadpool_API"],
2225
        )
2226
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2227
        if isinstance(df_partitions, int):
1✔
2228
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2229
        if df_partitions is not None:
1✔
2230
            dataframe = self._dataframe.partitions[df_partitions]
1✔
2231
        else:
2232
            dataframe = self._dataframe
1✔
2233

2234
        filter_params = kwds.pop("filter", None)
1✔
2235
        if filter_params is not None:
1✔
2236
            try:
1✔
2237
                for param in filter_params:
1✔
2238
                    if "col" not in param:
1✔
2239
                        raise ValueError(
1✔
2240
                            "'col' needs to be defined for each filter entry! ",
2241
                            f"Not present in {param}.",
2242
                        )
2243
                    assert set(param.keys()).issubset({"col", "lower_bound", "upper_bound"})
1✔
2244
                    dataframe = apply_filter(dataframe, **param)
1✔
2245
            except AssertionError as exc:
1✔
2246
                invalid_keys = set(param.keys()) - {"lower_bound", "upper_bound"}
1✔
2247
                raise ValueError(
1✔
2248
                    "Only 'col', 'lower_bound' and 'upper_bound' allowed as filter entries. ",
2249
                    f"Parameters {invalid_keys} not valid in {param}.",
2250
                ) from exc
2251

2252
        self._binned = bin_dataframe(
1✔
2253
            df=dataframe,
2254
            bins=bins,
2255
            axes=axes,
2256
            ranges=ranges,
2257
            hist_mode=hist_mode,
2258
            mode=mode,
2259
            pbar=pbar,
2260
            n_cores=num_cores,
2261
            threads_per_worker=threads_per_worker,
2262
            threadpool_api=threadpool_api,
2263
            **kwds,
2264
        )
2265

2266
        for dim in self._binned.dims:
1✔
2267
            try:
1✔
2268
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
2269
            except KeyError:
1✔
2270
                pass
1✔
2271

2272
        self._binned.attrs["units"] = "counts"
1✔
2273
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
2274
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
2275

2276
        if normalize_to_acquisition_time:
1✔
2277
            if isinstance(normalize_to_acquisition_time, str):
1✔
2278
                axis = normalize_to_acquisition_time
1✔
2279
                logger.info(f"Calculate normalization histogram for axis '{axis}'...")
1✔
2280
                self._normalization_histogram = self.get_normalization_histogram(
1✔
2281
                    axis=axis,
2282
                    df_partitions=df_partitions,
2283
                )
2284
                # if the axes are named correctly, xarray figures out the normalization correctly
2285
                self._normalized = self._binned / self._normalization_histogram
1✔
2286
                # Set datatype of binned data
2287
                self._normalized.data = self._normalized.data.astype(self._binned.data.dtype)
1✔
2288
                self._attributes.add(
1✔
2289
                    self._normalization_histogram.values,
2290
                    name="normalization_histogram",
2291
                    duplicate_policy="overwrite",
2292
                )
2293
            else:
UNCOV
2294
                acquisition_time = self.loader.get_elapsed_time(
×
2295
                    fids=df_partitions,
2296
                )
2297
                if acquisition_time > 0:
×
UNCOV
2298
                    self._normalized = self._binned / acquisition_time
×
UNCOV
2299
                self._attributes.add(
×
2300
                    acquisition_time,
2301
                    name="normalization_histogram",
2302
                    duplicate_policy="overwrite",
2303
                )
2304

2305
            self._normalized.attrs["units"] = "counts/second"
1✔
2306
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
2307
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
2308

2309
            return self._normalized
1✔
2310

2311
        return self._binned
1✔
2312

2313
    @call_logger(logger)
1✔
2314
    def get_normalization_histogram(
1✔
2315
        self,
2316
        axis: str = "delay",
2317
        use_time_stamps: bool = False,
2318
        **kwds,
2319
    ) -> xr.DataArray:
2320
        """Generates a normalization histogram from the timed dataframe. Optionally,
2321
        use the TimeStamps column instead.
2322

2323
        Args:
2324
            axis (str, optional): The axis for which to compute histogram.
2325
                Defaults to "delay".
2326
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2327
                dataframe, rather than the timed dataframe. Defaults to False.
2328
            **kwds: Keyword arguments:
2329

2330
                - **df_partitions**: A sequence of dataframe partitions, or the
2331
                  number of the dataframe partitions to use. Defaults to all partitions.
2332

2333
        Raises:
2334
            ValueError: Raised if no data are binned.
2335
            ValueError: Raised if 'axis' not in binned coordinates.
2336
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2337
                in Dataframe.
2338

2339
        Returns:
2340
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2341
            per bin).
2342
        """
2343

2344
        if self._binned is None:
1✔
2345
            raise ValueError("Need to bin data first!")
1✔
2346
        if axis not in self._binned.coords:
1✔
2347
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2348

2349
        df_partitions: int | Sequence[int] = kwds.pop("df_partitions", None)
1✔
2350

2351
        if len(kwds) > 0:
1✔
2352
            raise TypeError(
1✔
2353
                f"get_normalization_histogram() got unexpected keyword arguments {kwds.keys()}.",
2354
            )
2355

2356
        if isinstance(df_partitions, int):
1✔
2357
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2358

2359
        if use_time_stamps or self._timed_dataframe is None:
1✔
2360
            if df_partitions is not None:
1✔
2361
                dataframe = self._dataframe.partitions[df_partitions]
1✔
2362
            else:
UNCOV
2363
                dataframe = self._dataframe
×
2364
            self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2365
                df=dataframe,
2366
                axis=axis,
2367
                bin_centers=self._binned.coords[axis].values,
2368
                time_stamp_column=self._config["dataframe"]["columns"]["timestamp"],
2369
            )
2370
        else:
2371
            if df_partitions is not None:
1✔
2372
                timed_dataframe = self._timed_dataframe.partitions[df_partitions]
1✔
2373
            else:
UNCOV
2374
                timed_dataframe = self._timed_dataframe
×
2375
            self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2376
                df=timed_dataframe,
2377
                axis=axis,
2378
                bin_centers=self._binned.coords[axis].values,
2379
                time_unit=self._config["dataframe"]["timed_dataframe_unit_time"],
2380
                hist_mode=self.config["binning"]["hist_mode"],
2381
                mode=self.config["binning"]["mode"],
2382
                pbar=self.config["binning"]["pbar"],
2383
                n_cores=self.config["core"]["num_cores"],
2384
                threads_per_worker=self.config["binning"]["threads_per_worker"],
2385
                threadpool_api=self.config["binning"]["threadpool_API"],
2386
            )
2387

2388
        return self._normalization_histogram
1✔
2389

2390
    def view_event_histogram(
1✔
2391
        self,
2392
        dfpid: int,
2393
        ncol: int = 2,
2394
        bins: Sequence[int] = None,
2395
        axes: Sequence[str] = None,
2396
        ranges: Sequence[tuple[float, float]] = None,
2397
        backend: str = "matplotlib",
2398
        legend: bool = True,
2399
        histkwds: dict = None,
2400
        legkwds: dict = None,
2401
        **kwds,
2402
    ):
2403
        """Plot individual histograms of specified dimensions (axes) from a substituent
2404
        dataframe partition.
2405

2406
        Args:
2407
            dfpid (int): Number of the data frame partition to look at.
2408
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2409
            bins (Sequence[int], optional): Number of bins to use for the specified
2410
                axes. Defaults to config["histogram"]["bins"].
2411
            axes (Sequence[str], optional): Names of the axes to display.
2412
                Defaults to config["histogram"]["axes"].
2413
            ranges (Sequence[tuple[float, float]], optional): Value ranges of all
2414
                specified axes. Defaults to config["histogram"]["ranges"].
2415
            backend (str, optional): Backend of the plotting library
2416
                ("matplotlib" or "bokeh"). Defaults to "matplotlib".
2417
            legend (bool, optional): Option to include a legend in the histogram plots.
2418
                Defaults to True.
2419
            histkwds (dict, optional): Keyword arguments for histograms
2420
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2421
            legkwds (dict, optional): Keyword arguments for legend
2422
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2423
            **kwds: Extra keyword arguments passed to
2424
                ``sed.diagnostics.grid_histogram()``.
2425

2426
        Raises:
2427
            TypeError: Raises when the input values are not of the correct type.
2428
        """
2429
        if bins is None:
1✔
2430
            bins = self._config["histogram"]["bins"]
1✔
2431
        if axes is None:
1✔
2432
            axes = self._config["histogram"]["axes"]
1✔
2433
        axes = list(axes)
1✔
2434
        for loc, axis in enumerate(axes):
1✔
2435
            if axis.startswith("@"):
1✔
2436
                axes[loc] = self._config["dataframe"]["columns"].get(axis.strip("@"))
1✔
2437
        if ranges is None:
1✔
2438
            ranges = list(self._config["histogram"]["ranges"])
1✔
2439
            for loc, axis in enumerate(axes):
1✔
2440
                if axis == self._config["dataframe"]["columns"]["tof"]:
1✔
2441
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["tof_binning"]
1✔
2442
                elif axis == self._config["dataframe"]["columns"]["adc"]:
1✔
UNCOV
2443
                    ranges[loc] = np.asarray(ranges[loc]) / self._config["dataframe"]["adc_binning"]
×
2444

2445
        input_types = map(type, [axes, bins, ranges])
1✔
2446
        allowed_types = [list, tuple]
1✔
2447

2448
        df = self._dataframe
1✔
2449

2450
        if not set(input_types).issubset(allowed_types):
1✔
UNCOV
2451
            raise TypeError(
×
2452
                "Inputs of axes, bins, ranges need to be list or tuple!",
2453
            )
2454

2455
        # Read out the values for the specified groups
2456
        group_dict_dd = {}
1✔
2457
        dfpart = df.get_partition(dfpid)
1✔
2458
        cols = dfpart.columns
1✔
2459
        for ax in axes:
1✔
2460
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2461
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2462

2463
        # Plot multiple histograms in a grid
2464
        grid_histogram(
1✔
2465
            group_dict,
2466
            ncol=ncol,
2467
            rvs=axes,
2468
            rvbins=bins,
2469
            rvranges=ranges,
2470
            backend=backend,
2471
            legend=legend,
2472
            histkwds=histkwds,
2473
            legkwds=legkwds,
2474
            **kwds,
2475
        )
2476

2477
    @call_logger(logger)
1✔
2478
    def save(
1✔
2479
        self,
2480
        faddr: str,
2481
        **kwds,
2482
    ):
2483
        """Saves the binned data to the provided path and filename.
2484

2485
        Args:
2486
            faddr (str): Path and name of the file to write. Its extension determines
2487
                the file type to write. Valid file types are:
2488

2489
                - "*.tiff", "*.tif": Saves a TIFF stack.
2490
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2491
                - "*.nxs", "*.nexus": Saves a NeXus file.
2492

2493
            **kwds: Keyword arguments, which are passed to the writer functions:
2494
                For TIFF writing:
2495

2496
                - **alias_dict**: Dictionary of dimension aliases to use.
2497

2498
                For HDF5 writing:
2499

2500
                - **mode**: hdf5 read/write mode. Defaults to "w".
2501

2502
                For NeXus:
2503

2504
                - **reader**: Name of the pynxtools reader to use.
2505
                  Defaults to config["nexus"]["reader"]
2506
                - **definition**: NeXus application definition to use for saving.
2507
                  Must be supported by the used ``reader``. Defaults to
2508
                  config["nexus"]["definition"]
2509
                - **input_files**: A list of input files to pass to the reader.
2510
                  Defaults to config["nexus"]["input_files"]
2511
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2512
                  to add to the list of files to pass to the reader.
2513
        """
2514
        if self._binned is None:
1✔
2515
            raise NameError("Need to bin data first!")
1✔
2516

2517
        if self._normalized is not None:
1✔
UNCOV
2518
            data = self._normalized
×
2519
        else:
2520
            data = self._binned
1✔
2521

2522
        extension = pathlib.Path(faddr).suffix
1✔
2523

2524
        if extension in (".tif", ".tiff"):
1✔
2525
            to_tiff(
1✔
2526
                data=data,
2527
                faddr=faddr,
2528
                **kwds,
2529
            )
2530
        elif extension in (".h5", ".hdf5"):
1✔
2531
            to_h5(
1✔
2532
                data=data,
2533
                faddr=faddr,
2534
                **kwds,
2535
            )
2536
        elif extension in (".nxs", ".nexus"):
1✔
2537
            try:
1✔
2538
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2539
                definition = kwds.pop(
1✔
2540
                    "definition",
2541
                    self._config["nexus"]["definition"],
2542
                )
2543
                input_files = kwds.pop(
1✔
2544
                    "input_files",
2545
                    [str(path) for path in self._config["nexus"]["input_files"]],
2546
                )
2547
            except KeyError as exc:
×
UNCOV
2548
                raise ValueError(
×
2549
                    "The nexus reader, definition and input files need to be provide!",
2550
                ) from exc
2551

2552
            if isinstance(input_files, str):
1✔
UNCOV
2553
                input_files = [input_files]
×
2554

2555
            if "eln_data" in kwds:
1✔
2556
                input_files.append(kwds.pop("eln_data"))
1✔
2557

2558
            to_nexus(
1✔
2559
                data=data,
2560
                faddr=faddr,
2561
                reader=reader,
2562
                definition=definition,
2563
                input_files=input_files,
2564
                **kwds,
2565
            )
2566

2567
        else:
2568
            raise NotImplementedError(
1✔
2569
                f"Unrecognized file format: {extension}.",
2570
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc