• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6286144954

23 Sep 2023 09:41PM UTC coverage: 90.487% (-0.04%) from 90.527%
6286144954

Pull #152

github

web-flow
Merge 2322fb7cd into 2b8114c57
Pull Request #152: Flash energy calibration

25 of 25 new or added lines in 2 files covered. (100.0%)

4195 of 4636 relevant lines covered (90.49%)

2.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.86
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
3✔
5
from typing import Any
3✔
6
from typing import cast
3✔
7
from typing import Dict
3✔
8
from typing import List
3✔
9
from typing import Sequence
3✔
10
from typing import Tuple
3✔
11
from typing import Union
3✔
12

13
import dask.dataframe as ddf
3✔
14
import matplotlib.pyplot as plt
3✔
15
import numpy as np
3✔
16
import pandas as pd
3✔
17
import psutil
3✔
18
import xarray as xr
3✔
19

20
from sed.binning import bin_dataframe
3✔
21
from sed.calibrator import DelayCalibrator
3✔
22
from sed.calibrator import EnergyCalibrator
3✔
23
from sed.calibrator import MomentumCorrector
3✔
24
from sed.core.config import parse_config
3✔
25
from sed.core.config import save_config
3✔
26
from sed.core.dfops import apply_jitter
3✔
27
from sed.core.metadata import MetaHandler
3✔
28
from sed.diagnostics import grid_histogram
3✔
29
from sed.io import to_h5
3✔
30
from sed.io import to_nexus
3✔
31
from sed.io import to_tiff
3✔
32
from sed.loader import CopyTool
3✔
33
from sed.loader import get_loader
3✔
34

35
N_CPU = psutil.cpu_count()
3✔
36

37

38
class SedProcessor:
3✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
3✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to parse_config and to the reader.
86
        """
87
        config_kwds = {
3✔
88
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
89
        }
90
        for key in config_kwds.keys():
3✔
91
            del kwds[key]
3✔
92
        self._config = parse_config(config, **config_kwds)
3✔
93
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
3✔
94
        if num_cores >= N_CPU:
3✔
95
            num_cores = N_CPU - 1
3✔
96
        self._config["binning"]["num_cores"] = num_cores
3✔
97

98
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
3✔
99
        self._files: List[str] = []
3✔
100

101
        self._binned: xr.DataArray = None
3✔
102
        self._pre_binned: xr.DataArray = None
3✔
103

104
        self._attributes = MetaHandler(meta=metadata)
3✔
105

106
        loader_name = self._config["core"]["loader"]
3✔
107
        self.loader = get_loader(
3✔
108
            loader_name=loader_name,
109
            config=self._config,
110
        )
111

112
        self.ec = EnergyCalibrator(
3✔
113
            loader=self.loader,
114
            config=self._config,
115
        )
116

117
        self.mc = MomentumCorrector(
3✔
118
            config=self._config,
119
        )
120

121
        self.dc = DelayCalibrator(
3✔
122
            config=self._config,
123
        )
124

125
        self.use_copy_tool = self._config.get("core", {}).get(
3✔
126
            "use_copy_tool",
127
            False,
128
        )
129
        if self.use_copy_tool:
3✔
130
            try:
3✔
131
                self.ct = CopyTool(
3✔
132
                    source=self._config["core"]["copy_tool_source"],
133
                    dest=self._config["core"]["copy_tool_dest"],
134
                    **self._config["core"].get("copy_tool_kwds", {}),
135
                )
136
            except KeyError:
3✔
137
                self.use_copy_tool = False
3✔
138

139
        # Load data if provided:
140
        if dataframe is not None or files is not None or folder is not None or runs is not None:
3✔
141
            self.load(
3✔
142
                dataframe=dataframe,
143
                metadata=metadata,
144
                files=files,
145
                folder=folder,
146
                runs=runs,
147
                collect_metadata=collect_metadata,
148
                **kwds,
149
            )
150

151
    def __repr__(self):
3✔
152
        if self._dataframe is None:
3✔
153
            df_str = "Data Frame: No Data loaded"
3✔
154
        else:
155
            df_str = self._dataframe.__repr__()
3✔
156
        attributes_str = f"Metadata: {self._attributes.metadata}"
3✔
157
        pretty_str = df_str + "\n" + attributes_str
3✔
158
        return pretty_str
3✔
159

160
    @property
3✔
161
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
3✔
162
        """Accessor to the underlying dataframe.
163

164
        Returns:
165
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
166
        """
167
        return self._dataframe
3✔
168

169
    @dataframe.setter
3✔
170
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
3✔
171
        """Setter for the underlying dataframe.
172

173
        Args:
174
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
175
        """
176
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
3✔
177
            dataframe,
178
            self._dataframe.__class__,
179
        ):
180
            raise ValueError(
3✔
181
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
182
                "as the dataframe loaded into the SedProcessor!.\n"
183
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
184
            )
185
        self._dataframe = dataframe
3✔
186

187
    @property
3✔
188
    def attributes(self) -> dict:
3✔
189
        """Accessor to the metadata dict.
190

191
        Returns:
192
            dict: The metadata dict.
193
        """
194
        return self._attributes.metadata
3✔
195

196
    def add_attribute(self, attributes: dict, name: str, **kwds):
3✔
197
        """Function to add element to the attributes dict.
198

199
        Args:
200
            attributes (dict): The attributes dictionary object to add.
201
            name (str): Key under which to add the dictionary to the attributes.
202
        """
203
        self._attributes.add(
3✔
204
            entry=attributes,
205
            name=name,
206
            **kwds,
207
        )
208

209
    @property
3✔
210
    def config(self) -> Dict[Any, Any]:
3✔
211
        """Getter attribute for the config dictionary
212

213
        Returns:
214
            Dict: The config dictionary.
215
        """
216
        return self._config
3✔
217

218
    @property
3✔
219
    def files(self) -> List[str]:
3✔
220
        """Getter attribute for the list of files
221

222
        Returns:
223
            List[str]: The list of loaded files
224
        """
225
        return self._files
3✔
226

227
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
3✔
228
        """Function to mirror a list of files or a folder from a network drive to a
229
        local storage. Returns either the original or the copied path to the given
230
        path. The option to use this functionality is set by
231
        config["core"]["use_copy_tool"].
232

233
        Args:
234
            path (Union[str, List[str]]): Source path or path list.
235

236
        Returns:
237
            Union[str, List[str]]: Source or destination path or path list.
238
        """
239
        if self.use_copy_tool:
3✔
240
            if isinstance(path, list):
3✔
241
                path_out = []
3✔
242
                for file in path:
3✔
243
                    path_out.append(self.ct.copy(file))
3✔
244
                return path_out
3✔
245

246
            return self.ct.copy(path)
×
247

248
        if isinstance(path, list):
3✔
249
            return path
3✔
250

251
        return path
3✔
252

253
    def load(
3✔
254
        self,
255
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
256
        metadata: dict = None,
257
        files: List[str] = None,
258
        folder: str = None,
259
        runs: Sequence[str] = None,
260
        collect_metadata: bool = False,
261
        **kwds,
262
    ):
263
        """Load tabular data of single events into the dataframe object in the class.
264

265
        Args:
266
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
267
                format. Accepts anything which can be interpreted by pd.DataFrame as
268
                an input. Defaults to None.
269
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
270
            files (List[str], optional): List of file paths to pass to the loader.
271
                Defaults to None.
272
            runs (Sequence[str], optional): List of run identifiers to pass to the
273
                loader. Defaults to None.
274
            folder (str, optional): Folder path to pass to the loader.
275
                Defaults to None.
276

277
        Raises:
278
            ValueError: Raised if no valid input is provided.
279
        """
280
        if metadata is None:
3✔
281
            metadata = {}
3✔
282
        if dataframe is not None:
3✔
283
            self._dataframe = dataframe
3✔
284
        elif runs is not None:
3✔
285
            # If runs are provided, we only use the copy tool if also folder is provided.
286
            # In that case, we copy the whole provided base folder tree, and pass the copied
287
            # version to the loader as base folder to look for the runs.
288
            if folder is not None:
3✔
289
                dataframe, metadata = self.loader.read_dataframe(
3✔
290
                    folders=cast(str, self.cpy(folder)),
291
                    runs=runs,
292
                    metadata=metadata,
293
                    collect_metadata=collect_metadata,
294
                    **kwds,
295
                )
296
            else:
297
                dataframe, metadata = self.loader.read_dataframe(
×
298
                    runs=runs,
299
                    metadata=metadata,
300
                    collect_metadata=collect_metadata,
301
                    **kwds,
302
                )
303

304
        elif folder is not None:
3✔
305
            dataframe, metadata = self.loader.read_dataframe(
3✔
306
                folders=cast(str, self.cpy(folder)),
307
                metadata=metadata,
308
                collect_metadata=collect_metadata,
309
                **kwds,
310
            )
311

312
        elif files is not None:
3✔
313
            dataframe, metadata = self.loader.read_dataframe(
3✔
314
                files=cast(List[str], self.cpy(files)),
315
                metadata=metadata,
316
                collect_metadata=collect_metadata,
317
                **kwds,
318
            )
319

320
        else:
321
            raise ValueError(
3✔
322
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
323
            )
324

325
        self._dataframe = dataframe
3✔
326
        self._files = self.loader.files
3✔
327

328
        for key in metadata:
3✔
329
            self._attributes.add(
3✔
330
                entry=metadata[key],
331
                name=key,
332
                duplicate_policy="merge",
333
            )
334

335
    # Momentum calibration workflow
336
    # 1. Bin raw detector data for distortion correction
337
    def bin_and_load_momentum_calibration(
3✔
338
        self,
339
        df_partitions: int = 100,
340
        axes: List[str] = None,
341
        bins: List[int] = None,
342
        ranges: Sequence[Tuple[float, float]] = None,
343
        plane: int = 0,
344
        width: int = 5,
345
        apply: bool = False,
346
        **kwds,
347
    ):
348
        """1st step of momentum correction work flow. Function to do an initial binning
349
        of the dataframe loaded to the class, slice a plane from it using an
350
        interactive view, and load it into the momentum corrector class.
351

352
        Args:
353
            df_partitions (int, optional): Number of dataframe partitions to use for
354
                the initial binning. Defaults to 100.
355
            axes (List[str], optional): Axes to bin.
356
                Defaults to config["momentum"]["axes"].
357
            bins (List[int], optional): Bin numbers to use for binning.
358
                Defaults to config["momentum"]["bins"].
359
            ranges (List[Tuple], optional): Ranges to use for binning.
360
                Defaults to config["momentum"]["ranges"].
361
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
362
            width (int, optional): Initial value for the width slider. Defaults to 5.
363
            apply (bool, optional): Option to directly apply the values and select the
364
                slice. Defaults to False.
365
            **kwds: Keyword argument passed to the pre_binning function.
366
        """
367
        self._pre_binned = self.pre_binning(
3✔
368
            df_partitions=df_partitions,
369
            axes=axes,
370
            bins=bins,
371
            ranges=ranges,
372
            **kwds,
373
        )
374

375
        self.mc.load_data(data=self._pre_binned)
3✔
376
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
3✔
377

378
    # 2. Generate the spline warp correction from momentum features.
379
    # Either autoselect features, or input features from view above.
380
    def define_features(
3✔
381
        self,
382
        features: np.ndarray = None,
383
        rotation_symmetry: int = 6,
384
        auto_detect: bool = False,
385
        include_center: bool = True,
386
        apply: bool = False,
387
        **kwds,
388
    ):
389
        """2. Step of the distortion correction workflow: Define feature points in
390
        momentum space. They can be either manually selected using a GUI tool, be
391
        ptovided as list of feature points, or auto-generated using a
392
        feature-detection algorithm.
393

394
        Args:
395
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
396
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
397
                Defaults to 6.
398
            auto_detect (bool, optional): Whether to auto-detect the features.
399
                Defaults to False.
400
            include_center (bool, optional): Option to include a point at the center
401
                in the feature list. Defaults to True.
402
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
403
                MomentumCorrector.feature_select()
404
        """
405
        if auto_detect:  # automatic feature selection
3✔
406
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
407
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
408
            sigma_radius = kwds.pop(
×
409
                "sigma_radius",
410
                self._config["momentum"]["sigma_radius"],
411
            )
412
            self.mc.feature_extract(
×
413
                sigma=sigma,
414
                fwhm=fwhm,
415
                sigma_radius=sigma_radius,
416
                rotsym=rotation_symmetry,
417
                **kwds,
418
            )
419
            features = self.mc.peaks
×
420

421
        self.mc.feature_select(
3✔
422
            rotsym=rotation_symmetry,
423
            include_center=include_center,
424
            features=features,
425
            apply=apply,
426
            **kwds,
427
        )
428

429
    # 3. Generate the spline warp correction from momentum features.
430
    # If no features have been selected before, use class defaults.
431
    def generate_splinewarp(
3✔
432
        self,
433
        use_center: bool = None,
434
        **kwds,
435
    ):
436
        """3. Step of the distortion correction workflow: Generate the correction
437
        function restoring the symmetry in the image using a splinewarp algortihm.
438

439
        Args:
440
            use_center (bool, optional): Option to use the position of the
441
                center point in the correction. Default is read from config, or set to True.
442
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
443
        """
444
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
3✔
445

446
        if self.mc.slice is not None:
3✔
447
            print("Original slice with reference features")
3✔
448
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
3✔
449

450
            print("Corrected slice with target features")
3✔
451
            self.mc.view(
3✔
452
                image=self.mc.slice_corrected,
453
                annotated=True,
454
                points={"feats": self.mc.ptargs},
455
                backend="bokeh",
456
                crosshair=True,
457
            )
458

459
            print("Original slice with target features")
3✔
460
            self.mc.view(
3✔
461
                image=self.mc.slice,
462
                points={"feats": self.mc.ptargs},
463
                annotated=True,
464
                backend="bokeh",
465
            )
466

467
    # 3a. Save spline-warp parameters to config file.
468
    def save_splinewarp(
3✔
469
        self,
470
        filename: str = None,
471
        overwrite: bool = False,
472
    ):
473
        """Save the generated spline-warp parameters to the folder config file.
474

475
        Args:
476
            filename (str, optional): Filename of the config dictionary to save to.
477
                Defaults to "sed_config.yaml" in the current folder.
478
            overwrite (bool, optional): Option to overwrite the present dictionary.
479
                Defaults to False.
480
        """
481
        if filename is None:
3✔
482
            filename = "sed_config.yaml"
×
483
        points = []
3✔
484
        try:
3✔
485
            for point in self.mc.pouter_ord:
3✔
486
                points.append([float(i) for i in point])
3✔
487
            if self.mc.include_center:
3✔
488
                points.append([float(i) for i in self.mc.pcent])
3✔
489
        except AttributeError as exc:
×
490
            raise AttributeError(
×
491
                "Momentum correction parameters not found, need to generate parameters first!",
492
            ) from exc
493
        config = {
3✔
494
            "momentum": {
495
                "correction": {
496
                    "rotation_symmetry": self.mc.rotsym,
497
                    "feature_points": points,
498
                    "include_center": self.mc.include_center,
499
                    "use_center": self.mc.use_center,
500
                },
501
            },
502
        }
503
        save_config(config, filename, overwrite)
3✔
504

505
    # 4. Pose corrections. Provide interactive interface for correcting
506
    # scaling, shift and rotation
507
    def pose_adjustment(
3✔
508
        self,
509
        scale: float = 1,
510
        xtrans: float = 0,
511
        ytrans: float = 0,
512
        angle: float = 0,
513
        apply: bool = False,
514
        use_correction: bool = True,
515
    ):
516
        """3. step of the distortion correction workflow: Generate an interactive panel
517
        to adjust affine transformations that are applied to the image. Applies first
518
        a scaling, next an x/y translation, and last a rotation around the center of
519
        the image.
520

521
        Args:
522
            scale (float, optional): Initial value of the scaling slider.
523
                Defaults to 1.
524
            xtrans (float, optional): Initial value of the xtrans slider.
525
                Defaults to 0.
526
            ytrans (float, optional): Initial value of the ytrans slider.
527
                Defaults to 0.
528
            angle (float, optional): Initial value of the angle slider.
529
                Defaults to 0.
530
            apply (bool, optional): Option to directly apply the provided
531
                transformations. Defaults to False.
532
            use_correction (bool, option): Whether to use the spline warp correction
533
                or not. Defaults to True.
534
        """
535
        # Generate homomorphy as default if no distortion correction has been applied
536
        if self.mc.slice_corrected is None:
3✔
537
            if self.mc.slice is None:
3✔
538
                raise ValueError(
3✔
539
                    "No slice for corrections and transformations loaded!",
540
                )
541
            self.mc.slice_corrected = self.mc.slice
×
542

543
        if not use_correction:
3✔
544
            self.mc.reset_deformation()
3✔
545

546
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
3✔
547
            # Generate default distortion correction
548
            self.mc.add_features()
×
549
            self.mc.spline_warp_estimate()
×
550

551
        self.mc.pose_adjustment(
3✔
552
            scale=scale,
553
            xtrans=xtrans,
554
            ytrans=ytrans,
555
            angle=angle,
556
            apply=apply,
557
        )
558

559
    # 5. Apply the momentum correction to the dataframe
560
    def apply_momentum_correction(
3✔
561
        self,
562
        preview: bool = False,
563
    ):
564
        """Applies the distortion correction and pose adjustment (optional)
565
        to the dataframe.
566

567
        Args:
568
            rdeform_field (np.ndarray, optional): Row deformation field.
569
                Defaults to None.
570
            cdeform_field (np.ndarray, optional): Column deformation field.
571
                Defaults to None.
572
            inv_dfield (np.ndarray, optional): Inverse deformation field.
573
                Defaults to None.
574
            preview (bool): Option to preview the first elements of the data frame.
575
        """
576
        if self._dataframe is not None:
3✔
577
            print("Adding corrected X/Y columns to dataframe:")
3✔
578
            self._dataframe, metadata = self.mc.apply_corrections(
3✔
579
                df=self._dataframe,
580
            )
581
            # Add Metadata
582
            self._attributes.add(
3✔
583
                metadata,
584
                "momentum_correction",
585
                duplicate_policy="merge",
586
            )
587
            if preview:
3✔
588
                print(self._dataframe.head(10))
×
589
            else:
590
                print(self._dataframe)
3✔
591

592
    # Momentum calibration work flow
593
    # 1. Calculate momentum calibration
594
    def calibrate_momentum_axes(
3✔
595
        self,
596
        point_a: Union[np.ndarray, List[int]] = None,
597
        point_b: Union[np.ndarray, List[int]] = None,
598
        k_distance: float = None,
599
        k_coord_a: Union[np.ndarray, List[float]] = None,
600
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
601
        equiscale: bool = True,
602
        apply=False,
603
    ):
604
        """1. step of the momentum calibration workflow. Calibrate momentum
605
        axes using either provided pixel coordinates of a high-symmetry point and its
606
        distance to the BZ center, or the k-coordinates of two points in the BZ
607
        (depending on the equiscale option). Opens an interactive panel for selecting
608
        the points.
609

610
        Args:
611
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
612
                point used for momentum calibration.
613
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
614
                second point used for momentum calibration.
615
                Defaults to config["momentum"]["center_pixel"].
616
            k_distance (float, optional): Momentum distance between point a and b.
617
                Needs to be provided if no specific k-koordinates for the two points
618
                are given. Defaults to None.
619
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
620
                of the first point used for calibration. Used if equiscale is False.
621
                Defaults to None.
622
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
623
                of the second point used for calibration. Defaults to [0.0, 0.0].
624
            equiscale (bool, optional): Option to apply different scales to kx and ky.
625
                If True, the distance between points a and b, and the absolute
626
                position of point a are used for defining the scale. If False, the
627
                scale is calculated from the k-positions of both points a and b.
628
                Defaults to True.
629
            apply (bool, optional): Option to directly store the momentum calibration
630
                in the class. Defaults to False.
631
        """
632
        if point_b is None:
3✔
633
            point_b = self._config["momentum"]["center_pixel"]
3✔
634

635
        self.mc.select_k_range(
3✔
636
            point_a=point_a,
637
            point_b=point_b,
638
            k_distance=k_distance,
639
            k_coord_a=k_coord_a,
640
            k_coord_b=k_coord_b,
641
            equiscale=equiscale,
642
            apply=apply,
643
        )
644

645
    # 1a. Save momentum calibration parameters to config file.
646
    def save_momentum_calibration(
3✔
647
        self,
648
        filename: str = None,
649
        overwrite: bool = False,
650
    ):
651
        """Save the generated momentum calibration parameters to the folder config file.
652

653
        Args:
654
            filename (str, optional): Filename of the config dictionary to save to.
655
                Defaults to "sed_config.yaml" in the current folder.
656
            overwrite (bool, optional): Option to overwrite the present dictionary.
657
                Defaults to False.
658
        """
659
        if filename is None:
3✔
660
            filename = "sed_config.yaml"
3✔
661
        calibration = {}
3✔
662
        try:
3✔
663
            for key in [
3✔
664
                "kx_scale",
665
                "ky_scale",
666
                "x_center",
667
                "y_center",
668
                "rstart",
669
                "cstart",
670
                "rstep",
671
                "cstep",
672
            ]:
673
                calibration[key] = float(self.mc.calibration[key])
3✔
674
        except KeyError as exc:
×
675
            raise KeyError(
×
676
                "Momentum calibration parameters not found, need to generate parameters first!",
677
            ) from exc
678

679
        config = {"momentum": {"calibration": calibration}}
3✔
680
        save_config(config, filename, overwrite)
3✔
681

682
    # 2. Apply correction and calibration to the dataframe
683
    def apply_momentum_calibration(
3✔
684
        self,
685
        calibration: dict = None,
686
        preview: bool = False,
687
    ):
688
        """2. step of the momentum calibration work flow: Apply the momentum
689
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
690
        these are used.
691

692
        Args:
693
            calibration (dict, optional): Optional dictionary with calibration data to
694
                use. Defaults to None.
695
            preview (bool): Option to preview the first elements of the data frame.
696
        """
697
        if self._dataframe is not None:
3✔
698

699
            print("Adding kx/ky columns to dataframe:")
3✔
700
            self._dataframe, metadata = self.mc.append_k_axis(
3✔
701
                df=self._dataframe,
702
                calibration=calibration,
703
            )
704

705
            # Add Metadata
706
            self._attributes.add(
3✔
707
                metadata,
708
                "momentum_calibration",
709
                duplicate_policy="merge",
710
            )
711
            if preview:
3✔
712
                print(self._dataframe.head(10))
×
713
            else:
714
                print(self._dataframe)
3✔
715

716
    # Energy correction workflow
717
    # 1. Adjust the energy correction parameters
718
    def adjust_energy_correction(
3✔
719
        self,
720
        correction_type: str = None,
721
        amplitude: float = None,
722
        center: Tuple[float, float] = None,
723
        apply=False,
724
        **kwds,
725
    ):
726
        """1. step of the energy crrection workflow: Opens an interactive plot to
727
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
728
        they are not present yet.
729

730
        Args:
731
            correction_type (str, optional): Type of correction to apply to the TOF
732
                axis. Valid values are:
733

734
                - 'spherical'
735
                - 'Lorentzian'
736
                - 'Gaussian'
737
                - 'Lorentzian_asymmetric'
738

739
                Defaults to config["energy"]["correction_type"].
740
            amplitude (float, optional): Amplitude of the correction.
741
                Defaults to config["energy"]["correction"]["amplitude"].
742
            center (Tuple[float, float], optional): Center X/Y coordinates for the
743
                correction. Defaults to config["energy"]["correction"]["center"].
744
            apply (bool, optional): Option to directly apply the provided or default
745
                correction parameters. Defaults to False.
746
        """
747
        if self._pre_binned is None:
3✔
748
            print(
3✔
749
                "Pre-binned data not present, binning using defaults from config...",
750
            )
751
            self._pre_binned = self.pre_binning()
3✔
752

753
        self.ec.adjust_energy_correction(
3✔
754
            self._pre_binned,
755
            correction_type=correction_type,
756
            amplitude=amplitude,
757
            center=center,
758
            apply=apply,
759
            **kwds,
760
        )
761

762
    # 1a. Save energy correction parameters to config file.
763
    def save_energy_correction(
3✔
764
        self,
765
        filename: str = None,
766
        overwrite: bool = False,
767
    ):
768
        """Save the generated energy correction parameters to the folder config file.
769

770
        Args:
771
            filename (str, optional): Filename of the config dictionary to save to.
772
                Defaults to "sed_config.yaml" in the current folder.
773
            overwrite (bool, optional): Option to overwrite the present dictionary.
774
                Defaults to False.
775
        """
776
        if filename is None:
3✔
777
            filename = "sed_config.yaml"
3✔
778
        correction = {}
3✔
779
        try:
3✔
780
            for key, val in self.ec.correction.items():
3✔
781
                if key == "correction_type":
3✔
782
                    correction[key] = val
3✔
783
                elif key == "center":
3✔
784
                    correction[key] = [float(i) for i in val]
3✔
785
                else:
786
                    correction[key] = float(val)
3✔
787
        except AttributeError as exc:
×
788
            raise AttributeError(
×
789
                "Energy correction parameters not found, need to generate parameters first!",
790
            ) from exc
791

792
        config = {"energy": {"correction": correction}}
3✔
793
        save_config(config, filename, overwrite)
3✔
794

795
    # 2. Apply energy correction to dataframe
796
    def apply_energy_correction(
3✔
797
        self,
798
        correction: dict = None,
799
        preview: bool = False,
800
        **kwds,
801
    ):
802
        """2. step of the energy correction workflow: Apply the enery correction
803
        parameters stored in the class to the dataframe.
804

805
        Args:
806
            correction (dict, optional): Dictionary containing the correction
807
                parameters. Defaults to config["energy"]["calibration"].
808
            preview (bool): Option to preview the first elements of the data frame.
809
            **kwds:
810
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
811
            preview (bool): Option to preview the first elements of the data frame.
812
            **kwds:
813
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
814
        """
815
        if self._dataframe is not None:
3✔
816
            print("Applying energy correction to dataframe...")
3✔
817
            self._dataframe, metadata = self.ec.apply_energy_correction(
3✔
818
                df=self._dataframe,
819
                correction=correction,
820
                **kwds,
821
            )
822

823
            # Add Metadata
824
            self._attributes.add(
3✔
825
                metadata,
826
                "energy_correction",
827
            )
828
            if preview:
3✔
829
                print(self._dataframe.head(10))
×
830
            else:
831
                print(self._dataframe)
3✔
832

833
    # Energy calibrator workflow
834
    # 1. Load and normalize data
835
    def load_bias_series(
3✔
836
        self,
837
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
838
        data_files: List[str] = None,
839
        axes: List[str] = None,
840
        bins: List = None,
841
        ranges: Sequence[Tuple[float, float]] = None,
842
        biases: np.ndarray = None,
843
        bias_key: str = None,
844
        normalize: bool = None,
845
        span: int = None,
846
        order: int = None,
847
    ):
848
        """1. step of the energy calibration workflow: Load and bin data from
849
        single-event files, or load binned bias/TOF traces.
850

851
        Args:
852
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
853
                Binned data If provided as DataArray, Needs to contain dimensions
854
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
855
                provided as tuple, needs to contain elements tof, biases, traces.
856
            data_files (List[str], optional): list of file paths to bin
857
            axes (List[str], optional): bin axes.
858
                Defaults to config["dataframe"]["tof_column"].
859
            bins (List, optional): number of bins.
860
                Defaults to config["energy"]["bins"].
861
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
862
                Defaults to config["energy"]["ranges"].
863
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
864
                voltages are extracted from the data files.
865
            bias_key (str, optional): hdf5 path where bias values are stored.
866
                Defaults to config["energy"]["bias_key"].
867
            normalize (bool, optional): Option to normalize traces.
868
                Defaults to config["energy"]["normalize"].
869
            span (int, optional): span smoothing parameters of the LOESS method
870
                (see ``scipy.signal.savgol_filter()``).
871
                Defaults to config["energy"]["normalize_span"].
872
            order (int, optional): order smoothing parameters of the LOESS method
873
                (see ``scipy.signal.savgol_filter()``).
874
                Defaults to config["energy"]["normalize_order"].
875
        """
876
        if binned_data is not None:
3✔
877
            if isinstance(binned_data, xr.DataArray):
3✔
878
                if (
3✔
879
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
880
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
881
                ):
882
                    raise ValueError(
×
883
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
884
                        f"'{self._config['dataframe']['tof_column']}' and "
885
                        f"'{self._config['dataframe']['bias_column']}'!.",
886
                    )
887
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
3✔
888
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
3✔
889
                traces = binned_data.values[:, :]
3✔
890
            else:
891
                try:
3✔
892
                    (tof, biases, traces) = binned_data
3✔
893
                except ValueError as exc:
×
894
                    raise ValueError(
×
895
                        "If binned_data is provided as tuple, it needs to contain "
896
                        "(tof, biases, traces)!",
897
                    ) from exc
898
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
3✔
899

900
        elif data_files is not None:
3✔
901

902
            self.ec.bin_data(
3✔
903
                data_files=cast(List[str], self.cpy(data_files)),
904
                axes=axes,
905
                bins=bins,
906
                ranges=ranges,
907
                biases=biases,
908
                bias_key=bias_key,
909
            )
910

911
        else:
912
            raise ValueError("Either binned_data or data_files needs to be provided!")
×
913

914
        if (normalize is not None and normalize is True) or (
3✔
915
            normalize is None and self._config["energy"]["normalize"]
916
        ):
917
            if span is None:
3✔
918
                span = self._config["energy"]["normalize_span"]
3✔
919
            if order is None:
3✔
920
                order = self._config["energy"]["normalize_order"]
3✔
921
            self.ec.normalize(smooth=True, span=span, order=order)
3✔
922
        self.ec.view(
3✔
923
            traces=self.ec.traces_normed,
924
            xaxis=self.ec.tof,
925
            backend="bokeh",
926
        )
927

928
    # 2. extract ranges and get peak positions
929
    def find_bias_peaks(
3✔
930
        self,
931
        ranges: Union[List[Tuple], Tuple],
932
        ref_id: int = 0,
933
        infer_others: bool = True,
934
        mode: str = "replace",
935
        radius: int = None,
936
        peak_window: int = None,
937
        apply: bool = False,
938
    ):
939
        """2. step of the energy calibration workflow: Find a peak within a given range
940
        for the indicated reference trace, and tries to find the same peak for all
941
        other traces. Uses fast_dtw to align curves, which might not be too good if the
942
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
943
        middle of the set, and don't choose the range too narrow around the peak.
944
        Alternatively, a list of ranges for all traces can be provided.
945

946
        Args:
947
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
948
                Alternatively, a list of ranges for all traces can be given.
949
            refid (int, optional): The id of the trace the range refers to.
950
                Defaults to 0.
951
            infer_others (bool, optional): Whether to determine the range for the other
952
                traces. Defaults to True.
953
            mode (str, optional): Whether to "add" or "replace" existing ranges.
954
                Defaults to "replace".
955
            radius (int, optional): Radius parameter for fast_dtw.
956
                Defaults to config["energy"]["fastdtw_radius"].
957
            peak_window (int, optional): Peak_window parameter for the peak detection
958
                algorthm. amount of points that have to have to behave monotoneously
959
                around a peak. Defaults to config["energy"]["peak_window"].
960
            apply (bool, optional): Option to directly apply the provided parameters.
961
                Defaults to False.
962
        """
963
        if radius is None:
3✔
964
            radius = self._config["energy"]["fastdtw_radius"]
3✔
965
        if peak_window is None:
3✔
966
            peak_window = self._config["energy"]["peak_window"]
3✔
967
        if not infer_others:
3✔
968
            self.ec.add_ranges(
3✔
969
                ranges=ranges,
970
                ref_id=ref_id,
971
                infer_others=infer_others,
972
                mode=mode,
973
                radius=radius,
974
            )
975
            print(self.ec.featranges)
3✔
976
            try:
3✔
977
                self.ec.feature_extract(peak_window=peak_window)
3✔
978
                self.ec.view(
3✔
979
                    traces=self.ec.traces_normed,
980
                    segs=self.ec.featranges,
981
                    xaxis=self.ec.tof,
982
                    peaks=self.ec.peaks,
983
                    backend="bokeh",
984
                )
985
            except IndexError:
×
986
                print("Could not determine all peaks!")
×
987
                raise
×
988
        else:
989
            # New adjustment tool
990
            assert isinstance(ranges, tuple)
3✔
991
            self.ec.adjust_ranges(
3✔
992
                ranges=ranges,
993
                ref_id=ref_id,
994
                traces=self.ec.traces_normed,
995
                infer_others=infer_others,
996
                radius=radius,
997
                peak_window=peak_window,
998
                apply=apply,
999
            )
1000

1001
    # 3. Fit the energy calibration relation
1002
    def calibrate_energy_axis(
3✔
1003
        self,
1004
        ref_id: int,
1005
        ref_energy: float,
1006
        method: str = None,
1007
        energy_scale: str = None,
1008
        **kwds,
1009
    ):
1010
        """3. Step of the energy calibration workflow: Calculate the calibration
1011
        function for the energy axis, and apply it to the dataframe. Two
1012
        approximations are implemented, a (normally 3rd order) polynomial
1013
        approximation, and a d^2/(t-t0)^2 relation.
1014

1015
        Args:
1016
            ref_id (int): id of the trace at the bias where the reference energy is
1017
                given.
1018
            ref_energy (float): Absolute energy of the detected feature at the bias
1019
                of ref_id
1020
            method (str, optional): Method for determining the energy calibration.
1021

1022
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1023
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1024

1025
                Defaults to config["energy"]["calibration_method"]
1026
            energy_scale (str, optional): Direction of increasing energy scale.
1027

1028
                - **'kinetic'**: increasing energy with decreasing TOF.
1029
                - **'binding'**: increasing energy with increasing TOF.
1030

1031
                Defaults to config["energy"]["energy_scale"]
1032
        """
1033
        if method is None:
3✔
1034
            method = self._config["energy"]["calibration_method"]
3✔
1035

1036
        if energy_scale is None:
3✔
1037
            energy_scale = self._config["energy"]["energy_scale"]
3✔
1038

1039
        self.ec.calibrate(
3✔
1040
            ref_id=ref_id,
1041
            ref_energy=ref_energy,
1042
            method=method,
1043
            energy_scale=energy_scale,
1044
            **kwds,
1045
        )
1046
        print("Quality of Calibration:")
3✔
1047
        self.ec.view(
3✔
1048
            traces=self.ec.traces_normed,
1049
            xaxis=self.ec.calibration["axis"],
1050
            align=True,
1051
            energy_scale=energy_scale,
1052
            backend="bokeh",
1053
        )
1054
        print("E/TOF relationship:")
3✔
1055
        self.ec.view(
3✔
1056
            traces=self.ec.calibration["axis"][None, :],
1057
            xaxis=self.ec.tof,
1058
            backend="matplotlib",
1059
            show_legend=False,
1060
        )
1061
        if energy_scale == "kinetic":
3✔
1062
            plt.scatter(
3✔
1063
                self.ec.peaks[:, 0],
1064
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1065
                s=50,
1066
                c="k",
1067
            )
1068
        elif energy_scale == "binding":
3✔
1069
            plt.scatter(
3✔
1070
                self.ec.peaks[:, 0],
1071
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1072
                s=50,
1073
                c="k",
1074
            )
1075
        else:
1076
            raise ValueError(
×
1077
                'energy_scale needs to be either "binding" or "kinetic"',
1078
                f", got {energy_scale}.",
1079
            )
1080
        plt.xlabel("Time-of-flight", fontsize=15)
3✔
1081
        plt.ylabel("Energy (eV)", fontsize=15)
3✔
1082
        plt.show()
3✔
1083

1084
    # 3a. Save energy calibration parameters to config file.
1085
    def save_energy_calibration(
3✔
1086
        self,
1087
        filename: str = None,
1088
        overwrite: bool = False,
1089
    ):
1090
        """Save the generated energy calibration parameters to the folder config file.
1091

1092
        Args:
1093
            filename (str, optional): Filename of the config dictionary to save to.
1094
                Defaults to "sed_config.yaml" in the current folder.
1095
            overwrite (bool, optional): Option to overwrite the present dictionary.
1096
                Defaults to False.
1097
        """
1098
        if filename is None:
3✔
1099
            filename = "sed_config.yaml"
3✔
1100
        calibration = {}
3✔
1101
        try:
3✔
1102
            for (key, value) in self.ec.calibration.items():
3✔
1103
                if key in ["axis", "refid", "Tmat", "bvec"]:
3✔
1104
                    continue
3✔
1105
                if key == "energy_scale":
3✔
1106
                    calibration[key] = value
3✔
1107
                elif key == "coeffs":
3✔
1108
                    calibration[key] = [float(i) for i in value]
3✔
1109
                else:
1110
                    calibration[key] = float(value)
3✔
1111
        except AttributeError as exc:
×
1112
            raise AttributeError(
×
1113
                "Energy calibration parameters not found, need to generate parameters first!",
1114
            ) from exc
1115

1116
        config = {"energy": {"calibration": calibration}}
3✔
1117
        save_config(config, filename, overwrite)
3✔
1118

1119
    # 4. Apply energy calibration to the dataframe
1120
    def append_energy_axis(
3✔
1121
        self,
1122
        calibration: dict = None,
1123
        preview: bool = False,
1124
        **kwds,
1125
    ):
1126
        """4. step of the energy calibration workflow: Apply the calibration function
1127
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1128
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1129
        can be provided.
1130

1131
        Args:
1132
            calibration (dict, optional): Calibration dict containing calibration
1133
                parameters. Overrides calibration from class or config.
1134
                Defaults to None.
1135
            preview (bool): Option to preview the first elements of the data frame.
1136
            **kwds:
1137
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1138
        """
1139
        if self._dataframe is not None:
3✔
1140
            print("Adding energy column to dataframe:")
3✔
1141
            self._dataframe, metadata = self.ec.append_energy_axis(
3✔
1142
                df=self._dataframe,
1143
                calibration=calibration,
1144
                **kwds,
1145
            )
1146

1147
            # Add Metadata
1148
            self._attributes.add(
3✔
1149
                metadata,
1150
                "energy_calibration",
1151
                duplicate_policy="merge",
1152
            )
1153
            if preview:
3✔
1154
                print(self._dataframe.head(10))
3✔
1155
            else:
1156
                print(self._dataframe)
3✔
1157

1158
    # Delay calibration function
1159
    def calibrate_delay_axis(
3✔
1160
        self,
1161
        delay_range: Tuple[float, float] = None,
1162
        datafile: str = None,
1163
        preview: bool = False,
1164
        **kwds,
1165
    ):
1166
        """Append delay column to dataframe. Either provide delay ranges, or read
1167
        them from a file.
1168

1169
        Args:
1170
            delay_range (Tuple[float, float], optional): The scanned delay range in
1171
                picoseconds. Defaults to None.
1172
            datafile (str, optional): The file from which to read the delay ranges.
1173
                Defaults to None.
1174
            preview (bool): Option to preview the first elements of the data frame.
1175
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1176
        """
1177
        if self._dataframe is not None:
3✔
1178
            print("Adding delay column to dataframe:")
3✔
1179

1180
            if delay_range is not None:
3✔
1181
                self._dataframe, metadata = self.dc.append_delay_axis(
3✔
1182
                    self._dataframe,
1183
                    delay_range=delay_range,
1184
                    **kwds,
1185
                )
1186
            else:
1187
                if datafile is None:
3✔
1188
                    try:
3✔
1189
                        datafile = self._files[0]
3✔
1190
                    except IndexError:
×
1191
                        print(
×
1192
                            "No datafile available, specify either",
1193
                            " 'datafile' or 'delay_range'",
1194
                        )
1195
                        raise
×
1196

1197
                self._dataframe, metadata = self.dc.append_delay_axis(
3✔
1198
                    self._dataframe,
1199
                    datafile=datafile,
1200
                    **kwds,
1201
                )
1202

1203
            # Add Metadata
1204
            self._attributes.add(
3✔
1205
                metadata,
1206
                "delay_calibration",
1207
                duplicate_policy="merge",
1208
            )
1209
            if preview:
3✔
1210
                print(self._dataframe.head(10))
3✔
1211
            else:
1212
                print(self._dataframe)
3✔
1213

1214
    def add_jitter(self, cols: Sequence[str] = None):
3✔
1215
        """Add jitter to the selected dataframe columns.
1216

1217
        Args:
1218
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1219
                Defaults to config["dataframe"]["jitter_cols"].
1220
        """
1221
        if cols is None:
3✔
1222
            cols = self._config["dataframe"].get(
3✔
1223
                "jitter_cols",
1224
                self._dataframe.columns,
1225
            )  # jitter all columns
1226

1227
        self._dataframe = self._dataframe.map_partitions(
3✔
1228
            apply_jitter,
1229
            cols=cols,
1230
            cols_jittered=cols,
1231
        )
1232
        metadata = []
3✔
1233
        for col in cols:
3✔
1234
            metadata.append(col)
3✔
1235
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
3✔
1236

1237
    def pre_binning(
3✔
1238
        self,
1239
        df_partitions: int = 100,
1240
        axes: List[str] = None,
1241
        bins: List[int] = None,
1242
        ranges: Sequence[Tuple[float, float]] = None,
1243
        **kwds,
1244
    ) -> xr.DataArray:
1245
        """Function to do an initial binning of the dataframe loaded to the class.
1246

1247
        Args:
1248
            df_partitions (int, optional): Number of dataframe partitions to use for
1249
                the initial binning. Defaults to 100.
1250
            axes (List[str], optional): Axes to bin.
1251
                Defaults to config["momentum"]["axes"].
1252
            bins (List[int], optional): Bin numbers to use for binning.
1253
                Defaults to config["momentum"]["bins"].
1254
            ranges (List[Tuple], optional): Ranges to use for binning.
1255
                Defaults to config["momentum"]["ranges"].
1256
            **kwds: Keyword argument passed to ``compute``.
1257

1258
        Returns:
1259
            xr.DataArray: pre-binned data-array.
1260
        """
1261
        if axes is None:
3✔
1262
            axes = self._config["momentum"]["axes"]
3✔
1263
        for loc, axis in enumerate(axes):
3✔
1264
            if axis.startswith("@"):
3✔
1265
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
3✔
1266

1267
        if bins is None:
3✔
1268
            bins = self._config["momentum"]["bins"]
3✔
1269
        if ranges is None:
3✔
1270
            ranges_ = list(self._config["momentum"]["ranges"])
3✔
1271
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
3✔
1272
                self._config["dataframe"]["tof_binning"] - 1
1273
            )
1274
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
3✔
1275

1276
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1277

1278
        return self.compute(
3✔
1279
            bins=bins,
1280
            axes=axes,
1281
            ranges=ranges,
1282
            df_partitions=df_partitions,
1283
            **kwds,
1284
        )
1285

1286
    def compute(
3✔
1287
        self,
1288
        bins: Union[
1289
            int,
1290
            dict,
1291
            tuple,
1292
            List[int],
1293
            List[np.ndarray],
1294
            List[tuple],
1295
        ] = 100,
1296
        axes: Union[str, Sequence[str]] = None,
1297
        ranges: Sequence[Tuple[float, float]] = None,
1298
        **kwds,
1299
    ) -> xr.DataArray:
1300
        """Compute the histogram along the given dimensions.
1301

1302
        Args:
1303
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1304
                Definition of the bins. Can be any of the following cases:
1305

1306
                - an integer describing the number of bins in on all dimensions
1307
                - a tuple of 3 numbers describing start, end and step of the binning
1308
                  range
1309
                - a np.arrays defining the binning edges
1310
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1311
                - a dictionary made of the axes as keys and any of the above as values.
1312

1313
                This takes priority over the axes and range arguments. Defaults to 100.
1314
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1315
                on which to calculate the histogram. The order will be the order of the
1316
                dimensions in the resulting array. Defaults to None.
1317
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1318
                the start and end point of the binning range. Defaults to None.
1319
            **kwds: Keyword arguments:
1320

1321
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1322
                  ``bin_dataframe`` for details. Defaults to
1323
                  config["binning"]["hist_mode"].
1324
                - **mode**: Defines how the results from each partition are combined.
1325
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1326
                  Defaults to config["binning"]["mode"].
1327
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1328
                  config["binning"]["pbar"].
1329
                - **n_cores**: Number of CPU cores to use for parallelization.
1330
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1331
                - **threads_per_worker**: Limit the number of threads that
1332
                  multiprocessing can spawn per binning thread. Defaults to
1333
                  config["binning"]["threads_per_worker"].
1334
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1335
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1336
                  config["binning"]["threadpool_API"].
1337
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1338
                  partitions.
1339

1340
                Additional kwds are passed to ``bin_dataframe``.
1341

1342
        Raises:
1343
            AssertError: Rises when no dataframe has been loaded.
1344

1345
        Returns:
1346
            xr.DataArray: The result of the n-dimensional binning represented in an
1347
            xarray object, combining the data with the axes.
1348
        """
1349
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1350

1351
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
3✔
1352
        mode = kwds.pop("mode", self._config["binning"]["mode"])
3✔
1353
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
3✔
1354
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
3✔
1355
        threads_per_worker = kwds.pop(
3✔
1356
            "threads_per_worker",
1357
            self._config["binning"]["threads_per_worker"],
1358
        )
1359
        threadpool_api = kwds.pop(
3✔
1360
            "threadpool_API",
1361
            self._config["binning"]["threadpool_API"],
1362
        )
1363
        df_partitions = kwds.pop("df_partitions", None)
3✔
1364
        if df_partitions is not None:
3✔
1365
            dataframe = self._dataframe.partitions[
3✔
1366
                0 : min(df_partitions, self._dataframe.npartitions)
1367
            ]
1368
        else:
1369
            dataframe = self._dataframe
3✔
1370

1371
        self._binned = bin_dataframe(
3✔
1372
            df=dataframe,
1373
            bins=bins,
1374
            axes=axes,
1375
            ranges=ranges,
1376
            hist_mode=hist_mode,
1377
            mode=mode,
1378
            pbar=pbar,
1379
            n_cores=num_cores,
1380
            threads_per_worker=threads_per_worker,
1381
            threadpool_api=threadpool_api,
1382
            **kwds,
1383
        )
1384

1385
        for dim in self._binned.dims:
3✔
1386
            try:
3✔
1387
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
3✔
1388
            except KeyError:
3✔
1389
                pass
3✔
1390

1391
        self._binned.attrs["units"] = "counts"
3✔
1392
        self._binned.attrs["long_name"] = "photoelectron counts"
3✔
1393
        self._binned.attrs["metadata"] = self._attributes.metadata
3✔
1394

1395
        return self._binned
3✔
1396

1397
    def view_event_histogram(
3✔
1398
        self,
1399
        dfpid: int,
1400
        ncol: int = 2,
1401
        bins: Sequence[int] = None,
1402
        axes: Sequence[str] = None,
1403
        ranges: Sequence[Tuple[float, float]] = None,
1404
        backend: str = "bokeh",
1405
        legend: bool = True,
1406
        histkwds: dict = None,
1407
        legkwds: dict = None,
1408
        **kwds,
1409
    ):
1410
        """Plot individual histograms of specified dimensions (axes) from a substituent
1411
        dataframe partition.
1412

1413
        Args:
1414
            dfpid (int): Number of the data frame partition to look at.
1415
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1416
            bins (Sequence[int], optional): Number of bins to use for the speicified
1417
                axes. Defaults to config["histogram"]["bins"].
1418
            axes (Sequence[str], optional): Names of the axes to display.
1419
                Defaults to config["histogram"]["axes"].
1420
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1421
                specified axes. Defaults toconfig["histogram"]["ranges"].
1422
            backend (str, optional): Backend of the plotting library
1423
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1424
            legend (bool, optional): Option to include a legend in the histogram plots.
1425
                Defaults to True.
1426
            histkwds (dict, optional): Keyword arguments for histograms
1427
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1428
            legkwds (dict, optional): Keyword arguments for legend
1429
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1430
            **kwds: Extra keyword arguments passed to
1431
                ``sed.diagnostics.grid_histogram()``.
1432

1433
        Raises:
1434
            TypeError: Raises when the input values are not of the correct type.
1435
        """
1436
        if bins is None:
3✔
1437
            bins = self._config["histogram"]["bins"]
3✔
1438
        if axes is None:
3✔
1439
            axes = self._config["histogram"]["axes"]
3✔
1440
        axes = list(axes)
3✔
1441
        for loc, axis in enumerate(axes):
3✔
1442
            if axis.startswith("@"):
3✔
1443
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
3✔
1444
        if ranges is None:
3✔
1445
            ranges = list(self._config["histogram"]["ranges"])
3✔
1446
            for loc, axis in enumerate(axes):
3✔
1447
                if axis == self._config["dataframe"]["tof_column"]:
3✔
1448
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
3✔
1449
                        self._config["dataframe"]["tof_binning"] - 1
1450
                    )
1451
                elif axis == self._config["dataframe"]["adc_column"]:
3✔
1452
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1453
                        self._config["dataframe"]["adc_binning"] - 1
1454
                    )
1455

1456
        input_types = map(type, [axes, bins, ranges])
3✔
1457
        allowed_types = [list, tuple]
3✔
1458

1459
        df = self._dataframe
3✔
1460

1461
        if not set(input_types).issubset(allowed_types):
3✔
1462
            raise TypeError(
×
1463
                "Inputs of axes, bins, ranges need to be list or tuple!",
1464
            )
1465

1466
        # Read out the values for the specified groups
1467
        group_dict_dd = {}
3✔
1468
        dfpart = df.get_partition(dfpid)
3✔
1469
        cols = dfpart.columns
3✔
1470
        for ax in axes:
3✔
1471
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
3✔
1472
        group_dict = ddf.compute(group_dict_dd)[0]
3✔
1473

1474
        # Plot multiple histograms in a grid
1475
        grid_histogram(
3✔
1476
            group_dict,
1477
            ncol=ncol,
1478
            rvs=axes,
1479
            rvbins=bins,
1480
            rvranges=ranges,
1481
            backend=backend,
1482
            legend=legend,
1483
            histkwds=histkwds,
1484
            legkwds=legkwds,
1485
            **kwds,
1486
        )
1487

1488
    def save(
3✔
1489
        self,
1490
        faddr: str,
1491
        **kwds,
1492
    ):
1493
        """Saves the binned data to the provided path and filename.
1494

1495
        Args:
1496
            faddr (str): Path and name of the file to write. Its extension determines
1497
                the file type to write. Valid file types are:
1498

1499
                - "*.tiff", "*.tif": Saves a TIFF stack.
1500
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1501
                - "*.nxs", "*.nexus": Saves a NeXus file.
1502

1503
            **kwds: Keyword argumens, which are passed to the writer functions:
1504
                For TIFF writing:
1505

1506
                - **alias_dict**: Dictionary of dimension aliases to use.
1507

1508
                For HDF5 writing:
1509

1510
                - **mode**: hdf5 read/write mode. Defaults to "w".
1511

1512
                For NeXus:
1513

1514
                - **reader**: Name of the nexustools reader to use.
1515
                  Defaults to config["nexus"]["reader"]
1516
                - **definiton**: NeXus application definition to use for saving.
1517
                  Must be supported by the used ``reader``. Defaults to
1518
                  config["nexus"]["definition"]
1519
                - **input_files**: A list of input files to pass to the reader.
1520
                  Defaults to config["nexus"]["input_files"]
1521
        """
1522
        if self._binned is None:
3✔
1523
            raise NameError("Need to bin data first!")
3✔
1524

1525
        extension = pathlib.Path(faddr).suffix
3✔
1526

1527
        if extension in (".tif", ".tiff"):
3✔
1528
            to_tiff(
3✔
1529
                data=self._binned,
1530
                faddr=faddr,
1531
                **kwds,
1532
            )
1533
        elif extension in (".h5", ".hdf5"):
3✔
1534
            to_h5(
3✔
1535
                data=self._binned,
1536
                faddr=faddr,
1537
                **kwds,
1538
            )
1539
        elif extension in (".nxs", ".nexus"):
3✔
1540
            try:
3✔
1541
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
3✔
1542
                definition = kwds.pop(
3✔
1543
                    "definition",
1544
                    self._config["nexus"]["definition"],
1545
                )
1546
                input_files = kwds.pop(
3✔
1547
                    "input_files",
1548
                    self._config["nexus"]["input_files"],
1549
                )
1550
            except KeyError as exc:
×
1551
                raise ValueError(
×
1552
                    "The nexus reader, definition and input files need to be provide!",
1553
                ) from exc
1554

1555
            if isinstance(input_files, str):
3✔
1556
                input_files = [input_files]
3✔
1557

1558
            to_nexus(
3✔
1559
                data=self._binned,
1560
                faddr=faddr,
1561
                reader=reader,
1562
                definition=definition,
1563
                input_files=input_files,
1564
                **kwds,
1565
            )
1566

1567
        else:
1568
            raise NotImplementedError(
3✔
1569
                f"Unrecognized file format: {extension}.",
1570
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc