• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6786223079

07 Nov 2023 02:53PM UTC coverage: 89.985% (-0.06%) from 90.045%
6786223079

Pull #240

github

steinnymir
fix loading boolean parameter
Pull Request #240: implement pump-probe calibration for hextof

122 of 145 new or added lines in 3 files covered. (84.14%)

45 existing lines in 2 files now uncovered.

4762 of 5292 relevant lines covered (89.98%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.87
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.calibrator import DelayCalibrator
1✔
22
from sed.calibrator import EnergyCalibrator
1✔
23
from sed.calibrator import MomentumCorrector
1✔
24
from sed.core.config import parse_config
1✔
25
from sed.core.config import save_config
1✔
26
from sed.core.dfops import apply_jitter
1✔
27
from sed.core.metadata import MetaHandler
1✔
28
from sed.diagnostics import grid_histogram
1✔
29
from sed.io import to_h5
1✔
30
from sed.io import to_nexus
1✔
31
from sed.io import to_tiff
1✔
32
from sed.loader import CopyTool
1✔
33
from sed.loader import get_loader
1✔
34

35
N_CPU = psutil.cpu_count()
1✔
36

37

38
class SedProcessor:
1✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
1✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to parse_config and to the reader.
86
        """
87
        config_kwds = {
1✔
88
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
89
        }
90
        for key in config_kwds.keys():
1✔
91
            del kwds[key]
1✔
92
        self._config = parse_config(config, **config_kwds)
1✔
93
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
94
        if num_cores >= N_CPU:
1✔
95
            num_cores = N_CPU - 1
1✔
96
        self._config["binning"]["num_cores"] = num_cores
1✔
97

98
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
99
        self._files: List[str] = []
1✔
100

101
        self._binned: xr.DataArray = None
1✔
102
        self._pre_binned: xr.DataArray = None
1✔
103

104
        self._attributes = MetaHandler(meta=metadata)
1✔
105

106
        loader_name = self._config["core"]["loader"]
1✔
107
        self.loader = get_loader(
1✔
108
            loader_name=loader_name,
109
            config=self._config,
110
        )
111

112
        self.ec = EnergyCalibrator(
1✔
113
            loader=self.loader,
114
            config=self._config,
115
        )
116

117
        self.mc = MomentumCorrector(
1✔
118
            config=self._config,
119
        )
120

121
        self.dc = DelayCalibrator(
1✔
122
            config=self._config,
123
        )
124

125
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
126
            "use_copy_tool",
127
            False,
128
        )
129
        if self.use_copy_tool:
1✔
130
            try:
1✔
131
                self.ct = CopyTool(
1✔
132
                    source=self._config["core"]["copy_tool_source"],
133
                    dest=self._config["core"]["copy_tool_dest"],
134
                    **self._config["core"].get("copy_tool_kwds", {}),
135
                )
136
            except KeyError:
1✔
137
                self.use_copy_tool = False
1✔
138

139
        # Load data if provided:
140
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
141
            self.load(
1✔
142
                dataframe=dataframe,
143
                metadata=metadata,
144
                files=files,
145
                folder=folder,
146
                runs=runs,
147
                collect_metadata=collect_metadata,
148
                **kwds,
149
            )
150

151
    def __repr__(self):
1✔
152
        if self._dataframe is None:
1✔
153
            df_str = "Data Frame: No Data loaded"
1✔
154
        else:
155
            df_str = self._dataframe.__repr__()
1✔
156
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
157
        pretty_str = df_str + "\n" + attributes_str
1✔
158
        return pretty_str
1✔
159

160
    @property
1✔
161
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
162
        """Accessor to the underlying dataframe.
163

164
        Returns:
165
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
166
        """
167
        return self._dataframe
1✔
168

169
    @dataframe.setter
1✔
170
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
171
        """Setter for the underlying dataframe.
172

173
        Args:
174
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
175
        """
176
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
177
            dataframe,
178
            self._dataframe.__class__,
179
        ):
180
            raise ValueError(
1✔
181
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
182
                "as the dataframe loaded into the SedProcessor!.\n"
183
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
184
            )
185
        self._dataframe = dataframe
1✔
186

187
    @property
1✔
188
    def attributes(self) -> dict:
1✔
189
        """Accessor to the metadata dict.
190

191
        Returns:
192
            dict: The metadata dict.
193
        """
194
        return self._attributes.metadata
1✔
195

196
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
197
        """Function to add element to the attributes dict.
198

199
        Args:
200
            attributes (dict): The attributes dictionary object to add.
201
            name (str): Key under which to add the dictionary to the attributes.
202
        """
203
        self._attributes.add(
1✔
204
            entry=attributes,
205
            name=name,
206
            **kwds,
207
        )
208

209
    @property
1✔
210
    def config(self) -> Dict[Any, Any]:
1✔
211
        """Getter attribute for the config dictionary
212

213
        Returns:
214
            Dict: The config dictionary.
215
        """
216
        return self._config
1✔
217

218
    @property
1✔
219
    def files(self) -> List[str]:
1✔
220
        """Getter attribute for the list of files
221

222
        Returns:
223
            List[str]: The list of loaded files
224
        """
225
        return self._files
1✔
226

227
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
228
        """Function to mirror a list of files or a folder from a network drive to a
229
        local storage. Returns either the original or the copied path to the given
230
        path. The option to use this functionality is set by
231
        config["core"]["use_copy_tool"].
232

233
        Args:
234
            path (Union[str, List[str]]): Source path or path list.
235

236
        Returns:
237
            Union[str, List[str]]: Source or destination path or path list.
238
        """
239
        if self.use_copy_tool:
1✔
240
            if isinstance(path, list):
1✔
241
                path_out = []
1✔
242
                for file in path:
1✔
243
                    path_out.append(self.ct.copy(file))
1✔
244
                return path_out
1✔
245

246
            return self.ct.copy(path)
×
247

248
        if isinstance(path, list):
1✔
249
            return path
1✔
250

251
        return path
1✔
252

253
    def load(
1✔
254
        self,
255
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
256
        metadata: dict = None,
257
        files: List[str] = None,
258
        folder: str = None,
259
        runs: Sequence[str] = None,
260
        collect_metadata: bool = False,
261
        **kwds,
262
    ):
263
        """Load tabular data of single events into the dataframe object in the class.
264

265
        Args:
266
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
267
                format. Accepts anything which can be interpreted by pd.DataFrame as
268
                an input. Defaults to None.
269
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
270
            files (List[str], optional): List of file paths to pass to the loader.
271
                Defaults to None.
272
            runs (Sequence[str], optional): List of run identifiers to pass to the
273
                loader. Defaults to None.
274
            folder (str, optional): Folder path to pass to the loader.
275
                Defaults to None.
276

277
        Raises:
278
            ValueError: Raised if no valid input is provided.
279
        """
280
        if metadata is None:
1✔
281
            metadata = {}
1✔
282
        if dataframe is not None:
1✔
283
            self._dataframe = dataframe
1✔
284
        elif runs is not None:
1✔
285
            # If runs are provided, we only use the copy tool if also folder is provided.
286
            # In that case, we copy the whole provided base folder tree, and pass the copied
287
            # version to the loader as base folder to look for the runs.
288
            if folder is not None:
1✔
289
                dataframe, metadata = self.loader.read_dataframe(
1✔
290
                    folders=cast(str, self.cpy(folder)),
291
                    runs=runs,
292
                    metadata=metadata,
293
                    collect_metadata=collect_metadata,
294
                    **kwds,
295
                )
296
            else:
297
                dataframe, metadata = self.loader.read_dataframe(
×
298
                    runs=runs,
299
                    metadata=metadata,
300
                    collect_metadata=collect_metadata,
301
                    **kwds,
302
                )
303

304
        elif folder is not None:
1✔
305
            dataframe, metadata = self.loader.read_dataframe(
1✔
306
                folders=cast(str, self.cpy(folder)),
307
                metadata=metadata,
308
                collect_metadata=collect_metadata,
309
                **kwds,
310
            )
311

312
        elif files is not None:
1✔
313
            dataframe, metadata = self.loader.read_dataframe(
1✔
314
                files=cast(List[str], self.cpy(files)),
315
                metadata=metadata,
316
                collect_metadata=collect_metadata,
317
                **kwds,
318
            )
319

320
        else:
321
            raise ValueError(
1✔
322
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
323
            )
324

325
        self._dataframe = dataframe
1✔
326
        self._files = self.loader.files
1✔
327

328
        for key in metadata:
1✔
329
            self._attributes.add(
1✔
330
                entry=metadata[key],
331
                name=key,
332
                duplicate_policy="merge",
333
            )
334

335
    # Momentum calibration workflow
336
    # 1. Bin raw detector data for distortion correction
337
    def bin_and_load_momentum_calibration(
1✔
338
        self,
339
        df_partitions: int = 100,
340
        axes: List[str] = None,
341
        bins: List[int] = None,
342
        ranges: Sequence[Tuple[float, float]] = None,
343
        plane: int = 0,
344
        width: int = 5,
345
        apply: bool = False,
346
        **kwds,
347
    ):
348
        """1st step of momentum correction work flow. Function to do an initial binning
349
        of the dataframe loaded to the class, slice a plane from it using an
350
        interactive view, and load it into the momentum corrector class.
351

352
        Args:
353
            df_partitions (int, optional): Number of dataframe partitions to use for
354
                the initial binning. Defaults to 100.
355
            axes (List[str], optional): Axes to bin.
356
                Defaults to config["momentum"]["axes"].
357
            bins (List[int], optional): Bin numbers to use for binning.
358
                Defaults to config["momentum"]["bins"].
359
            ranges (List[Tuple], optional): Ranges to use for binning.
360
                Defaults to config["momentum"]["ranges"].
361
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
362
            width (int, optional): Initial value for the width slider. Defaults to 5.
363
            apply (bool, optional): Option to directly apply the values and select the
364
                slice. Defaults to False.
365
            **kwds: Keyword argument passed to the pre_binning function.
366
        """
367
        self._pre_binned = self.pre_binning(
1✔
368
            df_partitions=df_partitions,
369
            axes=axes,
370
            bins=bins,
371
            ranges=ranges,
372
            **kwds,
373
        )
374

375
        self.mc.load_data(data=self._pre_binned)
1✔
376
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
377

378
    # 2. Generate the spline warp correction from momentum features.
379
    # Either autoselect features, or input features from view above.
380
    def define_features(
1✔
381
        self,
382
        features: np.ndarray = None,
383
        rotation_symmetry: int = 6,
384
        auto_detect: bool = False,
385
        include_center: bool = True,
386
        apply: bool = False,
387
        **kwds,
388
    ):
389
        """2. Step of the distortion correction workflow: Define feature points in
390
        momentum space. They can be either manually selected using a GUI tool, be
391
        ptovided as list of feature points, or auto-generated using a
392
        feature-detection algorithm.
393

394
        Args:
395
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
396
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
397
                Defaults to 6.
398
            auto_detect (bool, optional): Whether to auto-detect the features.
399
                Defaults to False.
400
            include_center (bool, optional): Option to include a point at the center
401
                in the feature list. Defaults to True.
402
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
403
                MomentumCorrector.feature_select()
404
        """
405
        if auto_detect:  # automatic feature selection
1✔
406
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
407
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
408
            sigma_radius = kwds.pop(
×
409
                "sigma_radius",
410
                self._config["momentum"]["sigma_radius"],
411
            )
412
            self.mc.feature_extract(
×
413
                sigma=sigma,
414
                fwhm=fwhm,
415
                sigma_radius=sigma_radius,
416
                rotsym=rotation_symmetry,
417
                **kwds,
418
            )
419
            features = self.mc.peaks
×
420

421
        self.mc.feature_select(
1✔
422
            rotsym=rotation_symmetry,
423
            include_center=include_center,
424
            features=features,
425
            apply=apply,
426
            **kwds,
427
        )
428

429
    # 3. Generate the spline warp correction from momentum features.
430
    # If no features have been selected before, use class defaults.
431
    def generate_splinewarp(
1✔
432
        self,
433
        use_center: bool = None,
434
        **kwds,
435
    ):
436
        """3. Step of the distortion correction workflow: Generate the correction
437
        function restoring the symmetry in the image using a splinewarp algortihm.
438

439
        Args:
440
            use_center (bool, optional): Option to use the position of the
441
                center point in the correction. Default is read from config, or set to True.
442
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
443
        """
444
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
445

446
        if self.mc.slice is not None:
1✔
447
            print("Original slice with reference features")
1✔
448
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
449

450
            print("Corrected slice with target features")
1✔
451
            self.mc.view(
1✔
452
                image=self.mc.slice_corrected,
453
                annotated=True,
454
                points={"feats": self.mc.ptargs},
455
                backend="bokeh",
456
                crosshair=True,
457
            )
458

459
            print("Original slice with target features")
1✔
460
            self.mc.view(
1✔
461
                image=self.mc.slice,
462
                points={"feats": self.mc.ptargs},
463
                annotated=True,
464
                backend="bokeh",
465
            )
466

467
    # 3a. Save spline-warp parameters to config file.
468
    def save_splinewarp(
1✔
469
        self,
470
        filename: str = None,
471
        overwrite: bool = False,
472
    ):
473
        """Save the generated spline-warp parameters to the folder config file.
474

475
        Args:
476
            filename (str, optional): Filename of the config dictionary to save to.
477
                Defaults to "sed_config.yaml" in the current folder.
478
            overwrite (bool, optional): Option to overwrite the present dictionary.
479
                Defaults to False.
480
        """
481
        if filename is None:
1✔
482
            filename = "sed_config.yaml"
×
483
        points = []
1✔
484
        try:
1✔
485
            for point in self.mc.pouter_ord:
1✔
486
                points.append([float(i) for i in point])
1✔
487
            if self.mc.include_center:
1✔
488
                points.append([float(i) for i in self.mc.pcent])
1✔
489
        except AttributeError as exc:
×
490
            raise AttributeError(
×
491
                "Momentum correction parameters not found, need to generate parameters first!",
492
            ) from exc
493
        config = {
1✔
494
            "momentum": {
495
                "correction": {
496
                    "rotation_symmetry": self.mc.rotsym,
497
                    "feature_points": points,
498
                    "include_center": self.mc.include_center,
499
                    "use_center": self.mc.use_center,
500
                },
501
            },
502
        }
503
        save_config(config, filename, overwrite)
1✔
504

505
    # 4. Pose corrections. Provide interactive interface for correcting
506
    # scaling, shift and rotation
507
    def pose_adjustment(
1✔
508
        self,
509
        scale: float = 1,
510
        xtrans: float = 0,
511
        ytrans: float = 0,
512
        angle: float = 0,
513
        apply: bool = False,
514
        use_correction: bool = True,
515
        reset: bool = True,
516
    ):
517
        """3. step of the distortion correction workflow: Generate an interactive panel
518
        to adjust affine transformations that are applied to the image. Applies first
519
        a scaling, next an x/y translation, and last a rotation around the center of
520
        the image.
521

522
        Args:
523
            scale (float, optional): Initial value of the scaling slider.
524
                Defaults to 1.
525
            xtrans (float, optional): Initial value of the xtrans slider.
526
                Defaults to 0.
527
            ytrans (float, optional): Initial value of the ytrans slider.
528
                Defaults to 0.
529
            angle (float, optional): Initial value of the angle slider.
530
                Defaults to 0.
531
            apply (bool, optional): Option to directly apply the provided
532
                transformations. Defaults to False.
533
            use_correction (bool, option): Whether to use the spline warp correction
534
                or not. Defaults to True.
535
            reset (bool, optional):
536
                Option to reset the correction before transformation. Defaults to True.
537
        """
538
        # Generate homomorphy as default if no distortion correction has been applied
539
        if self.mc.slice_corrected is None:
1✔
540
            if self.mc.slice is None:
1✔
541
                raise ValueError(
1✔
542
                    "No slice for corrections and transformations loaded!",
543
                )
544
            self.mc.slice_corrected = self.mc.slice
×
545

546
        if not use_correction:
1✔
547
            self.mc.reset_deformation()
1✔
548

549
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
550
            # Generate distortion correction from config values
551
            self.mc.add_features()
×
552
            self.mc.spline_warp_estimate()
×
553

554
        self.mc.pose_adjustment(
1✔
555
            scale=scale,
556
            xtrans=xtrans,
557
            ytrans=ytrans,
558
            angle=angle,
559
            apply=apply,
560
            reset=reset,
561
        )
562

563
    # 5. Apply the momentum correction to the dataframe
564
    def apply_momentum_correction(
1✔
565
        self,
566
        preview: bool = False,
567
    ):
568
        """Applies the distortion correction and pose adjustment (optional)
569
        to the dataframe.
570

571
        Args:
572
            rdeform_field (np.ndarray, optional): Row deformation field.
573
                Defaults to None.
574
            cdeform_field (np.ndarray, optional): Column deformation field.
575
                Defaults to None.
576
            inv_dfield (np.ndarray, optional): Inverse deformation field.
577
                Defaults to None.
578
            preview (bool): Option to preview the first elements of the data frame.
579
        """
580
        if self._dataframe is not None:
1✔
581
            print("Adding corrected X/Y columns to dataframe:")
1✔
582
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
583
                df=self._dataframe,
584
            )
585
            # Add Metadata
586
            self._attributes.add(
1✔
587
                metadata,
588
                "momentum_correction",
589
                duplicate_policy="merge",
590
            )
591
            if preview:
1✔
UNCOV
592
                print(self._dataframe.head(10))
×
593
            else:
594
                print(self._dataframe)
1✔
595

596
    # Momentum calibration work flow
597
    # 1. Calculate momentum calibration
598
    def calibrate_momentum_axes(
1✔
599
        self,
600
        point_a: Union[np.ndarray, List[int]] = None,
601
        point_b: Union[np.ndarray, List[int]] = None,
602
        k_distance: float = None,
603
        k_coord_a: Union[np.ndarray, List[float]] = None,
604
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
605
        equiscale: bool = True,
606
        apply=False,
607
    ):
608
        """1. step of the momentum calibration workflow. Calibrate momentum
609
        axes using either provided pixel coordinates of a high-symmetry point and its
610
        distance to the BZ center, or the k-coordinates of two points in the BZ
611
        (depending on the equiscale option). Opens an interactive panel for selecting
612
        the points.
613

614
        Args:
615
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
616
                point used for momentum calibration.
617
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
618
                second point used for momentum calibration.
619
                Defaults to config["momentum"]["center_pixel"].
620
            k_distance (float, optional): Momentum distance between point a and b.
621
                Needs to be provided if no specific k-koordinates for the two points
622
                are given. Defaults to None.
623
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
624
                of the first point used for calibration. Used if equiscale is False.
625
                Defaults to None.
626
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
627
                of the second point used for calibration. Defaults to [0.0, 0.0].
628
            equiscale (bool, optional): Option to apply different scales to kx and ky.
629
                If True, the distance between points a and b, and the absolute
630
                position of point a are used for defining the scale. If False, the
631
                scale is calculated from the k-positions of both points a and b.
632
                Defaults to True.
633
            apply (bool, optional): Option to directly store the momentum calibration
634
                in the class. Defaults to False.
635
        """
636
        if point_b is None:
1✔
637
            point_b = self._config["momentum"]["center_pixel"]
1✔
638

639
        self.mc.select_k_range(
1✔
640
            point_a=point_a,
641
            point_b=point_b,
642
            k_distance=k_distance,
643
            k_coord_a=k_coord_a,
644
            k_coord_b=k_coord_b,
645
            equiscale=equiscale,
646
            apply=apply,
647
        )
648

649
    # 1a. Save momentum calibration parameters to config file.
650
    def save_momentum_calibration(
1✔
651
        self,
652
        filename: str = None,
653
        overwrite: bool = False,
654
    ):
655
        """Save the generated momentum calibration parameters to the folder config file.
656

657
        Args:
658
            filename (str, optional): Filename of the config dictionary to save to.
659
                Defaults to "sed_config.yaml" in the current folder.
660
            overwrite (bool, optional): Option to overwrite the present dictionary.
661
                Defaults to False.
662
        """
663
        if filename is None:
1✔
UNCOV
664
            filename = "sed_config.yaml"
×
665
        calibration = {}
1✔
666
        try:
1✔
667
            for key in [
1✔
668
                "kx_scale",
669
                "ky_scale",
670
                "x_center",
671
                "y_center",
672
                "rstart",
673
                "cstart",
674
                "rstep",
675
                "cstep",
676
            ]:
677
                calibration[key] = float(self.mc.calibration[key])
1✔
UNCOV
678
        except KeyError as exc:
×
679
            raise KeyError(
×
680
                "Momentum calibration parameters not found, need to generate parameters first!",
681
            ) from exc
682

683
        config = {"momentum": {"calibration": calibration}}
1✔
684
        save_config(config, filename, overwrite)
1✔
685

686
    # 2. Apply correction and calibration to the dataframe
687
    def apply_momentum_calibration(
1✔
688
        self,
689
        calibration: dict = None,
690
        preview: bool = False,
691
    ):
692
        """2. step of the momentum calibration work flow: Apply the momentum
693
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
694
        these are used.
695

696
        Args:
697
            calibration (dict, optional): Optional dictionary with calibration data to
698
                use. Defaults to None.
699
            preview (bool): Option to preview the first elements of the data frame.
700
        """
701
        if self._dataframe is not None:
1✔
702

703
            print("Adding kx/ky columns to dataframe:")
1✔
704
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
705
                df=self._dataframe,
706
                calibration=calibration,
707
            )
708

709
            # Add Metadata
710
            self._attributes.add(
1✔
711
                metadata,
712
                "momentum_calibration",
713
                duplicate_policy="merge",
714
            )
715
            if preview:
1✔
UNCOV
716
                print(self._dataframe.head(10))
×
717
            else:
718
                print(self._dataframe)
1✔
719

720
    # Energy correction workflow
721
    # 1. Adjust the energy correction parameters
722
    def adjust_energy_correction(
1✔
723
        self,
724
        correction_type: str = None,
725
        amplitude: float = None,
726
        center: Tuple[float, float] = None,
727
        apply=False,
728
        **kwds,
729
    ):
730
        """1. step of the energy crrection workflow: Opens an interactive plot to
731
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
732
        they are not present yet.
733

734
        Args:
735
            correction_type (str, optional): Type of correction to apply to the TOF
736
                axis. Valid values are:
737

738
                - 'spherical'
739
                - 'Lorentzian'
740
                - 'Gaussian'
741
                - 'Lorentzian_asymmetric'
742

743
                Defaults to config["energy"]["correction_type"].
744
            amplitude (float, optional): Amplitude of the correction.
745
                Defaults to config["energy"]["correction"]["amplitude"].
746
            center (Tuple[float, float], optional): Center X/Y coordinates for the
747
                correction. Defaults to config["energy"]["correction"]["center"].
748
            apply (bool, optional): Option to directly apply the provided or default
749
                correction parameters. Defaults to False.
750
        """
751
        if self._pre_binned is None:
1✔
752
            print(
1✔
753
                "Pre-binned data not present, binning using defaults from config...",
754
            )
755
            self._pre_binned = self.pre_binning()
1✔
756

757
        self.ec.adjust_energy_correction(
1✔
758
            self._pre_binned,
759
            correction_type=correction_type,
760
            amplitude=amplitude,
761
            center=center,
762
            apply=apply,
763
            **kwds,
764
        )
765

766
    # 1a. Save energy correction parameters to config file.
767
    def save_energy_correction(
1✔
768
        self,
769
        filename: str = None,
770
        overwrite: bool = False,
771
    ):
772
        """Save the generated energy correction parameters to the folder config file.
773

774
        Args:
775
            filename (str, optional): Filename of the config dictionary to save to.
776
                Defaults to "sed_config.yaml" in the current folder.
777
            overwrite (bool, optional): Option to overwrite the present dictionary.
778
                Defaults to False.
779
        """
780
        if filename is None:
1✔
781
            filename = "sed_config.yaml"
1✔
782
        correction = {}
1✔
783
        try:
1✔
784
            for key, val in self.ec.correction.items():
1✔
785
                if key == "correction_type":
1✔
786
                    correction[key] = val
1✔
787
                elif key == "center":
1✔
788
                    correction[key] = [float(i) for i in val]
1✔
789
                else:
790
                    correction[key] = float(val)
1✔
UNCOV
791
        except AttributeError as exc:
×
792
            raise AttributeError(
×
793
                "Energy correction parameters not found, need to generate parameters first!",
794
            ) from exc
795

796
        config = {"energy": {"correction": correction}}
1✔
797
        save_config(config, filename, overwrite)
1✔
798

799
    # 2. Apply energy correction to dataframe
800
    def apply_energy_correction(
1✔
801
        self,
802
        correction: dict = None,
803
        preview: bool = False,
804
        **kwds,
805
    ):
806
        """2. step of the energy correction workflow: Apply the enery correction
807
        parameters stored in the class to the dataframe.
808

809
        Args:
810
            correction (dict, optional): Dictionary containing the correction
811
                parameters. Defaults to config["energy"]["calibration"].
812
            preview (bool): Option to preview the first elements of the data frame.
813
            **kwds:
814
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
815
            preview (bool): Option to preview the first elements of the data frame.
816
            **kwds:
817
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
818
        """
819
        if self._dataframe is not None:
1✔
820
            print("Applying energy correction to dataframe...")
1✔
821
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
822
                df=self._dataframe,
823
                correction=correction,
824
                **kwds,
825
            )
826

827
            # Add Metadata
828
            self._attributes.add(
1✔
829
                metadata,
830
                "energy_correction",
831
            )
832
            if preview:
1✔
UNCOV
833
                print(self._dataframe.head(10))
×
834
            else:
835
                print(self._dataframe)
1✔
836

837
    # Energy calibrator workflow
838
    # 1. Load and normalize data
839
    def load_bias_series(
1✔
840
        self,
841
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
842
        data_files: List[str] = None,
843
        axes: List[str] = None,
844
        bins: List = None,
845
        ranges: Sequence[Tuple[float, float]] = None,
846
        biases: np.ndarray = None,
847
        bias_key: str = None,
848
        normalize: bool = None,
849
        span: int = None,
850
        order: int = None,
851
    ):
852
        """1. step of the energy calibration workflow: Load and bin data from
853
        single-event files, or load binned bias/TOF traces.
854

855
        Args:
856
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
857
                Binned data If provided as DataArray, Needs to contain dimensions
858
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
859
                provided as tuple, needs to contain elements tof, biases, traces.
860
            data_files (List[str], optional): list of file paths to bin
861
            axes (List[str], optional): bin axes.
862
                Defaults to config["dataframe"]["tof_column"].
863
            bins (List, optional): number of bins.
864
                Defaults to config["energy"]["bins"].
865
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
866
                Defaults to config["energy"]["ranges"].
867
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
868
                voltages are extracted from the data files.
869
            bias_key (str, optional): hdf5 path where bias values are stored.
870
                Defaults to config["energy"]["bias_key"].
871
            normalize (bool, optional): Option to normalize traces.
872
                Defaults to config["energy"]["normalize"].
873
            span (int, optional): span smoothing parameters of the LOESS method
874
                (see ``scipy.signal.savgol_filter()``).
875
                Defaults to config["energy"]["normalize_span"].
876
            order (int, optional): order smoothing parameters of the LOESS method
877
                (see ``scipy.signal.savgol_filter()``).
878
                Defaults to config["energy"]["normalize_order"].
879
        """
880
        if binned_data is not None:
1✔
881
            if isinstance(binned_data, xr.DataArray):
1✔
882
                if (
1✔
883
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
884
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
885
                ):
886
                    raise ValueError(
1✔
887
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
888
                        f"'{self._config['dataframe']['tof_column']}' and "
889
                        f"'{self._config['dataframe']['bias_column']}'!.",
890
                    )
891
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
892
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
893
                traces = binned_data.values[:, :]
1✔
894
            else:
895
                try:
1✔
896
                    (tof, biases, traces) = binned_data
1✔
897
                except ValueError as exc:
1✔
898
                    raise ValueError(
1✔
899
                        "If binned_data is provided as tuple, it needs to contain "
900
                        "(tof, biases, traces)!",
901
                    ) from exc
902
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
903

904
        elif data_files is not None:
1✔
905

906
            self.ec.bin_data(
1✔
907
                data_files=cast(List[str], self.cpy(data_files)),
908
                axes=axes,
909
                bins=bins,
910
                ranges=ranges,
911
                biases=biases,
912
                bias_key=bias_key,
913
            )
914

915
        else:
916
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
917

918
        if (normalize is not None and normalize is True) or (
1✔
919
            normalize is None and self._config["energy"]["normalize"]
920
        ):
921
            if span is None:
1✔
922
                span = self._config["energy"]["normalize_span"]
1✔
923
            if order is None:
1✔
924
                order = self._config["energy"]["normalize_order"]
1✔
925
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
926
        self.ec.view(
1✔
927
            traces=self.ec.traces_normed,
928
            xaxis=self.ec.tof,
929
            backend="bokeh",
930
        )
931

932
    # 2. extract ranges and get peak positions
933
    def find_bias_peaks(
1✔
934
        self,
935
        ranges: Union[List[Tuple], Tuple],
936
        ref_id: int = 0,
937
        infer_others: bool = True,
938
        mode: str = "replace",
939
        radius: int = None,
940
        peak_window: int = None,
941
        apply: bool = False,
942
    ):
943
        """2. step of the energy calibration workflow: Find a peak within a given range
944
        for the indicated reference trace, and tries to find the same peak for all
945
        other traces. Uses fast_dtw to align curves, which might not be too good if the
946
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
947
        middle of the set, and don't choose the range too narrow around the peak.
948
        Alternatively, a list of ranges for all traces can be provided.
949

950
        Args:
951
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
952
                Alternatively, a list of ranges for all traces can be given.
953
            refid (int, optional): The id of the trace the range refers to.
954
                Defaults to 0.
955
            infer_others (bool, optional): Whether to determine the range for the other
956
                traces. Defaults to True.
957
            mode (str, optional): Whether to "add" or "replace" existing ranges.
958
                Defaults to "replace".
959
            radius (int, optional): Radius parameter for fast_dtw.
960
                Defaults to config["energy"]["fastdtw_radius"].
961
            peak_window (int, optional): Peak_window parameter for the peak detection
962
                algorthm. amount of points that have to have to behave monotoneously
963
                around a peak. Defaults to config["energy"]["peak_window"].
964
            apply (bool, optional): Option to directly apply the provided parameters.
965
                Defaults to False.
966
        """
967
        if radius is None:
1✔
968
            radius = self._config["energy"]["fastdtw_radius"]
1✔
969
        if peak_window is None:
1✔
970
            peak_window = self._config["energy"]["peak_window"]
1✔
971
        if not infer_others:
1✔
972
            self.ec.add_ranges(
1✔
973
                ranges=ranges,
974
                ref_id=ref_id,
975
                infer_others=infer_others,
976
                mode=mode,
977
                radius=radius,
978
            )
979
            print(self.ec.featranges)
1✔
980
            try:
1✔
981
                self.ec.feature_extract(peak_window=peak_window)
1✔
982
                self.ec.view(
1✔
983
                    traces=self.ec.traces_normed,
984
                    segs=self.ec.featranges,
985
                    xaxis=self.ec.tof,
986
                    peaks=self.ec.peaks,
987
                    backend="bokeh",
988
                )
UNCOV
989
            except IndexError:
×
990
                print("Could not determine all peaks!")
×
991
                raise
×
992
        else:
993
            # New adjustment tool
994
            assert isinstance(ranges, tuple)
1✔
995
            self.ec.adjust_ranges(
1✔
996
                ranges=ranges,
997
                ref_id=ref_id,
998
                traces=self.ec.traces_normed,
999
                infer_others=infer_others,
1000
                radius=radius,
1001
                peak_window=peak_window,
1002
                apply=apply,
1003
            )
1004

1005
    # 3. Fit the energy calibration relation
1006
    def calibrate_energy_axis(
1✔
1007
        self,
1008
        ref_id: int,
1009
        ref_energy: float,
1010
        method: str = None,
1011
        energy_scale: str = None,
1012
        **kwds,
1013
    ):
1014
        """3. Step of the energy calibration workflow: Calculate the calibration
1015
        function for the energy axis, and apply it to the dataframe. Two
1016
        approximations are implemented, a (normally 3rd order) polynomial
1017
        approximation, and a d^2/(t-t0)^2 relation.
1018

1019
        Args:
1020
            ref_id (int): id of the trace at the bias where the reference energy is
1021
                given.
1022
            ref_energy (float): Absolute energy of the detected feature at the bias
1023
                of ref_id
1024
            method (str, optional): Method for determining the energy calibration.
1025

1026
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1027
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1028

1029
                Defaults to config["energy"]["calibration_method"]
1030
            energy_scale (str, optional): Direction of increasing energy scale.
1031

1032
                - **'kinetic'**: increasing energy with decreasing TOF.
1033
                - **'binding'**: increasing energy with increasing TOF.
1034

1035
                Defaults to config["energy"]["energy_scale"]
1036
        """
1037
        if method is None:
1✔
1038
            method = self._config["energy"]["calibration_method"]
1✔
1039

1040
        if energy_scale is None:
1✔
1041
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1042

1043
        self.ec.calibrate(
1✔
1044
            ref_id=ref_id,
1045
            ref_energy=ref_energy,
1046
            method=method,
1047
            energy_scale=energy_scale,
1048
            **kwds,
1049
        )
1050
        print("Quality of Calibration:")
1✔
1051
        self.ec.view(
1✔
1052
            traces=self.ec.traces_normed,
1053
            xaxis=self.ec.calibration["axis"],
1054
            align=True,
1055
            energy_scale=energy_scale,
1056
            backend="bokeh",
1057
        )
1058
        print("E/TOF relationship:")
1✔
1059
        self.ec.view(
1✔
1060
            traces=self.ec.calibration["axis"][None, :],
1061
            xaxis=self.ec.tof,
1062
            backend="matplotlib",
1063
            show_legend=False,
1064
        )
1065
        if energy_scale == "kinetic":
1✔
1066
            plt.scatter(
1✔
1067
                self.ec.peaks[:, 0],
1068
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1069
                s=50,
1070
                c="k",
1071
            )
1072
        elif energy_scale == "binding":
1✔
1073
            plt.scatter(
1✔
1074
                self.ec.peaks[:, 0],
1075
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1076
                s=50,
1077
                c="k",
1078
            )
1079
        else:
UNCOV
1080
            raise ValueError(
×
1081
                'energy_scale needs to be either "binding" or "kinetic"',
1082
                f", got {energy_scale}.",
1083
            )
1084
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1085
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1086
        plt.show()
1✔
1087

1088
    # 3a. Save energy calibration parameters to config file.
1089
    def save_energy_calibration(
1✔
1090
        self,
1091
        filename: str = None,
1092
        overwrite: bool = False,
1093
    ):
1094
        """Save the generated energy calibration parameters to the folder config file.
1095

1096
        Args:
1097
            filename (str, optional): Filename of the config dictionary to save to.
1098
                Defaults to "sed_config.yaml" in the current folder.
1099
            overwrite (bool, optional): Option to overwrite the present dictionary.
1100
                Defaults to False.
1101
        """
1102
        if filename is None:
1✔
UNCOV
1103
            filename = "sed_config.yaml"
×
1104
        calibration = {}
1✔
1105
        try:
1✔
1106
            for (key, value) in self.ec.calibration.items():
1✔
1107
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1108
                    continue
1✔
1109
                if key == "energy_scale":
1✔
1110
                    calibration[key] = value
1✔
1111
                elif key == "coeffs":
1✔
1112
                    calibration[key] = [float(i) for i in value]
1✔
1113
                else:
1114
                    calibration[key] = float(value)
1✔
UNCOV
1115
        except AttributeError as exc:
×
1116
            raise AttributeError(
×
1117
                "Energy calibration parameters not found, need to generate parameters first!",
1118
            ) from exc
1119

1120
        config = {
1✔
1121
            "energy": {
1122
                "calibration": calibration,
1123
            },
1124
        }
1125
        if isinstance(self.ec.offset, dict):
1✔
1126
            config["energy"]["offset"] = self.ec.offset
1✔
1127
        save_config(config, filename, overwrite)
1✔
1128

1129
    # 4. Apply energy calibration to the dataframe
1130
    def append_energy_axis(
1✔
1131
        self,
1132
        calibration: dict = None,
1133
        preview: bool = False,
1134
        **kwds,
1135
    ):
1136
        """4. step of the energy calibration workflow: Apply the calibration function
1137
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1138
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1139
        can be provided.
1140

1141
        Args:
1142
            calibration (dict, optional): Calibration dict containing calibration
1143
                parameters. Overrides calibration from class or config.
1144
                Defaults to None.
1145
            preview (bool): Option to preview the first elements of the data frame.
1146
            **kwds:
1147
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1148
        """
1149
        if self._dataframe is not None:
1✔
1150
            print("Adding energy column to dataframe:")
1✔
1151
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1152
                df=self._dataframe,
1153
                calibration=calibration,
1154
                **kwds,
1155
            )
1156

1157
            # Add Metadata
1158
            self._attributes.add(
1✔
1159
                metadata,
1160
                "energy_calibration",
1161
                duplicate_policy="merge",
1162
            )
1163
            if preview:
1✔
1164
                print(self._dataframe.head(10))
1✔
1165
            else:
1166
                print(self._dataframe)
1✔
1167

1168
    def add_energy_offset(
1✔
1169
        self,
1170
        constant: float = None,
1171
        columns: Union[str, Sequence[str]] = None,
1172
        signs: Union[int, Sequence[int]] = None,
1173
        reductions: Union[str, Sequence[str]] = None,
1174
        preserve_mean: Union[bool, Sequence[bool]] = None,
1175
    ) -> None:
1176
        """Shift the energy axis of the dataframe by a given amount.
1177

1178
        Args:
1179
            constant (float, optional): The constant to shift the energy axis by.
1180
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1181
            signs (Union[int, Sequence[int]]): Sign of the shift to apply. (+1 or -1) A positive
1182
                sign shifts the energy axis to higher kinetic energies. Defaults to +1.
1183
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1184
                shift. Defaults to False.
1185
            reductions (str): The reduction to apply to the column. Should be an available method
1186
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1187
                to the column to generate a single value for the whole dataset. If None, the shift
1188
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1189

1190
        Raises:
1191
            ValueError: If the energy column is not in the dataframe.
1192
        """
UNCOV
1193
        energy_column = self._config["dataframe"]["energy_column"]
×
1194
        if energy_column not in self._dataframe.columns:
×
1195
            raise ValueError(
×
1196
                f"Energy column {energy_column} not found in dataframe! "
1197
                "Run `append energy axis` first.",
1198
            )
UNCOV
1199
        if self.dataframe is not None:
×
1200
            df, metadata = self.ec.add_offsets(
×
1201
                df=self._dataframe,
1202
                constant=constant,
1203
                columns=columns,
1204
                energy_column=energy_column,
1205
                signs=signs,
1206
                reductions=reductions,
1207
                preserve_mean=preserve_mean,
1208
            )
UNCOV
1209
            self._attributes.add(
×
1210
                metadata,
1211
                "add_energy_offset",
1212
                # TODO: allow only appending when no offset along this column(s) was applied
1213
                # TODO: clear memory of modifications if the energy axis is recalculated
1214
                duplicate_policy="append",
1215
            )
UNCOV
1216
            self._dataframe = df
×
1217
        else:
UNCOV
1218
            raise ValueError("No dataframe loaded!")
×
1219

1220
    def append_tof_ns_axis(
1✔
1221
        self,
1222
        **kwargs,
1223
    ):
1224
        """Convert time-of-flight channel steps to nanoseconds.
1225

1226
        Args:
1227
            tof_ns_column (str, optional): Name of the generated column containing the
1228
                time-of-flight in nanosecond.
1229
                Defaults to config["dataframe"]["tof_ns_column"].
1230
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1231

1232
        """
1233
        if self._dataframe is not None:
×
1234
            print("Adding time-of-flight column in nanoseconds to dataframe:")
×
1235
            # TODO assert order of execution through metadata
1236

1237
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
×
1238
                df=self._dataframe,
1239
                **kwargs,
1240
            )
1241
            self._attributes.add(
×
1242
                metadata,
1243
                "tof_ns_conversion",
1244
                duplicate_policy="append",
1245
            )
1246

1247
    def align_dld_sectors(self, sector_delays: np.ndarray = None, **kwargs):
1✔
1248
        """Align the 8s sectors of the HEXTOF endstation.
1249

1250
        Args:
1251
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1252
                config["dataframe"]["sector_delays"].
1253
        """
1254
        if self._dataframe is not None:
×
1255
            print("Aligning 8s sectors of dataframe")
×
1256
            # TODO assert order of execution through metadata
1257
            self._dataframe, metadata = self.ec.align_dld_sectors(
×
1258
                df=self._dataframe,
1259
                sector_delays=sector_delays,
1260
                **kwargs,
1261
            )
1262
            self._attributes.add(
×
1263
                metadata,
1264
                "dld_sector_alignment",
1265
                duplicate_policy="raise",
1266
            )
1267

1268
    # Delay calibration function
1269
    def calibrate_delay_axis(
1✔
1270
        self,
1271
        time0: float = None,
1272
        flip_time_axis: bool = False,
1273
        delay_range: Tuple[float, float] = None,
1274
        datafile: str = None,
1275
        preview: bool = False,
1276
        **kwds,
1277
    ):
1278
        """Append delay column to dataframe. Either provide delay ranges, or read
1279
        them from a file.
1280

1281
        Args:
1282
            time0 (float, optional): The pump-probe time overlap in picoseconds.
1283
            flip_time_axis (bool, optional): Whether to flip the time axis direction.
1284
            delay_range (Tuple[float, float], optional): The scanned delay range in
1285
                picoseconds. Defaults to None.
1286
            datafile (str, optional): The file from which to read the delay ranges.
1287
                Defaults to None.
1288
            preview (bool): Option to preview the first elements of the data frame.
1289
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1290
        """
1291
        if self._dataframe is not None:
1✔
1292
            print("Adding delay column to dataframe:")
1✔
1293

1294
            if delay_range is not None:
1✔
1295
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1296
                    self._dataframe,
1297
                    delay_range=delay_range,
1298
                    **kwds,
1299
                )
1300
            else:
1301
                if datafile is None:
1✔
1302
                    try:
1✔
1303
                        datafile = self._files[0]
1✔
UNCOV
1304
                    except IndexError:
×
UNCOV
1305
                        print(
×
1306
                            "No datafile available, specify either",
1307
                            " 'datafile' or 'delay_range'",
1308
                        )
UNCOV
1309
                        raise
×
1310

1311
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1312
                    self._dataframe,
1313
                    time0=time0,
1314
                    flip_time_axis=flip_time_axis,
1315
                    datafile=datafile,
1316
                    **kwds,
1317
                )
1318

1319
            # Add Metadata
1320
            self._attributes.add(
1✔
1321
                metadata,
1322
                "delay_calibration",
1323
                duplicate_policy="merge",
1324
            )
1325
            if preview:
1✔
1326
                print(self._dataframe.head(10))
1✔
1327
            else:
1328
                print(self._dataframe)
1✔
1329

1330
    def correct_delay_fluctuations(
1✔
1331
        self,
1332
        delay_column: str = None,
1333
        columns: Union[str, Sequence[str]] = None,
1334
        signs: Union[int, Sequence[int]] = None,
1335
        reductions: Union[str, Sequence[str]] = None,
1336
        preserve_mean: Union[bool, Sequence[bool]] = None,
1337
        **kwargs,
1338
    ) -> None:
1339
        """Apply a correction to the delay axis of the dataframe.
1340

1341
        Args:
1342
            delay_column (str): Name of the column containing the delay values.
1343
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the correction to.
1344
            signs (Union[int, Sequence[int]]): Sign of the correction to apply. (+1 or -1)
1345
                A positive sign shifts the delay axis to higher delays. Defaults to +1.
1346
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1347
                correction. Defaults to False.
1348
            reductions (str): The reduction to apply to the column. Should be an available method
1349
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1350
                to the column to generate a single value for the whole dataset. If None, the shift
1351
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1352

1353
        Raises:
1354
            ValueError: If the delay column is not in the dataframe.
1355
        """
NEW
UNCOV
1356
        if delay_column is None:
×
NEW
UNCOV
1357
            delay_column = self._config["dataframe"]["delay_column"]
×
NEW
UNCOV
1358
        if delay_column not in self._dataframe.columns:
×
NEW
UNCOV
1359
            raise ValueError(
×
1360
                f"Delay column {delay_column} not found in dataframe! "
1361
                "Run `append delay axis` first.",
1362
            )
NEW
UNCOV
1363
        if self.dataframe is not None:
×
NEW
UNCOV
1364
            self._dataframe, metadata = self.dc.correct_delay_fluctuations(
×
1365
                df=self._dataframe,
1366
                delay_column=delay_column,
1367
                columns=columns,
1368
                signs=signs,
1369
                reductions=reductions,
1370
                preserve_mean=preserve_mean,
1371
                **kwargs,
1372
            )
NEW
UNCOV
1373
            self._attributes.add(
×
1374
                metadata,
1375
                "correct_delay_fluctuations",
1376
                duplicate_policy="raise",
1377
            )
1378

1379
    def save_delay_calibration(
1✔
1380
        self,
1381
        filename: str = None,
1382
        overwrite: bool = False,
1383
    ) -> None:
1384
        """Save the generated delay calibration parameters to the folder config file.
1385

1386
        Args:
1387
            filename (str, optional): Filename of the config dictionary to save to.
1388
                Defaults to "sed_config.yaml" in the current folder.
1389
            overwrite (bool, optional): Option to overwrite the present dictionary.
1390
                Defaults to False.
1391
        """
NEW
UNCOV
1392
        if filename is None:
×
NEW
UNCOV
1393
            filename = "sed_config.yaml"
×
1394
        # calibration: Dict[str, Any] = {}
1395
        # try:
1396
        #     for key, val in self.dc.calibration.items():
1397
        #         if key == "delay_range":
1398
        #             calibration[key] = [float(i) for i in val]
1399
        #         else:
1400
        #             calibration[key] = float(val)
1401
        # except AttributeError as exc:
1402
        #     raise AttributeError(
1403
        #         "Delay calibration parameters not found, need to generate parameters first!",
1404
        #     ) from exc
1405

NEW
UNCOV
1406
        config = {
×
1407
            "delay": {
1408
                "calibration": self.dc.calibration,
1409
                "fluctuations": self.dc.fluctuations,
1410
            },
1411
        }
NEW
UNCOV
1412
        save_config(config, filename, overwrite)
×
1413

1414
    def add_jitter(
1✔
1415
        self,
1416
        cols: List[str] = None,
1417
        amps: Union[float, Sequence[float]] = None,
1418
        **kwds,
1419
    ):
1420
        """Add jitter to the selected dataframe columns.
1421

1422
        Args:
1423
            cols (List[str], optional): The colums onto which to apply jitter.
1424
                Defaults to config["dataframe"]["jitter_cols"].
1425
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1426
                jittering noise. If one number is given, the same is used for all axes.
1427
                For uniform noise (default) it will cover the interval [-amp, +amp].
1428
                Defaults to config["dataframe"]["jitter_amps"].
1429
            **kwds: additional keyword arguments passed to apply_jitter
1430
        """
1431
        if cols is None:
1✔
1432
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1433
        for loc, col in enumerate(cols):
1✔
1434
            if col.startswith("@"):
1✔
1435
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1436

1437
        if amps is None:
1✔
1438
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1439

1440
        self._dataframe = self._dataframe.map_partitions(
1✔
1441
            apply_jitter,
1442
            cols=cols,
1443
            cols_jittered=cols,
1444
            amps=amps,
1445
            **kwds,
1446
        )
1447
        metadata = []
1✔
1448
        for col in cols:
1✔
1449
            metadata.append(col)
1✔
1450
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1451

1452
    def pre_binning(
1✔
1453
        self,
1454
        df_partitions: int = 100,
1455
        axes: List[str] = None,
1456
        bins: List[int] = None,
1457
        ranges: Sequence[Tuple[float, float]] = None,
1458
        **kwds,
1459
    ) -> xr.DataArray:
1460
        """Function to do an initial binning of the dataframe loaded to the class.
1461

1462
        Args:
1463
            df_partitions (int, optional): Number of dataframe partitions to use for
1464
                the initial binning. Defaults to 100.
1465
            axes (List[str], optional): Axes to bin.
1466
                Defaults to config["momentum"]["axes"].
1467
            bins (List[int], optional): Bin numbers to use for binning.
1468
                Defaults to config["momentum"]["bins"].
1469
            ranges (List[Tuple], optional): Ranges to use for binning.
1470
                Defaults to config["momentum"]["ranges"].
1471
            **kwds: Keyword argument passed to ``compute``.
1472

1473
        Returns:
1474
            xr.DataArray: pre-binned data-array.
1475
        """
1476
        if axes is None:
1✔
1477
            axes = self._config["momentum"]["axes"]
1✔
1478
        for loc, axis in enumerate(axes):
1✔
1479
            if axis.startswith("@"):
1✔
1480
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1481

1482
        if bins is None:
1✔
1483
            bins = self._config["momentum"]["bins"]
1✔
1484
        if ranges is None:
1✔
1485
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1486
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1487
                self._config["dataframe"]["tof_binning"] - 1
1488
            )
1489
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1490

1491
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1492

1493
        return self.compute(
1✔
1494
            bins=bins,
1495
            axes=axes,
1496
            ranges=ranges,
1497
            df_partitions=df_partitions,
1498
            **kwds,
1499
        )
1500

1501
    def compute(
1✔
1502
        self,
1503
        bins: Union[
1504
            int,
1505
            dict,
1506
            tuple,
1507
            List[int],
1508
            List[np.ndarray],
1509
            List[tuple],
1510
        ] = 100,
1511
        axes: Union[str, Sequence[str]] = None,
1512
        ranges: Sequence[Tuple[float, float]] = None,
1513
        **kwds,
1514
    ) -> xr.DataArray:
1515
        """Compute the histogram along the given dimensions.
1516

1517
        Args:
1518
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1519
                Definition of the bins. Can be any of the following cases:
1520

1521
                - an integer describing the number of bins in on all dimensions
1522
                - a tuple of 3 numbers describing start, end and step of the binning
1523
                  range
1524
                - a np.arrays defining the binning edges
1525
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1526
                - a dictionary made of the axes as keys and any of the above as values.
1527

1528
                This takes priority over the axes and range arguments. Defaults to 100.
1529
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1530
                on which to calculate the histogram. The order will be the order of the
1531
                dimensions in the resulting array. Defaults to None.
1532
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1533
                the start and end point of the binning range. Defaults to None.
1534
            **kwds: Keyword arguments:
1535

1536
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1537
                  ``bin_dataframe`` for details. Defaults to
1538
                  config["binning"]["hist_mode"].
1539
                - **mode**: Defines how the results from each partition are combined.
1540
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1541
                  Defaults to config["binning"]["mode"].
1542
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1543
                  config["binning"]["pbar"].
1544
                - **n_cores**: Number of CPU cores to use for parallelization.
1545
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1546
                - **threads_per_worker**: Limit the number of threads that
1547
                  multiprocessing can spawn per binning thread. Defaults to
1548
                  config["binning"]["threads_per_worker"].
1549
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1550
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1551
                  config["binning"]["threadpool_API"].
1552
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1553
                  partitions.
1554

1555
                Additional kwds are passed to ``bin_dataframe``.
1556

1557
        Raises:
1558
            AssertError: Rises when no dataframe has been loaded.
1559

1560
        Returns:
1561
            xr.DataArray: The result of the n-dimensional binning represented in an
1562
            xarray object, combining the data with the axes.
1563
        """
1564
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1565

1566
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1567
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1568
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1569
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1570
        threads_per_worker = kwds.pop(
1✔
1571
            "threads_per_worker",
1572
            self._config["binning"]["threads_per_worker"],
1573
        )
1574
        threadpool_api = kwds.pop(
1✔
1575
            "threadpool_API",
1576
            self._config["binning"]["threadpool_API"],
1577
        )
1578
        df_partitions = kwds.pop("df_partitions", None)
1✔
1579
        if df_partitions is not None:
1✔
1580
            dataframe = self._dataframe.partitions[
1✔
1581
                0 : min(df_partitions, self._dataframe.npartitions)
1582
            ]
1583
        else:
1584
            dataframe = self._dataframe
1✔
1585

1586
        self._binned = bin_dataframe(
1✔
1587
            df=dataframe,
1588
            bins=bins,
1589
            axes=axes,
1590
            ranges=ranges,
1591
            hist_mode=hist_mode,
1592
            mode=mode,
1593
            pbar=pbar,
1594
            n_cores=num_cores,
1595
            threads_per_worker=threads_per_worker,
1596
            threadpool_api=threadpool_api,
1597
            **kwds,
1598
        )
1599

1600
        for dim in self._binned.dims:
1✔
1601
            try:
1✔
1602
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1603
            except KeyError:
1✔
1604
                pass
1✔
1605

1606
        self._binned.attrs["units"] = "counts"
1✔
1607
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1608
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1609

1610
        return self._binned
1✔
1611

1612
    def view_event_histogram(
1✔
1613
        self,
1614
        dfpid: int,
1615
        ncol: int = 2,
1616
        bins: Sequence[int] = None,
1617
        axes: Sequence[str] = None,
1618
        ranges: Sequence[Tuple[float, float]] = None,
1619
        backend: str = "bokeh",
1620
        legend: bool = True,
1621
        histkwds: dict = None,
1622
        legkwds: dict = None,
1623
        **kwds,
1624
    ):
1625
        """Plot individual histograms of specified dimensions (axes) from a substituent
1626
        dataframe partition.
1627

1628
        Args:
1629
            dfpid (int): Number of the data frame partition to look at.
1630
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1631
            bins (Sequence[int], optional): Number of bins to use for the speicified
1632
                axes. Defaults to config["histogram"]["bins"].
1633
            axes (Sequence[str], optional): Names of the axes to display.
1634
                Defaults to config["histogram"]["axes"].
1635
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1636
                specified axes. Defaults toconfig["histogram"]["ranges"].
1637
            backend (str, optional): Backend of the plotting library
1638
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1639
            legend (bool, optional): Option to include a legend in the histogram plots.
1640
                Defaults to True.
1641
            histkwds (dict, optional): Keyword arguments for histograms
1642
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1643
            legkwds (dict, optional): Keyword arguments for legend
1644
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1645
            **kwds: Extra keyword arguments passed to
1646
                ``sed.diagnostics.grid_histogram()``.
1647

1648
        Raises:
1649
            TypeError: Raises when the input values are not of the correct type.
1650
        """
1651
        if bins is None:
1✔
1652
            bins = self._config["histogram"]["bins"]
1✔
1653
        if axes is None:
1✔
1654
            axes = self._config["histogram"]["axes"]
1✔
1655
        axes = list(axes)
1✔
1656
        for loc, axis in enumerate(axes):
1✔
1657
            if axis.startswith("@"):
1✔
1658
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1659
        if ranges is None:
1✔
1660
            ranges = list(self._config["histogram"]["ranges"])
1✔
1661
            for loc, axis in enumerate(axes):
1✔
1662
                if axis == self._config["dataframe"]["tof_column"]:
1✔
1663
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
1664
                        self._config["dataframe"]["tof_binning"] - 1
1665
                    )
1666
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
UNCOV
1667
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1668
                        self._config["dataframe"]["adc_binning"] - 1
1669
                    )
1670

1671
        input_types = map(type, [axes, bins, ranges])
1✔
1672
        allowed_types = [list, tuple]
1✔
1673

1674
        df = self._dataframe
1✔
1675

1676
        if not set(input_types).issubset(allowed_types):
1✔
UNCOV
1677
            raise TypeError(
×
1678
                "Inputs of axes, bins, ranges need to be list or tuple!",
1679
            )
1680

1681
        # Read out the values for the specified groups
1682
        group_dict_dd = {}
1✔
1683
        dfpart = df.get_partition(dfpid)
1✔
1684
        cols = dfpart.columns
1✔
1685
        for ax in axes:
1✔
1686
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
1687
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
1688

1689
        # Plot multiple histograms in a grid
1690
        grid_histogram(
1✔
1691
            group_dict,
1692
            ncol=ncol,
1693
            rvs=axes,
1694
            rvbins=bins,
1695
            rvranges=ranges,
1696
            backend=backend,
1697
            legend=legend,
1698
            histkwds=histkwds,
1699
            legkwds=legkwds,
1700
            **kwds,
1701
        )
1702

1703
    def save(
1✔
1704
        self,
1705
        faddr: str,
1706
        **kwds,
1707
    ):
1708
        """Saves the binned data to the provided path and filename.
1709

1710
        Args:
1711
            faddr (str): Path and name of the file to write. Its extension determines
1712
                the file type to write. Valid file types are:
1713

1714
                - "*.tiff", "*.tif": Saves a TIFF stack.
1715
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1716
                - "*.nxs", "*.nexus": Saves a NeXus file.
1717

1718
            **kwds: Keyword argumens, which are passed to the writer functions:
1719
                For TIFF writing:
1720

1721
                - **alias_dict**: Dictionary of dimension aliases to use.
1722

1723
                For HDF5 writing:
1724

1725
                - **mode**: hdf5 read/write mode. Defaults to "w".
1726

1727
                For NeXus:
1728

1729
                - **reader**: Name of the nexustools reader to use.
1730
                  Defaults to config["nexus"]["reader"]
1731
                - **definiton**: NeXus application definition to use for saving.
1732
                  Must be supported by the used ``reader``. Defaults to
1733
                  config["nexus"]["definition"]
1734
                - **input_files**: A list of input files to pass to the reader.
1735
                  Defaults to config["nexus"]["input_files"]
1736
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
1737
                  to add to the list of files to pass to the reader.
1738
        """
1739
        if self._binned is None:
1✔
1740
            raise NameError("Need to bin data first!")
1✔
1741

1742
        extension = pathlib.Path(faddr).suffix
1✔
1743

1744
        if extension in (".tif", ".tiff"):
1✔
1745
            to_tiff(
1✔
1746
                data=self._binned,
1747
                faddr=faddr,
1748
                **kwds,
1749
            )
1750
        elif extension in (".h5", ".hdf5"):
1✔
1751
            to_h5(
1✔
1752
                data=self._binned,
1753
                faddr=faddr,
1754
                **kwds,
1755
            )
1756
        elif extension in (".nxs", ".nexus"):
1✔
1757
            try:
1✔
1758
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
1759
                definition = kwds.pop(
1✔
1760
                    "definition",
1761
                    self._config["nexus"]["definition"],
1762
                )
1763
                input_files = kwds.pop(
1✔
1764
                    "input_files",
1765
                    self._config["nexus"]["input_files"],
1766
                )
UNCOV
1767
            except KeyError as exc:
×
UNCOV
1768
                raise ValueError(
×
1769
                    "The nexus reader, definition and input files need to be provide!",
1770
                ) from exc
1771

1772
            if isinstance(input_files, str):
1✔
1773
                input_files = [input_files]
1✔
1774

1775
            if "eln_data" in kwds:
1✔
UNCOV
1776
                input_files.append(kwds.pop("eln_data"))
×
1777

1778
            to_nexus(
1✔
1779
                data=self._binned,
1780
                faddr=faddr,
1781
                reader=reader,
1782
                definition=definition,
1783
                input_files=input_files,
1784
                **kwds,
1785
            )
1786

1787
        else:
1788
            raise NotImplementedError(
1✔
1789
                f"Unrecognized file format: {extension}.",
1790
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc