• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 5708723573

pending completion
5708723573

push

github

rettigl
add runs implementation for mpes loader

23 of 23 new or added lines in 2 files covered. (100.0%)

2999 of 4065 relevant lines covered (73.78%)

0.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.64
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.calibrator import DelayCalibrator
1✔
22
from sed.calibrator import EnergyCalibrator
1✔
23
from sed.calibrator import MomentumCorrector
1✔
24
from sed.core.config import parse_config
1✔
25
from sed.core.config import save_config
1✔
26
from sed.core.dfops import apply_jitter
1✔
27
from sed.core.metadata import MetaHandler
1✔
28
from sed.diagnostics import grid_histogram
1✔
29
from sed.io import to_h5
1✔
30
from sed.io import to_nexus
1✔
31
from sed.io import to_tiff
1✔
32
from sed.loader import CopyTool
1✔
33
from sed.loader import get_loader
1✔
34

35
N_CPU = psutil.cpu_count()
1✔
36

37

38
class SedProcessor:
1✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
1✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to the reader.
86
        """
87
        self._config = parse_config(config, **kwds)
1✔
88
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
89
        if num_cores >= N_CPU:
1✔
90
            num_cores = N_CPU - 1
1✔
91
        self._config["binning"]["num_cores"] = num_cores
1✔
92

93
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
94
        self._files: List[str] = []
1✔
95

96
        self._binned: xr.DataArray = None
1✔
97
        self._pre_binned: xr.DataArray = None
1✔
98

99
        self._dimensions: List[str] = []
1✔
100
        self._coordinates: Dict[Any, Any] = {}
1✔
101
        self.axis: Dict[Any, Any] = {}
1✔
102
        self._attributes = MetaHandler(meta=metadata)
1✔
103

104
        loader_name = self._config["core"]["loader"]
1✔
105
        self.loader = get_loader(
1✔
106
            loader_name=loader_name,
107
            config=self._config,
108
        )
109

110
        self.ec = EnergyCalibrator(
1✔
111
            loader=self.loader,
112
            config=self._config,
113
        )
114

115
        self.mc = MomentumCorrector(
1✔
116
            config=self._config,
117
        )
118

119
        self.dc = DelayCalibrator(
1✔
120
            config=self._config,
121
        )
122

123
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
124
            "use_copy_tool",
125
            False,
126
        )
127
        if self.use_copy_tool:
1✔
128
            try:
×
129
                self.ct = CopyTool(
×
130
                    source=self._config["core"]["copy_tool_source"],
131
                    dest=self._config["core"]["copy_tool_dest"],
132
                    **self._config["core"].get("copy_tool_kwds", {}),
133
                )
134
            except KeyError:
×
135
                self.use_copy_tool = False
×
136

137
        # Load data if provided:
138
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
139
            self.load(
×
140
                dataframe=dataframe,
141
                metadata=metadata,
142
                files=files,
143
                folder=folder,
144
                runs=runs,
145
                collect_metadata=collect_metadata,
146
                **kwds,
147
            )
148

149
    def __repr__(self):
1✔
150
        if self._dataframe is None:
×
151
            df_str = "Data Frame: No Data loaded"
×
152
        else:
153
            df_str = self._dataframe.__repr__()
×
154
        coordinates_str = f"Coordinates: {self._coordinates}"
×
155
        dimensions_str = f"Dimensions: {self._dimensions}"
×
156
        pretty_str = df_str + "\n" + coordinates_str + "\n" + dimensions_str
×
157
        return pretty_str
×
158

159
    def __getitem__(self, val: str) -> pd.DataFrame:
1✔
160
        """Accessor to the underlying data structure.
161

162
        Args:
163
            val (str): Name of the dataframe column to retrieve.
164

165
        Returns:
166
            pd.DataFrame: Selected dataframe column.
167
        """
168
        return self._dataframe[val]
×
169

170
    @property
1✔
171
    def config(self) -> Dict[Any, Any]:
1✔
172
        """Getter attribute for the config dictionary
173

174
        Returns:
175
            Dict: The config dictionary.
176
        """
177
        return self._config
×
178

179
    @config.setter
1✔
180
    def config(self, config: Union[dict, str]):
1✔
181
        """Setter function for the config dictionary.
182

183
        Args:
184
            config (Union[dict, str]): Config dictionary or path of config file
185
                to load.
186
        """
187
        self._config = parse_config(config)
×
188
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
×
189
        if num_cores >= N_CPU:
×
190
            num_cores = N_CPU - 1
×
191
        self._config["binning"]["num_cores"] = num_cores
×
192

193
    @property
1✔
194
    def dimensions(self) -> list:
1✔
195
        """Getter attribute for the dimensions.
196

197
        Returns:
198
            list: List of dimensions.
199
        """
200
        return self._dimensions
×
201

202
    @dimensions.setter
1✔
203
    def dimensions(self, dims: list):
1✔
204
        """Setter function for the dimensions.
205

206
        Args:
207
            dims (list): List of dimensions to set.
208
        """
209
        assert isinstance(dims, list)
×
210
        self._dimensions = dims
×
211

212
    @property
1✔
213
    def coordinates(self) -> dict:
1✔
214
        """Getter attribute for the coordinates dict.
215

216
        Returns:
217
            dict: Dictionary of coordinates.
218
        """
219
        return self._coordinates
×
220

221
    @coordinates.setter
1✔
222
    def coordinates(self, coords: dict):
1✔
223
        """Setter function for the coordinates dict
224

225
        Args:
226
            coords (dict): Dictionary of coordinates.
227
        """
228
        assert isinstance(coords, dict)
×
229
        self._coordinates = {}
×
230
        for k, v in coords.items():
×
231
            self._coordinates[k] = xr.DataArray(v)
×
232

233
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
234
        """Function to mirror a list of files or a folder from a network drive to a
235
        local storage. Returns either the original or the copied path to the given
236
        path. The option to use this functionality is set by
237
        config["core"]["use_copy_tool"].
238

239
        Args:
240
            path (Union[str, List[str]]): Source path or path list.
241

242
        Returns:
243
            Union[str, List[str]]: Source or destination path or path list.
244
        """
245
        if self.use_copy_tool:
1✔
246
            if isinstance(path, list):
×
247
                path_out = []
×
248
                for file in path:
×
249
                    path_out.append(self.ct.copy(file))
×
250
                return path_out
×
251

252
            return self.ct.copy(path)
×
253

254
        if isinstance(path, list):
1✔
255
            return path
1✔
256

257
        return path
×
258

259
    def load(
1✔
260
        self,
261
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
262
        metadata: dict = None,
263
        files: List[str] = None,
264
        folder: str = None,
265
        runs: Sequence[str] = None,
266
        collect_metadata: bool = False,
267
        **kwds,
268
    ):
269
        """Load tabular data of single events into the dataframe object in the class.
270

271
        Args:
272
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
273
                format. Accepts anything which can be interpreted by pd.DataFrame as
274
                an input. Defaults to None.
275
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
276
            files (List[str], optional): List of file paths to pass to the loader.
277
                Defaults to None.
278
            runs (Sequence[str], optional): List of run identifiers to pass to the
279
                loader. Defaults to None.
280
            folder (str, optional): Folder path to pass to the loader.
281
                Defaults to None.
282

283
        Raises:
284
            ValueError: Raised if no valid input is provided.
285
        """
286
        if metadata is None:
1✔
287
            metadata = {}
1✔
288
        if dataframe is not None:
1✔
289
            self._dataframe = dataframe
×
290
        elif runs is not None:
1✔
291
            # If runs are provided, we only use the copy tool if also folder is provided.
292
            # In that case, we copy the whole provided base folder tree, and pass the copied
293
            # version to the loader as base folder to look for the runs.
294
            if folder is not None:
×
295
                dataframe, metadata = self.loader.read_dataframe(
×
296
                    folders=cast(str, self.cpy(folder)),
297
                    runs=runs,
298
                    metadata=metadata,
299
                    collect_metadata=collect_metadata,
300
                    **kwds,
301
                )
302
            else:
303
                dataframe, metadata = self.loader.read_dataframe(
×
304
                    runs=runs,
305
                    metadata=metadata,
306
                    collect_metadata=collect_metadata,
307
                    **kwds,
308
                )
309

310
        elif folder is not None:
1✔
311
            dataframe, metadata = self.loader.read_dataframe(
×
312
                folders=cast(str, self.cpy(folder)),
313
                metadata=metadata,
314
                collect_metadata=collect_metadata,
315
                **kwds,
316
            )
317

318
        elif files is not None:
1✔
319
            dataframe, metadata = self.loader.read_dataframe(
1✔
320
                files=cast(List[str], self.cpy(files)),
321
                metadata=metadata,
322
                collect_metadata=collect_metadata,
323
                **kwds,
324
            )
325

326
        else:
327
            raise ValueError(
×
328
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
329
            )
330

331
        self._dataframe = dataframe
1✔
332
        self._files = self.loader.files
1✔
333

334
        for key in metadata:
1✔
335
            self._attributes.add(
×
336
                entry=metadata[key],
337
                name=key,
338
                duplicate_policy="merge",
339
            )
340

341
    # Momentum calibration workflow
342
    # 1. Bin raw detector data for distortion correction
343
    def bin_and_load_momentum_calibration(
1✔
344
        self,
345
        df_partitions: int = 100,
346
        axes: List[str] = None,
347
        bins: List[int] = None,
348
        ranges: Sequence[Tuple[float, float]] = None,
349
        plane: int = 0,
350
        width: int = 5,
351
        apply: bool = False,
352
        **kwds,
353
    ):
354
        """1st step of momentum correction work flow. Function to do an initial binning
355
        of the dataframe loaded to the class, slice a plane from it using an
356
        interactive view, and load it into the momentum corrector class.
357

358
        Args:
359
            df_partitions (int, optional): Number of dataframe partitions to use for
360
                the initial binning. Defaults to 100.
361
            axes (List[str], optional): Axes to bin.
362
                Defaults to config["momentum"]["axes"].
363
            bins (List[int], optional): Bin numbers to use for binning.
364
                Defaults to config["momentum"]["bins"].
365
            ranges (List[Tuple], optional): Ranges to use for binning.
366
                Defaults to config["momentum"]["ranges"].
367
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
368
            width (int, optional): Initial value for the width slider. Defaults to 5.
369
            apply (bool, optional): Option to directly apply the values and select the
370
                slice. Defaults to False.
371
            **kwds: Keyword argument passed to the pre_binning function.
372
        """
373
        self._pre_binned = self.pre_binning(
1✔
374
            df_partitions=df_partitions,
375
            axes=axes,
376
            bins=bins,
377
            ranges=ranges,
378
            **kwds,
379
        )
380

381
        self.mc.load_data(data=self._pre_binned)
1✔
382
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
383

384
    # 2. Generate the spline warp correction from momentum features.
385
    # Either autoselect features, or input features from view above.
386
    def define_features(
1✔
387
        self,
388
        features: np.ndarray = None,
389
        rotation_symmetry: int = 6,
390
        auto_detect: bool = False,
391
        include_center: bool = True,
392
        apply: bool = False,
393
        **kwds,
394
    ):
395
        """2. Step of the distortion correction workflow: Define feature points in
396
        momentum space. They can be either manually selected using a GUI tool, be
397
        ptovided as list of feature points, or auto-generated using a
398
        feature-detection algorithm.
399

400
        Args:
401
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
402
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
403
                Defaults to 6.
404
            auto_detect (bool, optional): Whether to auto-detect the features.
405
                Defaults to False.
406
            include_center (bool, optional): Option to include a point at the center
407
                in the feature list. Defaults to True.
408
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
409
                MomentumCorrector.feature_select()
410
        """
411
        if auto_detect:  # automatic feature selection
×
412
            sigma = kwds.pop(
×
413
                "sigma",
414
                self._config.get("momentum", {}).get("sigma", 5),
415
            )
416
            fwhm = kwds.pop(
×
417
                "fwhm",
418
                self._config.get("momentum", {}).get("fwhm", 8),
419
            )
420
            sigma_radius = kwds.pop(
×
421
                "sigma_radius",
422
                self._config.get("momentum", {}).get("sigma_radius", 1),
423
            )
424
            self.mc.feature_extract(
×
425
                sigma=sigma,
426
                fwhm=fwhm,
427
                sigma_radius=sigma_radius,
428
                rotsym=rotation_symmetry,
429
                **kwds,
430
            )
431
            features = self.mc.peaks
×
432

433
        self.mc.feature_select(
×
434
            rotsym=rotation_symmetry,
435
            include_center=include_center,
436
            features=features,
437
            apply=apply,
438
            **kwds,
439
        )
440

441
    # 3. Generate the spline warp correction from momentum features.
442
    # If no features have been selected before, use class defaults.
443
    def generate_splinewarp(
1✔
444
        self,
445
        include_center: bool = True,
446
        **kwds,
447
    ):
448
        """3. Step of the distortion correction workflow: Generate the correction
449
        function restoring the symmetry in the image using a splinewarp algortihm.
450

451
        Args:
452
            include_center (bool, optional): Option to include the position of the
453
                center point in the correction. Defaults to True.
454
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
455
        """
456
        self.mc.spline_warp_estimate(include_center=include_center, **kwds)
×
457

458
        if self.mc.slice is not None:
×
459
            print("Original slice with reference features")
×
460
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
×
461

462
            print("Corrected slice with target features")
×
463
            self.mc.view(
×
464
                image=self.mc.slice_corrected,
465
                annotated=True,
466
                points={"feats": self.mc.ptargs},
467
                backend="bokeh",
468
                crosshair=True,
469
            )
470

471
            print("Original slice with target features")
×
472
            self.mc.view(
×
473
                image=self.mc.slice,
474
                points={"feats": self.mc.ptargs},
475
                annotated=True,
476
                backend="bokeh",
477
            )
478

479
    # 3a. Save spline-warp parameters to config file.
480
    def save_splinewarp(
1✔
481
        self,
482
        filename: str = None,
483
        overwrite: bool = False,
484
    ):
485
        """Save the generated spline-warp parameters to the folder config file.
486

487
        Args:
488
            filename (str, optional): Filename of the config dictionary to save to.
489
                Defaults to "sed_config.yaml" in the current folder.
490
            overwrite (bool, optional): Option to overwrite the present dictionary.
491
                Defaults to False.
492
        """
493
        if filename is None:
×
494
            filename = "sed_config.yaml"
×
495
        points = []
×
496
        try:
×
497
            for point in self.mc.pouter_ord:
×
498
                points.append([float(i) for i in point])
×
499
            if self.mc.pcent:
×
500
                points.append([float(i) for i in self.mc.pcent])
×
501
        except AttributeError as exc:
×
502
            raise AttributeError(
×
503
                "Momentum correction parameters not found, need to generate parameters first!",
504
            ) from exc
505
        config = {
×
506
            "momentum": {
507
                "correction": {"rotation_symmetry": self.mc.rotsym, "feature_points": points},
508
            },
509
        }
510
        save_config(config, filename, overwrite)
×
511

512
    # 4. Pose corrections. Provide interactive interface for correcting
513
    # scaling, shift and rotation
514
    def pose_adjustment(
1✔
515
        self,
516
        scale: float = 1,
517
        xtrans: float = 0,
518
        ytrans: float = 0,
519
        angle: float = 0,
520
        apply: bool = False,
521
        use_correction: bool = True,
522
        reset: bool = True,
523
    ):
524
        """3. step of the distortion correction workflow: Generate an interactive panel
525
        to adjust affine transformations that are applied to the image. Applies first
526
        a scaling, next an x/y translation, and last a rotation around the center of
527
        the image.
528

529
        Args:
530
            scale (float, optional): Initial value of the scaling slider.
531
                Defaults to 1.
532
            xtrans (float, optional): Initial value of the xtrans slider.
533
                Defaults to 0.
534
            ytrans (float, optional): Initial value of the ytrans slider.
535
                Defaults to 0.
536
            angle (float, optional): Initial value of the angle slider.
537
                Defaults to 0.
538
            apply (bool, optional): Option to directly apply the provided
539
                transformations. Defaults to False.
540
            use_correction (bool, option): Whether to use the spline warp correction
541
                or not. Defaults to True.
542
            reset (bool, optional):
543
                Option to reset the correction before transformation. Defaults to True.
544
        """
545
        # Generate homomorphy as default if no distortion correction has been applied
546
        if self.mc.slice_corrected is None:
×
547
            if self.mc.slice is None:
×
548
                raise ValueError(
×
549
                    "No slice for corrections and transformations loaded!",
550
                )
551
            self.mc.slice_corrected = self.mc.slice
×
552

553
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
×
554
            # Generate distortion correction from config values
555
            self.mc.add_features()
×
556
            self.mc.spline_warp_estimate()
×
557

558
        if not use_correction:
×
559
            self.mc.reset_deformation()
×
560

561
        self.mc.pose_adjustment(
×
562
            scale=scale,
563
            xtrans=xtrans,
564
            ytrans=ytrans,
565
            angle=angle,
566
            apply=apply,
567
            reset=reset,
568
        )
569

570
    # 5. Apply the momentum correction to the dataframe
571
    def apply_momentum_correction(
1✔
572
        self,
573
        preview: bool = False,
574
    ):
575
        """Applies the distortion correction and pose adjustment (optional)
576
        to the dataframe.
577

578
        Args:
579
            rdeform_field (np.ndarray, optional): Row deformation field.
580
                Defaults to None.
581
            cdeform_field (np.ndarray, optional): Column deformation field.
582
                Defaults to None.
583
            inv_dfield (np.ndarray, optional): Inverse deformation field.
584
                Defaults to None.
585
            preview (bool): Option to preview the first elements of the data frame.
586
        """
587
        if self._dataframe is not None:
×
588
            print("Adding corrected X/Y columns to dataframe:")
×
589
            self._dataframe, metadata = self.mc.apply_corrections(
×
590
                df=self._dataframe,
591
            )
592
            # Add Metadata
593
            self._attributes.add(
×
594
                metadata,
595
                "momentum_correction",
596
                duplicate_policy="merge",
597
            )
598
            if preview:
×
599
                print(self._dataframe.head(10))
×
600
            else:
601
                print(self._dataframe)
×
602

603
    # Momentum calibration work flow
604
    # 1. Calculate momentum calibration
605
    def calibrate_momentum_axes(
1✔
606
        self,
607
        point_a: Union[np.ndarray, List[int]] = None,
608
        point_b: Union[np.ndarray, List[int]] = None,
609
        k_distance: float = None,
610
        k_coord_a: Union[np.ndarray, List[float]] = None,
611
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
612
        equiscale: bool = True,
613
        apply=False,
614
    ):
615
        """1. step of the momentum calibration workflow. Calibrate momentum
616
        axes using either provided pixel coordinates of a high-symmetry point and its
617
        distance to the BZ center, or the k-coordinates of two points in the BZ
618
        (depending on the equiscale option). Opens an interactive panel for selecting
619
        the points.
620

621
        Args:
622
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
623
                point used for momentum calibration.
624
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
625
                second point used for momentum calibration.
626
                Defaults to config["momentum"]["center_pixel"].
627
            k_distance (float, optional): Momentum distance between point a and b.
628
                Needs to be provided if no specific k-koordinates for the two points
629
                are given. Defaults to None.
630
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
631
                of the first point used for calibration. Used if equiscale is False.
632
                Defaults to None.
633
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
634
                of the second point used for calibration. Defaults to [0.0, 0.0].
635
            equiscale (bool, optional): Option to apply different scales to kx and ky.
636
                If True, the distance between points a and b, and the absolute
637
                position of point a are used for defining the scale. If False, the
638
                scale is calculated from the k-positions of both points a and b.
639
                Defaults to True.
640
            apply (bool, optional): Option to directly store the momentum calibration
641
                in the class. Defaults to False.
642
        """
643
        if point_b is None:
×
644
            point_b = self._config.get("momentum", {}).get(
×
645
                "center_pixel",
646
                [256, 256],
647
            )
648

649
        self.mc.select_k_range(
×
650
            point_a=point_a,
651
            point_b=point_b,
652
            k_distance=k_distance,
653
            k_coord_a=k_coord_a,
654
            k_coord_b=k_coord_b,
655
            equiscale=equiscale,
656
            apply=apply,
657
        )
658

659
    # 1a. Save momentum calibration parameters to config file.
660
    def save_momentum_calibration(
1✔
661
        self,
662
        filename: str = None,
663
        overwrite: bool = False,
664
    ):
665
        """Save the generated momentum calibration parameters to the folder config file.
666

667
        Args:
668
            filename (str, optional): Filename of the config dictionary to save to.
669
                Defaults to "sed_config.yaml" in the current folder.
670
            overwrite (bool, optional): Option to overwrite the present dictionary.
671
                Defaults to False.
672
        """
673
        if filename is None:
×
674
            filename = "sed_config.yaml"
×
675
        calibration = {}
×
676
        try:
×
677
            for key in [
×
678
                "kx_scale",
679
                "ky_scale",
680
                "x_center",
681
                "y_center",
682
                "rstart",
683
                "cstart",
684
                "rstep",
685
                "cstep",
686
            ]:
687
                calibration[key] = float(self.mc.calibration[key])
×
688
        except KeyError as exc:
×
689
            raise KeyError(
×
690
                "Momentum calibration parameters not found, need to generate parameters first!",
691
            ) from exc
692

693
        config = {"momentum": {"calibration": calibration}}
×
694
        save_config(config, filename, overwrite)
×
695

696
    # 2. Apply correction and calibration to the dataframe
697
    def apply_momentum_calibration(
1✔
698
        self,
699
        calibration: dict = None,
700
        preview: bool = False,
701
    ):
702
        """2. step of the momentum calibration work flow: Apply the momentum
703
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
704
        these are used.
705

706
        Args:
707
            calibration (dict, optional): Optional dictionary with calibration data to
708
                use. Defaults to None.
709
            preview (bool): Option to preview the first elements of the data frame.
710
        """
711
        if self._dataframe is not None:
×
712

713
            print("Adding kx/ky columns to dataframe:")
×
714
            self._dataframe, metadata = self.mc.append_k_axis(
×
715
                df=self._dataframe,
716
                calibration=calibration,
717
            )
718

719
            # Add Metadata
720
            self._attributes.add(
×
721
                metadata,
722
                "momentum_calibration",
723
                duplicate_policy="merge",
724
            )
725
            if preview:
×
726
                print(self._dataframe.head(10))
×
727
            else:
728
                print(self._dataframe)
×
729

730
    # Energy correction workflow
731
    # 1. Adjust the energy correction parameters
732
    def adjust_energy_correction(
1✔
733
        self,
734
        correction_type: str = None,
735
        amplitude: float = None,
736
        center: Tuple[float, float] = None,
737
        apply=False,
738
        **kwds,
739
    ):
740
        """1. step of the energy crrection workflow: Opens an interactive plot to
741
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
742
        they are not present yet.
743

744
        Args:
745
            correction_type (str, optional): Type of correction to apply to the TOF
746
                axis. Valid values are:
747

748
                - 'spherical'
749
                - 'Lorentzian'
750
                - 'Gaussian'
751
                - 'Lorentzian_asymmetric'
752

753
                Defaults to config["energy"]["correction_type"].
754
            amplitude (float, optional): Amplitude of the correction.
755
                Defaults to config["energy"]["correction"]["amplitude"].
756
            center (Tuple[float, float], optional): Center X/Y coordinates for the
757
                correction. Defaults to config["energy"]["correction"]["center"].
758
            apply (bool, optional): Option to directly apply the provided or default
759
                correction parameters. Defaults to False.
760
        """
761
        if self._pre_binned is None:
×
762
            print(
×
763
                "Pre-binned data not present, binning using defaults from config...",
764
            )
765
            self._pre_binned = self.pre_binning()
×
766

767
        self.ec.adjust_energy_correction(
×
768
            self._pre_binned,
769
            correction_type=correction_type,
770
            amplitude=amplitude,
771
            center=center,
772
            apply=apply,
773
            **kwds,
774
        )
775

776
    # 1a. Save energy correction parameters to config file.
777
    def save_energy_correction(
1✔
778
        self,
779
        filename: str = None,
780
        overwrite: bool = False,
781
    ):
782
        """Save the generated energy correction parameters to the folder config file.
783

784
        Args:
785
            filename (str, optional): Filename of the config dictionary to save to.
786
                Defaults to "sed_config.yaml" in the current folder.
787
            overwrite (bool, optional): Option to overwrite the present dictionary.
788
                Defaults to False.
789
        """
790
        if filename is None:
×
791
            filename = "sed_config.yaml"
×
792
        correction = {}
×
793
        try:
×
794
            for key in self.ec.correction.keys():
×
795
                if key == "correction_type":
×
796
                    correction[key] = self.ec.correction[key]
×
797
                elif key == "center":
×
798
                    correction[key] = [float(i) for i in self.ec.correction[key]]
×
799
                else:
800
                    correction[key] = float(self.ec.correction[key])
×
801
        except AttributeError as exc:
×
802
            raise AttributeError(
×
803
                "Energy correction parameters not found, need to generate parameters first!",
804
            ) from exc
805

806
        config = {"energy": {"correction": correction}}
×
807
        save_config(config, filename, overwrite)
×
808

809
    # 2. Apply energy correction to dataframe
810
    def apply_energy_correction(
1✔
811
        self,
812
        correction: dict = None,
813
        preview: bool = False,
814
        **kwds,
815
    ):
816
        """2. step of the energy correction workflow: Apply the enery correction
817
        parameters stored in the class to the dataframe.
818

819
        Args:
820
            correction (dict, optional): Dictionary containing the correction
821
                parameters. Defaults to config["energy"]["calibration"].
822
            preview (bool): Option to preview the first elements of the data frame.
823
            **kwds:
824
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
825
            preview (bool): Option to preview the first elements of the data frame.
826
            **kwds:
827
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
828
        """
829
        if self._dataframe is not None:
×
830
            print("Applying energy correction to dataframe...")
×
831
            self._dataframe, metadata = self.ec.apply_energy_correction(
×
832
                df=self._dataframe,
833
                correction=correction,
834
                **kwds,
835
            )
836

837
            # Add Metadata
838
            self._attributes.add(
×
839
                metadata,
840
                "energy_correction",
841
            )
842
            if preview:
×
843
                print(self._dataframe.head(10))
×
844
            else:
845
                print(self._dataframe)
×
846

847
    # Energy calibrator workflow
848
    # 1. Load and normalize data
849
    def load_bias_series(
1✔
850
        self,
851
        data_files: List[str],
852
        axes: List[str] = None,
853
        bins: List = None,
854
        ranges: Sequence[Tuple[float, float]] = None,
855
        biases: np.ndarray = None,
856
        bias_key: str = None,
857
        normalize: bool = None,
858
        span: int = None,
859
        order: int = None,
860
    ):
861
        """1. step of the energy calibration workflow: Load and bin data from
862
        single-event files.
863

864
        Args:
865
            data_files (List[str]): list of file paths to bin
866
            axes (List[str], optional): bin axes.
867
                Defaults to config["dataframe"]["tof_column"].
868
            bins (List, optional): number of bins.
869
                Defaults to config["energy"]["bins"].
870
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
871
                Defaults to config["energy"]["ranges"].
872
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
873
                voltages are extracted from the data files.
874
            bias_key (str, optional): hdf5 path where bias values are stored.
875
                Defaults to config["energy"]["bias_key"].
876
            normalize (bool, optional): Option to normalize traces.
877
                Defaults to config["energy"]["normalize"].
878
            span (int, optional): span smoothing parameters of the LOESS method
879
                (see ``scipy.signal.savgol_filter()``).
880
                Defaults to config["energy"]["normalize_span"].
881
            order (int, optional): order smoothing parameters of the LOESS method
882
                (see ``scipy.signal.savgol_filter()``).
883
                Defaults to config["energy"]["normalize_order"].
884
        """
885
        self.ec.bin_data(
×
886
            data_files=cast(List[str], self.cpy(data_files)),
887
            axes=axes,
888
            bins=bins,
889
            ranges=ranges,
890
            biases=biases,
891
            bias_key=bias_key,
892
        )
893
        if (normalize is not None and normalize is True) or (
×
894
            normalize is None and self._config.get("energy", {}).get("normalize", True)
895
        ):
896
            if span is None:
×
897
                span = self._config.get("energy", {}).get("normalize_span", 7)
×
898
            if order is None:
×
899
                order = self._config.get("energy", {}).get(
×
900
                    "normalize_order",
901
                    1,
902
                )
903
            self.ec.normalize(smooth=True, span=span, order=order)
×
904
        self.ec.view(
×
905
            traces=self.ec.traces_normed,
906
            xaxis=self.ec.tof,
907
            backend="bokeh",
908
        )
909

910
    # 2. extract ranges and get peak positions
911
    def find_bias_peaks(
1✔
912
        self,
913
        ranges: Union[List[Tuple], Tuple],
914
        ref_id: int = 0,
915
        infer_others: bool = True,
916
        mode: str = "replace",
917
        radius: int = None,
918
        peak_window: int = None,
919
    ):
920
        """2. step of the energy calibration workflow: Find a peak within a given range
921
        for the indicated reference trace, and tries to find the same peak for all
922
        other traces. Uses fast_dtw to align curves, which might not be too good if the
923
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
924
        middle of the set, and don't choose the range too narrow around the peak.
925
        Alternatively, a list of ranges for all traces can be provided.
926

927
        Args:
928
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
929
                Alternatively, a list of ranges for all traces can be given.
930
            refid (int, optional): The id of the trace the range refers to.
931
                Defaults to 0.
932
            infer_others (bool, optional): Whether to determine the range for the other
933
                traces. Defaults to True.
934
            mode (str, optional): Whether to "add" or "replace" existing ranges.
935
                Defaults to "replace".
936
            radius (int, optional): Radius parameter for fast_dtw.
937
                Defaults to config["energy"]["fastdtw_radius"].
938
            peak_window (int, optional): Peak_window parameter for the peak detection
939
                algorthm. amount of points that have to have to behave monotoneously
940
                around a peak. Defaults to config["energy"]["peak_window"].
941
        """
942
        if radius is None:
×
943
            radius = self._config.get("energy", {}).get("fastdtw_radius", 2)
×
944
        self.ec.add_features(
×
945
            ranges=ranges,
946
            ref_id=ref_id,
947
            infer_others=infer_others,
948
            mode=mode,
949
            radius=radius,
950
        )
951
        self.ec.view(
×
952
            traces=self.ec.traces_normed,
953
            segs=self.ec.featranges,
954
            xaxis=self.ec.tof,
955
            backend="bokeh",
956
        )
957
        print(self.ec.featranges)
×
958
        if peak_window is None:
×
959
            peak_window = self._config.get("energy", {}).get("peak_window", 7)
×
960
        try:
×
961
            self.ec.feature_extract(peak_window=peak_window)
×
962
            self.ec.view(
×
963
                traces=self.ec.traces_normed,
964
                peaks=self.ec.peaks,
965
                backend="bokeh",
966
            )
967
        except IndexError:
×
968
            print("Could not determine all peaks!")
×
969
            raise
×
970

971
    # 3. Fit the energy calibration relation
972
    def calibrate_energy_axis(
1✔
973
        self,
974
        ref_id: int,
975
        ref_energy: float,
976
        method: str = None,
977
        energy_scale: str = None,
978
        **kwds,
979
    ):
980
        """3. Step of the energy calibration workflow: Calculate the calibration
981
        function for the energy axis, and apply it to the dataframe. Two
982
        approximations are implemented, a (normally 3rd order) polynomial
983
        approximation, and a d^2/(t-t0)^2 relation.
984

985
        Args:
986
            ref_id (int): id of the trace at the bias where the reference energy is
987
                given.
988
            ref_energy (float): Absolute energy of the detected feature at the bias
989
                of ref_id
990
            method (str, optional): Method for determining the energy calibration.
991

992
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
993
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
994

995
                Defaults to config["energy"]["calibration_method"]
996
            energy_scale (str, optional): Direction of increasing energy scale.
997

998
                - **'kinetic'**: increasing energy with decreasing TOF.
999
                - **'binding'**: increasing energy with increasing TOF.
1000

1001
                Defaults to config["energy"]["energy_scale"]
1002
        """
1003
        if method is None:
×
1004
            method = self._config.get("energy", {}).get(
×
1005
                "calibration_method",
1006
                "lmfit",
1007
            )
1008

1009
        if energy_scale is None:
×
1010
            energy_scale = self._config.get("energy", {}).get(
×
1011
                "energy_scale",
1012
                "kinetic",
1013
            )
1014

1015
        self.ec.calibrate(
×
1016
            ref_id=ref_id,
1017
            ref_energy=ref_energy,
1018
            method=method,
1019
            energy_scale=energy_scale,
1020
            **kwds,
1021
        )
1022
        print("Quality of Calibration:")
×
1023
        self.ec.view(
×
1024
            traces=self.ec.traces_normed,
1025
            xaxis=self.ec.calibration["axis"],
1026
            align=True,
1027
            energy_scale=energy_scale,
1028
            backend="bokeh",
1029
        )
1030
        print("E/TOF relationship:")
×
1031
        self.ec.view(
×
1032
            traces=self.ec.calibration["axis"][None, :],
1033
            xaxis=self.ec.tof,
1034
            backend="matplotlib",
1035
            show_legend=False,
1036
        )
1037
        if energy_scale == "kinetic":
×
1038
            plt.scatter(
×
1039
                self.ec.peaks[:, 0],
1040
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1041
                s=50,
1042
                c="k",
1043
            )
1044
        elif energy_scale == "binding":
×
1045
            plt.scatter(
×
1046
                self.ec.peaks[:, 0],
1047
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1048
                s=50,
1049
                c="k",
1050
            )
1051
        else:
1052
            raise ValueError(
×
1053
                'energy_scale needs to be either "binding" or "kinetic"',
1054
                f", got {energy_scale}.",
1055
            )
1056
        plt.xlabel("Time-of-flight", fontsize=15)
×
1057
        plt.ylabel("Energy (eV)", fontsize=15)
×
1058
        plt.show()
×
1059

1060
    # 3a. Save energy calibration parameters to config file.
1061
    def save_energy_calibration(
1✔
1062
        self,
1063
        filename: str = None,
1064
        overwrite: bool = False,
1065
    ):
1066
        """Save the generated energy calibration parameters to the folder config file.
1067

1068
        Args:
1069
            filename (str, optional): Filename of the config dictionary to save to.
1070
                Defaults to "sed_config.yaml" in the current folder.
1071
            overwrite (bool, optional): Option to overwrite the present dictionary.
1072
                Defaults to False.
1073
        """
1074
        if filename is None:
×
1075
            filename = "sed_config.yaml"
×
1076
        calibration = {}
×
1077
        try:
×
1078
            for (key, value) in self.ec.calibration.items():
×
1079
                if key in ["axis", "refid"]:
×
1080
                    continue
×
1081
                if key == "energy_scale":
×
1082
                    calibration[key] = value
×
1083
                else:
1084
                    calibration[key] = float(value)
×
1085
        except AttributeError as exc:
×
1086
            raise AttributeError(
×
1087
                "Energy calibration parameters not found, need to generate parameters first!",
1088
            ) from exc
1089

1090
        config = {"energy": {"calibration": calibration}}
×
1091
        save_config(config, filename, overwrite)
×
1092

1093
    # 4. Apply energy calibration to the dataframe
1094
    def append_energy_axis(
1✔
1095
        self,
1096
        calibration: dict = None,
1097
        preview: bool = False,
1098
        **kwds,
1099
    ):
1100
        """4. step of the energy calibration workflow: Apply the calibration function
1101
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1102
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1103
        can be provided.
1104

1105
        Args:
1106
            calibration (dict, optional): Calibration dict containing calibration
1107
                parameters. Overrides calibration from class or config.
1108
                Defaults to None.
1109
            preview (bool): Option to preview the first elements of the data frame.
1110
            **kwds:
1111
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1112
        """
1113
        if self._dataframe is not None:
×
1114
            print("Adding energy column to dataframe:")
×
1115
            self._dataframe, metadata = self.ec.append_energy_axis(
×
1116
                df=self._dataframe,
1117
                calibration=calibration,
1118
                **kwds,
1119
            )
1120

1121
            # Add Metadata
1122
            self._attributes.add(
×
1123
                metadata,
1124
                "energy_calibration",
1125
                duplicate_policy="merge",
1126
            )
1127
            if preview:
×
1128
                print(self._dataframe.head(10))
×
1129
            else:
1130
                print(self._dataframe)
×
1131

1132
    # Delay calibration function
1133
    def calibrate_delay_axis(
1✔
1134
        self,
1135
        delay_range: Tuple[float, float] = None,
1136
        datafile: str = None,
1137
        preview: bool = False,
1138
        **kwds,
1139
    ):
1140
        """Append delay column to dataframe. Either provide delay ranges, or read
1141
        them from a file.
1142

1143
        Args:
1144
            delay_range (Tuple[float, float], optional): The scanned delay range in
1145
                picoseconds. Defaults to None.
1146
            datafile (str, optional): The file from which to read the delay ranges.
1147
                Defaults to None.
1148
            preview (bool): Option to preview the first elements of the data frame.
1149
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1150
        """
1151
        if self._dataframe is not None:
×
1152
            print("Adding delay column to dataframe:")
×
1153

1154
            if delay_range is not None:
×
1155
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1156
                    self._dataframe,
1157
                    delay_range=delay_range,
1158
                    **kwds,
1159
                )
1160
            else:
1161
                if datafile is None:
×
1162
                    try:
×
1163
                        datafile = self._files[0]
×
1164
                    except IndexError:
×
1165
                        print(
×
1166
                            "No datafile available, specify eihter",
1167
                            " 'datafile' or 'delay_range'",
1168
                        )
1169
                        raise
×
1170

1171
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1172
                    self._dataframe,
1173
                    datafile=datafile,
1174
                    **kwds,
1175
                )
1176

1177
            # Add Metadata
1178
            self._attributes.add(
×
1179
                metadata,
1180
                "delay_calibration",
1181
                duplicate_policy="merge",
1182
            )
1183
            if preview:
×
1184
                print(self._dataframe.head(10))
×
1185
            else:
1186
                print(self._dataframe)
×
1187

1188
    def add_jitter(self, cols: Sequence[str] = None):
1✔
1189
        """Add jitter to the selected dataframe columns.
1190

1191
        Args:
1192
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1193
                Defaults to config["dataframe"]["jitter_cols"].
1194
        """
1195
        if cols is None:
×
1196
            cols = self._config.get("dataframe", {}).get(
×
1197
                "jitter_cols",
1198
                self._dataframe.columns,
1199
            )  # jitter all columns
1200

1201
        self._dataframe = self._dataframe.map_partitions(
×
1202
            apply_jitter,
1203
            cols=cols,
1204
            cols_jittered=cols,
1205
        )
1206
        metadata = []
×
1207
        for col in cols:
×
1208
            metadata.append(col)
×
1209
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
×
1210

1211
    def pre_binning(
1✔
1212
        self,
1213
        df_partitions: int = 100,
1214
        axes: List[str] = None,
1215
        bins: List[int] = None,
1216
        ranges: Sequence[Tuple[float, float]] = None,
1217
        **kwds,
1218
    ) -> xr.DataArray:
1219
        """Function to do an initial binning of the dataframe loaded to the class.
1220

1221
        Args:
1222
            df_partitions (int, optional): Number of dataframe partitions to use for
1223
                the initial binning. Defaults to 100.
1224
            axes (List[str], optional): Axes to bin.
1225
                Defaults to config["momentum"]["axes"].
1226
            bins (List[int], optional): Bin numbers to use for binning.
1227
                Defaults to config["momentum"]["bins"].
1228
            ranges (List[Tuple], optional): Ranges to use for binning.
1229
                Defaults to config["momentum"]["ranges"].
1230
            **kwds: Keyword argument passed to ``compute``.
1231

1232
        Returns:
1233
            xr.DataArray: pre-binned data-array.
1234
        """
1235
        if axes is None:
1✔
1236
            axes = self._config.get("momentum", {}).get(
1✔
1237
                "axes",
1238
                ["@x_column, @y_column, @tof_column"],
1239
            )
1240
        for loc, axis in enumerate(axes):
1✔
1241
            if axis.startswith("@"):
1✔
1242
                axes[loc] = self._config.get("dataframe").get(axis.strip("@"))
1✔
1243

1244
        if bins is None:
1✔
1245
            bins = self._config.get("momentum", {}).get(
1✔
1246
                "bins",
1247
                [512, 512, 300],
1248
            )
1249
        if ranges is None:
1✔
1250
            ranges_ = self._config.get("momentum", {}).get(
1✔
1251
                "ranges",
1252
                [[-256, 1792], [-256, 1792], [128000, 138000]],
1253
            )
1254
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1255

1256
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1257

1258
        return self.compute(
1✔
1259
            bins=bins,
1260
            axes=axes,
1261
            ranges=ranges,
1262
            df_partitions=df_partitions,
1263
            **kwds,
1264
        )
1265

1266
    def compute(
1✔
1267
        self,
1268
        bins: Union[
1269
            int,
1270
            dict,
1271
            tuple,
1272
            List[int],
1273
            List[np.ndarray],
1274
            List[tuple],
1275
        ] = 100,
1276
        axes: Union[str, Sequence[str]] = None,
1277
        ranges: Sequence[Tuple[float, float]] = None,
1278
        **kwds,
1279
    ) -> xr.DataArray:
1280
        """Compute the histogram along the given dimensions.
1281

1282
        Args:
1283
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1284
                Definition of the bins. Can be any of the following cases:
1285

1286
                - an integer describing the number of bins in on all dimensions
1287
                - a tuple of 3 numbers describing start, end and step of the binning
1288
                  range
1289
                - a np.arrays defining the binning edges
1290
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1291
                - a dictionary made of the axes as keys and any of the above as values.
1292

1293
                This takes priority over the axes and range arguments. Defaults to 100.
1294
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1295
                on which to calculate the histogram. The order will be the order of the
1296
                dimensions in the resulting array. Defaults to None.
1297
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1298
                the start and end point of the binning range. Defaults to None.
1299
            **kwds: Keyword arguments:
1300

1301
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1302
                  ``bin_dataframe`` for details. Defaults to
1303
                  config["binning"]["hist_mode"].
1304
                - **mode**: Defines how the results from each partition are combined.
1305
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1306
                  Defaults to config["binning"]["mode"].
1307
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1308
                  config["binning"]["pbar"].
1309
                - **n_cores**: Number of CPU cores to use for parallelization.
1310
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1311
                - **threads_per_worker**: Limit the number of threads that
1312
                  multiprocessing can spawn per binning thread. Defaults to
1313
                  config["binning"]["threads_per_worker"].
1314
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1315
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1316
                  config["binning"]["threadpool_API"].
1317
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1318
                  partitions.
1319

1320
                Additional kwds are passed to ``bin_dataframe``.
1321

1322
        Raises:
1323
            AssertError: Rises when no dataframe has been loaded.
1324

1325
        Returns:
1326
            xr.DataArray: The result of the n-dimensional binning represented in an
1327
            xarray object, combining the data with the axes.
1328
        """
1329
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1330

1331
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1332
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1333
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1334
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1335
        threads_per_worker = kwds.pop(
1✔
1336
            "threads_per_worker",
1337
            self._config["binning"]["threads_per_worker"],
1338
        )
1339
        threadpool_api = kwds.pop(
1✔
1340
            "threadpool_API",
1341
            self._config["binning"]["threadpool_API"],
1342
        )
1343
        df_partitions = kwds.pop("df_partitions", None)
1✔
1344
        if df_partitions is not None:
1✔
1345
            dataframe = self._dataframe.partitions[
1✔
1346
                0 : min(df_partitions, self._dataframe.npartitions)
1347
            ]
1348
        else:
1349
            dataframe = self._dataframe
×
1350

1351
        self._binned = bin_dataframe(
1✔
1352
            df=dataframe,
1353
            bins=bins,
1354
            axes=axes,
1355
            ranges=ranges,
1356
            hist_mode=hist_mode,
1357
            mode=mode,
1358
            pbar=pbar,
1359
            n_cores=num_cores,
1360
            threads_per_worker=threads_per_worker,
1361
            threadpool_api=threadpool_api,
1362
            **kwds,
1363
        )
1364

1365
        for dim in self._binned.dims:
1✔
1366
            try:
1✔
1367
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1368
            except KeyError:
×
1369
                pass
×
1370

1371
        self._binned.attrs["units"] = "counts"
1✔
1372
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1373
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1374

1375
        return self._binned
1✔
1376

1377
    def view_event_histogram(
1✔
1378
        self,
1379
        dfpid: int,
1380
        ncol: int = 2,
1381
        bins: Sequence[int] = None,
1382
        axes: Sequence[str] = None,
1383
        ranges: Sequence[Tuple[float, float]] = None,
1384
        backend: str = "bokeh",
1385
        legend: bool = True,
1386
        histkwds: dict = None,
1387
        legkwds: dict = None,
1388
        **kwds,
1389
    ):
1390
        """Plot individual histograms of specified dimensions (axes) from a substituent
1391
        dataframe partition.
1392

1393
        Args:
1394
            dfpid (int): Number of the data frame partition to look at.
1395
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1396
            bins (Sequence[int], optional): Number of bins to use for the speicified
1397
                axes. Defaults to config["histogram"]["bins"].
1398
            axes (Sequence[str], optional): Names of the axes to display.
1399
                Defaults to config["histogram"]["axes"].
1400
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1401
                specified axes. Defaults toconfig["histogram"]["ranges"].
1402
            backend (str, optional): Backend of the plotting library
1403
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1404
            legend (bool, optional): Option to include a legend in the histogram plots.
1405
                Defaults to True.
1406
            histkwds (dict, optional): Keyword arguments for histograms
1407
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1408
            legkwds (dict, optional): Keyword arguments for legend
1409
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1410
            **kwds: Extra keyword arguments passed to
1411
                ``sed.diagnostics.grid_histogram()``.
1412

1413
        Raises:
1414
            TypeError: Raises when the input values are not of the correct type.
1415
        """
1416
        if bins is None:
×
1417
            bins = self._config["histogram"]["bins"]
×
1418
        if axes is None:
×
1419
            axes = self._config["histogram"]["axes"]
×
1420
        if ranges is None:
×
1421
            ranges = self._config["histogram"]["ranges"]
×
1422

1423
        input_types = map(type, [axes, bins, ranges])
×
1424
        allowed_types = [list, tuple]
×
1425

1426
        df = self._dataframe
×
1427

1428
        if not set(input_types).issubset(allowed_types):
×
1429
            raise TypeError(
×
1430
                "Inputs of axes, bins, ranges need to be list or tuple!",
1431
            )
1432

1433
        # Read out the values for the specified groups
1434
        group_dict_dd = {}
×
1435
        dfpart = df.get_partition(dfpid)
×
1436
        cols = dfpart.columns
×
1437
        for ax in axes:
×
1438
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
×
1439
        group_dict = ddf.compute(group_dict_dd)[0]
×
1440

1441
        # Plot multiple histograms in a grid
1442
        grid_histogram(
×
1443
            group_dict,
1444
            ncol=ncol,
1445
            rvs=axes,
1446
            rvbins=bins,
1447
            rvranges=ranges,
1448
            backend=backend,
1449
            legend=legend,
1450
            histkwds=histkwds,
1451
            legkwds=legkwds,
1452
            **kwds,
1453
        )
1454

1455
    def save(
1✔
1456
        self,
1457
        faddr: str,
1458
        **kwds,
1459
    ):
1460
        """Saves the binned data to the provided path and filename.
1461

1462
        Args:
1463
            faddr (str): Path and name of the file to write. Its extension determines
1464
                the file type to write. Valid file types are:
1465

1466
                - "*.tiff", "*.tif": Saves a TIFF stack.
1467
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1468
                - "*.nxs", "*.nexus": Saves a NeXus file.
1469

1470
            **kwds: Keyword argumens, which are passed to the writer functions:
1471
                For TIFF writing:
1472

1473
                - **alias_dict**: Dictionary of dimension aliases to use.
1474

1475
                For HDF5 writing:
1476

1477
                - **mode**: hdf5 read/write mode. Defaults to "w".
1478

1479
                For NeXus:
1480

1481
                - **reader**: Name of the nexustools reader to use.
1482
                  Defaults to config["nexus"]["reader"]
1483
                - **definiton**: NeXus application definition to use for saving.
1484
                  Must be supported by the used ``reader``. Defaults to
1485
                  config["nexus"]["definition"]
1486
                - **input_files**: A list of input files to pass to the reader.
1487
                  Defaults to config["nexus"]["input_files"]
1488
        """
1489
        if self._binned is None:
×
1490
            raise NameError("Need to bin data first!")
×
1491

1492
        extension = pathlib.Path(faddr).suffix
×
1493

1494
        if extension in (".tif", ".tiff"):
×
1495
            to_tiff(
×
1496
                data=self._binned,
1497
                faddr=faddr,
1498
                **kwds,
1499
            )
1500
        elif extension in (".h5", ".hdf5"):
×
1501
            to_h5(
×
1502
                data=self._binned,
1503
                faddr=faddr,
1504
                **kwds,
1505
            )
1506
        elif extension in (".nxs", ".nexus"):
×
1507
            reader = kwds.pop("reader", self._config["nexus"]["reader"])
×
1508
            definition = kwds.pop(
×
1509
                "definition",
1510
                self._config["nexus"]["definition"],
1511
            )
1512
            input_files = kwds.pop(
×
1513
                "input_files",
1514
                self._config["nexus"]["input_files"],
1515
            )
1516
            if isinstance(input_files, str):
×
1517
                input_files = [input_files]
×
1518

1519
            to_nexus(
×
1520
                data=self._binned,
1521
                faddr=faddr,
1522
                reader=reader,
1523
                definition=definition,
1524
                input_files=input_files,
1525
                **kwds,
1526
            )
1527

1528
        else:
1529
            raise NotImplementedError(
×
1530
                f"Unrecognized file format: {extension}.",
1531
            )
1532

1533
    def add_dimension(self, name: str, axis_range: Tuple):
1✔
1534
        """Add a dimension axis.
1535

1536
        Args:
1537
            name (str): name of the axis
1538
            axis_range (Tuple): range for the axis.
1539

1540
        Raises:
1541
            ValueError: Raised if an axis with that name already exists.
1542
        """
1543
        if name in self._coordinates:
×
1544
            raise ValueError(f"Axis {name} already exists")
×
1545

1546
        self.axis[name] = self.make_axis(axis_range)
×
1547

1548
    def make_axis(self, axis_range: Tuple) -> np.ndarray:
1✔
1549
        """Function to make an axis.
1550

1551
        Args:
1552
            axis_range (Tuple): range for the new axis.
1553
        """
1554

1555
        # TODO: What shall this function do?
1556
        return np.arange(*axis_range)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc