• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 5867450023

pending completion
5867450023

push

github

web-flow
Merge pull request #134 from OpenCOMPES/tof_range_selector

Tof range selector

79 of 79 new or added lines in 3 files covered. (100.0%)

3026 of 4093 relevant lines covered (73.93%)

2.22 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

35.46
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
3✔
5
from typing import Any
3✔
6
from typing import cast
3✔
7
from typing import Dict
3✔
8
from typing import List
3✔
9
from typing import Sequence
3✔
10
from typing import Tuple
3✔
11
from typing import Union
3✔
12

13
import dask.dataframe as ddf
3✔
14
import matplotlib.pyplot as plt
3✔
15
import numpy as np
3✔
16
import pandas as pd
3✔
17
import psutil
3✔
18
import xarray as xr
3✔
19

20
from sed.binning import bin_dataframe
3✔
21
from sed.calibrator import DelayCalibrator
3✔
22
from sed.calibrator import EnergyCalibrator
3✔
23
from sed.calibrator import MomentumCorrector
3✔
24
from sed.core.config import parse_config
3✔
25
from sed.core.config import save_config
3✔
26
from sed.core.dfops import apply_jitter
3✔
27
from sed.core.metadata import MetaHandler
3✔
28
from sed.diagnostics import grid_histogram
3✔
29
from sed.io import to_h5
3✔
30
from sed.io import to_nexus
3✔
31
from sed.io import to_tiff
3✔
32
from sed.loader import CopyTool
3✔
33
from sed.loader import get_loader
3✔
34

35
N_CPU = psutil.cpu_count()
3✔
36

37

38
class SedProcessor:
3✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
3✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to the reader.
86
        """
87
        self._config = parse_config(config, **kwds)
3✔
88
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
3✔
89
        if num_cores >= N_CPU:
3✔
90
            num_cores = N_CPU - 1
3✔
91
        self._config["binning"]["num_cores"] = num_cores
3✔
92

93
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
3✔
94
        self._files: List[str] = []
3✔
95

96
        self._binned: xr.DataArray = None
3✔
97
        self._pre_binned: xr.DataArray = None
3✔
98

99
        self._dimensions: List[str] = []
3✔
100
        self._coordinates: Dict[Any, Any] = {}
3✔
101
        self.axis: Dict[Any, Any] = {}
3✔
102
        self._attributes = MetaHandler(meta=metadata)
3✔
103

104
        loader_name = self._config["core"]["loader"]
3✔
105
        self.loader = get_loader(
3✔
106
            loader_name=loader_name,
107
            config=self._config,
108
        )
109

110
        self.ec = EnergyCalibrator(
3✔
111
            loader=self.loader,
112
            config=self._config,
113
        )
114

115
        self.mc = MomentumCorrector(
3✔
116
            config=self._config,
117
        )
118

119
        self.dc = DelayCalibrator(
3✔
120
            config=self._config,
121
        )
122

123
        self.use_copy_tool = self._config.get("core", {}).get(
3✔
124
            "use_copy_tool",
125
            False,
126
        )
127
        if self.use_copy_tool:
3✔
128
            try:
×
129
                self.ct = CopyTool(
×
130
                    source=self._config["core"]["copy_tool_source"],
131
                    dest=self._config["core"]["copy_tool_dest"],
132
                    **self._config["core"].get("copy_tool_kwds", {}),
133
                )
134
            except KeyError:
×
135
                self.use_copy_tool = False
×
136

137
        # Load data if provided:
138
        if dataframe is not None or files is not None or folder is not None or runs is not None:
3✔
139
            self.load(
×
140
                dataframe=dataframe,
141
                metadata=metadata,
142
                files=files,
143
                folder=folder,
144
                runs=runs,
145
                collect_metadata=collect_metadata,
146
                **kwds,
147
            )
148

149
    def __repr__(self):
3✔
150
        if self._dataframe is None:
×
151
            df_str = "Data Frame: No Data loaded"
×
152
        else:
153
            df_str = self._dataframe.__repr__()
×
154
        coordinates_str = f"Coordinates: {self._coordinates}"
×
155
        dimensions_str = f"Dimensions: {self._dimensions}"
×
156
        pretty_str = df_str + "\n" + coordinates_str + "\n" + dimensions_str
×
157
        return pretty_str
×
158

159
    def __getitem__(self, val: str) -> pd.DataFrame:
3✔
160
        """Accessor to the underlying data structure.
161

162
        Args:
163
            val (str): Name of the dataframe column to retrieve.
164

165
        Returns:
166
            pd.DataFrame: Selected dataframe column.
167
        """
168
        return self._dataframe[val]
×
169

170
    @property
3✔
171
    def config(self) -> Dict[Any, Any]:
3✔
172
        """Getter attribute for the config dictionary
173

174
        Returns:
175
            Dict: The config dictionary.
176
        """
177
        return self._config
×
178

179
    @config.setter
3✔
180
    def config(self, config: Union[dict, str]):
3✔
181
        """Setter function for the config dictionary.
182

183
        Args:
184
            config (Union[dict, str]): Config dictionary or path of config file
185
                to load.
186
        """
187
        self._config = parse_config(config)
×
188
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
×
189
        if num_cores >= N_CPU:
×
190
            num_cores = N_CPU - 1
×
191
        self._config["binning"]["num_cores"] = num_cores
×
192

193
    @property
3✔
194
    def dimensions(self) -> list:
3✔
195
        """Getter attribute for the dimensions.
196

197
        Returns:
198
            list: List of dimensions.
199
        """
200
        return self._dimensions
×
201

202
    @dimensions.setter
3✔
203
    def dimensions(self, dims: list):
3✔
204
        """Setter function for the dimensions.
205

206
        Args:
207
            dims (list): List of dimensions to set.
208
        """
209
        assert isinstance(dims, list)
×
210
        self._dimensions = dims
×
211

212
    @property
3✔
213
    def coordinates(self) -> dict:
3✔
214
        """Getter attribute for the coordinates dict.
215

216
        Returns:
217
            dict: Dictionary of coordinates.
218
        """
219
        return self._coordinates
×
220

221
    @coordinates.setter
3✔
222
    def coordinates(self, coords: dict):
3✔
223
        """Setter function for the coordinates dict
224

225
        Args:
226
            coords (dict): Dictionary of coordinates.
227
        """
228
        assert isinstance(coords, dict)
×
229
        self._coordinates = {}
×
230
        for k, v in coords.items():
×
231
            self._coordinates[k] = xr.DataArray(v)
×
232

233
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
3✔
234
        """Function to mirror a list of files or a folder from a network drive to a
235
        local storage. Returns either the original or the copied path to the given
236
        path. The option to use this functionality is set by
237
        config["core"]["use_copy_tool"].
238

239
        Args:
240
            path (Union[str, List[str]]): Source path or path list.
241

242
        Returns:
243
            Union[str, List[str]]: Source or destination path or path list.
244
        """
245
        if self.use_copy_tool:
3✔
246
            if isinstance(path, list):
×
247
                path_out = []
×
248
                for file in path:
×
249
                    path_out.append(self.ct.copy(file))
×
250
                return path_out
×
251

252
            return self.ct.copy(path)
×
253

254
        if isinstance(path, list):
3✔
255
            return path
3✔
256

257
        return path
×
258

259
    def load(
3✔
260
        self,
261
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
262
        metadata: dict = None,
263
        files: List[str] = None,
264
        folder: str = None,
265
        runs: Sequence[str] = None,
266
        collect_metadata: bool = False,
267
        **kwds,
268
    ):
269
        """Load tabular data of single events into the dataframe object in the class.
270

271
        Args:
272
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
273
                format. Accepts anything which can be interpreted by pd.DataFrame as
274
                an input. Defaults to None.
275
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
276
            files (List[str], optional): List of file paths to pass to the loader.
277
                Defaults to None.
278
            runs (Sequence[str], optional): List of run identifiers to pass to the
279
                loader. Defaults to None.
280
            folder (str, optional): Folder path to pass to the loader.
281
                Defaults to None.
282

283
        Raises:
284
            ValueError: Raised if no valid input is provided.
285
        """
286
        if metadata is None:
3✔
287
            metadata = {}
3✔
288
        if dataframe is not None:
3✔
289
            self._dataframe = dataframe
×
290
        elif runs is not None:
3✔
291
            # If runs are provided, we only use the copy tool if also folder is provided.
292
            # In that case, we copy the whole provided base folder tree, and pass the copied
293
            # version to the loader as base folder to look for the runs.
294
            if folder is not None:
×
295
                dataframe, metadata = self.loader.read_dataframe(
×
296
                    folders=cast(str, self.cpy(folder)),
297
                    runs=runs,
298
                    metadata=metadata,
299
                    collect_metadata=collect_metadata,
300
                    **kwds,
301
                )
302
            else:
303
                dataframe, metadata = self.loader.read_dataframe(
×
304
                    runs=runs,
305
                    metadata=metadata,
306
                    collect_metadata=collect_metadata,
307
                    **kwds,
308
                )
309

310
        elif folder is not None:
3✔
311
            dataframe, metadata = self.loader.read_dataframe(
×
312
                folders=cast(str, self.cpy(folder)),
313
                metadata=metadata,
314
                collect_metadata=collect_metadata,
315
                **kwds,
316
            )
317

318
        elif files is not None:
3✔
319
            dataframe, metadata = self.loader.read_dataframe(
3✔
320
                files=cast(List[str], self.cpy(files)),
321
                metadata=metadata,
322
                collect_metadata=collect_metadata,
323
                **kwds,
324
            )
325

326
        else:
327
            raise ValueError(
×
328
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
329
            )
330

331
        self._dataframe = dataframe
3✔
332
        self._files = self.loader.files
3✔
333

334
        for key in metadata:
3✔
335
            self._attributes.add(
×
336
                entry=metadata[key],
337
                name=key,
338
                duplicate_policy="merge",
339
            )
340

341
    # Momentum calibration workflow
342
    # 1. Bin raw detector data for distortion correction
343
    def bin_and_load_momentum_calibration(
3✔
344
        self,
345
        df_partitions: int = 100,
346
        axes: List[str] = None,
347
        bins: List[int] = None,
348
        ranges: Sequence[Tuple[float, float]] = None,
349
        plane: int = 0,
350
        width: int = 5,
351
        apply: bool = False,
352
        **kwds,
353
    ):
354
        """1st step of momentum correction work flow. Function to do an initial binning
355
        of the dataframe loaded to the class, slice a plane from it using an
356
        interactive view, and load it into the momentum corrector class.
357

358
        Args:
359
            df_partitions (int, optional): Number of dataframe partitions to use for
360
                the initial binning. Defaults to 100.
361
            axes (List[str], optional): Axes to bin.
362
                Defaults to config["momentum"]["axes"].
363
            bins (List[int], optional): Bin numbers to use for binning.
364
                Defaults to config["momentum"]["bins"].
365
            ranges (List[Tuple], optional): Ranges to use for binning.
366
                Defaults to config["momentum"]["ranges"].
367
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
368
            width (int, optional): Initial value for the width slider. Defaults to 5.
369
            apply (bool, optional): Option to directly apply the values and select the
370
                slice. Defaults to False.
371
            **kwds: Keyword argument passed to the pre_binning function.
372
        """
373
        self._pre_binned = self.pre_binning(
3✔
374
            df_partitions=df_partitions,
375
            axes=axes,
376
            bins=bins,
377
            ranges=ranges,
378
            **kwds,
379
        )
380

381
        self.mc.load_data(data=self._pre_binned)
3✔
382
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
3✔
383

384
    # 2. Generate the spline warp correction from momentum features.
385
    # Either autoselect features, or input features from view above.
386
    def define_features(
3✔
387
        self,
388
        features: np.ndarray = None,
389
        rotation_symmetry: int = 6,
390
        auto_detect: bool = False,
391
        include_center: bool = True,
392
        apply: bool = False,
393
        **kwds,
394
    ):
395
        """2. Step of the distortion correction workflow: Define feature points in
396
        momentum space. They can be either manually selected using a GUI tool, be
397
        ptovided as list of feature points, or auto-generated using a
398
        feature-detection algorithm.
399

400
        Args:
401
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
402
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
403
                Defaults to 6.
404
            auto_detect (bool, optional): Whether to auto-detect the features.
405
                Defaults to False.
406
            include_center (bool, optional): Option to include a point at the center
407
                in the feature list. Defaults to True.
408
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
409
                MomentumCorrector.feature_select()
410
        """
411
        if auto_detect:  # automatic feature selection
×
412
            sigma = kwds.pop(
×
413
                "sigma",
414
                self._config.get("momentum", {}).get("sigma", 5),
415
            )
416
            fwhm = kwds.pop(
×
417
                "fwhm",
418
                self._config.get("momentum", {}).get("fwhm", 8),
419
            )
420
            sigma_radius = kwds.pop(
×
421
                "sigma_radius",
422
                self._config.get("momentum", {}).get("sigma_radius", 1),
423
            )
424
            self.mc.feature_extract(
×
425
                sigma=sigma,
426
                fwhm=fwhm,
427
                sigma_radius=sigma_radius,
428
                rotsym=rotation_symmetry,
429
                **kwds,
430
            )
431
            features = self.mc.peaks
×
432

433
        self.mc.feature_select(
×
434
            rotsym=rotation_symmetry,
435
            include_center=include_center,
436
            features=features,
437
            apply=apply,
438
            **kwds,
439
        )
440

441
    # 3. Generate the spline warp correction from momentum features.
442
    # If no features have been selected before, use class defaults.
443
    def generate_splinewarp(
3✔
444
        self,
445
        include_center: bool = True,
446
        **kwds,
447
    ):
448
        """3. Step of the distortion correction workflow: Generate the correction
449
        function restoring the symmetry in the image using a splinewarp algortihm.
450

451
        Args:
452
            include_center (bool, optional): Option to include the position of the
453
                center point in the correction. Defaults to True.
454
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
455
        """
456
        self.mc.spline_warp_estimate(include_center=include_center, **kwds)
×
457

458
        if self.mc.slice is not None:
×
459
            print("Original slice with reference features")
×
460
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
×
461

462
            print("Corrected slice with target features")
×
463
            self.mc.view(
×
464
                image=self.mc.slice_corrected,
465
                annotated=True,
466
                points={"feats": self.mc.ptargs},
467
                backend="bokeh",
468
                crosshair=True,
469
            )
470

471
            print("Original slice with target features")
×
472
            self.mc.view(
×
473
                image=self.mc.slice,
474
                points={"feats": self.mc.ptargs},
475
                annotated=True,
476
                backend="bokeh",
477
            )
478

479
    # 3a. Save spline-warp parameters to config file.
480
    def save_splinewarp(
3✔
481
        self,
482
        filename: str = None,
483
        overwrite: bool = False,
484
    ):
485
        """Save the generated spline-warp parameters to the folder config file.
486

487
        Args:
488
            filename (str, optional): Filename of the config dictionary to save to.
489
                Defaults to "sed_config.yaml" in the current folder.
490
            overwrite (bool, optional): Option to overwrite the present dictionary.
491
                Defaults to False.
492
        """
493
        if filename is None:
×
494
            filename = "sed_config.yaml"
×
495
        points = []
×
496
        try:
×
497
            for point in self.mc.pouter_ord:
×
498
                points.append([float(i) for i in point])
×
499
            if self.mc.pcent:
×
500
                points.append([float(i) for i in self.mc.pcent])
×
501
        except AttributeError as exc:
×
502
            raise AttributeError(
×
503
                "Momentum correction parameters not found, need to generate parameters first!",
504
            ) from exc
505
        config = {
×
506
            "momentum": {
507
                "correction": {"rotation_symmetry": self.mc.rotsym, "feature_points": points},
508
            },
509
        }
510
        save_config(config, filename, overwrite)
×
511

512
    # 4. Pose corrections. Provide interactive interface for correcting
513
    # scaling, shift and rotation
514
    def pose_adjustment(
3✔
515
        self,
516
        scale: float = 1,
517
        xtrans: float = 0,
518
        ytrans: float = 0,
519
        angle: float = 0,
520
        apply: bool = False,
521
        use_correction: bool = True,
522
    ):
523
        """3. step of the distortion correction workflow: Generate an interactive panel
524
        to adjust affine transformations that are applied to the image. Applies first
525
        a scaling, next an x/y translation, and last a rotation around the center of
526
        the image.
527

528
        Args:
529
            scale (float, optional): Initial value of the scaling slider.
530
                Defaults to 1.
531
            xtrans (float, optional): Initial value of the xtrans slider.
532
                Defaults to 0.
533
            ytrans (float, optional): Initial value of the ytrans slider.
534
                Defaults to 0.
535
            angle (float, optional): Initial value of the angle slider.
536
                Defaults to 0.
537
            apply (bool, optional): Option to directly apply the provided
538
                transformations. Defaults to False.
539
            use_correction (bool, option): Whether to use the spline warp correction
540
                or not. Defaults to True.
541
        """
542
        # Generate homomorphy as default if no distortion correction has been applied
543
        if self.mc.slice_corrected is None:
×
544
            if self.mc.slice is None:
×
545
                raise ValueError(
×
546
                    "No slice for corrections and transformations loaded!",
547
                )
548
            self.mc.slice_corrected = self.mc.slice
×
549

550
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
×
551
            # Generate default distortion correction
552
            self.mc.add_features()
×
553
            self.mc.spline_warp_estimate()
×
554

555
        if not use_correction:
×
556
            self.mc.reset_deformation()
×
557

558
        self.mc.pose_adjustment(
×
559
            scale=scale,
560
            xtrans=xtrans,
561
            ytrans=ytrans,
562
            angle=angle,
563
            apply=apply,
564
        )
565

566
    # 5. Apply the momentum correction to the dataframe
567
    def apply_momentum_correction(
3✔
568
        self,
569
        preview: bool = False,
570
    ):
571
        """Applies the distortion correction and pose adjustment (optional)
572
        to the dataframe.
573

574
        Args:
575
            rdeform_field (np.ndarray, optional): Row deformation field.
576
                Defaults to None.
577
            cdeform_field (np.ndarray, optional): Column deformation field.
578
                Defaults to None.
579
            inv_dfield (np.ndarray, optional): Inverse deformation field.
580
                Defaults to None.
581
            preview (bool): Option to preview the first elements of the data frame.
582
        """
583
        if self._dataframe is not None:
×
584
            print("Adding corrected X/Y columns to dataframe:")
×
585
            self._dataframe, metadata = self.mc.apply_corrections(
×
586
                df=self._dataframe,
587
            )
588
            # Add Metadata
589
            self._attributes.add(
×
590
                metadata,
591
                "momentum_correction",
592
                duplicate_policy="merge",
593
            )
594
            if preview:
×
595
                print(self._dataframe.head(10))
×
596
            else:
597
                print(self._dataframe)
×
598

599
    # Momentum calibration work flow
600
    # 1. Calculate momentum calibration
601
    def calibrate_momentum_axes(
3✔
602
        self,
603
        point_a: Union[np.ndarray, List[int]] = None,
604
        point_b: Union[np.ndarray, List[int]] = None,
605
        k_distance: float = None,
606
        k_coord_a: Union[np.ndarray, List[float]] = None,
607
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
608
        equiscale: bool = True,
609
        apply=False,
610
    ):
611
        """1. step of the momentum calibration workflow. Calibrate momentum
612
        axes using either provided pixel coordinates of a high-symmetry point and its
613
        distance to the BZ center, or the k-coordinates of two points in the BZ
614
        (depending on the equiscale option). Opens an interactive panel for selecting
615
        the points.
616

617
        Args:
618
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
619
                point used for momentum calibration.
620
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
621
                second point used for momentum calibration.
622
                Defaults to config["momentum"]["center_pixel"].
623
            k_distance (float, optional): Momentum distance between point a and b.
624
                Needs to be provided if no specific k-koordinates for the two points
625
                are given. Defaults to None.
626
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
627
                of the first point used for calibration. Used if equiscale is False.
628
                Defaults to None.
629
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
630
                of the second point used for calibration. Defaults to [0.0, 0.0].
631
            equiscale (bool, optional): Option to apply different scales to kx and ky.
632
                If True, the distance between points a and b, and the absolute
633
                position of point a are used for defining the scale. If False, the
634
                scale is calculated from the k-positions of both points a and b.
635
                Defaults to True.
636
            apply (bool, optional): Option to directly store the momentum calibration
637
                in the class. Defaults to False.
638
        """
639
        if point_b is None:
×
640
            point_b = self._config.get("momentum", {}).get(
×
641
                "center_pixel",
642
                [256, 256],
643
            )
644

645
        self.mc.select_k_range(
×
646
            point_a=point_a,
647
            point_b=point_b,
648
            k_distance=k_distance,
649
            k_coord_a=k_coord_a,
650
            k_coord_b=k_coord_b,
651
            equiscale=equiscale,
652
            apply=apply,
653
        )
654

655
    # 1a. Save momentum calibration parameters to config file.
656
    def save_momentum_calibration(
3✔
657
        self,
658
        filename: str = None,
659
        overwrite: bool = False,
660
    ):
661
        """Save the generated momentum calibration parameters to the folder config file.
662

663
        Args:
664
            filename (str, optional): Filename of the config dictionary to save to.
665
                Defaults to "sed_config.yaml" in the current folder.
666
            overwrite (bool, optional): Option to overwrite the present dictionary.
667
                Defaults to False.
668
        """
669
        if filename is None:
×
670
            filename = "sed_config.yaml"
×
671
        calibration = {}
×
672
        try:
×
673
            for key in [
×
674
                "kx_scale",
675
                "ky_scale",
676
                "x_center",
677
                "y_center",
678
                "rstart",
679
                "cstart",
680
                "rstep",
681
                "cstep",
682
            ]:
683
                calibration[key] = float(self.mc.calibration[key])
×
684
        except KeyError as exc:
×
685
            raise KeyError(
×
686
                "Momentum calibration parameters not found, need to generate parameters first!",
687
            ) from exc
688

689
        config = {"momentum": {"calibration": calibration}}
×
690
        save_config(config, filename, overwrite)
×
691

692
    # 2. Apply correction and calibration to the dataframe
693
    def apply_momentum_calibration(
3✔
694
        self,
695
        calibration: dict = None,
696
        preview: bool = False,
697
    ):
698
        """2. step of the momentum calibration work flow: Apply the momentum
699
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
700
        these are used.
701

702
        Args:
703
            calibration (dict, optional): Optional dictionary with calibration data to
704
                use. Defaults to None.
705
            preview (bool): Option to preview the first elements of the data frame.
706
        """
707
        if self._dataframe is not None:
×
708

709
            print("Adding kx/ky columns to dataframe:")
×
710
            self._dataframe, metadata = self.mc.append_k_axis(
×
711
                df=self._dataframe,
712
                calibration=calibration,
713
            )
714

715
            # Add Metadata
716
            self._attributes.add(
×
717
                metadata,
718
                "momentum_calibration",
719
                duplicate_policy="merge",
720
            )
721
            if preview:
×
722
                print(self._dataframe.head(10))
×
723
            else:
724
                print(self._dataframe)
×
725

726
    # Energy correction workflow
727
    # 1. Adjust the energy correction parameters
728
    def adjust_energy_correction(
3✔
729
        self,
730
        correction_type: str = None,
731
        amplitude: float = None,
732
        center: Tuple[float, float] = None,
733
        apply=False,
734
        **kwds,
735
    ):
736
        """1. step of the energy crrection workflow: Opens an interactive plot to
737
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
738
        they are not present yet.
739

740
        Args:
741
            correction_type (str, optional): Type of correction to apply to the TOF
742
                axis. Valid values are:
743

744
                - 'spherical'
745
                - 'Lorentzian'
746
                - 'Gaussian'
747
                - 'Lorentzian_asymmetric'
748

749
                Defaults to config["energy"]["correction_type"].
750
            amplitude (float, optional): Amplitude of the correction.
751
                Defaults to config["energy"]["correction"]["amplitude"].
752
            center (Tuple[float, float], optional): Center X/Y coordinates for the
753
                correction. Defaults to config["energy"]["correction"]["center"].
754
            apply (bool, optional): Option to directly apply the provided or default
755
                correction parameters. Defaults to False.
756
        """
757
        if self._pre_binned is None:
×
758
            print(
×
759
                "Pre-binned data not present, binning using defaults from config...",
760
            )
761
            self._pre_binned = self.pre_binning()
×
762

763
        self.ec.adjust_energy_correction(
×
764
            self._pre_binned,
765
            correction_type=correction_type,
766
            amplitude=amplitude,
767
            center=center,
768
            apply=apply,
769
            **kwds,
770
        )
771

772
    # 1a. Save energy correction parameters to config file.
773
    def save_energy_correction(
3✔
774
        self,
775
        filename: str = None,
776
        overwrite: bool = False,
777
    ):
778
        """Save the generated energy correction parameters to the folder config file.
779

780
        Args:
781
            filename (str, optional): Filename of the config dictionary to save to.
782
                Defaults to "sed_config.yaml" in the current folder.
783
            overwrite (bool, optional): Option to overwrite the present dictionary.
784
                Defaults to False.
785
        """
786
        if filename is None:
×
787
            filename = "sed_config.yaml"
×
788
        correction = {}
×
789
        try:
×
790
            for key in self.ec.correction.keys():
×
791
                if key == "correction_type":
×
792
                    correction[key] = self.ec.correction[key]
×
793
                elif key == "center":
×
794
                    correction[key] = [float(i) for i in self.ec.correction[key]]
×
795
                else:
796
                    correction[key] = float(self.ec.correction[key])
×
797
        except AttributeError as exc:
×
798
            raise AttributeError(
×
799
                "Energy correction parameters not found, need to generate parameters first!",
800
            ) from exc
801

802
        config = {"energy": {"correction": correction}}
×
803
        save_config(config, filename, overwrite)
×
804

805
    # 2. Apply energy correction to dataframe
806
    def apply_energy_correction(
3✔
807
        self,
808
        correction: dict = None,
809
        preview: bool = False,
810
        **kwds,
811
    ):
812
        """2. step of the energy correction workflow: Apply the enery correction
813
        parameters stored in the class to the dataframe.
814

815
        Args:
816
            correction (dict, optional): Dictionary containing the correction
817
                parameters. Defaults to config["energy"]["calibration"].
818
            preview (bool): Option to preview the first elements of the data frame.
819
            **kwds:
820
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
821
            preview (bool): Option to preview the first elements of the data frame.
822
            **kwds:
823
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
824
        """
825
        if self._dataframe is not None:
×
826
            print("Applying energy correction to dataframe...")
×
827
            self._dataframe, metadata = self.ec.apply_energy_correction(
×
828
                df=self._dataframe,
829
                correction=correction,
830
                **kwds,
831
            )
832

833
            # Add Metadata
834
            self._attributes.add(
×
835
                metadata,
836
                "energy_correction",
837
            )
838
            if preview:
×
839
                print(self._dataframe.head(10))
×
840
            else:
841
                print(self._dataframe)
×
842

843
    # Energy calibrator workflow
844
    # 1. Load and normalize data
845
    def load_bias_series(
3✔
846
        self,
847
        data_files: List[str],
848
        axes: List[str] = None,
849
        bins: List = None,
850
        ranges: Sequence[Tuple[float, float]] = None,
851
        biases: np.ndarray = None,
852
        bias_key: str = None,
853
        normalize: bool = None,
854
        span: int = None,
855
        order: int = None,
856
    ):
857
        """1. step of the energy calibration workflow: Load and bin data from
858
        single-event files.
859

860
        Args:
861
            data_files (List[str]): list of file paths to bin
862
            axes (List[str], optional): bin axes.
863
                Defaults to config["dataframe"]["tof_column"].
864
            bins (List, optional): number of bins.
865
                Defaults to config["energy"]["bins"].
866
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
867
                Defaults to config["energy"]["ranges"].
868
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
869
                voltages are extracted from the data files.
870
            bias_key (str, optional): hdf5 path where bias values are stored.
871
                Defaults to config["energy"]["bias_key"].
872
            normalize (bool, optional): Option to normalize traces.
873
                Defaults to config["energy"]["normalize"].
874
            span (int, optional): span smoothing parameters of the LOESS method
875
                (see ``scipy.signal.savgol_filter()``).
876
                Defaults to config["energy"]["normalize_span"].
877
            order (int, optional): order smoothing parameters of the LOESS method
878
                (see ``scipy.signal.savgol_filter()``).
879
                Defaults to config["energy"]["normalize_order"].
880
        """
881
        self.ec.bin_data(
×
882
            data_files=cast(List[str], self.cpy(data_files)),
883
            axes=axes,
884
            bins=bins,
885
            ranges=ranges,
886
            biases=biases,
887
            bias_key=bias_key,
888
        )
889
        if (normalize is not None and normalize is True) or (
×
890
            normalize is None and self._config.get("energy", {}).get("normalize", True)
891
        ):
892
            if span is None:
×
893
                span = self._config.get("energy", {}).get("normalize_span", 7)
×
894
            if order is None:
×
895
                order = self._config.get("energy", {}).get(
×
896
                    "normalize_order",
897
                    1,
898
                )
899
            self.ec.normalize(smooth=True, span=span, order=order)
×
900
        self.ec.view(
×
901
            traces=self.ec.traces_normed,
902
            xaxis=self.ec.tof,
903
            backend="bokeh",
904
        )
905

906
    # 2. extract ranges and get peak positions
907
    def find_bias_peaks(
3✔
908
        self,
909
        ranges: Union[List[Tuple], Tuple],
910
        ref_id: int = 0,
911
        infer_others: bool = True,
912
        mode: str = "replace",
913
        radius: int = None,
914
        peak_window: int = None,
915
        apply: bool = False,
916
    ):
917
        """2. step of the energy calibration workflow: Find a peak within a given range
918
        for the indicated reference trace, and tries to find the same peak for all
919
        other traces. Uses fast_dtw to align curves, which might not be too good if the
920
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
921
        middle of the set, and don't choose the range too narrow around the peak.
922
        Alternatively, a list of ranges for all traces can be provided.
923

924
        Args:
925
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
926
                Alternatively, a list of ranges for all traces can be given.
927
            refid (int, optional): The id of the trace the range refers to.
928
                Defaults to 0.
929
            infer_others (bool, optional): Whether to determine the range for the other
930
                traces. Defaults to True.
931
            mode (str, optional): Whether to "add" or "replace" existing ranges.
932
                Defaults to "replace".
933
            radius (int, optional): Radius parameter for fast_dtw.
934
                Defaults to config["energy"]["fastdtw_radius"].
935
            peak_window (int, optional): Peak_window parameter for the peak detection
936
                algorthm. amount of points that have to have to behave monotoneously
937
                around a peak. Defaults to config["energy"]["peak_window"].
938
            apply (bool, optional): Option to directly apply the provided parameters.
939
                Defaults to False.
940
        """
941
        if radius is None:
×
942
            radius = self._config.get("energy", {}).get("fastdtw_radius", 2)
×
943
        if peak_window is None:
×
944
            peak_window = self._config.get("energy", {}).get("peak_window", 7)
×
945
        if not infer_others:
×
946
            self.ec.add_ranges(
×
947
                ranges=ranges,
948
                ref_id=ref_id,
949
                infer_others=infer_others,
950
                mode=mode,
951
                radius=radius,
952
            )
953
            print(self.ec.featranges)
×
954
            try:
×
955
                self.ec.feature_extract(peak_window=peak_window)
×
956
                self.ec.view(
×
957
                    traces=self.ec.traces_normed,
958
                    segs=self.ec.featranges,
959
                    xaxis=self.ec.tof,
960
                    peaks=self.ec.peaks,
961
                    backend="bokeh",
962
                )
963
            except IndexError:
×
964
                print("Could not determine all peaks!")
×
965
                raise
×
966
        else:
967
            # New adjustment tool
968
            assert isinstance(ranges, tuple)
×
969
            self.ec.adjust_ranges(
×
970
                ranges=ranges,
971
                ref_id=ref_id,
972
                traces=self.ec.traces_normed,
973
                infer_others=infer_others,
974
                radius=radius,
975
                peak_window=peak_window,
976
                apply=apply,
977
            )
978

979
    # 3. Fit the energy calibration relation
980
    def calibrate_energy_axis(
3✔
981
        self,
982
        ref_id: int,
983
        ref_energy: float,
984
        method: str = None,
985
        energy_scale: str = None,
986
        **kwds,
987
    ):
988
        """3. Step of the energy calibration workflow: Calculate the calibration
989
        function for the energy axis, and apply it to the dataframe. Two
990
        approximations are implemented, a (normally 3rd order) polynomial
991
        approximation, and a d^2/(t-t0)^2 relation.
992

993
        Args:
994
            ref_id (int): id of the trace at the bias where the reference energy is
995
                given.
996
            ref_energy (float): Absolute energy of the detected feature at the bias
997
                of ref_id
998
            method (str, optional): Method for determining the energy calibration.
999

1000
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1001
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1002

1003
                Defaults to config["energy"]["calibration_method"]
1004
            energy_scale (str, optional): Direction of increasing energy scale.
1005

1006
                - **'kinetic'**: increasing energy with decreasing TOF.
1007
                - **'binding'**: increasing energy with increasing TOF.
1008

1009
                Defaults to config["energy"]["energy_scale"]
1010
        """
1011
        if method is None:
×
1012
            method = self._config.get("energy", {}).get(
×
1013
                "calibration_method",
1014
                "lmfit",
1015
            )
1016

1017
        if energy_scale is None:
×
1018
            energy_scale = self._config.get("energy", {}).get(
×
1019
                "energy_scale",
1020
                "kinetic",
1021
            )
1022

1023
        self.ec.calibrate(
×
1024
            ref_id=ref_id,
1025
            ref_energy=ref_energy,
1026
            method=method,
1027
            energy_scale=energy_scale,
1028
            **kwds,
1029
        )
1030
        print("Quality of Calibration:")
×
1031
        self.ec.view(
×
1032
            traces=self.ec.traces_normed,
1033
            xaxis=self.ec.calibration["axis"],
1034
            align=True,
1035
            energy_scale=energy_scale,
1036
            backend="bokeh",
1037
        )
1038
        print("E/TOF relationship:")
×
1039
        self.ec.view(
×
1040
            traces=self.ec.calibration["axis"][None, :],
1041
            xaxis=self.ec.tof,
1042
            backend="matplotlib",
1043
            show_legend=False,
1044
        )
1045
        if energy_scale == "kinetic":
×
1046
            plt.scatter(
×
1047
                self.ec.peaks[:, 0],
1048
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1049
                s=50,
1050
                c="k",
1051
            )
1052
        elif energy_scale == "binding":
×
1053
            plt.scatter(
×
1054
                self.ec.peaks[:, 0],
1055
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1056
                s=50,
1057
                c="k",
1058
            )
1059
        else:
1060
            raise ValueError(
×
1061
                'energy_scale needs to be either "binding" or "kinetic"',
1062
                f", got {energy_scale}.",
1063
            )
1064
        plt.xlabel("Time-of-flight", fontsize=15)
×
1065
        plt.ylabel("Energy (eV)", fontsize=15)
×
1066
        plt.show()
×
1067

1068
    # 3a. Save energy calibration parameters to config file.
1069
    def save_energy_calibration(
3✔
1070
        self,
1071
        filename: str = None,
1072
        overwrite: bool = False,
1073
    ):
1074
        """Save the generated energy calibration parameters to the folder config file.
1075

1076
        Args:
1077
            filename (str, optional): Filename of the config dictionary to save to.
1078
                Defaults to "sed_config.yaml" in the current folder.
1079
            overwrite (bool, optional): Option to overwrite the present dictionary.
1080
                Defaults to False.
1081
        """
1082
        if filename is None:
×
1083
            filename = "sed_config.yaml"
×
1084
        calibration = {}
×
1085
        try:
×
1086
            for (key, value) in self.ec.calibration.items():
×
1087
                if key in ["axis", "refid"]:
×
1088
                    continue
×
1089
                if key == "energy_scale":
×
1090
                    calibration[key] = value
×
1091
                else:
1092
                    calibration[key] = float(value)
×
1093
        except AttributeError as exc:
×
1094
            raise AttributeError(
×
1095
                "Energy calibration parameters not found, need to generate parameters first!",
1096
            ) from exc
1097

1098
        config = {"energy": {"calibration": calibration}}
×
1099
        save_config(config, filename, overwrite)
×
1100

1101
    # 4. Apply energy calibration to the dataframe
1102
    def append_energy_axis(
3✔
1103
        self,
1104
        calibration: dict = None,
1105
        preview: bool = False,
1106
        **kwds,
1107
    ):
1108
        """4. step of the energy calibration workflow: Apply the calibration function
1109
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1110
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1111
        can be provided.
1112

1113
        Args:
1114
            calibration (dict, optional): Calibration dict containing calibration
1115
                parameters. Overrides calibration from class or config.
1116
                Defaults to None.
1117
            preview (bool): Option to preview the first elements of the data frame.
1118
            **kwds:
1119
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1120
        """
1121
        if self._dataframe is not None:
×
1122
            print("Adding energy column to dataframe:")
×
1123
            self._dataframe, metadata = self.ec.append_energy_axis(
×
1124
                df=self._dataframe,
1125
                calibration=calibration,
1126
                **kwds,
1127
            )
1128

1129
            # Add Metadata
1130
            self._attributes.add(
×
1131
                metadata,
1132
                "energy_calibration",
1133
                duplicate_policy="merge",
1134
            )
1135
            if preview:
×
1136
                print(self._dataframe.head(10))
×
1137
            else:
1138
                print(self._dataframe)
×
1139

1140
    # Delay calibration function
1141
    def calibrate_delay_axis(
3✔
1142
        self,
1143
        delay_range: Tuple[float, float] = None,
1144
        datafile: str = None,
1145
        preview: bool = False,
1146
        **kwds,
1147
    ):
1148
        """Append delay column to dataframe. Either provide delay ranges, or read
1149
        them from a file.
1150

1151
        Args:
1152
            delay_range (Tuple[float, float], optional): The scanned delay range in
1153
                picoseconds. Defaults to None.
1154
            datafile (str, optional): The file from which to read the delay ranges.
1155
                Defaults to None.
1156
            preview (bool): Option to preview the first elements of the data frame.
1157
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1158
        """
1159
        if self._dataframe is not None:
×
1160
            print("Adding delay column to dataframe:")
×
1161

1162
            if delay_range is not None:
×
1163
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1164
                    self._dataframe,
1165
                    delay_range=delay_range,
1166
                    **kwds,
1167
                )
1168
            else:
1169
                if datafile is None:
×
1170
                    try:
×
1171
                        datafile = self._files[0]
×
1172
                    except IndexError:
×
1173
                        print(
×
1174
                            "No datafile available, specify eihter",
1175
                            " 'datafile' or 'delay_range'",
1176
                        )
1177
                        raise
×
1178

1179
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1180
                    self._dataframe,
1181
                    datafile=datafile,
1182
                    **kwds,
1183
                )
1184

1185
            # Add Metadata
1186
            self._attributes.add(
×
1187
                metadata,
1188
                "delay_calibration",
1189
                duplicate_policy="merge",
1190
            )
1191
            if preview:
×
1192
                print(self._dataframe.head(10))
×
1193
            else:
1194
                print(self._dataframe)
×
1195

1196
    def add_jitter(self, cols: Sequence[str] = None):
3✔
1197
        """Add jitter to the selected dataframe columns.
1198

1199
        Args:
1200
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1201
                Defaults to config["dataframe"]["jitter_cols"].
1202
        """
1203
        if cols is None:
×
1204
            cols = self._config.get("dataframe", {}).get(
×
1205
                "jitter_cols",
1206
                self._dataframe.columns,
1207
            )  # jitter all columns
1208

1209
        self._dataframe = self._dataframe.map_partitions(
×
1210
            apply_jitter,
1211
            cols=cols,
1212
            cols_jittered=cols,
1213
        )
1214
        metadata = []
×
1215
        for col in cols:
×
1216
            metadata.append(col)
×
1217
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
×
1218

1219
    def pre_binning(
3✔
1220
        self,
1221
        df_partitions: int = 100,
1222
        axes: List[str] = None,
1223
        bins: List[int] = None,
1224
        ranges: Sequence[Tuple[float, float]] = None,
1225
        **kwds,
1226
    ) -> xr.DataArray:
1227
        """Function to do an initial binning of the dataframe loaded to the class.
1228

1229
        Args:
1230
            df_partitions (int, optional): Number of dataframe partitions to use for
1231
                the initial binning. Defaults to 100.
1232
            axes (List[str], optional): Axes to bin.
1233
                Defaults to config["momentum"]["axes"].
1234
            bins (List[int], optional): Bin numbers to use for binning.
1235
                Defaults to config["momentum"]["bins"].
1236
            ranges (List[Tuple], optional): Ranges to use for binning.
1237
                Defaults to config["momentum"]["ranges"].
1238
            **kwds: Keyword argument passed to ``compute``.
1239

1240
        Returns:
1241
            xr.DataArray: pre-binned data-array.
1242
        """
1243
        if axes is None:
3✔
1244
            axes = self._config.get("momentum", {}).get(
3✔
1245
                "axes",
1246
                ["@x_column, @y_column, @tof_column"],
1247
            )
1248
        for loc, axis in enumerate(axes):
3✔
1249
            if axis.startswith("@"):
3✔
1250
                axes[loc] = self._config.get("dataframe").get(axis.strip("@"))
3✔
1251

1252
        if bins is None:
3✔
1253
            bins = self._config.get("momentum", {}).get(
3✔
1254
                "bins",
1255
                [512, 512, 300],
1256
            )
1257
        if ranges is None:
3✔
1258
            ranges_ = self._config.get("momentum", {}).get(
3✔
1259
                "ranges",
1260
                [[-256, 1792], [-256, 1792], [128000, 138000]],
1261
            )
1262
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
3✔
1263

1264
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1265

1266
        return self.compute(
3✔
1267
            bins=bins,
1268
            axes=axes,
1269
            ranges=ranges,
1270
            df_partitions=df_partitions,
1271
            **kwds,
1272
        )
1273

1274
    def compute(
3✔
1275
        self,
1276
        bins: Union[
1277
            int,
1278
            dict,
1279
            tuple,
1280
            List[int],
1281
            List[np.ndarray],
1282
            List[tuple],
1283
        ] = 100,
1284
        axes: Union[str, Sequence[str]] = None,
1285
        ranges: Sequence[Tuple[float, float]] = None,
1286
        **kwds,
1287
    ) -> xr.DataArray:
1288
        """Compute the histogram along the given dimensions.
1289

1290
        Args:
1291
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1292
                Definition of the bins. Can be any of the following cases:
1293

1294
                - an integer describing the number of bins in on all dimensions
1295
                - a tuple of 3 numbers describing start, end and step of the binning
1296
                  range
1297
                - a np.arrays defining the binning edges
1298
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1299
                - a dictionary made of the axes as keys and any of the above as values.
1300

1301
                This takes priority over the axes and range arguments. Defaults to 100.
1302
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1303
                on which to calculate the histogram. The order will be the order of the
1304
                dimensions in the resulting array. Defaults to None.
1305
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1306
                the start and end point of the binning range. Defaults to None.
1307
            **kwds: Keyword arguments:
1308

1309
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1310
                  ``bin_dataframe`` for details. Defaults to
1311
                  config["binning"]["hist_mode"].
1312
                - **mode**: Defines how the results from each partition are combined.
1313
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1314
                  Defaults to config["binning"]["mode"].
1315
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1316
                  config["binning"]["pbar"].
1317
                - **n_cores**: Number of CPU cores to use for parallelization.
1318
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1319
                - **threads_per_worker**: Limit the number of threads that
1320
                  multiprocessing can spawn per binning thread. Defaults to
1321
                  config["binning"]["threads_per_worker"].
1322
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1323
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1324
                  config["binning"]["threadpool_API"].
1325
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1326
                  partitions.
1327

1328
                Additional kwds are passed to ``bin_dataframe``.
1329

1330
        Raises:
1331
            AssertError: Rises when no dataframe has been loaded.
1332

1333
        Returns:
1334
            xr.DataArray: The result of the n-dimensional binning represented in an
1335
            xarray object, combining the data with the axes.
1336
        """
1337
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1338

1339
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
3✔
1340
        mode = kwds.pop("mode", self._config["binning"]["mode"])
3✔
1341
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
3✔
1342
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
3✔
1343
        threads_per_worker = kwds.pop(
3✔
1344
            "threads_per_worker",
1345
            self._config["binning"]["threads_per_worker"],
1346
        )
1347
        threadpool_api = kwds.pop(
3✔
1348
            "threadpool_API",
1349
            self._config["binning"]["threadpool_API"],
1350
        )
1351
        df_partitions = kwds.pop("df_partitions", None)
3✔
1352
        if df_partitions is not None:
3✔
1353
            dataframe = self._dataframe.partitions[
3✔
1354
                0 : min(df_partitions, self._dataframe.npartitions)
1355
            ]
1356
        else:
1357
            dataframe = self._dataframe
×
1358

1359
        self._binned = bin_dataframe(
3✔
1360
            df=dataframe,
1361
            bins=bins,
1362
            axes=axes,
1363
            ranges=ranges,
1364
            hist_mode=hist_mode,
1365
            mode=mode,
1366
            pbar=pbar,
1367
            n_cores=num_cores,
1368
            threads_per_worker=threads_per_worker,
1369
            threadpool_api=threadpool_api,
1370
            **kwds,
1371
        )
1372

1373
        for dim in self._binned.dims:
3✔
1374
            try:
3✔
1375
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
3✔
1376
            except KeyError:
×
1377
                pass
×
1378

1379
        self._binned.attrs["units"] = "counts"
3✔
1380
        self._binned.attrs["long_name"] = "photoelectron counts"
3✔
1381
        self._binned.attrs["metadata"] = self._attributes.metadata
3✔
1382

1383
        return self._binned
3✔
1384

1385
    def view_event_histogram(
3✔
1386
        self,
1387
        dfpid: int,
1388
        ncol: int = 2,
1389
        bins: Sequence[int] = None,
1390
        axes: Sequence[str] = None,
1391
        ranges: Sequence[Tuple[float, float]] = None,
1392
        backend: str = "bokeh",
1393
        legend: bool = True,
1394
        histkwds: dict = None,
1395
        legkwds: dict = None,
1396
        **kwds,
1397
    ):
1398
        """Plot individual histograms of specified dimensions (axes) from a substituent
1399
        dataframe partition.
1400

1401
        Args:
1402
            dfpid (int): Number of the data frame partition to look at.
1403
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1404
            bins (Sequence[int], optional): Number of bins to use for the speicified
1405
                axes. Defaults to config["histogram"]["bins"].
1406
            axes (Sequence[str], optional): Names of the axes to display.
1407
                Defaults to config["histogram"]["axes"].
1408
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1409
                specified axes. Defaults toconfig["histogram"]["ranges"].
1410
            backend (str, optional): Backend of the plotting library
1411
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1412
            legend (bool, optional): Option to include a legend in the histogram plots.
1413
                Defaults to True.
1414
            histkwds (dict, optional): Keyword arguments for histograms
1415
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1416
            legkwds (dict, optional): Keyword arguments for legend
1417
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1418
            **kwds: Extra keyword arguments passed to
1419
                ``sed.diagnostics.grid_histogram()``.
1420

1421
        Raises:
1422
            TypeError: Raises when the input values are not of the correct type.
1423
        """
1424
        if bins is None:
×
1425
            bins = self._config["histogram"]["bins"]
×
1426
        if axes is None:
×
1427
            axes = self._config["histogram"]["axes"]
×
1428
        if ranges is None:
×
1429
            ranges = self._config["histogram"]["ranges"]
×
1430

1431
        input_types = map(type, [axes, bins, ranges])
×
1432
        allowed_types = [list, tuple]
×
1433

1434
        df = self._dataframe
×
1435

1436
        if not set(input_types).issubset(allowed_types):
×
1437
            raise TypeError(
×
1438
                "Inputs of axes, bins, ranges need to be list or tuple!",
1439
            )
1440

1441
        # Read out the values for the specified groups
1442
        group_dict_dd = {}
×
1443
        dfpart = df.get_partition(dfpid)
×
1444
        cols = dfpart.columns
×
1445
        for ax in axes:
×
1446
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
×
1447
        group_dict = ddf.compute(group_dict_dd)[0]
×
1448

1449
        # Plot multiple histograms in a grid
1450
        grid_histogram(
×
1451
            group_dict,
1452
            ncol=ncol,
1453
            rvs=axes,
1454
            rvbins=bins,
1455
            rvranges=ranges,
1456
            backend=backend,
1457
            legend=legend,
1458
            histkwds=histkwds,
1459
            legkwds=legkwds,
1460
            **kwds,
1461
        )
1462

1463
    def save(
3✔
1464
        self,
1465
        faddr: str,
1466
        **kwds,
1467
    ):
1468
        """Saves the binned data to the provided path and filename.
1469

1470
        Args:
1471
            faddr (str): Path and name of the file to write. Its extension determines
1472
                the file type to write. Valid file types are:
1473

1474
                - "*.tiff", "*.tif": Saves a TIFF stack.
1475
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1476
                - "*.nxs", "*.nexus": Saves a NeXus file.
1477

1478
            **kwds: Keyword argumens, which are passed to the writer functions:
1479
                For TIFF writing:
1480

1481
                - **alias_dict**: Dictionary of dimension aliases to use.
1482

1483
                For HDF5 writing:
1484

1485
                - **mode**: hdf5 read/write mode. Defaults to "w".
1486

1487
                For NeXus:
1488

1489
                - **reader**: Name of the nexustools reader to use.
1490
                  Defaults to config["nexus"]["reader"]
1491
                - **definiton**: NeXus application definition to use for saving.
1492
                  Must be supported by the used ``reader``. Defaults to
1493
                  config["nexus"]["definition"]
1494
                - **input_files**: A list of input files to pass to the reader.
1495
                  Defaults to config["nexus"]["input_files"]
1496
        """
1497
        if self._binned is None:
×
1498
            raise NameError("Need to bin data first!")
×
1499

1500
        extension = pathlib.Path(faddr).suffix
×
1501

1502
        if extension in (".tif", ".tiff"):
×
1503
            to_tiff(
×
1504
                data=self._binned,
1505
                faddr=faddr,
1506
                **kwds,
1507
            )
1508
        elif extension in (".h5", ".hdf5"):
×
1509
            to_h5(
×
1510
                data=self._binned,
1511
                faddr=faddr,
1512
                **kwds,
1513
            )
1514
        elif extension in (".nxs", ".nexus"):
×
1515
            reader = kwds.pop("reader", self._config["nexus"]["reader"])
×
1516
            definition = kwds.pop(
×
1517
                "definition",
1518
                self._config["nexus"]["definition"],
1519
            )
1520
            input_files = kwds.pop(
×
1521
                "input_files",
1522
                self._config["nexus"]["input_files"],
1523
            )
1524
            if isinstance(input_files, str):
×
1525
                input_files = [input_files]
×
1526

1527
            to_nexus(
×
1528
                data=self._binned,
1529
                faddr=faddr,
1530
                reader=reader,
1531
                definition=definition,
1532
                input_files=input_files,
1533
                **kwds,
1534
            )
1535

1536
        else:
1537
            raise NotImplementedError(
×
1538
                f"Unrecognized file format: {extension}.",
1539
            )
1540

1541
    def add_dimension(self, name: str, axis_range: Tuple):
3✔
1542
        """Add a dimension axis.
1543

1544
        Args:
1545
            name (str): name of the axis
1546
            axis_range (Tuple): range for the axis.
1547

1548
        Raises:
1549
            ValueError: Raised if an axis with that name already exists.
1550
        """
1551
        if name in self._coordinates:
×
1552
            raise ValueError(f"Axis {name} already exists")
×
1553

1554
        self.axis[name] = self.make_axis(axis_range)
×
1555

1556
    def make_axis(self, axis_range: Tuple) -> np.ndarray:
3✔
1557
        """Function to make an axis.
1558

1559
        Args:
1560
            axis_range (Tuple): range for the new axis.
1561
        """
1562

1563
        # TODO: What shall this function do?
1564
        return np.arange(*axis_range)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc