• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6285960521

23 Sep 2023 08:59PM UTC coverage: 90.481% (+16.6%) from 73.908%
6285960521

push

github

web-flow
Merge pull request #143 from OpenCOMPES/processor_tests

Processor tests

642 of 642 new or added lines in 13 files covered. (100.0%)

4173 of 4612 relevant lines covered (90.48%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.56
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.calibrator import DelayCalibrator
1✔
22
from sed.calibrator import EnergyCalibrator
1✔
23
from sed.calibrator import MomentumCorrector
1✔
24
from sed.core.config import parse_config
1✔
25
from sed.core.config import save_config
1✔
26
from sed.core.dfops import apply_jitter
1✔
27
from sed.core.metadata import MetaHandler
1✔
28
from sed.diagnostics import grid_histogram
1✔
29
from sed.io import to_h5
1✔
30
from sed.io import to_nexus
1✔
31
from sed.io import to_tiff
1✔
32
from sed.loader import CopyTool
1✔
33
from sed.loader import get_loader
1✔
34

35
N_CPU = psutil.cpu_count()
1✔
36

37

38
class SedProcessor:
1✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
1✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to parse_config and to the reader.
86
        """
87
        config_kwds = {
1✔
88
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
89
        }
90
        for key in config_kwds.keys():
1✔
91
            del kwds[key]
1✔
92
        self._config = parse_config(config, **config_kwds)
1✔
93
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
94
        if num_cores >= N_CPU:
1✔
95
            num_cores = N_CPU - 1
1✔
96
        self._config["binning"]["num_cores"] = num_cores
1✔
97

98
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
99
        self._files: List[str] = []
1✔
100

101
        self._binned: xr.DataArray = None
1✔
102
        self._pre_binned: xr.DataArray = None
1✔
103

104
        self._attributes = MetaHandler(meta=metadata)
1✔
105

106
        loader_name = self._config["core"]["loader"]
1✔
107
        self.loader = get_loader(
1✔
108
            loader_name=loader_name,
109
            config=self._config,
110
        )
111

112
        self.ec = EnergyCalibrator(
1✔
113
            loader=self.loader,
114
            config=self._config,
115
        )
116

117
        self.mc = MomentumCorrector(
1✔
118
            config=self._config,
119
        )
120

121
        self.dc = DelayCalibrator(
1✔
122
            config=self._config,
123
        )
124

125
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
126
            "use_copy_tool",
127
            False,
128
        )
129
        if self.use_copy_tool:
1✔
130
            try:
1✔
131
                self.ct = CopyTool(
1✔
132
                    source=self._config["core"]["copy_tool_source"],
133
                    dest=self._config["core"]["copy_tool_dest"],
134
                    **self._config["core"].get("copy_tool_kwds", {}),
135
                )
136
            except KeyError:
1✔
137
                self.use_copy_tool = False
1✔
138

139
        # Load data if provided:
140
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
141
            self.load(
1✔
142
                dataframe=dataframe,
143
                metadata=metadata,
144
                files=files,
145
                folder=folder,
146
                runs=runs,
147
                collect_metadata=collect_metadata,
148
                **kwds,
149
            )
150

151
    def __repr__(self):
1✔
152
        if self._dataframe is None:
1✔
153
            df_str = "Data Frame: No Data loaded"
1✔
154
        else:
155
            df_str = self._dataframe.__repr__()
1✔
156
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
157
        pretty_str = df_str + "\n" + attributes_str
1✔
158
        return pretty_str
1✔
159

160
    @property
1✔
161
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
162
        """Accessor to the underlying dataframe.
163

164
        Returns:
165
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
166
        """
167
        return self._dataframe
1✔
168

169
    @dataframe.setter
1✔
170
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
171
        """Setter for the underlying dataframe.
172

173
        Args:
174
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
175
        """
176
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
177
            dataframe,
178
            self._dataframe.__class__,
179
        ):
180
            raise ValueError(
1✔
181
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
182
                "as the dataframe loaded into the SedProcessor!.\n"
183
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
184
            )
185
        self._dataframe = dataframe
1✔
186

187
    @property
1✔
188
    def attributes(self) -> dict:
1✔
189
        """Accessor to the metadata dict.
190

191
        Returns:
192
            dict: The metadata dict.
193
        """
194
        return self._attributes.metadata
1✔
195

196
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
197
        """Function to add element to the attributes dict.
198

199
        Args:
200
            attributes (dict): The attributes dictionary object to add.
201
            name (str): Key under which to add the dictionary to the attributes.
202
        """
203
        self._attributes.add(
1✔
204
            entry=attributes,
205
            name=name,
206
            **kwds,
207
        )
208

209
    @property
1✔
210
    def config(self) -> Dict[Any, Any]:
1✔
211
        """Getter attribute for the config dictionary
212

213
        Returns:
214
            Dict: The config dictionary.
215
        """
216
        return self._config
1✔
217

218
    @property
1✔
219
    def files(self) -> List[str]:
1✔
220
        """Getter attribute for the list of files
221

222
        Returns:
223
            List[str]: The list of loaded files
224
        """
225
        return self._files
1✔
226

227
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
228
        """Function to mirror a list of files or a folder from a network drive to a
229
        local storage. Returns either the original or the copied path to the given
230
        path. The option to use this functionality is set by
231
        config["core"]["use_copy_tool"].
232

233
        Args:
234
            path (Union[str, List[str]]): Source path or path list.
235

236
        Returns:
237
            Union[str, List[str]]: Source or destination path or path list.
238
        """
239
        if self.use_copy_tool:
1✔
240
            if isinstance(path, list):
1✔
241
                path_out = []
1✔
242
                for file in path:
1✔
243
                    path_out.append(self.ct.copy(file))
1✔
244
                return path_out
1✔
245

246
            return self.ct.copy(path)
×
247

248
        if isinstance(path, list):
1✔
249
            return path
1✔
250

251
        return path
1✔
252

253
    def load(
1✔
254
        self,
255
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
256
        metadata: dict = None,
257
        files: List[str] = None,
258
        folder: str = None,
259
        runs: Sequence[str] = None,
260
        collect_metadata: bool = False,
261
        **kwds,
262
    ):
263
        """Load tabular data of single events into the dataframe object in the class.
264

265
        Args:
266
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
267
                format. Accepts anything which can be interpreted by pd.DataFrame as
268
                an input. Defaults to None.
269
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
270
            files (List[str], optional): List of file paths to pass to the loader.
271
                Defaults to None.
272
            runs (Sequence[str], optional): List of run identifiers to pass to the
273
                loader. Defaults to None.
274
            folder (str, optional): Folder path to pass to the loader.
275
                Defaults to None.
276

277
        Raises:
278
            ValueError: Raised if no valid input is provided.
279
        """
280
        if metadata is None:
1✔
281
            metadata = {}
1✔
282
        if dataframe is not None:
1✔
283
            self._dataframe = dataframe
1✔
284
        elif runs is not None:
1✔
285
            # If runs are provided, we only use the copy tool if also folder is provided.
286
            # In that case, we copy the whole provided base folder tree, and pass the copied
287
            # version to the loader as base folder to look for the runs.
288
            if folder is not None:
1✔
289
                dataframe, metadata = self.loader.read_dataframe(
1✔
290
                    folders=cast(str, self.cpy(folder)),
291
                    runs=runs,
292
                    metadata=metadata,
293
                    collect_metadata=collect_metadata,
294
                    **kwds,
295
                )
296
            else:
297
                dataframe, metadata = self.loader.read_dataframe(
×
298
                    runs=runs,
299
                    metadata=metadata,
300
                    collect_metadata=collect_metadata,
301
                    **kwds,
302
                )
303

304
        elif folder is not None:
1✔
305
            dataframe, metadata = self.loader.read_dataframe(
1✔
306
                folders=cast(str, self.cpy(folder)),
307
                metadata=metadata,
308
                collect_metadata=collect_metadata,
309
                **kwds,
310
            )
311

312
        elif files is not None:
1✔
313
            dataframe, metadata = self.loader.read_dataframe(
1✔
314
                files=cast(List[str], self.cpy(files)),
315
                metadata=metadata,
316
                collect_metadata=collect_metadata,
317
                **kwds,
318
            )
319

320
        else:
321
            raise ValueError(
1✔
322
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
323
            )
324

325
        self._dataframe = dataframe
1✔
326
        self._files = self.loader.files
1✔
327

328
        for key in metadata:
1✔
329
            self._attributes.add(
1✔
330
                entry=metadata[key],
331
                name=key,
332
                duplicate_policy="merge",
333
            )
334

335
    # Momentum calibration workflow
336
    # 1. Bin raw detector data for distortion correction
337
    def bin_and_load_momentum_calibration(
1✔
338
        self,
339
        df_partitions: int = 100,
340
        axes: List[str] = None,
341
        bins: List[int] = None,
342
        ranges: Sequence[Tuple[float, float]] = None,
343
        plane: int = 0,
344
        width: int = 5,
345
        apply: bool = False,
346
        **kwds,
347
    ):
348
        """1st step of momentum correction work flow. Function to do an initial binning
349
        of the dataframe loaded to the class, slice a plane from it using an
350
        interactive view, and load it into the momentum corrector class.
351

352
        Args:
353
            df_partitions (int, optional): Number of dataframe partitions to use for
354
                the initial binning. Defaults to 100.
355
            axes (List[str], optional): Axes to bin.
356
                Defaults to config["momentum"]["axes"].
357
            bins (List[int], optional): Bin numbers to use for binning.
358
                Defaults to config["momentum"]["bins"].
359
            ranges (List[Tuple], optional): Ranges to use for binning.
360
                Defaults to config["momentum"]["ranges"].
361
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
362
            width (int, optional): Initial value for the width slider. Defaults to 5.
363
            apply (bool, optional): Option to directly apply the values and select the
364
                slice. Defaults to False.
365
            **kwds: Keyword argument passed to the pre_binning function.
366
        """
367
        self._pre_binned = self.pre_binning(
1✔
368
            df_partitions=df_partitions,
369
            axes=axes,
370
            bins=bins,
371
            ranges=ranges,
372
            **kwds,
373
        )
374

375
        self.mc.load_data(data=self._pre_binned)
1✔
376
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
377

378
    # 2. Generate the spline warp correction from momentum features.
379
    # Either autoselect features, or input features from view above.
380
    def define_features(
1✔
381
        self,
382
        features: np.ndarray = None,
383
        rotation_symmetry: int = 6,
384
        auto_detect: bool = False,
385
        include_center: bool = True,
386
        apply: bool = False,
387
        **kwds,
388
    ):
389
        """2. Step of the distortion correction workflow: Define feature points in
390
        momentum space. They can be either manually selected using a GUI tool, be
391
        ptovided as list of feature points, or auto-generated using a
392
        feature-detection algorithm.
393

394
        Args:
395
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
396
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
397
                Defaults to 6.
398
            auto_detect (bool, optional): Whether to auto-detect the features.
399
                Defaults to False.
400
            include_center (bool, optional): Option to include a point at the center
401
                in the feature list. Defaults to True.
402
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
403
                MomentumCorrector.feature_select()
404
        """
405
        if auto_detect:  # automatic feature selection
1✔
406
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
407
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
408
            sigma_radius = kwds.pop(
×
409
                "sigma_radius",
410
                self._config["momentum"]["sigma_radius"],
411
            )
412
            self.mc.feature_extract(
×
413
                sigma=sigma,
414
                fwhm=fwhm,
415
                sigma_radius=sigma_radius,
416
                rotsym=rotation_symmetry,
417
                **kwds,
418
            )
419
            features = self.mc.peaks
×
420

421
        self.mc.feature_select(
1✔
422
            rotsym=rotation_symmetry,
423
            include_center=include_center,
424
            features=features,
425
            apply=apply,
426
            **kwds,
427
        )
428

429
    # 3. Generate the spline warp correction from momentum features.
430
    # If no features have been selected before, use class defaults.
431
    def generate_splinewarp(
1✔
432
        self,
433
        use_center: bool = None,
434
        **kwds,
435
    ):
436
        """3. Step of the distortion correction workflow: Generate the correction
437
        function restoring the symmetry in the image using a splinewarp algortihm.
438

439
        Args:
440
            use_center (bool, optional): Option to use the position of the
441
                center point in the correction. Default is read from config, or set to True.
442
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
443
        """
444
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
445

446
        if self.mc.slice is not None:
1✔
447
            print("Original slice with reference features")
1✔
448
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
449

450
            print("Corrected slice with target features")
1✔
451
            self.mc.view(
1✔
452
                image=self.mc.slice_corrected,
453
                annotated=True,
454
                points={"feats": self.mc.ptargs},
455
                backend="bokeh",
456
                crosshair=True,
457
            )
458

459
            print("Original slice with target features")
1✔
460
            self.mc.view(
1✔
461
                image=self.mc.slice,
462
                points={"feats": self.mc.ptargs},
463
                annotated=True,
464
                backend="bokeh",
465
            )
466

467
    # 3a. Save spline-warp parameters to config file.
468
    def save_splinewarp(
1✔
469
        self,
470
        filename: str = None,
471
        overwrite: bool = False,
472
    ):
473
        """Save the generated spline-warp parameters to the folder config file.
474

475
        Args:
476
            filename (str, optional): Filename of the config dictionary to save to.
477
                Defaults to "sed_config.yaml" in the current folder.
478
            overwrite (bool, optional): Option to overwrite the present dictionary.
479
                Defaults to False.
480
        """
481
        if filename is None:
1✔
482
            filename = "sed_config.yaml"
×
483
        points = []
1✔
484
        try:
1✔
485
            for point in self.mc.pouter_ord:
1✔
486
                points.append([float(i) for i in point])
1✔
487
            if self.mc.include_center:
1✔
488
                points.append([float(i) for i in self.mc.pcent])
1✔
489
        except AttributeError as exc:
×
490
            raise AttributeError(
×
491
                "Momentum correction parameters not found, need to generate parameters first!",
492
            ) from exc
493
        config = {
1✔
494
            "momentum": {
495
                "correction": {
496
                    "rotation_symmetry": self.mc.rotsym,
497
                    "feature_points": points,
498
                    "include_center": self.mc.include_center,
499
                    "use_center": self.mc.use_center,
500
                },
501
            },
502
        }
503
        save_config(config, filename, overwrite)
1✔
504

505
    # 4. Pose corrections. Provide interactive interface for correcting
506
    # scaling, shift and rotation
507
    def pose_adjustment(
1✔
508
        self,
509
        scale: float = 1,
510
        xtrans: float = 0,
511
        ytrans: float = 0,
512
        angle: float = 0,
513
        apply: bool = False,
514
        use_correction: bool = True,
515
    ):
516
        """3. step of the distortion correction workflow: Generate an interactive panel
517
        to adjust affine transformations that are applied to the image. Applies first
518
        a scaling, next an x/y translation, and last a rotation around the center of
519
        the image.
520

521
        Args:
522
            scale (float, optional): Initial value of the scaling slider.
523
                Defaults to 1.
524
            xtrans (float, optional): Initial value of the xtrans slider.
525
                Defaults to 0.
526
            ytrans (float, optional): Initial value of the ytrans slider.
527
                Defaults to 0.
528
            angle (float, optional): Initial value of the angle slider.
529
                Defaults to 0.
530
            apply (bool, optional): Option to directly apply the provided
531
                transformations. Defaults to False.
532
            use_correction (bool, option): Whether to use the spline warp correction
533
                or not. Defaults to True.
534
        """
535
        # Generate homomorphy as default if no distortion correction has been applied
536
        if self.mc.slice_corrected is None:
1✔
537
            if self.mc.slice is None:
1✔
538
                raise ValueError(
1✔
539
                    "No slice for corrections and transformations loaded!",
540
                )
541
            self.mc.slice_corrected = self.mc.slice
×
542

543
        if not use_correction:
1✔
544
            self.mc.reset_deformation()
1✔
545

546
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
547
            # Generate default distortion correction
548
            self.mc.add_features()
×
549
            self.mc.spline_warp_estimate()
×
550

551
        self.mc.pose_adjustment(
1✔
552
            scale=scale,
553
            xtrans=xtrans,
554
            ytrans=ytrans,
555
            angle=angle,
556
            apply=apply,
557
        )
558

559
    # 5. Apply the momentum correction to the dataframe
560
    def apply_momentum_correction(
1✔
561
        self,
562
        preview: bool = False,
563
    ):
564
        """Applies the distortion correction and pose adjustment (optional)
565
        to the dataframe.
566

567
        Args:
568
            rdeform_field (np.ndarray, optional): Row deformation field.
569
                Defaults to None.
570
            cdeform_field (np.ndarray, optional): Column deformation field.
571
                Defaults to None.
572
            inv_dfield (np.ndarray, optional): Inverse deformation field.
573
                Defaults to None.
574
            preview (bool): Option to preview the first elements of the data frame.
575
        """
576
        if self._dataframe is not None:
1✔
577
            print("Adding corrected X/Y columns to dataframe:")
1✔
578
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
579
                df=self._dataframe,
580
            )
581
            # Add Metadata
582
            self._attributes.add(
1✔
583
                metadata,
584
                "momentum_correction",
585
                duplicate_policy="merge",
586
            )
587
            if preview:
1✔
588
                print(self._dataframe.head(10))
×
589
            else:
590
                print(self._dataframe)
1✔
591

592
    # Momentum calibration work flow
593
    # 1. Calculate momentum calibration
594
    def calibrate_momentum_axes(
1✔
595
        self,
596
        point_a: Union[np.ndarray, List[int]] = None,
597
        point_b: Union[np.ndarray, List[int]] = None,
598
        k_distance: float = None,
599
        k_coord_a: Union[np.ndarray, List[float]] = None,
600
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
601
        equiscale: bool = True,
602
        apply=False,
603
    ):
604
        """1. step of the momentum calibration workflow. Calibrate momentum
605
        axes using either provided pixel coordinates of a high-symmetry point and its
606
        distance to the BZ center, or the k-coordinates of two points in the BZ
607
        (depending on the equiscale option). Opens an interactive panel for selecting
608
        the points.
609

610
        Args:
611
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
612
                point used for momentum calibration.
613
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
614
                second point used for momentum calibration.
615
                Defaults to config["momentum"]["center_pixel"].
616
            k_distance (float, optional): Momentum distance between point a and b.
617
                Needs to be provided if no specific k-koordinates for the two points
618
                are given. Defaults to None.
619
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
620
                of the first point used for calibration. Used if equiscale is False.
621
                Defaults to None.
622
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
623
                of the second point used for calibration. Defaults to [0.0, 0.0].
624
            equiscale (bool, optional): Option to apply different scales to kx and ky.
625
                If True, the distance between points a and b, and the absolute
626
                position of point a are used for defining the scale. If False, the
627
                scale is calculated from the k-positions of both points a and b.
628
                Defaults to True.
629
            apply (bool, optional): Option to directly store the momentum calibration
630
                in the class. Defaults to False.
631
        """
632
        if point_b is None:
1✔
633
            point_b = self._config["momentum"]["center_pixel"]
1✔
634

635
        self.mc.select_k_range(
1✔
636
            point_a=point_a,
637
            point_b=point_b,
638
            k_distance=k_distance,
639
            k_coord_a=k_coord_a,
640
            k_coord_b=k_coord_b,
641
            equiscale=equiscale,
642
            apply=apply,
643
        )
644

645
    # 1a. Save momentum calibration parameters to config file.
646
    def save_momentum_calibration(
1✔
647
        self,
648
        filename: str = None,
649
        overwrite: bool = False,
650
    ):
651
        """Save the generated momentum calibration parameters to the folder config file.
652

653
        Args:
654
            filename (str, optional): Filename of the config dictionary to save to.
655
                Defaults to "sed_config.yaml" in the current folder.
656
            overwrite (bool, optional): Option to overwrite the present dictionary.
657
                Defaults to False.
658
        """
659
        if filename is None:
1✔
660
            filename = "sed_config.yaml"
1✔
661
        calibration = {}
1✔
662
        try:
1✔
663
            for key in [
1✔
664
                "kx_scale",
665
                "ky_scale",
666
                "x_center",
667
                "y_center",
668
                "rstart",
669
                "cstart",
670
                "rstep",
671
                "cstep",
672
            ]:
673
                calibration[key] = float(self.mc.calibration[key])
1✔
674
        except KeyError as exc:
×
675
            raise KeyError(
×
676
                "Momentum calibration parameters not found, need to generate parameters first!",
677
            ) from exc
678

679
        config = {"momentum": {"calibration": calibration}}
1✔
680
        save_config(config, filename, overwrite)
1✔
681

682
    # 2. Apply correction and calibration to the dataframe
683
    def apply_momentum_calibration(
1✔
684
        self,
685
        calibration: dict = None,
686
        preview: bool = False,
687
    ):
688
        """2. step of the momentum calibration work flow: Apply the momentum
689
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
690
        these are used.
691

692
        Args:
693
            calibration (dict, optional): Optional dictionary with calibration data to
694
                use. Defaults to None.
695
            preview (bool): Option to preview the first elements of the data frame.
696
        """
697
        if self._dataframe is not None:
1✔
698

699
            print("Adding kx/ky columns to dataframe:")
1✔
700
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
701
                df=self._dataframe,
702
                calibration=calibration,
703
            )
704

705
            # Add Metadata
706
            self._attributes.add(
1✔
707
                metadata,
708
                "momentum_calibration",
709
                duplicate_policy="merge",
710
            )
711
            if preview:
1✔
712
                print(self._dataframe.head(10))
×
713
            else:
714
                print(self._dataframe)
1✔
715

716
    # Energy correction workflow
717
    # 1. Adjust the energy correction parameters
718
    def adjust_energy_correction(
1✔
719
        self,
720
        correction_type: str = None,
721
        amplitude: float = None,
722
        center: Tuple[float, float] = None,
723
        apply=False,
724
        **kwds,
725
    ):
726
        """1. step of the energy crrection workflow: Opens an interactive plot to
727
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
728
        they are not present yet.
729

730
        Args:
731
            correction_type (str, optional): Type of correction to apply to the TOF
732
                axis. Valid values are:
733

734
                - 'spherical'
735
                - 'Lorentzian'
736
                - 'Gaussian'
737
                - 'Lorentzian_asymmetric'
738

739
                Defaults to config["energy"]["correction_type"].
740
            amplitude (float, optional): Amplitude of the correction.
741
                Defaults to config["energy"]["correction"]["amplitude"].
742
            center (Tuple[float, float], optional): Center X/Y coordinates for the
743
                correction. Defaults to config["energy"]["correction"]["center"].
744
            apply (bool, optional): Option to directly apply the provided or default
745
                correction parameters. Defaults to False.
746
        """
747
        if self._pre_binned is None:
1✔
748
            print(
1✔
749
                "Pre-binned data not present, binning using defaults from config...",
750
            )
751
            self._pre_binned = self.pre_binning()
1✔
752

753
        self.ec.adjust_energy_correction(
1✔
754
            self._pre_binned,
755
            correction_type=correction_type,
756
            amplitude=amplitude,
757
            center=center,
758
            apply=apply,
759
            **kwds,
760
        )
761

762
    # 1a. Save energy correction parameters to config file.
763
    def save_energy_correction(
1✔
764
        self,
765
        filename: str = None,
766
        overwrite: bool = False,
767
    ):
768
        """Save the generated energy correction parameters to the folder config file.
769

770
        Args:
771
            filename (str, optional): Filename of the config dictionary to save to.
772
                Defaults to "sed_config.yaml" in the current folder.
773
            overwrite (bool, optional): Option to overwrite the present dictionary.
774
                Defaults to False.
775
        """
776
        if filename is None:
1✔
777
            filename = "sed_config.yaml"
1✔
778
        correction = {}
1✔
779
        try:
1✔
780
            for key, val in self.ec.correction.items():
1✔
781
                if key == "correction_type":
1✔
782
                    correction[key] = val
1✔
783
                elif key == "center":
1✔
784
                    correction[key] = [float(i) for i in val]
1✔
785
                else:
786
                    correction[key] = float(val)
1✔
787
        except AttributeError as exc:
×
788
            raise AttributeError(
×
789
                "Energy correction parameters not found, need to generate parameters first!",
790
            ) from exc
791

792
        config = {"energy": {"correction": correction}}
1✔
793
        save_config(config, filename, overwrite)
1✔
794

795
    # 2. Apply energy correction to dataframe
796
    def apply_energy_correction(
1✔
797
        self,
798
        correction: dict = None,
799
        preview: bool = False,
800
        **kwds,
801
    ):
802
        """2. step of the energy correction workflow: Apply the enery correction
803
        parameters stored in the class to the dataframe.
804

805
        Args:
806
            correction (dict, optional): Dictionary containing the correction
807
                parameters. Defaults to config["energy"]["calibration"].
808
            preview (bool): Option to preview the first elements of the data frame.
809
            **kwds:
810
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
811
            preview (bool): Option to preview the first elements of the data frame.
812
            **kwds:
813
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
814
        """
815
        if self._dataframe is not None:
1✔
816
            print("Applying energy correction to dataframe...")
1✔
817
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
818
                df=self._dataframe,
819
                correction=correction,
820
                **kwds,
821
            )
822

823
            # Add Metadata
824
            self._attributes.add(
1✔
825
                metadata,
826
                "energy_correction",
827
            )
828
            if preview:
1✔
829
                print(self._dataframe.head(10))
×
830
            else:
831
                print(self._dataframe)
1✔
832

833
    # Energy calibrator workflow
834
    # 1. Load and normalize data
835
    def load_bias_series(
1✔
836
        self,
837
        data_files: List[str],
838
        axes: List[str] = None,
839
        bins: List = None,
840
        ranges: Sequence[Tuple[float, float]] = None,
841
        biases: np.ndarray = None,
842
        bias_key: str = None,
843
        normalize: bool = None,
844
        span: int = None,
845
        order: int = None,
846
    ):
847
        """1. step of the energy calibration workflow: Load and bin data from
848
        single-event files.
849

850
        Args:
851
            data_files (List[str]): list of file paths to bin
852
            axes (List[str], optional): bin axes.
853
                Defaults to config["dataframe"]["tof_column"].
854
            bins (List, optional): number of bins.
855
                Defaults to config["energy"]["bins"].
856
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
857
                Defaults to config["energy"]["ranges"].
858
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
859
                voltages are extracted from the data files.
860
            bias_key (str, optional): hdf5 path where bias values are stored.
861
                Defaults to config["energy"]["bias_key"].
862
            normalize (bool, optional): Option to normalize traces.
863
                Defaults to config["energy"]["normalize"].
864
            span (int, optional): span smoothing parameters of the LOESS method
865
                (see ``scipy.signal.savgol_filter()``).
866
                Defaults to config["energy"]["normalize_span"].
867
            order (int, optional): order smoothing parameters of the LOESS method
868
                (see ``scipy.signal.savgol_filter()``).
869
                Defaults to config["energy"]["normalize_order"].
870
        """
871
        self.ec.bin_data(
1✔
872
            data_files=cast(List[str], self.cpy(data_files)),
873
            axes=axes,
874
            bins=bins,
875
            ranges=ranges,
876
            biases=biases,
877
            bias_key=bias_key,
878
        )
879
        if (normalize is not None and normalize is True) or (
1✔
880
            normalize is None and self._config["energy"]["normalize"]
881
        ):
882
            if span is None:
1✔
883
                span = self._config["energy"]["normalize_span"]
1✔
884
            if order is None:
1✔
885
                order = self._config["energy"]["normalize_order"]
1✔
886
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
887
        self.ec.view(
1✔
888
            traces=self.ec.traces_normed,
889
            xaxis=self.ec.tof,
890
            backend="bokeh",
891
        )
892

893
    # 2. extract ranges and get peak positions
894
    def find_bias_peaks(
1✔
895
        self,
896
        ranges: Union[List[Tuple], Tuple],
897
        ref_id: int = 0,
898
        infer_others: bool = True,
899
        mode: str = "replace",
900
        radius: int = None,
901
        peak_window: int = None,
902
        apply: bool = False,
903
    ):
904
        """2. step of the energy calibration workflow: Find a peak within a given range
905
        for the indicated reference trace, and tries to find the same peak for all
906
        other traces. Uses fast_dtw to align curves, which might not be too good if the
907
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
908
        middle of the set, and don't choose the range too narrow around the peak.
909
        Alternatively, a list of ranges for all traces can be provided.
910

911
        Args:
912
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
913
                Alternatively, a list of ranges for all traces can be given.
914
            refid (int, optional): The id of the trace the range refers to.
915
                Defaults to 0.
916
            infer_others (bool, optional): Whether to determine the range for the other
917
                traces. Defaults to True.
918
            mode (str, optional): Whether to "add" or "replace" existing ranges.
919
                Defaults to "replace".
920
            radius (int, optional): Radius parameter for fast_dtw.
921
                Defaults to config["energy"]["fastdtw_radius"].
922
            peak_window (int, optional): Peak_window parameter for the peak detection
923
                algorthm. amount of points that have to have to behave monotoneously
924
                around a peak. Defaults to config["energy"]["peak_window"].
925
            apply (bool, optional): Option to directly apply the provided parameters.
926
                Defaults to False.
927
        """
928
        if radius is None:
1✔
929
            radius = self._config["energy"]["fastdtw_radius"]
1✔
930
        if peak_window is None:
1✔
931
            peak_window = self._config["energy"]["peak_window"]
1✔
932
        if not infer_others:
1✔
933
            self.ec.add_ranges(
1✔
934
                ranges=ranges,
935
                ref_id=ref_id,
936
                infer_others=infer_others,
937
                mode=mode,
938
                radius=radius,
939
            )
940
            print(self.ec.featranges)
1✔
941
            try:
1✔
942
                self.ec.feature_extract(peak_window=peak_window)
1✔
943
                self.ec.view(
1✔
944
                    traces=self.ec.traces_normed,
945
                    segs=self.ec.featranges,
946
                    xaxis=self.ec.tof,
947
                    peaks=self.ec.peaks,
948
                    backend="bokeh",
949
                )
950
            except IndexError:
×
951
                print("Could not determine all peaks!")
×
952
                raise
×
953
        else:
954
            # New adjustment tool
955
            assert isinstance(ranges, tuple)
1✔
956
            self.ec.adjust_ranges(
1✔
957
                ranges=ranges,
958
                ref_id=ref_id,
959
                traces=self.ec.traces_normed,
960
                infer_others=infer_others,
961
                radius=radius,
962
                peak_window=peak_window,
963
                apply=apply,
964
            )
965

966
    # 3. Fit the energy calibration relation
967
    def calibrate_energy_axis(
1✔
968
        self,
969
        ref_id: int,
970
        ref_energy: float,
971
        method: str = None,
972
        energy_scale: str = None,
973
        **kwds,
974
    ):
975
        """3. Step of the energy calibration workflow: Calculate the calibration
976
        function for the energy axis, and apply it to the dataframe. Two
977
        approximations are implemented, a (normally 3rd order) polynomial
978
        approximation, and a d^2/(t-t0)^2 relation.
979

980
        Args:
981
            ref_id (int): id of the trace at the bias where the reference energy is
982
                given.
983
            ref_energy (float): Absolute energy of the detected feature at the bias
984
                of ref_id
985
            method (str, optional): Method for determining the energy calibration.
986

987
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
988
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
989

990
                Defaults to config["energy"]["calibration_method"]
991
            energy_scale (str, optional): Direction of increasing energy scale.
992

993
                - **'kinetic'**: increasing energy with decreasing TOF.
994
                - **'binding'**: increasing energy with increasing TOF.
995

996
                Defaults to config["energy"]["energy_scale"]
997
        """
998
        if method is None:
1✔
999
            method = self._config["energy"]["calibration_method"]
1✔
1000

1001
        if energy_scale is None:
1✔
1002
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1003

1004
        self.ec.calibrate(
1✔
1005
            ref_id=ref_id,
1006
            ref_energy=ref_energy,
1007
            method=method,
1008
            energy_scale=energy_scale,
1009
            **kwds,
1010
        )
1011
        print("Quality of Calibration:")
1✔
1012
        self.ec.view(
1✔
1013
            traces=self.ec.traces_normed,
1014
            xaxis=self.ec.calibration["axis"],
1015
            align=True,
1016
            energy_scale=energy_scale,
1017
            backend="bokeh",
1018
        )
1019
        print("E/TOF relationship:")
1✔
1020
        self.ec.view(
1✔
1021
            traces=self.ec.calibration["axis"][None, :],
1022
            xaxis=self.ec.tof,
1023
            backend="matplotlib",
1024
            show_legend=False,
1025
        )
1026
        if energy_scale == "kinetic":
1✔
1027
            plt.scatter(
1✔
1028
                self.ec.peaks[:, 0],
1029
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1030
                s=50,
1031
                c="k",
1032
            )
1033
        elif energy_scale == "binding":
1✔
1034
            plt.scatter(
1✔
1035
                self.ec.peaks[:, 0],
1036
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1037
                s=50,
1038
                c="k",
1039
            )
1040
        else:
1041
            raise ValueError(
×
1042
                'energy_scale needs to be either "binding" or "kinetic"',
1043
                f", got {energy_scale}.",
1044
            )
1045
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1046
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1047
        plt.show()
1✔
1048

1049
    # 3a. Save energy calibration parameters to config file.
1050
    def save_energy_calibration(
1✔
1051
        self,
1052
        filename: str = None,
1053
        overwrite: bool = False,
1054
    ):
1055
        """Save the generated energy calibration parameters to the folder config file.
1056

1057
        Args:
1058
            filename (str, optional): Filename of the config dictionary to save to.
1059
                Defaults to "sed_config.yaml" in the current folder.
1060
            overwrite (bool, optional): Option to overwrite the present dictionary.
1061
                Defaults to False.
1062
        """
1063
        if filename is None:
1✔
1064
            filename = "sed_config.yaml"
1✔
1065
        calibration = {}
1✔
1066
        try:
1✔
1067
            for (key, value) in self.ec.calibration.items():
1✔
1068
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1069
                    continue
1✔
1070
                if key == "energy_scale":
1✔
1071
                    calibration[key] = value
1✔
1072
                elif key == "coeffs":
1✔
1073
                    calibration[key] = [float(i) for i in value]
1✔
1074
                else:
1075
                    calibration[key] = float(value)
1✔
1076
        except AttributeError as exc:
×
1077
            raise AttributeError(
×
1078
                "Energy calibration parameters not found, need to generate parameters first!",
1079
            ) from exc
1080

1081
        config = {"energy": {"calibration": calibration}}
1✔
1082
        save_config(config, filename, overwrite)
1✔
1083

1084
    # 4. Apply energy calibration to the dataframe
1085
    def append_energy_axis(
1✔
1086
        self,
1087
        calibration: dict = None,
1088
        preview: bool = False,
1089
        **kwds,
1090
    ):
1091
        """4. step of the energy calibration workflow: Apply the calibration function
1092
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1093
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1094
        can be provided.
1095

1096
        Args:
1097
            calibration (dict, optional): Calibration dict containing calibration
1098
                parameters. Overrides calibration from class or config.
1099
                Defaults to None.
1100
            preview (bool): Option to preview the first elements of the data frame.
1101
            **kwds:
1102
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1103
        """
1104
        if self._dataframe is not None:
1✔
1105
            print("Adding energy column to dataframe:")
1✔
1106
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1107
                df=self._dataframe,
1108
                calibration=calibration,
1109
                **kwds,
1110
            )
1111

1112
            # Add Metadata
1113
            self._attributes.add(
1✔
1114
                metadata,
1115
                "energy_calibration",
1116
                duplicate_policy="merge",
1117
            )
1118
            if preview:
1✔
1119
                print(self._dataframe.head(10))
1✔
1120
            else:
1121
                print(self._dataframe)
1✔
1122

1123
    # Delay calibration function
1124
    def calibrate_delay_axis(
1✔
1125
        self,
1126
        delay_range: Tuple[float, float] = None,
1127
        datafile: str = None,
1128
        preview: bool = False,
1129
        **kwds,
1130
    ):
1131
        """Append delay column to dataframe. Either provide delay ranges, or read
1132
        them from a file.
1133

1134
        Args:
1135
            delay_range (Tuple[float, float], optional): The scanned delay range in
1136
                picoseconds. Defaults to None.
1137
            datafile (str, optional): The file from which to read the delay ranges.
1138
                Defaults to None.
1139
            preview (bool): Option to preview the first elements of the data frame.
1140
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1141
        """
1142
        if self._dataframe is not None:
1✔
1143
            print("Adding delay column to dataframe:")
1✔
1144

1145
            if delay_range is not None:
1✔
1146
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1147
                    self._dataframe,
1148
                    delay_range=delay_range,
1149
                    **kwds,
1150
                )
1151
            else:
1152
                if datafile is None:
1✔
1153
                    try:
1✔
1154
                        datafile = self._files[0]
1✔
1155
                    except IndexError:
×
1156
                        print(
×
1157
                            "No datafile available, specify eihter",
1158
                            " 'datafile' or 'delay_range'",
1159
                        )
1160
                        raise
×
1161

1162
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1163
                    self._dataframe,
1164
                    datafile=datafile,
1165
                    **kwds,
1166
                )
1167

1168
            # Add Metadata
1169
            self._attributes.add(
1✔
1170
                metadata,
1171
                "delay_calibration",
1172
                duplicate_policy="merge",
1173
            )
1174
            if preview:
1✔
1175
                print(self._dataframe.head(10))
1✔
1176
            else:
1177
                print(self._dataframe)
1✔
1178

1179
    def add_jitter(self, cols: Sequence[str] = None):
1✔
1180
        """Add jitter to the selected dataframe columns.
1181

1182
        Args:
1183
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1184
                Defaults to config["dataframe"]["jitter_cols"].
1185
        """
1186
        if cols is None:
1✔
1187
            cols = self._config["dataframe"].get(
1✔
1188
                "jitter_cols",
1189
                self._dataframe.columns,
1190
            )  # jitter all columns
1191

1192
        self._dataframe = self._dataframe.map_partitions(
1✔
1193
            apply_jitter,
1194
            cols=cols,
1195
            cols_jittered=cols,
1196
        )
1197
        metadata = []
1✔
1198
        for col in cols:
1✔
1199
            metadata.append(col)
1✔
1200
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1201

1202
    def pre_binning(
1✔
1203
        self,
1204
        df_partitions: int = 100,
1205
        axes: List[str] = None,
1206
        bins: List[int] = None,
1207
        ranges: Sequence[Tuple[float, float]] = None,
1208
        **kwds,
1209
    ) -> xr.DataArray:
1210
        """Function to do an initial binning of the dataframe loaded to the class.
1211

1212
        Args:
1213
            df_partitions (int, optional): Number of dataframe partitions to use for
1214
                the initial binning. Defaults to 100.
1215
            axes (List[str], optional): Axes to bin.
1216
                Defaults to config["momentum"]["axes"].
1217
            bins (List[int], optional): Bin numbers to use for binning.
1218
                Defaults to config["momentum"]["bins"].
1219
            ranges (List[Tuple], optional): Ranges to use for binning.
1220
                Defaults to config["momentum"]["ranges"].
1221
            **kwds: Keyword argument passed to ``compute``.
1222

1223
        Returns:
1224
            xr.DataArray: pre-binned data-array.
1225
        """
1226
        if axes is None:
1✔
1227
            axes = self._config["momentum"]["axes"]
1✔
1228
        for loc, axis in enumerate(axes):
1✔
1229
            if axis.startswith("@"):
1✔
1230
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1231

1232
        if bins is None:
1✔
1233
            bins = self._config["momentum"]["bins"]
1✔
1234
        if ranges is None:
1✔
1235
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1236
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1237
                self._config["dataframe"]["tof_binning"] - 1
1238
            )
1239
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1240

1241
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1242

1243
        return self.compute(
1✔
1244
            bins=bins,
1245
            axes=axes,
1246
            ranges=ranges,
1247
            df_partitions=df_partitions,
1248
            **kwds,
1249
        )
1250

1251
    def compute(
1✔
1252
        self,
1253
        bins: Union[
1254
            int,
1255
            dict,
1256
            tuple,
1257
            List[int],
1258
            List[np.ndarray],
1259
            List[tuple],
1260
        ] = 100,
1261
        axes: Union[str, Sequence[str]] = None,
1262
        ranges: Sequence[Tuple[float, float]] = None,
1263
        **kwds,
1264
    ) -> xr.DataArray:
1265
        """Compute the histogram along the given dimensions.
1266

1267
        Args:
1268
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1269
                Definition of the bins. Can be any of the following cases:
1270

1271
                - an integer describing the number of bins in on all dimensions
1272
                - a tuple of 3 numbers describing start, end and step of the binning
1273
                  range
1274
                - a np.arrays defining the binning edges
1275
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1276
                - a dictionary made of the axes as keys and any of the above as values.
1277

1278
                This takes priority over the axes and range arguments. Defaults to 100.
1279
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1280
                on which to calculate the histogram. The order will be the order of the
1281
                dimensions in the resulting array. Defaults to None.
1282
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1283
                the start and end point of the binning range. Defaults to None.
1284
            **kwds: Keyword arguments:
1285

1286
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1287
                  ``bin_dataframe`` for details. Defaults to
1288
                  config["binning"]["hist_mode"].
1289
                - **mode**: Defines how the results from each partition are combined.
1290
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1291
                  Defaults to config["binning"]["mode"].
1292
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1293
                  config["binning"]["pbar"].
1294
                - **n_cores**: Number of CPU cores to use for parallelization.
1295
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1296
                - **threads_per_worker**: Limit the number of threads that
1297
                  multiprocessing can spawn per binning thread. Defaults to
1298
                  config["binning"]["threads_per_worker"].
1299
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1300
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1301
                  config["binning"]["threadpool_API"].
1302
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1303
                  partitions.
1304

1305
                Additional kwds are passed to ``bin_dataframe``.
1306

1307
        Raises:
1308
            AssertError: Rises when no dataframe has been loaded.
1309

1310
        Returns:
1311
            xr.DataArray: The result of the n-dimensional binning represented in an
1312
            xarray object, combining the data with the axes.
1313
        """
1314
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1315

1316
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1317
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1318
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1319
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1320
        threads_per_worker = kwds.pop(
1✔
1321
            "threads_per_worker",
1322
            self._config["binning"]["threads_per_worker"],
1323
        )
1324
        threadpool_api = kwds.pop(
1✔
1325
            "threadpool_API",
1326
            self._config["binning"]["threadpool_API"],
1327
        )
1328
        df_partitions = kwds.pop("df_partitions", None)
1✔
1329
        if df_partitions is not None:
1✔
1330
            dataframe = self._dataframe.partitions[
1✔
1331
                0 : min(df_partitions, self._dataframe.npartitions)
1332
            ]
1333
        else:
1334
            dataframe = self._dataframe
1✔
1335

1336
        self._binned = bin_dataframe(
1✔
1337
            df=dataframe,
1338
            bins=bins,
1339
            axes=axes,
1340
            ranges=ranges,
1341
            hist_mode=hist_mode,
1342
            mode=mode,
1343
            pbar=pbar,
1344
            n_cores=num_cores,
1345
            threads_per_worker=threads_per_worker,
1346
            threadpool_api=threadpool_api,
1347
            **kwds,
1348
        )
1349

1350
        for dim in self._binned.dims:
1✔
1351
            try:
1✔
1352
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1353
            except KeyError:
1✔
1354
                pass
1✔
1355

1356
        self._binned.attrs["units"] = "counts"
1✔
1357
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1358
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1359

1360
        return self._binned
1✔
1361

1362
    def view_event_histogram(
1✔
1363
        self,
1364
        dfpid: int,
1365
        ncol: int = 2,
1366
        bins: Sequence[int] = None,
1367
        axes: Sequence[str] = None,
1368
        ranges: Sequence[Tuple[float, float]] = None,
1369
        backend: str = "bokeh",
1370
        legend: bool = True,
1371
        histkwds: dict = None,
1372
        legkwds: dict = None,
1373
        **kwds,
1374
    ):
1375
        """Plot individual histograms of specified dimensions (axes) from a substituent
1376
        dataframe partition.
1377

1378
        Args:
1379
            dfpid (int): Number of the data frame partition to look at.
1380
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1381
            bins (Sequence[int], optional): Number of bins to use for the speicified
1382
                axes. Defaults to config["histogram"]["bins"].
1383
            axes (Sequence[str], optional): Names of the axes to display.
1384
                Defaults to config["histogram"]["axes"].
1385
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1386
                specified axes. Defaults toconfig["histogram"]["ranges"].
1387
            backend (str, optional): Backend of the plotting library
1388
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1389
            legend (bool, optional): Option to include a legend in the histogram plots.
1390
                Defaults to True.
1391
            histkwds (dict, optional): Keyword arguments for histograms
1392
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1393
            legkwds (dict, optional): Keyword arguments for legend
1394
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1395
            **kwds: Extra keyword arguments passed to
1396
                ``sed.diagnostics.grid_histogram()``.
1397

1398
        Raises:
1399
            TypeError: Raises when the input values are not of the correct type.
1400
        """
1401
        if bins is None:
1✔
1402
            bins = self._config["histogram"]["bins"]
1✔
1403
        if axes is None:
1✔
1404
            axes = self._config["histogram"]["axes"]
1✔
1405
        axes = list(axes)
1✔
1406
        for loc, axis in enumerate(axes):
1✔
1407
            if axis.startswith("@"):
1✔
1408
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1409
        if ranges is None:
1✔
1410
            ranges = list(self._config["histogram"]["ranges"])
1✔
1411
            for loc, axis in enumerate(axes):
1✔
1412
                if axis == self._config["dataframe"]["tof_column"]:
1✔
1413
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
1414
                        self._config["dataframe"]["tof_binning"] - 1
1415
                    )
1416
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
1417
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1418
                        self._config["dataframe"]["adc_binning"] - 1
1419
                    )
1420

1421
        input_types = map(type, [axes, bins, ranges])
1✔
1422
        allowed_types = [list, tuple]
1✔
1423

1424
        df = self._dataframe
1✔
1425

1426
        if not set(input_types).issubset(allowed_types):
1✔
1427
            raise TypeError(
×
1428
                "Inputs of axes, bins, ranges need to be list or tuple!",
1429
            )
1430

1431
        # Read out the values for the specified groups
1432
        group_dict_dd = {}
1✔
1433
        dfpart = df.get_partition(dfpid)
1✔
1434
        cols = dfpart.columns
1✔
1435
        for ax in axes:
1✔
1436
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
1437
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
1438

1439
        # Plot multiple histograms in a grid
1440
        grid_histogram(
1✔
1441
            group_dict,
1442
            ncol=ncol,
1443
            rvs=axes,
1444
            rvbins=bins,
1445
            rvranges=ranges,
1446
            backend=backend,
1447
            legend=legend,
1448
            histkwds=histkwds,
1449
            legkwds=legkwds,
1450
            **kwds,
1451
        )
1452

1453
    def save(
1✔
1454
        self,
1455
        faddr: str,
1456
        **kwds,
1457
    ):
1458
        """Saves the binned data to the provided path and filename.
1459

1460
        Args:
1461
            faddr (str): Path and name of the file to write. Its extension determines
1462
                the file type to write. Valid file types are:
1463

1464
                - "*.tiff", "*.tif": Saves a TIFF stack.
1465
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1466
                - "*.nxs", "*.nexus": Saves a NeXus file.
1467

1468
            **kwds: Keyword argumens, which are passed to the writer functions:
1469
                For TIFF writing:
1470

1471
                - **alias_dict**: Dictionary of dimension aliases to use.
1472

1473
                For HDF5 writing:
1474

1475
                - **mode**: hdf5 read/write mode. Defaults to "w".
1476

1477
                For NeXus:
1478

1479
                - **reader**: Name of the nexustools reader to use.
1480
                  Defaults to config["nexus"]["reader"]
1481
                - **definiton**: NeXus application definition to use for saving.
1482
                  Must be supported by the used ``reader``. Defaults to
1483
                  config["nexus"]["definition"]
1484
                - **input_files**: A list of input files to pass to the reader.
1485
                  Defaults to config["nexus"]["input_files"]
1486
        """
1487
        if self._binned is None:
1✔
1488
            raise NameError("Need to bin data first!")
1✔
1489

1490
        extension = pathlib.Path(faddr).suffix
1✔
1491

1492
        if extension in (".tif", ".tiff"):
1✔
1493
            to_tiff(
1✔
1494
                data=self._binned,
1495
                faddr=faddr,
1496
                **kwds,
1497
            )
1498
        elif extension in (".h5", ".hdf5"):
1✔
1499
            to_h5(
1✔
1500
                data=self._binned,
1501
                faddr=faddr,
1502
                **kwds,
1503
            )
1504
        elif extension in (".nxs", ".nexus"):
1✔
1505
            try:
1✔
1506
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
1507
                definition = kwds.pop(
1✔
1508
                    "definition",
1509
                    self._config["nexus"]["definition"],
1510
                )
1511
                input_files = kwds.pop(
1✔
1512
                    "input_files",
1513
                    self._config["nexus"]["input_files"],
1514
                )
1515
            except KeyError as exc:
×
1516
                raise ValueError(
×
1517
                    "The nexus reader, definition and input files need to be provide!",
1518
                ) from exc
1519

1520
            if isinstance(input_files, str):
1✔
1521
                input_files = [input_files]
1✔
1522

1523
            to_nexus(
1✔
1524
                data=self._binned,
1525
                faddr=faddr,
1526
                reader=reader,
1527
                definition=definition,
1528
                input_files=input_files,
1529
                **kwds,
1530
            )
1531

1532
        else:
1533
            raise NotImplementedError(
1✔
1534
                f"Unrecognized file format: {extension}.",
1535
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc