• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 7183089817

23 Nov 2023 01:32PM UTC coverage: 90.589% (-0.07%) from 90.656%
7183089817

Pull #312

github

zain-sohail
add a first version of development documentation
Pull Request #312: Documentation developing loaders

5564 of 6142 relevant lines covered (90.59%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.77
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
23
from sed.calibrator import DelayCalibrator
1✔
24
from sed.calibrator import EnergyCalibrator
1✔
25
from sed.calibrator import MomentumCorrector
1✔
26
from sed.core.config import parse_config
1✔
27
from sed.core.config import save_config
1✔
28
from sed.core.dfops import add_time_stamped_data
1✔
29
from sed.core.dfops import apply_filter
1✔
30
from sed.core.dfops import apply_jitter
1✔
31
from sed.core.metadata import MetaHandler
1✔
32
from sed.diagnostics import grid_histogram
1✔
33
from sed.io import to_h5
1✔
34
from sed.io import to_nexus
1✔
35
from sed.io import to_tiff
1✔
36
from sed.loader import CopyTool
1✔
37
from sed.loader import get_loader
1✔
38
from sed.loader.mpes.loader import get_archiver_data
1✔
39
from sed.loader.mpes.loader import MpesLoader
1✔
40

41
N_CPU = psutil.cpu_count()
1✔
42

43

44
class SedProcessor:
1✔
45
    """Processor class of sed. Contains wrapper functions defining a work flow for data
46
    correction, calibration and binning.
47

48
    Args:
49
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
50
        config (Union[dict, str], optional): Config dictionary or config file name.
51
            Defaults to None.
52
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
53
            into the class. Defaults to None.
54
        files (List[str], optional): List of files to pass to the loader defined in
55
            the config. Defaults to None.
56
        folder (str, optional): Folder containing files to pass to the loader
57
            defined in the config. Defaults to None.
58
        collect_metadata (bool): Option to collect metadata from files.
59
            Defaults to False.
60
        **kwds: Keyword arguments passed to the reader.
61
    """
62

63
    def __init__(
1✔
64
        self,
65
        metadata: dict = None,
66
        config: Union[dict, str] = None,
67
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
68
        files: List[str] = None,
69
        folder: str = None,
70
        runs: Sequence[str] = None,
71
        collect_metadata: bool = False,
72
        verbose: bool = False,
73
        **kwds,
74
    ):
75
        """Processor class of sed. Contains wrapper functions defining a work flow
76
        for data correction, calibration, and binning.
77

78
        Args:
79
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
80
            config (Union[dict, str], optional): Config dictionary or config file name.
81
                Defaults to None.
82
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
83
                into the class. Defaults to None.
84
            files (List[str], optional): List of files to pass to the loader defined in
85
                the config. Defaults to None.
86
            folder (str, optional): Folder containing files to pass to the loader
87
                defined in the config. Defaults to None.
88
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
89
                defined in the config. Defaults to None.
90
            collect_metadata (bool): Option to collect metadata from files.
91
                Defaults to False.
92
            **kwds: Keyword arguments passed to parse_config and to the reader.
93
        """
94
        config_kwds = {
1✔
95
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
96
        }
97
        for key in config_kwds.keys():
1✔
98
            del kwds[key]
1✔
99
        self._config = parse_config(config, **config_kwds)
1✔
100
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
101
        if num_cores >= N_CPU:
1✔
102
            num_cores = N_CPU - 1
1✔
103
        self._config["binning"]["num_cores"] = num_cores
1✔
104

105
        self.verbose = verbose
1✔
106

107
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
108
        self._timed_dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
109
        self._files: List[str] = []
1✔
110

111
        self._binned: xr.DataArray = None
1✔
112
        self._pre_binned: xr.DataArray = None
1✔
113
        self._normalization_histogram: xr.DataArray = None
1✔
114
        self._normalized: xr.DataArray = None
1✔
115

116
        self._attributes = MetaHandler(meta=metadata)
1✔
117

118
        loader_name = self._config["core"]["loader"]
1✔
119
        self.loader = get_loader(
1✔
120
            loader_name=loader_name,
121
            config=self._config,
122
        )
123

124
        self.ec = EnergyCalibrator(
1✔
125
            loader=get_loader(
126
                loader_name=loader_name,
127
                config=self._config,
128
            ),
129
            config=self._config,
130
        )
131

132
        self.mc = MomentumCorrector(
1✔
133
            config=self._config,
134
        )
135

136
        self.dc = DelayCalibrator(
1✔
137
            config=self._config,
138
        )
139

140
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
141
            "use_copy_tool",
142
            False,
143
        )
144
        if self.use_copy_tool:
1✔
145
            try:
1✔
146
                self.ct = CopyTool(
1✔
147
                    source=self._config["core"]["copy_tool_source"],
148
                    dest=self._config["core"]["copy_tool_dest"],
149
                    **self._config["core"].get("copy_tool_kwds", {}),
150
                )
151
            except KeyError:
1✔
152
                self.use_copy_tool = False
1✔
153

154
        # Load data if provided:
155
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
156
            self.load(
1✔
157
                dataframe=dataframe,
158
                metadata=metadata,
159
                files=files,
160
                folder=folder,
161
                runs=runs,
162
                collect_metadata=collect_metadata,
163
                **kwds,
164
            )
165

166
    def __repr__(self):
1✔
167
        if self._dataframe is None:
1✔
168
            df_str = "Data Frame: No Data loaded"
1✔
169
        else:
170
            df_str = self._dataframe.__repr__()
1✔
171
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
172
        pretty_str = df_str + "\n" + attributes_str
1✔
173
        return pretty_str
1✔
174

175
    @property
1✔
176
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
177
        """Accessor to the underlying dataframe.
178

179
        Returns:
180
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
181
        """
182
        return self._dataframe
1✔
183

184
    @dataframe.setter
1✔
185
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
186
        """Setter for the underlying dataframe.
187

188
        Args:
189
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
190
        """
191
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
192
            dataframe,
193
            self._dataframe.__class__,
194
        ):
195
            raise ValueError(
1✔
196
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
197
                "as the dataframe loaded into the SedProcessor!.\n"
198
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
199
            )
200
        self._dataframe = dataframe
1✔
201

202
    @property
1✔
203
    def timed_dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
204
        """Accessor to the underlying timed_dataframe.
205

206
        Returns:
207
            Union[pd.DataFrame, ddf.DataFrame]: Timed Dataframe object.
208
        """
209
        return self._timed_dataframe
1✔
210

211
    @timed_dataframe.setter
1✔
212
    def timed_dataframe(self, timed_dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
213
        """Setter for the underlying timed dataframe.
214

215
        Args:
216
            timed_dataframe (Union[pd.DataFrame, ddf.DataFrame]): The timed dataframe object to set
217
        """
218
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
219
            timed_dataframe,
220
            self._timed_dataframe.__class__,
221
        ):
222
            raise ValueError(
×
223
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
224
                "kind as the dataframe loaded into the SedProcessor!.\n"
225
                f"Loaded type: {self._timed_dataframe.__class__}, "
226
                f"provided type: {timed_dataframe}.",
227
            )
228
        self._timed_dataframe = timed_dataframe
×
229

230
    @property
1✔
231
    def attributes(self) -> dict:
1✔
232
        """Accessor to the metadata dict.
233

234
        Returns:
235
            dict: The metadata dict.
236
        """
237
        return self._attributes.metadata
1✔
238

239
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
240
        """Function to add element to the attributes dict.
241

242
        Args:
243
            attributes (dict): The attributes dictionary object to add.
244
            name (str): Key under which to add the dictionary to the attributes.
245
        """
246
        self._attributes.add(
1✔
247
            entry=attributes,
248
            name=name,
249
            **kwds,
250
        )
251

252
    @property
1✔
253
    def config(self) -> Dict[Any, Any]:
1✔
254
        """Getter attribute for the config dictionary
255

256
        Returns:
257
            Dict: The config dictionary.
258
        """
259
        return self._config
1✔
260

261
    @property
1✔
262
    def files(self) -> List[str]:
1✔
263
        """Getter attribute for the list of files
264

265
        Returns:
266
            List[str]: The list of loaded files
267
        """
268
        return self._files
1✔
269

270
    @property
1✔
271
    def binned(self) -> xr.DataArray:
1✔
272
        """Getter attribute for the binned data array
273

274
        Returns:
275
            xr.DataArray: The binned data array
276
        """
277
        if self._binned is None:
1✔
278
            raise ValueError("No binned data available, need to compute histogram first!")
×
279
        return self._binned
1✔
280

281
    @property
1✔
282
    def normalized(self) -> xr.DataArray:
1✔
283
        """Getter attribute for the normalized data array
284

285
        Returns:
286
            xr.DataArray: The normalized data array
287
        """
288
        if self._normalized is None:
1✔
289
            raise ValueError(
×
290
                "No normalized data available, compute data with normalization enabled!",
291
            )
292
        return self._normalized
1✔
293

294
    @property
1✔
295
    def normalization_histogram(self) -> xr.DataArray:
1✔
296
        """Getter attribute for the normalization histogram
297

298
        Returns:
299
            xr.DataArray: The normalizazion histogram
300
        """
301
        if self._normalization_histogram is None:
1✔
302
            raise ValueError("No normalization histogram available, generate histogram first!")
×
303
        return self._normalization_histogram
1✔
304

305
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
306
        """Function to mirror a list of files or a folder from a network drive to a
307
        local storage. Returns either the original or the copied path to the given
308
        path. The option to use this functionality is set by
309
        config["core"]["use_copy_tool"].
310

311
        Args:
312
            path (Union[str, List[str]]): Source path or path list.
313

314
        Returns:
315
            Union[str, List[str]]: Source or destination path or path list.
316
        """
317
        if self.use_copy_tool:
1✔
318
            if isinstance(path, list):
1✔
319
                path_out = []
1✔
320
                for file in path:
1✔
321
                    path_out.append(self.ct.copy(file))
1✔
322
                return path_out
1✔
323

324
            return self.ct.copy(path)
×
325

326
        if isinstance(path, list):
1✔
327
            return path
1✔
328

329
        return path
1✔
330

331
    def load(
1✔
332
        self,
333
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
334
        metadata: dict = None,
335
        files: List[str] = None,
336
        folder: str = None,
337
        runs: Sequence[str] = None,
338
        collect_metadata: bool = False,
339
        **kwds,
340
    ):
341
        """Load tabular data of single events into the dataframe object in the class.
342

343
        Args:
344
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
345
                format. Accepts anything which can be interpreted by pd.DataFrame as
346
                an input. Defaults to None.
347
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
348
            files (List[str], optional): List of file paths to pass to the loader.
349
                Defaults to None.
350
            runs (Sequence[str], optional): List of run identifiers to pass to the
351
                loader. Defaults to None.
352
            folder (str, optional): Folder path to pass to the loader.
353
                Defaults to None.
354

355
        Raises:
356
            ValueError: Raised if no valid input is provided.
357
        """
358
        if metadata is None:
1✔
359
            metadata = {}
1✔
360
        if dataframe is not None:
1✔
361
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
362
        elif runs is not None:
1✔
363
            # If runs are provided, we only use the copy tool if also folder is provided.
364
            # In that case, we copy the whole provided base folder tree, and pass the copied
365
            # version to the loader as base folder to look for the runs.
366
            if folder is not None:
1✔
367
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
368
                    folders=cast(str, self.cpy(folder)),
369
                    runs=runs,
370
                    metadata=metadata,
371
                    collect_metadata=collect_metadata,
372
                    **kwds,
373
                )
374
            else:
375
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
376
                    runs=runs,
377
                    metadata=metadata,
378
                    collect_metadata=collect_metadata,
379
                    **kwds,
380
                )
381

382
        elif folder is not None:
1✔
383
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
384
                folders=cast(str, self.cpy(folder)),
385
                metadata=metadata,
386
                collect_metadata=collect_metadata,
387
                **kwds,
388
            )
389
        elif files is not None:
1✔
390
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
391
                files=cast(List[str], self.cpy(files)),
392
                metadata=metadata,
393
                collect_metadata=collect_metadata,
394
                **kwds,
395
            )
396
        else:
397
            raise ValueError(
1✔
398
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
399
            )
400

401
        self._dataframe = dataframe
1✔
402
        self._timed_dataframe = timed_dataframe
1✔
403
        self._files = self.loader.files
1✔
404

405
        for key in metadata:
1✔
406
            self._attributes.add(
1✔
407
                entry=metadata[key],
408
                name=key,
409
                duplicate_policy="merge",
410
            )
411

412
    def filter_column(
1✔
413
        self,
414
        column: str,
415
        min_value: float = -np.inf,
416
        max_value: float = np.inf,
417
    ) -> None:
418
        """Filter values in a column which are outside of a given range
419

420
        Args:
421
            column (str): Name of the column to filter
422
            min_value (float, optional): Minimum value to keep. Defaults to None.
423
            max_value (float, optional): Maximum value to keep. Defaults to None.
424
        """
425
        if column not in self._dataframe.columns:
1✔
426
            raise KeyError(f"Column {column} not found in dataframe!")
1✔
427
        if min_value >= max_value:
1✔
428
            raise ValueError("min_value has to be smaller than max_value!")
1✔
429
        if self._dataframe is not None:
1✔
430
            self._dataframe = apply_filter(
1✔
431
                self._dataframe,
432
                col=column,
433
                lower_bound=min_value,
434
                upper_bound=max_value,
435
            )
436
        if self._timed_dataframe is not None and column in self._timed_dataframe.columns:
1✔
437
            self._timed_dataframe = apply_filter(
1✔
438
                self._timed_dataframe,
439
                column,
440
                lower_bound=min_value,
441
                upper_bound=max_value,
442
            )
443
        metadata = {
1✔
444
            "filter": {
445
                "column": column,
446
                "min_value": min_value,
447
                "max_value": max_value,
448
            },
449
        }
450
        self._attributes.add(metadata, "filter", duplicate_policy="merge")
1✔
451

452
    # Momentum calibration workflow
453
    # 1. Bin raw detector data for distortion correction
454
    def bin_and_load_momentum_calibration(
1✔
455
        self,
456
        df_partitions: Union[int, Sequence[int]] = 100,
457
        axes: List[str] = None,
458
        bins: List[int] = None,
459
        ranges: Sequence[Tuple[float, float]] = None,
460
        plane: int = 0,
461
        width: int = 5,
462
        apply: bool = False,
463
        **kwds,
464
    ):
465
        """1st step of momentum correction work flow. Function to do an initial binning
466
        of the dataframe loaded to the class, slice a plane from it using an
467
        interactive view, and load it into the momentum corrector class.
468

469
        Args:
470
            df_partitions (Union[int, Sequence[int]], optional): Number of dataframe partitions
471
                to use for the initial binning. Defaults to 100.
472
            axes (List[str], optional): Axes to bin.
473
                Defaults to config["momentum"]["axes"].
474
            bins (List[int], optional): Bin numbers to use for binning.
475
                Defaults to config["momentum"]["bins"].
476
            ranges (List[Tuple], optional): Ranges to use for binning.
477
                Defaults to config["momentum"]["ranges"].
478
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
479
            width (int, optional): Initial value for the width slider. Defaults to 5.
480
            apply (bool, optional): Option to directly apply the values and select the
481
                slice. Defaults to False.
482
            **kwds: Keyword argument passed to the pre_binning function.
483
        """
484
        self._pre_binned = self.pre_binning(
1✔
485
            df_partitions=df_partitions,
486
            axes=axes,
487
            bins=bins,
488
            ranges=ranges,
489
            **kwds,
490
        )
491

492
        self.mc.load_data(data=self._pre_binned)
1✔
493
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
494

495
    # 2. Generate the spline warp correction from momentum features.
496
    # Either autoselect features, or input features from view above.
497
    def define_features(
1✔
498
        self,
499
        features: np.ndarray = None,
500
        rotation_symmetry: int = 6,
501
        auto_detect: bool = False,
502
        include_center: bool = True,
503
        apply: bool = False,
504
        **kwds,
505
    ):
506
        """2. Step of the distortion correction workflow: Define feature points in
507
        momentum space. They can be either manually selected using a GUI tool, be
508
        ptovided as list of feature points, or auto-generated using a
509
        feature-detection algorithm.
510

511
        Args:
512
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
513
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
514
                Defaults to 6.
515
            auto_detect (bool, optional): Whether to auto-detect the features.
516
                Defaults to False.
517
            include_center (bool, optional): Option to include a point at the center
518
                in the feature list. Defaults to True.
519
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
520
                MomentumCorrector.feature_select()
521
        """
522
        if auto_detect:  # automatic feature selection
1✔
523
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
524
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
525
            sigma_radius = kwds.pop(
×
526
                "sigma_radius",
527
                self._config["momentum"]["sigma_radius"],
528
            )
529
            self.mc.feature_extract(
×
530
                sigma=sigma,
531
                fwhm=fwhm,
532
                sigma_radius=sigma_radius,
533
                rotsym=rotation_symmetry,
534
                **kwds,
535
            )
536
            features = self.mc.peaks
×
537

538
        self.mc.feature_select(
1✔
539
            rotsym=rotation_symmetry,
540
            include_center=include_center,
541
            features=features,
542
            apply=apply,
543
            **kwds,
544
        )
545

546
    # 3. Generate the spline warp correction from momentum features.
547
    # If no features have been selected before, use class defaults.
548
    def generate_splinewarp(
1✔
549
        self,
550
        use_center: bool = None,
551
        **kwds,
552
    ):
553
        """3. Step of the distortion correction workflow: Generate the correction
554
        function restoring the symmetry in the image using a splinewarp algortihm.
555

556
        Args:
557
            use_center (bool, optional): Option to use the position of the
558
                center point in the correction. Default is read from config, or set to True.
559
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
560
        """
561
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
562

563
        if self.mc.slice is not None:
1✔
564
            print("Original slice with reference features")
1✔
565
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
566

567
            print("Corrected slice with target features")
1✔
568
            self.mc.view(
1✔
569
                image=self.mc.slice_corrected,
570
                annotated=True,
571
                points={"feats": self.mc.ptargs},
572
                backend="bokeh",
573
                crosshair=True,
574
            )
575

576
            print("Original slice with target features")
1✔
577
            self.mc.view(
1✔
578
                image=self.mc.slice,
579
                points={"feats": self.mc.ptargs},
580
                annotated=True,
581
                backend="bokeh",
582
            )
583

584
    # 3a. Save spline-warp parameters to config file.
585
    def save_splinewarp(
1✔
586
        self,
587
        filename: str = None,
588
        overwrite: bool = False,
589
    ):
590
        """Save the generated spline-warp parameters to the folder config file.
591

592
        Args:
593
            filename (str, optional): Filename of the config dictionary to save to.
594
                Defaults to "sed_config.yaml" in the current folder.
595
            overwrite (bool, optional): Option to overwrite the present dictionary.
596
                Defaults to False.
597
        """
598
        if filename is None:
1✔
599
            filename = "sed_config.yaml"
×
600
        points = []
1✔
601
        if self.mc.pouter_ord is not None:  # if there is any calibration info
1✔
602
            try:
1✔
603
                for point in self.mc.pouter_ord:
1✔
604
                    points.append([float(i) for i in point])
1✔
605
                if self.mc.include_center:
1✔
606
                    points.append([float(i) for i in self.mc.pcent])
1✔
607
            except AttributeError as exc:
×
608
                raise AttributeError(
×
609
                    "Momentum correction parameters not found, need to generate parameters first!",
610
                ) from exc
611
            config = {
1✔
612
                "momentum": {
613
                    "correction": {
614
                        "rotation_symmetry": self.mc.rotsym,
615
                        "feature_points": points,
616
                        "include_center": self.mc.include_center,
617
                        "use_center": self.mc.use_center,
618
                    },
619
                },
620
            }
621
            save_config(config, filename, overwrite)
1✔
622

623
    # 4. Pose corrections. Provide interactive interface for correcting
624
    # scaling, shift and rotation
625
    def pose_adjustment(
1✔
626
        self,
627
        scale: float = 1,
628
        xtrans: float = 0,
629
        ytrans: float = 0,
630
        angle: float = 0,
631
        apply: bool = False,
632
        use_correction: bool = True,
633
        reset: bool = True,
634
    ):
635
        """3. step of the distortion correction workflow: Generate an interactive panel
636
        to adjust affine transformations that are applied to the image. Applies first
637
        a scaling, next an x/y translation, and last a rotation around the center of
638
        the image.
639

640
        Args:
641
            scale (float, optional): Initial value of the scaling slider.
642
                Defaults to 1.
643
            xtrans (float, optional): Initial value of the xtrans slider.
644
                Defaults to 0.
645
            ytrans (float, optional): Initial value of the ytrans slider.
646
                Defaults to 0.
647
            angle (float, optional): Initial value of the angle slider.
648
                Defaults to 0.
649
            apply (bool, optional): Option to directly apply the provided
650
                transformations. Defaults to False.
651
            use_correction (bool, option): Whether to use the spline warp correction
652
                or not. Defaults to True.
653
            reset (bool, optional):
654
                Option to reset the correction before transformation. Defaults to True.
655
        """
656
        # Generate homomorphy as default if no distortion correction has been applied
657
        if self.mc.slice_corrected is None:
1✔
658
            if self.mc.slice is None:
1✔
659
                raise ValueError(
1✔
660
                    "No slice for corrections and transformations loaded!",
661
                )
662
            self.mc.slice_corrected = self.mc.slice
×
663

664
        if not use_correction:
1✔
665
            self.mc.reset_deformation()
1✔
666

667
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
668
            # Generate distortion correction from config values
669
            self.mc.add_features()
×
670
            self.mc.spline_warp_estimate()
×
671

672
        self.mc.pose_adjustment(
1✔
673
            scale=scale,
674
            xtrans=xtrans,
675
            ytrans=ytrans,
676
            angle=angle,
677
            apply=apply,
678
            reset=reset,
679
        )
680

681
    # 5. Apply the momentum correction to the dataframe
682
    def apply_momentum_correction(
1✔
683
        self,
684
        preview: bool = False,
685
    ):
686
        """Applies the distortion correction and pose adjustment (optional)
687
        to the dataframe.
688

689
        Args:
690
            rdeform_field (np.ndarray, optional): Row deformation field.
691
                Defaults to None.
692
            cdeform_field (np.ndarray, optional): Column deformation field.
693
                Defaults to None.
694
            inv_dfield (np.ndarray, optional): Inverse deformation field.
695
                Defaults to None.
696
            preview (bool): Option to preview the first elements of the data frame.
697
        """
698
        if self._dataframe is not None:
1✔
699
            print("Adding corrected X/Y columns to dataframe:")
1✔
700
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
701
                df=self._dataframe,
702
            )
703
            if self._timed_dataframe is not None:
1✔
704
                if (
1✔
705
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
706
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
707
                ):
708
                    self._timed_dataframe, _ = self.mc.apply_corrections(
1✔
709
                        self._timed_dataframe,
710
                    )
711
            # Add Metadata
712
            self._attributes.add(
1✔
713
                metadata,
714
                "momentum_correction",
715
                duplicate_policy="merge",
716
            )
717
            if preview:
1✔
718
                print(self._dataframe.head(10))
×
719
            else:
720
                if self.verbose:
1✔
721
                    print(self._dataframe)
×
722

723
    # Momentum calibration work flow
724
    # 1. Calculate momentum calibration
725
    def calibrate_momentum_axes(
1✔
726
        self,
727
        point_a: Union[np.ndarray, List[int]] = None,
728
        point_b: Union[np.ndarray, List[int]] = None,
729
        k_distance: float = None,
730
        k_coord_a: Union[np.ndarray, List[float]] = None,
731
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
732
        equiscale: bool = True,
733
        apply=False,
734
    ):
735
        """1. step of the momentum calibration workflow. Calibrate momentum
736
        axes using either provided pixel coordinates of a high-symmetry point and its
737
        distance to the BZ center, or the k-coordinates of two points in the BZ
738
        (depending on the equiscale option). Opens an interactive panel for selecting
739
        the points.
740

741
        Args:
742
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
743
                point used for momentum calibration.
744
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
745
                second point used for momentum calibration.
746
                Defaults to config["momentum"]["center_pixel"].
747
            k_distance (float, optional): Momentum distance between point a and b.
748
                Needs to be provided if no specific k-koordinates for the two points
749
                are given. Defaults to None.
750
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
751
                of the first point used for calibration. Used if equiscale is False.
752
                Defaults to None.
753
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
754
                of the second point used for calibration. Defaults to [0.0, 0.0].
755
            equiscale (bool, optional): Option to apply different scales to kx and ky.
756
                If True, the distance between points a and b, and the absolute
757
                position of point a are used for defining the scale. If False, the
758
                scale is calculated from the k-positions of both points a and b.
759
                Defaults to True.
760
            apply (bool, optional): Option to directly store the momentum calibration
761
                in the class. Defaults to False.
762
        """
763
        if point_b is None:
1✔
764
            point_b = self._config["momentum"]["center_pixel"]
1✔
765

766
        self.mc.select_k_range(
1✔
767
            point_a=point_a,
768
            point_b=point_b,
769
            k_distance=k_distance,
770
            k_coord_a=k_coord_a,
771
            k_coord_b=k_coord_b,
772
            equiscale=equiscale,
773
            apply=apply,
774
        )
775

776
    # 1a. Save momentum calibration parameters to config file.
777
    def save_momentum_calibration(
1✔
778
        self,
779
        filename: str = None,
780
        overwrite: bool = False,
781
    ):
782
        """Save the generated momentum calibration parameters to the folder config file.
783

784
        Args:
785
            filename (str, optional): Filename of the config dictionary to save to.
786
                Defaults to "sed_config.yaml" in the current folder.
787
            overwrite (bool, optional): Option to overwrite the present dictionary.
788
                Defaults to False.
789
        """
790
        if filename is None:
1✔
791
            filename = "sed_config.yaml"
×
792
        calibration = {}
1✔
793
        try:
1✔
794
            for key in [
1✔
795
                "kx_scale",
796
                "ky_scale",
797
                "x_center",
798
                "y_center",
799
                "rstart",
800
                "cstart",
801
                "rstep",
802
                "cstep",
803
            ]:
804
                calibration[key] = float(self.mc.calibration[key])
1✔
805
        except KeyError as exc:
×
806
            raise KeyError(
×
807
                "Momentum calibration parameters not found, need to generate parameters first!",
808
            ) from exc
809

810
        config = {"momentum": {"calibration": calibration}}
1✔
811
        save_config(config, filename, overwrite)
1✔
812
        print(f"Saved momentum calibration parameters to {filename}")
1✔
813

814
    # 2. Apply correction and calibration to the dataframe
815
    def apply_momentum_calibration(
1✔
816
        self,
817
        calibration: dict = None,
818
        preview: bool = False,
819
    ):
820
        """2. step of the momentum calibration work flow: Apply the momentum
821
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
822
        these are used.
823

824
        Args:
825
            calibration (dict, optional): Optional dictionary with calibration data to
826
                use. Defaults to None.
827
            preview (bool): Option to preview the first elements of the data frame.
828
        """
829
        if self._dataframe is not None:
1✔
830

831
            print("Adding kx/ky columns to dataframe:")
1✔
832
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
833
                df=self._dataframe,
834
                calibration=calibration,
835
            )
836
            if self._timed_dataframe is not None:
1✔
837
                if (
1✔
838
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
839
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
840
                ):
841
                    self._timed_dataframe, _ = self.mc.append_k_axis(
1✔
842
                        df=self._timed_dataframe,
843
                        calibration=calibration,
844
                    )
845

846
            # Add Metadata
847
            self._attributes.add(
1✔
848
                metadata,
849
                "momentum_calibration",
850
                duplicate_policy="merge",
851
            )
852
            if preview:
1✔
853
                print(self._dataframe.head(10))
×
854
            else:
855
                if self.verbose:
1✔
856
                    print(self._dataframe)
×
857

858
    # Energy correction workflow
859
    # 1. Adjust the energy correction parameters
860
    def adjust_energy_correction(
1✔
861
        self,
862
        correction_type: str = None,
863
        amplitude: float = None,
864
        center: Tuple[float, float] = None,
865
        apply=False,
866
        **kwds,
867
    ):
868
        """1. step of the energy crrection workflow: Opens an interactive plot to
869
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
870
        they are not present yet.
871

872
        Args:
873
            correction_type (str, optional): Type of correction to apply to the TOF
874
                axis. Valid values are:
875

876
                - 'spherical'
877
                - 'Lorentzian'
878
                - 'Gaussian'
879
                - 'Lorentzian_asymmetric'
880

881
                Defaults to config["energy"]["correction_type"].
882
            amplitude (float, optional): Amplitude of the correction.
883
                Defaults to config["energy"]["correction"]["amplitude"].
884
            center (Tuple[float, float], optional): Center X/Y coordinates for the
885
                correction. Defaults to config["energy"]["correction"]["center"].
886
            apply (bool, optional): Option to directly apply the provided or default
887
                correction parameters. Defaults to False.
888
        """
889
        if self._pre_binned is None:
1✔
890
            print(
1✔
891
                "Pre-binned data not present, binning using defaults from config...",
892
            )
893
            self._pre_binned = self.pre_binning()
1✔
894

895
        self.ec.adjust_energy_correction(
1✔
896
            self._pre_binned,
897
            correction_type=correction_type,
898
            amplitude=amplitude,
899
            center=center,
900
            apply=apply,
901
            **kwds,
902
        )
903

904
    # 1a. Save energy correction parameters to config file.
905
    def save_energy_correction(
1✔
906
        self,
907
        filename: str = None,
908
        overwrite: bool = False,
909
    ):
910
        """Save the generated energy correction parameters to the folder config file.
911

912
        Args:
913
            filename (str, optional): Filename of the config dictionary to save to.
914
                Defaults to "sed_config.yaml" in the current folder.
915
            overwrite (bool, optional): Option to overwrite the present dictionary.
916
                Defaults to False.
917
        """
918
        if filename is None:
1✔
919
            filename = "sed_config.yaml"
1✔
920
        correction = {}
1✔
921
        try:
1✔
922
            for key, val in self.ec.correction.items():
1✔
923
                if key == "correction_type":
1✔
924
                    correction[key] = val
1✔
925
                elif key == "center":
1✔
926
                    correction[key] = [float(i) for i in val]
1✔
927
                else:
928
                    correction[key] = float(val)
1✔
929
        except AttributeError as exc:
×
930
            raise AttributeError(
×
931
                "Energy correction parameters not found, need to generate parameters first!",
932
            ) from exc
933

934
        config = {"energy": {"correction": correction}}
1✔
935
        save_config(config, filename, overwrite)
1✔
936
        print(f"Saved energy correction parameters to {filename}")
1✔
937

938
    # 2. Apply energy correction to dataframe
939
    def apply_energy_correction(
1✔
940
        self,
941
        correction: dict = None,
942
        preview: bool = False,
943
        **kwds,
944
    ):
945
        """2. step of the energy correction workflow: Apply the enery correction
946
        parameters stored in the class to the dataframe.
947

948
        Args:
949
            correction (dict, optional): Dictionary containing the correction
950
                parameters. Defaults to config["energy"]["calibration"].
951
            preview (bool): Option to preview the first elements of the data frame.
952
            **kwds:
953
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
954
            preview (bool): Option to preview the first elements of the data frame.
955
            **kwds:
956
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
957
        """
958
        if self._dataframe is not None:
1✔
959
            print("Applying energy correction to dataframe...")
1✔
960
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
961
                df=self._dataframe,
962
                correction=correction,
963
                **kwds,
964
            )
965
            if self._timed_dataframe is not None:
1✔
966
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
967
                    self._timed_dataframe, _ = self.ec.apply_energy_correction(
1✔
968
                        df=self._timed_dataframe,
969
                        correction=correction,
970
                        **kwds,
971
                    )
972

973
            # Add Metadata
974
            self._attributes.add(
1✔
975
                metadata,
976
                "energy_correction",
977
            )
978
            if preview:
1✔
979
                print(self._dataframe.head(10))
×
980
            else:
981
                if self.verbose:
1✔
982
                    print(self._dataframe)
×
983

984
    # Energy calibrator workflow
985
    # 1. Load and normalize data
986
    def load_bias_series(
1✔
987
        self,
988
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
989
        data_files: List[str] = None,
990
        axes: List[str] = None,
991
        bins: List = None,
992
        ranges: Sequence[Tuple[float, float]] = None,
993
        biases: np.ndarray = None,
994
        bias_key: str = None,
995
        normalize: bool = None,
996
        span: int = None,
997
        order: int = None,
998
    ):
999
        """1. step of the energy calibration workflow: Load and bin data from
1000
        single-event files, or load binned bias/TOF traces.
1001

1002
        Args:
1003
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
1004
                Binned data If provided as DataArray, Needs to contain dimensions
1005
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
1006
                provided as tuple, needs to contain elements tof, biases, traces.
1007
            data_files (List[str], optional): list of file paths to bin
1008
            axes (List[str], optional): bin axes.
1009
                Defaults to config["dataframe"]["tof_column"].
1010
            bins (List, optional): number of bins.
1011
                Defaults to config["energy"]["bins"].
1012
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
1013
                Defaults to config["energy"]["ranges"].
1014
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
1015
                voltages are extracted from the data files.
1016
            bias_key (str, optional): hdf5 path where bias values are stored.
1017
                Defaults to config["energy"]["bias_key"].
1018
            normalize (bool, optional): Option to normalize traces.
1019
                Defaults to config["energy"]["normalize"].
1020
            span (int, optional): span smoothing parameters of the LOESS method
1021
                (see ``scipy.signal.savgol_filter()``).
1022
                Defaults to config["energy"]["normalize_span"].
1023
            order (int, optional): order smoothing parameters of the LOESS method
1024
                (see ``scipy.signal.savgol_filter()``).
1025
                Defaults to config["energy"]["normalize_order"].
1026
        """
1027
        if binned_data is not None:
1✔
1028
            if isinstance(binned_data, xr.DataArray):
1✔
1029
                if (
1✔
1030
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
1031
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
1032
                ):
1033
                    raise ValueError(
1✔
1034
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
1035
                        f"'{self._config['dataframe']['tof_column']}' and "
1036
                        f"'{self._config['dataframe']['bias_column']}'!.",
1037
                    )
1038
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
1039
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
1040
                traces = binned_data.values[:, :]
1✔
1041
            else:
1042
                try:
1✔
1043
                    (tof, biases, traces) = binned_data
1✔
1044
                except ValueError as exc:
1✔
1045
                    raise ValueError(
1✔
1046
                        "If binned_data is provided as tuple, it needs to contain "
1047
                        "(tof, biases, traces)!",
1048
                    ) from exc
1049
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1050

1051
        elif data_files is not None:
1✔
1052

1053
            self.ec.bin_data(
1✔
1054
                data_files=cast(List[str], self.cpy(data_files)),
1055
                axes=axes,
1056
                bins=bins,
1057
                ranges=ranges,
1058
                biases=biases,
1059
                bias_key=bias_key,
1060
            )
1061

1062
        else:
1063
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1064

1065
        if (normalize is not None and normalize is True) or (
1✔
1066
            normalize is None and self._config["energy"]["normalize"]
1067
        ):
1068
            if span is None:
1✔
1069
                span = self._config["energy"]["normalize_span"]
1✔
1070
            if order is None:
1✔
1071
                order = self._config["energy"]["normalize_order"]
1✔
1072
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1073
        self.ec.view(
1✔
1074
            traces=self.ec.traces_normed,
1075
            xaxis=self.ec.tof,
1076
            backend="bokeh",
1077
        )
1078

1079
    # 2. extract ranges and get peak positions
1080
    def find_bias_peaks(
1✔
1081
        self,
1082
        ranges: Union[List[Tuple], Tuple],
1083
        ref_id: int = 0,
1084
        infer_others: bool = True,
1085
        mode: str = "replace",
1086
        radius: int = None,
1087
        peak_window: int = None,
1088
        apply: bool = False,
1089
    ):
1090
        """2. step of the energy calibration workflow: Find a peak within a given range
1091
        for the indicated reference trace, and tries to find the same peak for all
1092
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1093
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1094
        middle of the set, and don't choose the range too narrow around the peak.
1095
        Alternatively, a list of ranges for all traces can be provided.
1096

1097
        Args:
1098
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
1099
                Alternatively, a list of ranges for all traces can be given.
1100
            refid (int, optional): The id of the trace the range refers to.
1101
                Defaults to 0.
1102
            infer_others (bool, optional): Whether to determine the range for the other
1103
                traces. Defaults to True.
1104
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1105
                Defaults to "replace".
1106
            radius (int, optional): Radius parameter for fast_dtw.
1107
                Defaults to config["energy"]["fastdtw_radius"].
1108
            peak_window (int, optional): Peak_window parameter for the peak detection
1109
                algorthm. amount of points that have to have to behave monotoneously
1110
                around a peak. Defaults to config["energy"]["peak_window"].
1111
            apply (bool, optional): Option to directly apply the provided parameters.
1112
                Defaults to False.
1113
        """
1114
        if radius is None:
1✔
1115
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1116
        if peak_window is None:
1✔
1117
            peak_window = self._config["energy"]["peak_window"]
1✔
1118
        if not infer_others:
1✔
1119
            self.ec.add_ranges(
1✔
1120
                ranges=ranges,
1121
                ref_id=ref_id,
1122
                infer_others=infer_others,
1123
                mode=mode,
1124
                radius=radius,
1125
            )
1126
            print(self.ec.featranges)
1✔
1127
            try:
1✔
1128
                self.ec.feature_extract(peak_window=peak_window)
1✔
1129
                self.ec.view(
1✔
1130
                    traces=self.ec.traces_normed,
1131
                    segs=self.ec.featranges,
1132
                    xaxis=self.ec.tof,
1133
                    peaks=self.ec.peaks,
1134
                    backend="bokeh",
1135
                )
1136
            except IndexError:
×
1137
                print("Could not determine all peaks!")
×
1138
                raise
×
1139
        else:
1140
            # New adjustment tool
1141
            assert isinstance(ranges, tuple)
1✔
1142
            self.ec.adjust_ranges(
1✔
1143
                ranges=ranges,
1144
                ref_id=ref_id,
1145
                traces=self.ec.traces_normed,
1146
                infer_others=infer_others,
1147
                radius=radius,
1148
                peak_window=peak_window,
1149
                apply=apply,
1150
            )
1151

1152
    # 3. Fit the energy calibration relation
1153
    def calibrate_energy_axis(
1✔
1154
        self,
1155
        ref_id: int,
1156
        ref_energy: float,
1157
        method: str = None,
1158
        energy_scale: str = None,
1159
        **kwds,
1160
    ):
1161
        """3. Step of the energy calibration workflow: Calculate the calibration
1162
        function for the energy axis, and apply it to the dataframe. Two
1163
        approximations are implemented, a (normally 3rd order) polynomial
1164
        approximation, and a d^2/(t-t0)^2 relation.
1165

1166
        Args:
1167
            ref_id (int): id of the trace at the bias where the reference energy is
1168
                given.
1169
            ref_energy (float): Absolute energy of the detected feature at the bias
1170
                of ref_id
1171
            method (str, optional): Method for determining the energy calibration.
1172

1173
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1174
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1175

1176
                Defaults to config["energy"]["calibration_method"]
1177
            energy_scale (str, optional): Direction of increasing energy scale.
1178

1179
                - **'kinetic'**: increasing energy with decreasing TOF.
1180
                - **'binding'**: increasing energy with increasing TOF.
1181

1182
                Defaults to config["energy"]["energy_scale"]
1183
        """
1184
        if method is None:
1✔
1185
            method = self._config["energy"]["calibration_method"]
1✔
1186

1187
        if energy_scale is None:
1✔
1188
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1189

1190
        self.ec.calibrate(
1✔
1191
            ref_id=ref_id,
1192
            ref_energy=ref_energy,
1193
            method=method,
1194
            energy_scale=energy_scale,
1195
            **kwds,
1196
        )
1197
        print("Quality of Calibration:")
1✔
1198
        self.ec.view(
1✔
1199
            traces=self.ec.traces_normed,
1200
            xaxis=self.ec.calibration["axis"],
1201
            align=True,
1202
            energy_scale=energy_scale,
1203
            backend="bokeh",
1204
        )
1205
        print("E/TOF relationship:")
1✔
1206
        self.ec.view(
1✔
1207
            traces=self.ec.calibration["axis"][None, :],
1208
            xaxis=self.ec.tof,
1209
            backend="matplotlib",
1210
            show_legend=False,
1211
        )
1212
        if energy_scale == "kinetic":
1✔
1213
            plt.scatter(
1✔
1214
                self.ec.peaks[:, 0],
1215
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1216
                s=50,
1217
                c="k",
1218
            )
1219
        elif energy_scale == "binding":
1✔
1220
            plt.scatter(
1✔
1221
                self.ec.peaks[:, 0],
1222
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1223
                s=50,
1224
                c="k",
1225
            )
1226
        else:
1227
            raise ValueError(
×
1228
                'energy_scale needs to be either "binding" or "kinetic"',
1229
                f", got {energy_scale}.",
1230
            )
1231
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1232
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1233
        plt.show()
1✔
1234

1235
    # 3a. Save energy calibration parameters to config file.
1236
    def save_energy_calibration(
1✔
1237
        self,
1238
        filename: str = None,
1239
        overwrite: bool = False,
1240
    ):
1241
        """Save the generated energy calibration parameters to the folder config file.
1242

1243
        Args:
1244
            filename (str, optional): Filename of the config dictionary to save to.
1245
                Defaults to "sed_config.yaml" in the current folder.
1246
            overwrite (bool, optional): Option to overwrite the present dictionary.
1247
                Defaults to False.
1248
        """
1249
        if filename is None:
1✔
1250
            filename = "sed_config.yaml"
×
1251
        calibration = {}
1✔
1252
        try:
1✔
1253
            for (key, value) in self.ec.calibration.items():
1✔
1254
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1255
                    continue
1✔
1256
                if key == "energy_scale":
1✔
1257
                    calibration[key] = value
1✔
1258
                elif key == "coeffs":
1✔
1259
                    calibration[key] = [float(i) for i in value]
1✔
1260
                else:
1261
                    calibration[key] = float(value)
1✔
1262
        except AttributeError as exc:
×
1263
            raise AttributeError(
×
1264
                "Energy calibration parameters not found, need to generate parameters first!",
1265
            ) from exc
1266
        config = {"energy": {"calibration": calibration}}
1✔
1267
        save_config(config, filename, overwrite)
1✔
1268
        print(f'Saved energy calibration parameters to "{filename}".')
1✔
1269

1270
    # 4. Apply energy calibration to the dataframe
1271
    def append_energy_axis(
1✔
1272
        self,
1273
        calibration: dict = None,
1274
        preview: bool = False,
1275
        **kwds,
1276
    ):
1277
        """4. step of the energy calibration workflow: Apply the calibration function
1278
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1279
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1280
        can be provided.
1281

1282
        Args:
1283
            calibration (dict, optional): Calibration dict containing calibration
1284
                parameters. Overrides calibration from class or config.
1285
                Defaults to None.
1286
            preview (bool): Option to preview the first elements of the data frame.
1287
            **kwds:
1288
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1289
        """
1290
        if self._dataframe is not None:
1✔
1291
            print("Adding energy column to dataframe:")
1✔
1292
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1293
                df=self._dataframe,
1294
                calibration=calibration,
1295
                **kwds,
1296
            )
1297
            if self._timed_dataframe is not None:
1✔
1298
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1299
                    self._timed_dataframe, _ = self.ec.append_energy_axis(
1✔
1300
                        df=self._timed_dataframe,
1301
                        calibration=calibration,
1302
                        **kwds,
1303
                    )
1304

1305
            # Add Metadata
1306
            self._attributes.add(
1✔
1307
                metadata,
1308
                "energy_calibration",
1309
                duplicate_policy="merge",
1310
            )
1311
            if preview:
1✔
1312
                print(self._dataframe.head(10))
1✔
1313
            else:
1314
                if self.verbose:
1✔
1315
                    print(self._dataframe)
×
1316

1317
    def add_energy_offset(
1✔
1318
        self,
1319
        constant: float = None,
1320
        columns: Union[str, Sequence[str]] = None,
1321
        weights: Union[float, Sequence[float]] = None,
1322
        reductions: Union[str, Sequence[str]] = None,
1323
        preserve_mean: Union[bool, Sequence[bool]] = None,
1324
    ) -> None:
1325
        """Shift the energy axis of the dataframe by a given amount.
1326

1327
        Args:
1328
            constant (float, optional): The constant to shift the energy axis by.
1329
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1330
            weights (Union[float, Sequence[float]]): weights to apply to the columns.
1331
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1332
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1333
                shift. Defaults to False.
1334
            reductions (str): The reduction to apply to the column. Should be an available method
1335
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1336
                to the column to generate a single value for the whole dataset. If None, the shift
1337
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1338

1339
        Raises:
1340
            ValueError: If the energy column is not in the dataframe.
1341
        """
1342
        print("Adding energy offset to dataframe:")
1✔
1343
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1344
        if self.dataframe is not None:
1✔
1345
            if energy_column not in self._dataframe.columns:
1✔
1346
                raise ValueError(
1✔
1347
                    f"Energy column {energy_column} not found in dataframe! "
1348
                    "Run `append energy axis` first.",
1349
                )
1350
            df, metadata = self.ec.add_offsets(
1✔
1351
                df=self._dataframe,
1352
                constant=constant,
1353
                columns=columns,
1354
                energy_column=energy_column,
1355
                weights=weights,
1356
                reductions=reductions,
1357
                preserve_mean=preserve_mean,
1358
            )
1359
            if self._timed_dataframe is not None:
1✔
1360
                if energy_column in self._timed_dataframe.columns:
1✔
1361
                    self._timed_dataframe, _ = self.ec.add_offsets(
1✔
1362
                        df=self._timed_dataframe,
1363
                        constant=constant,
1364
                        columns=columns,
1365
                        energy_column=energy_column,
1366
                        weights=weights,
1367
                        reductions=reductions,
1368
                        preserve_mean=preserve_mean,
1369
                    )
1370
            self._attributes.add(
1✔
1371
                metadata,
1372
                "add_energy_offset",
1373
                # TODO: allow only appending when no offset along this column(s) was applied
1374
                # TODO: clear memory of modifications if the energy axis is recalculated
1375
                duplicate_policy="append",
1376
            )
1377
            self._dataframe = df
1✔
1378
        else:
1379
            raise ValueError("No dataframe loaded!")
×
1380

1381
    def save_energy_offset(
1✔
1382
        self,
1383
        filename: str = None,
1384
        overwrite: bool = False,
1385
    ):
1386
        """Save the generated energy calibration parameters to the folder config file.
1387

1388
        Args:
1389
            filename (str, optional): Filename of the config dictionary to save to.
1390
                Defaults to "sed_config.yaml" in the current folder.
1391
            overwrite (bool, optional): Option to overwrite the present dictionary.
1392
                Defaults to False.
1393
        """
1394
        if filename is None:
×
1395
            filename = "sed_config.yaml"
×
1396
        if len(self.ec.offsets) == 0:
×
1397
            raise ValueError("No energy offset parameters to save!")
×
1398
        config = {"energy": {"offsets": self.ec.offsets}}
×
1399
        save_config(config, filename, overwrite)
×
1400
        print(f'Saved energy offset parameters to "{filename}".')
×
1401

1402
    def append_tof_ns_axis(
1✔
1403
        self,
1404
        **kwargs,
1405
    ):
1406
        """Convert time-of-flight channel steps to nanoseconds.
1407

1408
        Args:
1409
            tof_ns_column (str, optional): Name of the generated column containing the
1410
                time-of-flight in nanosecond.
1411
                Defaults to config["dataframe"]["tof_ns_column"].
1412
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1413

1414
        """
1415
        if self._dataframe is not None:
1✔
1416
            print("Adding time-of-flight column in nanoseconds to dataframe:")
1✔
1417
            # TODO assert order of execution through metadata
1418

1419
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
1✔
1420
                df=self._dataframe,
1421
                **kwargs,
1422
            )
1423
            if self._timed_dataframe is not None:
1✔
1424
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1425
                    self._timed_dataframe, _ = self.ec.append_tof_ns_axis(
1✔
1426
                        df=self._timed_dataframe,
1427
                        **kwargs,
1428
                    )
1429
            self._attributes.add(
1✔
1430
                metadata,
1431
                "tof_ns_conversion",
1432
                duplicate_policy="append",
1433
            )
1434

1435
    def align_dld_sectors(self, sector_delays: np.ndarray = None, **kwargs):
1✔
1436
        """Align the 8s sectors of the HEXTOF endstation.
1437

1438
        Args:
1439
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1440
                config["dataframe"]["sector_delays"].
1441
        """
1442
        if self._dataframe is not None:
1✔
1443
            print("Aligning 8s sectors of dataframe")
1✔
1444
            # TODO assert order of execution through metadata
1445
            self._dataframe, metadata = self.ec.align_dld_sectors(
1✔
1446
                df=self._dataframe,
1447
                sector_delays=sector_delays,
1448
                **kwargs,
1449
            )
1450
            if self._timed_dataframe is not None:
1✔
1451
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1452
                    self._timed_dataframe, _ = self.ec.align_dld_sectors(
×
1453
                        df=self._timed_dataframe,
1454
                        sector_delays=sector_delays,
1455
                        **kwargs,
1456
                    )
1457
            self._attributes.add(
1✔
1458
                metadata,
1459
                "dld_sector_alignment",
1460
                duplicate_policy="raise",
1461
            )
1462

1463
    # Delay calibration function
1464
    def calibrate_delay_axis(
1✔
1465
        self,
1466
        delay_range: Tuple[float, float] = None,
1467
        datafile: str = None,
1468
        preview: bool = False,
1469
        **kwds,
1470
    ):
1471
        """Append delay column to dataframe. Either provide delay ranges, or read
1472
        them from a file.
1473

1474
        Args:
1475
            delay_range (Tuple[float, float], optional): The scanned delay range in
1476
                picoseconds. Defaults to None.
1477
            datafile (str, optional): The file from which to read the delay ranges.
1478
                Defaults to None.
1479
            preview (bool): Option to preview the first elements of the data frame.
1480
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1481
        """
1482
        if self._dataframe is not None:
1✔
1483
            print("Adding delay column to dataframe:")
1✔
1484

1485
            if delay_range is not None:
1✔
1486
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1487
                    self._dataframe,
1488
                    delay_range=delay_range,
1489
                    **kwds,
1490
                )
1491
                if self._timed_dataframe is not None:
1✔
1492
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1493
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1494
                            self._timed_dataframe,
1495
                            delay_range=delay_range,
1496
                            **kwds,
1497
                        )
1498
            else:
1499
                if datafile is None:
1✔
1500
                    try:
1✔
1501
                        datafile = self._files[0]
1✔
1502
                    except IndexError:
×
1503
                        print(
×
1504
                            "No datafile available, specify either",
1505
                            " 'datafile' or 'delay_range'",
1506
                        )
1507
                        raise
×
1508

1509
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1510
                    self._dataframe,
1511
                    datafile=datafile,
1512
                    **kwds,
1513
                )
1514
                if self._timed_dataframe is not None:
1✔
1515
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1516
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1517
                            self._timed_dataframe,
1518
                            datafile=datafile,
1519
                            **kwds,
1520
                        )
1521

1522
            # Add Metadata
1523
            self._attributes.add(
1✔
1524
                metadata,
1525
                "delay_calibration",
1526
                duplicate_policy="merge",
1527
            )
1528
            if preview:
1✔
1529
                print(self._dataframe.head(10))
1✔
1530
            else:
1531
                if self.verbose:
1✔
1532
                    print(self._dataframe)
×
1533

1534
    def save_delay_calibration(
1✔
1535
        self,
1536
        filename: str = None,
1537
        overwrite: bool = False,
1538
    ) -> None:
1539
        """Save the generated delay calibration parameters to the folder config file.
1540

1541
        Args:
1542
            filename (str, optional): Filename of the config dictionary to save to.
1543
                Defaults to "sed_config.yaml" in the current folder.
1544
            overwrite (bool, optional): Option to overwrite the present dictionary.
1545
                Defaults to False.
1546
        """
1547
        if filename is None:
×
1548
            filename = "sed_config.yaml"
×
1549

1550
        config = {
×
1551
            "delay": {
1552
                "calibration": self.dc.calibration,
1553
            },
1554
        }
1555
        save_config(config, filename, overwrite)
×
1556

1557
    def add_delay_offset(
1✔
1558
        self,
1559
        constant: float = None,
1560
        flip_delay_axis: bool = None,
1561
        columns: Union[str, Sequence[str]] = None,
1562
        weights: Union[float, Sequence[float]] = None,
1563
        reductions: Union[str, Sequence[str]] = None,
1564
        preserve_mean: Union[bool, Sequence[bool]] = None,
1565
    ) -> None:
1566
        """Shift the delay axis of the dataframe by a constant or other columns.
1567

1568
        Args:
1569
            constant (float, optional): The constant to shift the delay axis by.
1570
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1571
            weights (Union[float, Sequence[float]]): weights to apply to the columns.
1572
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1573
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1574
                shift. Defaults to False.
1575
            reductions (str): The reduction to apply to the column. Should be an available method
1576
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1577
                to the column to generate a single value for the whole dataset. If None, the shift
1578
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1579

1580
        Returns:
1581
            None
1582
        """
1583
        print("Adding delay offset to dataframe:")
×
1584
        delay_column = self._config["dataframe"]["delay_column"]
×
1585
        if delay_column not in self._dataframe.columns:
×
1586
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
×
1587

1588
        if self.dataframe is not None:
×
1589
            df, metadata = self.dc.add_offsets(
×
1590
                df=self._dataframe,
1591
                constant=constant,
1592
                flip_delay_axis=flip_delay_axis,
1593
                columns=columns,
1594
                delay_column=delay_column,
1595
                weights=weights,
1596
                reductions=reductions,
1597
                preserve_mean=preserve_mean,
1598
            )
1599
        if self._timed_dataframe is not None:
×
1600
            if delay_column in self._timed_dataframe.columns:
×
1601
                tdf, _ = self.dc.add_offsets(
×
1602
                    df=self._timed_dataframe,
1603
                    constant=constant,
1604
                    flip_delay_axis=flip_delay_axis,
1605
                    columns=columns,
1606
                    delay_column=delay_column,
1607
                    weights=weights,
1608
                    reductions=reductions,
1609
                    preserve_mean=preserve_mean,
1610
                )
1611
            self._attributes.add(
×
1612
                metadata,
1613
                "add_delay_offset",
1614
                duplicate_policy="append",
1615
            )
1616
            self._dataframe = df
×
1617
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
×
1618
                self._timed_dataframe = tdf
×
1619
        else:
1620
            raise ValueError("No dataframe loaded!")
×
1621

1622
    def save_delay_offsets(
1✔
1623
        self,
1624
        filename: str = None,
1625
        overwrite: bool = False,
1626
    ) -> None:
1627
        """Save the generated delay calibration parameters to the folder config file.
1628

1629
        Args:
1630
            filename (str, optional): Filename of the config dictionary to save to.
1631
                Defaults to "sed_config.yaml" in the current folder.
1632
            overwrite (bool, optional): Option to overwrite the present dictionary.
1633
                Defaults to False.
1634
        """
1635
        if filename is None:
×
1636
            filename = "sed_config.yaml"
×
1637
        if len(self.dc.offsets) == 0:
×
1638
            raise ValueError("No delay offset parameters to save!")
×
1639
        config = {
×
1640
            "delay": {
1641
                "offsets": self.dc.offsets,
1642
            },
1643
        }
1644
        save_config(config, filename, overwrite)
×
1645
        print(f'Saved delay offset parameters to "{filename}".')
×
1646

1647
    def save_workflow_params(
1✔
1648
        self,
1649
        filename: str = None,
1650
        overwrite: bool = False,
1651
    ) -> None:
1652
        """run all save calibration parameter methods
1653

1654
        Args:
1655
            filename (str, optional): Filename of the config dictionary to save to.
1656
                Defaults to "sed_config.yaml" in the current folder.
1657
            overwrite (bool, optional): Option to overwrite the present dictionary.
1658
                Defaults to False.
1659
        """
1660
        for method in [
×
1661
            self.save_momentum_calibration,
1662
            self.save_splinewarp,
1663
            self.save_energy_correction,
1664
            self.save_energy_calibration,
1665
            self.save_energy_offset,
1666
            self.save_delay_calibration,
1667
            self.save_delay_offsets,
1668
        ]:
1669
            try:
×
1670
                method(filename, overwrite)
×
1671
            except (ValueError, AttributeError, KeyError):
×
1672
                pass
×
1673

1674
    def add_jitter(
1✔
1675
        self,
1676
        cols: List[str] = None,
1677
        amps: Union[float, Sequence[float]] = None,
1678
        **kwds,
1679
    ):
1680
        """Add jitter to the selected dataframe columns.
1681

1682
        Args:
1683
            cols (List[str], optional): The colums onto which to apply jitter.
1684
                Defaults to config["dataframe"]["jitter_cols"].
1685
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1686
                jittering noise. If one number is given, the same is used for all axes.
1687
                For uniform noise (default) it will cover the interval [-amp, +amp].
1688
                Defaults to config["dataframe"]["jitter_amps"].
1689
            **kwds: additional keyword arguments passed to apply_jitter
1690
        """
1691
        if cols is None:
1✔
1692
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1693
        for loc, col in enumerate(cols):
1✔
1694
            if col.startswith("@"):
1✔
1695
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1696

1697
        if amps is None:
1✔
1698
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1699

1700
        self._dataframe = self._dataframe.map_partitions(
1✔
1701
            apply_jitter,
1702
            cols=cols,
1703
            cols_jittered=cols,
1704
            amps=amps,
1705
            **kwds,
1706
        )
1707
        if self._timed_dataframe is not None:
1✔
1708
            cols_timed = cols.copy()
1✔
1709
            for col in cols:
1✔
1710
                if col not in self._timed_dataframe.columns:
1✔
1711
                    cols_timed.remove(col)
×
1712

1713
            if cols_timed:
1✔
1714
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1715
                    apply_jitter,
1716
                    cols=cols_timed,
1717
                    cols_jittered=cols_timed,
1718
                )
1719
        metadata = []
1✔
1720
        for col in cols:
1✔
1721
            metadata.append(col)
1✔
1722
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1723

1724
    def add_time_stamped_data(
1✔
1725
        self,
1726
        dest_column: str,
1727
        time_stamps: np.ndarray = None,
1728
        data: np.ndarray = None,
1729
        archiver_channel: str = None,
1730
        **kwds,
1731
    ):
1732
        """Add data in form of timestamp/value pairs to the dataframe using interpolation to the
1733
        timestamps in the dataframe. The time-stamped data can either be provided, or fetched from
1734
        an EPICS archiver instance.
1735

1736
        Args:
1737
            dest_column (str): destination column name
1738
            time_stamps (np.ndarray, optional): Time stamps of the values to add. If omitted,
1739
                time stamps are retrieved from the epics archiver
1740
            data (np.ndarray, optional): Values corresponding at the time stamps in time_stamps.
1741
                If omitted, data are retrieved from the epics archiver.
1742
            archiver_channel (str, optional): EPICS archiver channel from which to retrieve data.
1743
                Either this or data and time_stamps have to be present.
1744
            **kwds: additional keyword arguments passed to add_time_stamped_data
1745
        """
1746
        time_stamp_column = kwds.pop(
1✔
1747
            "time_stamp_column",
1748
            self._config["dataframe"].get("time_stamp_alias", ""),
1749
        )
1750

1751
        if time_stamps is None and data is None:
1✔
1752
            if archiver_channel is None:
×
1753
                raise ValueError(
×
1754
                    "Either archiver_channel or both time_stamps and data have to be present!",
1755
                )
1756
            if self.loader.__name__ != "mpes":
×
1757
                raise NotImplementedError(
×
1758
                    "This function is currently only implemented for the mpes loader!",
1759
                )
1760
            ts_from, ts_to = cast(MpesLoader, self.loader).get_start_and_end_time()
×
1761
            # get channel data with +-5 seconds safety margin
1762
            time_stamps, data = get_archiver_data(
×
1763
                archiver_url=self._config["metadata"].get("archiver_url", ""),
1764
                archiver_channel=archiver_channel,
1765
                ts_from=ts_from - 5,
1766
                ts_to=ts_to + 5,
1767
            )
1768

1769
        self._dataframe = add_time_stamped_data(
1✔
1770
            self._dataframe,
1771
            time_stamps=time_stamps,
1772
            data=data,
1773
            dest_column=dest_column,
1774
            time_stamp_column=time_stamp_column,
1775
            **kwds,
1776
        )
1777
        if self._timed_dataframe is not None:
1✔
1778
            if time_stamp_column in self._timed_dataframe:
1✔
1779
                self._timed_dataframe = add_time_stamped_data(
1✔
1780
                    self._timed_dataframe,
1781
                    time_stamps=time_stamps,
1782
                    data=data,
1783
                    dest_column=dest_column,
1784
                    time_stamp_column=time_stamp_column,
1785
                    **kwds,
1786
                )
1787
        metadata: List[Any] = []
1✔
1788
        metadata.append(dest_column)
1✔
1789
        metadata.append(time_stamps)
1✔
1790
        metadata.append(data)
1✔
1791
        self._attributes.add(metadata, "time_stamped_data", duplicate_policy="append")
1✔
1792

1793
    def pre_binning(
1✔
1794
        self,
1795
        df_partitions: Union[int, Sequence[int]] = 100,
1796
        axes: List[str] = None,
1797
        bins: List[int] = None,
1798
        ranges: Sequence[Tuple[float, float]] = None,
1799
        **kwds,
1800
    ) -> xr.DataArray:
1801
        """Function to do an initial binning of the dataframe loaded to the class.
1802

1803
        Args:
1804
            df_partitions (Union[int, Sequence[int]], optional): Number of dataframe partitions to
1805
                use for the initial binning. Defaults to 100.
1806
            axes (List[str], optional): Axes to bin.
1807
                Defaults to config["momentum"]["axes"].
1808
            bins (List[int], optional): Bin numbers to use for binning.
1809
                Defaults to config["momentum"]["bins"].
1810
            ranges (List[Tuple], optional): Ranges to use for binning.
1811
                Defaults to config["momentum"]["ranges"].
1812
            **kwds: Keyword argument passed to ``compute``.
1813

1814
        Returns:
1815
            xr.DataArray: pre-binned data-array.
1816
        """
1817
        if axes is None:
1✔
1818
            axes = self._config["momentum"]["axes"]
1✔
1819
        for loc, axis in enumerate(axes):
1✔
1820
            if axis.startswith("@"):
1✔
1821
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1822

1823
        if bins is None:
1✔
1824
            bins = self._config["momentum"]["bins"]
1✔
1825
        if ranges is None:
1✔
1826
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1827
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1828
                self._config["dataframe"]["tof_binning"] - 1
1829
            )
1830
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1831

1832
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1833

1834
        return self.compute(
1✔
1835
            bins=bins,
1836
            axes=axes,
1837
            ranges=ranges,
1838
            df_partitions=df_partitions,
1839
            **kwds,
1840
        )
1841

1842
    def compute(
1✔
1843
        self,
1844
        bins: Union[
1845
            int,
1846
            dict,
1847
            tuple,
1848
            List[int],
1849
            List[np.ndarray],
1850
            List[tuple],
1851
        ] = 100,
1852
        axes: Union[str, Sequence[str]] = None,
1853
        ranges: Sequence[Tuple[float, float]] = None,
1854
        normalize_to_acquisition_time: Union[bool, str] = False,
1855
        **kwds,
1856
    ) -> xr.DataArray:
1857
        """Compute the histogram along the given dimensions.
1858

1859
        Args:
1860
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1861
                Definition of the bins. Can be any of the following cases:
1862

1863
                - an integer describing the number of bins in on all dimensions
1864
                - a tuple of 3 numbers describing start, end and step of the binning
1865
                  range
1866
                - a np.arrays defining the binning edges
1867
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1868
                - a dictionary made of the axes as keys and any of the above as values.
1869

1870
                This takes priority over the axes and range arguments. Defaults to 100.
1871
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1872
                on which to calculate the histogram. The order will be the order of the
1873
                dimensions in the resulting array. Defaults to None.
1874
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1875
                the start and end point of the binning range. Defaults to None.
1876
            normalize_to_acquisition_time (Union[bool, str]): Option to normalize the
1877
                result to the acquistion time. If a "slow" axis was scanned, providing
1878
                the name of the scanned axis will compute and apply the corresponding
1879
                normalization histogram. Defaults to False.
1880
            **kwds: Keyword arguments:
1881

1882
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1883
                  ``bin_dataframe`` for details. Defaults to
1884
                  config["binning"]["hist_mode"].
1885
                - **mode**: Defines how the results from each partition are combined.
1886
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1887
                  Defaults to config["binning"]["mode"].
1888
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1889
                  config["binning"]["pbar"].
1890
                - **n_cores**: Number of CPU cores to use for parallelization.
1891
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1892
                - **threads_per_worker**: Limit the number of threads that
1893
                  multiprocessing can spawn per binning thread. Defaults to
1894
                  config["binning"]["threads_per_worker"].
1895
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1896
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1897
                  config["binning"]["threadpool_API"].
1898
                - **df_partitions**: A sequence of dataframe partitions, or the
1899
                  number of the dataframe partitions to use. Defaults to all partitions.
1900

1901
                Additional kwds are passed to ``bin_dataframe``.
1902

1903
        Raises:
1904
            AssertError: Rises when no dataframe has been loaded.
1905

1906
        Returns:
1907
            xr.DataArray: The result of the n-dimensional binning represented in an
1908
            xarray object, combining the data with the axes.
1909
        """
1910
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1911

1912
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1913
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1914
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1915
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1916
        threads_per_worker = kwds.pop(
1✔
1917
            "threads_per_worker",
1918
            self._config["binning"]["threads_per_worker"],
1919
        )
1920
        threadpool_api = kwds.pop(
1✔
1921
            "threadpool_API",
1922
            self._config["binning"]["threadpool_API"],
1923
        )
1924
        df_partitions: Union[int, Sequence[int]] = kwds.pop("df_partitions", None)
1✔
1925
        if isinstance(df_partitions, int):
1✔
1926
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
1927
        if df_partitions is not None:
1✔
1928
            dataframe = self._dataframe.partitions[df_partitions]
1✔
1929
        else:
1930
            dataframe = self._dataframe
1✔
1931

1932
        self._binned = bin_dataframe(
1✔
1933
            df=dataframe,
1934
            bins=bins,
1935
            axes=axes,
1936
            ranges=ranges,
1937
            hist_mode=hist_mode,
1938
            mode=mode,
1939
            pbar=pbar,
1940
            n_cores=num_cores,
1941
            threads_per_worker=threads_per_worker,
1942
            threadpool_api=threadpool_api,
1943
            **kwds,
1944
        )
1945

1946
        for dim in self._binned.dims:
1✔
1947
            try:
1✔
1948
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1949
            except KeyError:
1✔
1950
                pass
1✔
1951

1952
        self._binned.attrs["units"] = "counts"
1✔
1953
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1954
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1955

1956
        if normalize_to_acquisition_time:
1✔
1957
            if isinstance(normalize_to_acquisition_time, str):
1✔
1958
                axis = normalize_to_acquisition_time
1✔
1959
                print(
1✔
1960
                    f"Calculate normalization histogram for axis '{axis}'...",
1961
                )
1962
                self._normalization_histogram = self.get_normalization_histogram(
1✔
1963
                    axis=axis,
1964
                    df_partitions=df_partitions,
1965
                )
1966
                # if the axes are named correctly, xarray figures out the normalization correctly
1967
                self._normalized = self._binned / self._normalization_histogram
1✔
1968
                self._attributes.add(
1✔
1969
                    self._normalization_histogram.values,
1970
                    name="normalization_histogram",
1971
                    duplicate_policy="overwrite",
1972
                )
1973
            else:
1974
                acquisition_time = self.loader.get_elapsed_time(
×
1975
                    fids=df_partitions,
1976
                )
1977
                if acquisition_time > 0:
×
1978
                    self._normalized = self._binned / acquisition_time
×
1979
                self._attributes.add(
×
1980
                    acquisition_time,
1981
                    name="normalization_histogram",
1982
                    duplicate_policy="overwrite",
1983
                )
1984

1985
            self._normalized.attrs["units"] = "counts/second"
1✔
1986
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
1987
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
1988

1989
            return self._normalized
1✔
1990

1991
        return self._binned
1✔
1992

1993
    def get_normalization_histogram(
1✔
1994
        self,
1995
        axis: str = "delay",
1996
        use_time_stamps: bool = False,
1997
        **kwds,
1998
    ) -> xr.DataArray:
1999
        """Generates a normalization histogram from the timed dataframe. Optionally,
2000
        use the TimeStamps column instead.
2001

2002
        Args:
2003
            axis (str, optional): The axis for which to compute histogram.
2004
                Defaults to "delay".
2005
            use_time_stamps (bool, optional): Use the TimeStamps column of the
2006
                dataframe, rather than the timed dataframe. Defaults to False.
2007
            **kwds: Keyword arguments:
2008

2009
                - **df_partitions**: A sequence of dataframe partitions, or the
2010
                  number of the dataframe partitions to use. Defaults to all partitions.
2011

2012
        Raises:
2013
            ValueError: Raised if no data are binned.
2014
            ValueError: Raised if 'axis' not in binned coordinates.
2015
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
2016
                in Dataframe.
2017

2018
        Returns:
2019
            xr.DataArray: The computed normalization histogram (in TimeStamp units
2020
            per bin).
2021
        """
2022

2023
        if self._binned is None:
1✔
2024
            raise ValueError("Need to bin data first!")
1✔
2025
        if axis not in self._binned.coords:
1✔
2026
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
2027

2028
        df_partitions: Union[int, Sequence[int]] = kwds.pop("df_partitions", None)
1✔
2029
        if isinstance(df_partitions, int):
1✔
2030
            df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
1✔
2031
        if use_time_stamps or self._timed_dataframe is None:
1✔
2032
            if df_partitions is not None:
1✔
2033
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
2034
                    self._dataframe.partitions[df_partitions],
2035
                    axis,
2036
                    self._binned.coords[axis].values,
2037
                    self._config["dataframe"]["time_stamp_alias"],
2038
                )
2039
            else:
2040
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
2041
                    self._dataframe,
2042
                    axis,
2043
                    self._binned.coords[axis].values,
2044
                    self._config["dataframe"]["time_stamp_alias"],
2045
                )
2046
        else:
2047
            if df_partitions is not None:
1✔
2048
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
2049
                    self._timed_dataframe.partitions[df_partitions],
2050
                    axis,
2051
                    self._binned.coords[axis].values,
2052
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2053
                )
2054
            else:
2055
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
2056
                    self._timed_dataframe,
2057
                    axis,
2058
                    self._binned.coords[axis].values,
2059
                    self._config["dataframe"]["timed_dataframe_unit_time"],
2060
                )
2061

2062
        return self._normalization_histogram
1✔
2063

2064
    def view_event_histogram(
1✔
2065
        self,
2066
        dfpid: int,
2067
        ncol: int = 2,
2068
        bins: Sequence[int] = None,
2069
        axes: Sequence[str] = None,
2070
        ranges: Sequence[Tuple[float, float]] = None,
2071
        backend: str = "bokeh",
2072
        legend: bool = True,
2073
        histkwds: dict = None,
2074
        legkwds: dict = None,
2075
        **kwds,
2076
    ):
2077
        """Plot individual histograms of specified dimensions (axes) from a substituent
2078
        dataframe partition.
2079

2080
        Args:
2081
            dfpid (int): Number of the data frame partition to look at.
2082
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
2083
            bins (Sequence[int], optional): Number of bins to use for the speicified
2084
                axes. Defaults to config["histogram"]["bins"].
2085
            axes (Sequence[str], optional): Names of the axes to display.
2086
                Defaults to config["histogram"]["axes"].
2087
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
2088
                specified axes. Defaults toconfig["histogram"]["ranges"].
2089
            backend (str, optional): Backend of the plotting library
2090
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
2091
            legend (bool, optional): Option to include a legend in the histogram plots.
2092
                Defaults to True.
2093
            histkwds (dict, optional): Keyword arguments for histograms
2094
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
2095
            legkwds (dict, optional): Keyword arguments for legend
2096
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
2097
            **kwds: Extra keyword arguments passed to
2098
                ``sed.diagnostics.grid_histogram()``.
2099

2100
        Raises:
2101
            TypeError: Raises when the input values are not of the correct type.
2102
        """
2103
        if bins is None:
1✔
2104
            bins = self._config["histogram"]["bins"]
1✔
2105
        if axes is None:
1✔
2106
            axes = self._config["histogram"]["axes"]
1✔
2107
        axes = list(axes)
1✔
2108
        for loc, axis in enumerate(axes):
1✔
2109
            if axis.startswith("@"):
1✔
2110
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2111
        if ranges is None:
1✔
2112
            ranges = list(self._config["histogram"]["ranges"])
1✔
2113
            for loc, axis in enumerate(axes):
1✔
2114
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2115
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
2116
                        self._config["dataframe"]["tof_binning"] - 1
2117
                    )
2118
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
2119
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
2120
                        self._config["dataframe"]["adc_binning"] - 1
2121
                    )
2122

2123
        input_types = map(type, [axes, bins, ranges])
1✔
2124
        allowed_types = [list, tuple]
1✔
2125

2126
        df = self._dataframe
1✔
2127

2128
        if not set(input_types).issubset(allowed_types):
1✔
2129
            raise TypeError(
×
2130
                "Inputs of axes, bins, ranges need to be list or tuple!",
2131
            )
2132

2133
        # Read out the values for the specified groups
2134
        group_dict_dd = {}
1✔
2135
        dfpart = df.get_partition(dfpid)
1✔
2136
        cols = dfpart.columns
1✔
2137
        for ax in axes:
1✔
2138
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2139
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2140

2141
        # Plot multiple histograms in a grid
2142
        grid_histogram(
1✔
2143
            group_dict,
2144
            ncol=ncol,
2145
            rvs=axes,
2146
            rvbins=bins,
2147
            rvranges=ranges,
2148
            backend=backend,
2149
            legend=legend,
2150
            histkwds=histkwds,
2151
            legkwds=legkwds,
2152
            **kwds,
2153
        )
2154

2155
    def save(
1✔
2156
        self,
2157
        faddr: str,
2158
        **kwds,
2159
    ):
2160
        """Saves the binned data to the provided path and filename.
2161

2162
        Args:
2163
            faddr (str): Path and name of the file to write. Its extension determines
2164
                the file type to write. Valid file types are:
2165

2166
                - "*.tiff", "*.tif": Saves a TIFF stack.
2167
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2168
                - "*.nxs", "*.nexus": Saves a NeXus file.
2169

2170
            **kwds: Keyword argumens, which are passed to the writer functions:
2171
                For TIFF writing:
2172

2173
                - **alias_dict**: Dictionary of dimension aliases to use.
2174

2175
                For HDF5 writing:
2176

2177
                - **mode**: hdf5 read/write mode. Defaults to "w".
2178

2179
                For NeXus:
2180

2181
                - **reader**: Name of the nexustools reader to use.
2182
                  Defaults to config["nexus"]["reader"]
2183
                - **definiton**: NeXus application definition to use for saving.
2184
                  Must be supported by the used ``reader``. Defaults to
2185
                  config["nexus"]["definition"]
2186
                - **input_files**: A list of input files to pass to the reader.
2187
                  Defaults to config["nexus"]["input_files"]
2188
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2189
                  to add to the list of files to pass to the reader.
2190
        """
2191
        if self._binned is None:
1✔
2192
            raise NameError("Need to bin data first!")
1✔
2193

2194
        if self._normalized is not None:
1✔
2195
            data = self._normalized
×
2196
        else:
2197
            data = self._binned
1✔
2198

2199
        extension = pathlib.Path(faddr).suffix
1✔
2200

2201
        if extension in (".tif", ".tiff"):
1✔
2202
            to_tiff(
1✔
2203
                data=data,
2204
                faddr=faddr,
2205
                **kwds,
2206
            )
2207
        elif extension in (".h5", ".hdf5"):
1✔
2208
            to_h5(
1✔
2209
                data=data,
2210
                faddr=faddr,
2211
                **kwds,
2212
            )
2213
        elif extension in (".nxs", ".nexus"):
1✔
2214
            try:
1✔
2215
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2216
                definition = kwds.pop(
1✔
2217
                    "definition",
2218
                    self._config["nexus"]["definition"],
2219
                )
2220
                input_files = kwds.pop(
1✔
2221
                    "input_files",
2222
                    self._config["nexus"]["input_files"],
2223
                )
2224
            except KeyError as exc:
×
2225
                raise ValueError(
×
2226
                    "The nexus reader, definition and input files need to be provide!",
2227
                ) from exc
2228

2229
            if isinstance(input_files, str):
1✔
2230
                input_files = [input_files]
1✔
2231

2232
            if "eln_data" in kwds:
1✔
2233
                input_files.append(kwds.pop("eln_data"))
×
2234

2235
            to_nexus(
1✔
2236
                data=data,
2237
                faddr=faddr,
2238
                reader=reader,
2239
                definition=definition,
2240
                input_files=input_files,
2241
                **kwds,
2242
            )
2243

2244
        else:
2245
            raise NotImplementedError(
1✔
2246
                f"Unrecognized file format: {extension}.",
2247
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc