• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6216173306

17 Sep 2023 10:21PM UTC coverage: 90.198% (+16.2%) from 74.035%
6216173306

Pull #143

github

web-flow
Merge ae3e42ece into 9b8cae0cf
Pull Request #143: Processor tests

515 of 515 new or added lines in 11 files covered. (100.0%)

4104 of 4550 relevant lines covered (90.2%)

2.7 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.52
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
3✔
5
from typing import Any
3✔
6
from typing import cast
3✔
7
from typing import Dict
3✔
8
from typing import List
3✔
9
from typing import Sequence
3✔
10
from typing import Tuple
3✔
11
from typing import Union
3✔
12

13
import dask.dataframe as ddf
3✔
14
import matplotlib.pyplot as plt
3✔
15
import numpy as np
3✔
16
import pandas as pd
3✔
17
import psutil
3✔
18
import xarray as xr
3✔
19

20
from sed.binning import bin_dataframe
3✔
21
from sed.calibrator import DelayCalibrator
3✔
22
from sed.calibrator import EnergyCalibrator
3✔
23
from sed.calibrator import MomentumCorrector
3✔
24
from sed.core.config import parse_config
3✔
25
from sed.core.config import save_config
3✔
26
from sed.core.dfops import apply_jitter
3✔
27
from sed.core.metadata import MetaHandler
3✔
28
from sed.diagnostics import grid_histogram
3✔
29
from sed.io import to_h5
3✔
30
from sed.io import to_nexus
3✔
31
from sed.io import to_tiff
3✔
32
from sed.loader import CopyTool
3✔
33
from sed.loader import get_loader
3✔
34

35
N_CPU = psutil.cpu_count()
3✔
36

37

38
class SedProcessor:
3✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
3✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to parse_config and to the reader.
86
        """
87
        config_kwds = {
3✔
88
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
89
        }
90
        for key in config_kwds.keys():
3✔
91
            del kwds[key]
3✔
92
        self._config = parse_config(config, **config_kwds)
3✔
93
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
3✔
94
        if num_cores >= N_CPU:
3✔
95
            num_cores = N_CPU - 1
3✔
96
        self._config["binning"]["num_cores"] = num_cores
3✔
97

98
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
3✔
99
        self._files: List[str] = []
3✔
100

101
        self._binned: xr.DataArray = None
3✔
102
        self._pre_binned: xr.DataArray = None
3✔
103

104
        self._attributes = MetaHandler(meta=metadata)
3✔
105

106
        loader_name = self._config["core"]["loader"]
3✔
107
        self.loader = get_loader(
3✔
108
            loader_name=loader_name,
109
            config=self._config,
110
        )
111

112
        self.ec = EnergyCalibrator(
3✔
113
            loader=self.loader,
114
            config=self._config,
115
        )
116

117
        self.mc = MomentumCorrector(
3✔
118
            config=self._config,
119
        )
120

121
        self.dc = DelayCalibrator(
3✔
122
            config=self._config,
123
        )
124

125
        self.use_copy_tool = self._config.get("core", {}).get(
3✔
126
            "use_copy_tool",
127
            False,
128
        )
129
        if self.use_copy_tool:
3✔
130
            try:
3✔
131
                self.ct = CopyTool(
3✔
132
                    source=self._config["core"]["copy_tool_source"],
133
                    dest=self._config["core"]["copy_tool_dest"],
134
                    **self._config["core"].get("copy_tool_kwds", {}),
135
                )
136
            except KeyError:
3✔
137
                self.use_copy_tool = False
3✔
138

139
        # Load data if provided:
140
        if dataframe is not None or files is not None or folder is not None or runs is not None:
3✔
141
            self.load(
3✔
142
                dataframe=dataframe,
143
                metadata=metadata,
144
                files=files,
145
                folder=folder,
146
                runs=runs,
147
                collect_metadata=collect_metadata,
148
                **kwds,
149
            )
150

151
    def __repr__(self):
3✔
152
        if self._dataframe is None:
3✔
153
            df_str = "Data Frame: No Data loaded"
3✔
154
        else:
155
            df_str = self._dataframe.__repr__()
3✔
156
        attributes_str = f"Metadata: {self._attributes.metadata}"
3✔
157
        pretty_str = df_str + "\n" + attributes_str
3✔
158
        return pretty_str
3✔
159

160
    @property
3✔
161
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
3✔
162
        """Accessor to the underlying dataframe.
163

164
        Returns:
165
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
166
        """
167
        return self._dataframe
3✔
168

169
    @dataframe.setter
3✔
170
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
3✔
171
        """Setter for the underlying dataframe.
172

173
        Args:
174
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
175
        """
176
        self._dataframe = dataframe
3✔
177

178
    @property
3✔
179
    def attributes(self) -> dict:
3✔
180
        """Accessor to the metadata dict.
181

182
        Returns:
183
            dict: The metadata dict.
184
        """
185
        return self._attributes.metadata
3✔
186

187
    def add_attribute(self, attributes: dict, name: str, **kwds):
3✔
188
        """Function to add element to the attributes dict.
189

190
        Args:
191
            attributes (dict): The attributes dictionary object to add.
192
            name (str): Key under which to add the dictionary to the attributes.
193
        """
194
        self._attributes.add(
3✔
195
            entry=attributes,
196
            name=name,
197
            **kwds,
198
        )
199

200
    @property
3✔
201
    def config(self) -> Dict[Any, Any]:
3✔
202
        """Getter attribute for the config dictionary
203

204
        Returns:
205
            Dict: The config dictionary.
206
        """
207
        return self._config
3✔
208

209
    @property
3✔
210
    def files(self) -> List[str]:
3✔
211
        """Getter attribute for the list of files
212

213
        Returns:
214
            List[str]: The list of loaded files
215
        """
216
        return self._files
3✔
217

218
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
3✔
219
        """Function to mirror a list of files or a folder from a network drive to a
220
        local storage. Returns either the original or the copied path to the given
221
        path. The option to use this functionality is set by
222
        config["core"]["use_copy_tool"].
223

224
        Args:
225
            path (Union[str, List[str]]): Source path or path list.
226

227
        Returns:
228
            Union[str, List[str]]: Source or destination path or path list.
229
        """
230
        if self.use_copy_tool:
3✔
231
            if isinstance(path, list):
3✔
232
                path_out = []
3✔
233
                for file in path:
3✔
234
                    path_out.append(self.ct.copy(file))
3✔
235
                return path_out
3✔
236

237
            return self.ct.copy(path)
×
238

239
        if isinstance(path, list):
3✔
240
            return path
3✔
241

242
        return path
3✔
243

244
    def load(
3✔
245
        self,
246
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
247
        metadata: dict = None,
248
        files: List[str] = None,
249
        folder: str = None,
250
        runs: Sequence[str] = None,
251
        collect_metadata: bool = False,
252
        **kwds,
253
    ):
254
        """Load tabular data of single events into the dataframe object in the class.
255

256
        Args:
257
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
258
                format. Accepts anything which can be interpreted by pd.DataFrame as
259
                an input. Defaults to None.
260
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
261
            files (List[str], optional): List of file paths to pass to the loader.
262
                Defaults to None.
263
            runs (Sequence[str], optional): List of run identifiers to pass to the
264
                loader. Defaults to None.
265
            folder (str, optional): Folder path to pass to the loader.
266
                Defaults to None.
267

268
        Raises:
269
            ValueError: Raised if no valid input is provided.
270
        """
271
        if metadata is None:
3✔
272
            metadata = {}
3✔
273
        if dataframe is not None:
3✔
274
            self._dataframe = dataframe
3✔
275
        elif runs is not None:
3✔
276
            # If runs are provided, we only use the copy tool if also folder is provided.
277
            # In that case, we copy the whole provided base folder tree, and pass the copied
278
            # version to the loader as base folder to look for the runs.
279
            if folder is not None:
3✔
280
                dataframe, metadata = self.loader.read_dataframe(
3✔
281
                    folders=cast(str, self.cpy(folder)),
282
                    runs=runs,
283
                    metadata=metadata,
284
                    collect_metadata=collect_metadata,
285
                    **kwds,
286
                )
287
            else:
288
                dataframe, metadata = self.loader.read_dataframe(
×
289
                    runs=runs,
290
                    metadata=metadata,
291
                    collect_metadata=collect_metadata,
292
                    **kwds,
293
                )
294

295
        elif folder is not None:
3✔
296
            dataframe, metadata = self.loader.read_dataframe(
3✔
297
                folders=cast(str, self.cpy(folder)),
298
                metadata=metadata,
299
                collect_metadata=collect_metadata,
300
                **kwds,
301
            )
302

303
        elif files is not None:
3✔
304
            dataframe, metadata = self.loader.read_dataframe(
3✔
305
                files=cast(List[str], self.cpy(files)),
306
                metadata=metadata,
307
                collect_metadata=collect_metadata,
308
                **kwds,
309
            )
310

311
        else:
312
            raise ValueError(
3✔
313
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
314
            )
315

316
        self._dataframe = dataframe
3✔
317
        self._files = self.loader.files
3✔
318

319
        for key in metadata:
3✔
320
            self._attributes.add(
3✔
321
                entry=metadata[key],
322
                name=key,
323
                duplicate_policy="merge",
324
            )
325

326
    # Momentum calibration workflow
327
    # 1. Bin raw detector data for distortion correction
328
    def bin_and_load_momentum_calibration(
3✔
329
        self,
330
        df_partitions: int = 100,
331
        axes: List[str] = None,
332
        bins: List[int] = None,
333
        ranges: Sequence[Tuple[float, float]] = None,
334
        plane: int = 0,
335
        width: int = 5,
336
        apply: bool = False,
337
        **kwds,
338
    ):
339
        """1st step of momentum correction work flow. Function to do an initial binning
340
        of the dataframe loaded to the class, slice a plane from it using an
341
        interactive view, and load it into the momentum corrector class.
342

343
        Args:
344
            df_partitions (int, optional): Number of dataframe partitions to use for
345
                the initial binning. Defaults to 100.
346
            axes (List[str], optional): Axes to bin.
347
                Defaults to config["momentum"]["axes"].
348
            bins (List[int], optional): Bin numbers to use for binning.
349
                Defaults to config["momentum"]["bins"].
350
            ranges (List[Tuple], optional): Ranges to use for binning.
351
                Defaults to config["momentum"]["ranges"].
352
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
353
            width (int, optional): Initial value for the width slider. Defaults to 5.
354
            apply (bool, optional): Option to directly apply the values and select the
355
                slice. Defaults to False.
356
            **kwds: Keyword argument passed to the pre_binning function.
357
        """
358
        self._pre_binned = self.pre_binning(
3✔
359
            df_partitions=df_partitions,
360
            axes=axes,
361
            bins=bins,
362
            ranges=ranges,
363
            **kwds,
364
        )
365

366
        self.mc.load_data(data=self._pre_binned)
3✔
367
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
3✔
368

369
    # 2. Generate the spline warp correction from momentum features.
370
    # Either autoselect features, or input features from view above.
371
    def define_features(
3✔
372
        self,
373
        features: np.ndarray = None,
374
        rotation_symmetry: int = 6,
375
        auto_detect: bool = False,
376
        include_center: bool = True,
377
        apply: bool = False,
378
        **kwds,
379
    ):
380
        """2. Step of the distortion correction workflow: Define feature points in
381
        momentum space. They can be either manually selected using a GUI tool, be
382
        ptovided as list of feature points, or auto-generated using a
383
        feature-detection algorithm.
384

385
        Args:
386
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
387
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
388
                Defaults to 6.
389
            auto_detect (bool, optional): Whether to auto-detect the features.
390
                Defaults to False.
391
            include_center (bool, optional): Option to include a point at the center
392
                in the feature list. Defaults to True.
393
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
394
                MomentumCorrector.feature_select()
395
        """
396
        if auto_detect:  # automatic feature selection
3✔
397
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
398
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
399
            sigma_radius = kwds.pop(
×
400
                "sigma_radius",
401
                self._config["momentum"]["sigma_radius"],
402
            )
403
            self.mc.feature_extract(
×
404
                sigma=sigma,
405
                fwhm=fwhm,
406
                sigma_radius=sigma_radius,
407
                rotsym=rotation_symmetry,
408
                **kwds,
409
            )
410
            features = self.mc.peaks
×
411

412
        self.mc.feature_select(
3✔
413
            rotsym=rotation_symmetry,
414
            include_center=include_center,
415
            features=features,
416
            apply=apply,
417
            **kwds,
418
        )
419

420
    # 3. Generate the spline warp correction from momentum features.
421
    # If no features have been selected before, use class defaults.
422
    def generate_splinewarp(
3✔
423
        self,
424
        use_center: bool = None,
425
        **kwds,
426
    ):
427
        """3. Step of the distortion correction workflow: Generate the correction
428
        function restoring the symmetry in the image using a splinewarp algortihm.
429

430
        Args:
431
            use_center (bool, optional): Option to use the position of the
432
                center point in the correction. Default is read from config, or set to True.
433
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
434
        """
435
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
3✔
436

437
        if self.mc.slice is not None:
3✔
438
            print("Original slice with reference features")
3✔
439
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
3✔
440

441
            print("Corrected slice with target features")
3✔
442
            self.mc.view(
3✔
443
                image=self.mc.slice_corrected,
444
                annotated=True,
445
                points={"feats": self.mc.ptargs},
446
                backend="bokeh",
447
                crosshair=True,
448
            )
449

450
            print("Original slice with target features")
3✔
451
            self.mc.view(
3✔
452
                image=self.mc.slice,
453
                points={"feats": self.mc.ptargs},
454
                annotated=True,
455
                backend="bokeh",
456
            )
457

458
    # 3a. Save spline-warp parameters to config file.
459
    def save_splinewarp(
3✔
460
        self,
461
        filename: str = None,
462
        overwrite: bool = False,
463
    ):
464
        """Save the generated spline-warp parameters to the folder config file.
465

466
        Args:
467
            filename (str, optional): Filename of the config dictionary to save to.
468
                Defaults to "sed_config.yaml" in the current folder.
469
            overwrite (bool, optional): Option to overwrite the present dictionary.
470
                Defaults to False.
471
        """
472
        if filename is None:
3✔
473
            filename = "sed_config.yaml"
×
474
        points = []
3✔
475
        try:
3✔
476
            for point in self.mc.pouter_ord:
3✔
477
                points.append([float(i) for i in point])
3✔
478
            if self.mc.include_center:
3✔
479
                points.append([float(i) for i in self.mc.pcent])
3✔
480
        except AttributeError as exc:
×
481
            raise AttributeError(
×
482
                "Momentum correction parameters not found, need to generate parameters first!",
483
            ) from exc
484
        config = {
3✔
485
            "momentum": {
486
                "correction": {
487
                    "rotation_symmetry": self.mc.rotsym,
488
                    "feature_points": points,
489
                    "include_center": self.mc.include_center,
490
                    "use_center": self.mc.use_center,
491
                },
492
            },
493
        }
494
        save_config(config, filename, overwrite)
3✔
495

496
    # 4. Pose corrections. Provide interactive interface for correcting
497
    # scaling, shift and rotation
498
    def pose_adjustment(
3✔
499
        self,
500
        scale: float = 1,
501
        xtrans: float = 0,
502
        ytrans: float = 0,
503
        angle: float = 0,
504
        apply: bool = False,
505
        use_correction: bool = True,
506
    ):
507
        """3. step of the distortion correction workflow: Generate an interactive panel
508
        to adjust affine transformations that are applied to the image. Applies first
509
        a scaling, next an x/y translation, and last a rotation around the center of
510
        the image.
511

512
        Args:
513
            scale (float, optional): Initial value of the scaling slider.
514
                Defaults to 1.
515
            xtrans (float, optional): Initial value of the xtrans slider.
516
                Defaults to 0.
517
            ytrans (float, optional): Initial value of the ytrans slider.
518
                Defaults to 0.
519
            angle (float, optional): Initial value of the angle slider.
520
                Defaults to 0.
521
            apply (bool, optional): Option to directly apply the provided
522
                transformations. Defaults to False.
523
            use_correction (bool, option): Whether to use the spline warp correction
524
                or not. Defaults to True.
525
        """
526
        # Generate homomorphy as default if no distortion correction has been applied
527
        if self.mc.slice_corrected is None:
3✔
528
            if self.mc.slice is None:
3✔
529
                raise ValueError(
3✔
530
                    "No slice for corrections and transformations loaded!",
531
                )
532
            self.mc.slice_corrected = self.mc.slice
×
533

534
        if not use_correction:
3✔
535
            self.mc.reset_deformation()
3✔
536

537
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
3✔
538
            # Generate default distortion correction
539
            self.mc.add_features()
×
540
            self.mc.spline_warp_estimate()
×
541

542
        self.mc.pose_adjustment(
3✔
543
            scale=scale,
544
            xtrans=xtrans,
545
            ytrans=ytrans,
546
            angle=angle,
547
            apply=apply,
548
        )
549

550
    # 5. Apply the momentum correction to the dataframe
551
    def apply_momentum_correction(
3✔
552
        self,
553
        preview: bool = False,
554
    ):
555
        """Applies the distortion correction and pose adjustment (optional)
556
        to the dataframe.
557

558
        Args:
559
            rdeform_field (np.ndarray, optional): Row deformation field.
560
                Defaults to None.
561
            cdeform_field (np.ndarray, optional): Column deformation field.
562
                Defaults to None.
563
            inv_dfield (np.ndarray, optional): Inverse deformation field.
564
                Defaults to None.
565
            preview (bool): Option to preview the first elements of the data frame.
566
        """
567
        if self._dataframe is not None:
3✔
568
            print("Adding corrected X/Y columns to dataframe:")
3✔
569
            self._dataframe, metadata = self.mc.apply_corrections(
3✔
570
                df=self._dataframe,
571
            )
572
            # Add Metadata
573
            self._attributes.add(
3✔
574
                metadata,
575
                "momentum_correction",
576
                duplicate_policy="merge",
577
            )
578
            if preview:
3✔
579
                print(self._dataframe.head(10))
×
580
            else:
581
                print(self._dataframe)
3✔
582

583
    # Momentum calibration work flow
584
    # 1. Calculate momentum calibration
585
    def calibrate_momentum_axes(
3✔
586
        self,
587
        point_a: Union[np.ndarray, List[int]] = None,
588
        point_b: Union[np.ndarray, List[int]] = None,
589
        k_distance: float = None,
590
        k_coord_a: Union[np.ndarray, List[float]] = None,
591
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
592
        equiscale: bool = True,
593
        apply=False,
594
    ):
595
        """1. step of the momentum calibration workflow. Calibrate momentum
596
        axes using either provided pixel coordinates of a high-symmetry point and its
597
        distance to the BZ center, or the k-coordinates of two points in the BZ
598
        (depending on the equiscale option). Opens an interactive panel for selecting
599
        the points.
600

601
        Args:
602
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
603
                point used for momentum calibration.
604
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
605
                second point used for momentum calibration.
606
                Defaults to config["momentum"]["center_pixel"].
607
            k_distance (float, optional): Momentum distance between point a and b.
608
                Needs to be provided if no specific k-koordinates for the two points
609
                are given. Defaults to None.
610
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
611
                of the first point used for calibration. Used if equiscale is False.
612
                Defaults to None.
613
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
614
                of the second point used for calibration. Defaults to [0.0, 0.0].
615
            equiscale (bool, optional): Option to apply different scales to kx and ky.
616
                If True, the distance between points a and b, and the absolute
617
                position of point a are used for defining the scale. If False, the
618
                scale is calculated from the k-positions of both points a and b.
619
                Defaults to True.
620
            apply (bool, optional): Option to directly store the momentum calibration
621
                in the class. Defaults to False.
622
        """
623
        if point_b is None:
3✔
624
            point_b = self._config["momentum"]["center_pixel"]
3✔
625

626
        self.mc.select_k_range(
3✔
627
            point_a=point_a,
628
            point_b=point_b,
629
            k_distance=k_distance,
630
            k_coord_a=k_coord_a,
631
            k_coord_b=k_coord_b,
632
            equiscale=equiscale,
633
            apply=apply,
634
        )
635

636
    # 1a. Save momentum calibration parameters to config file.
637
    def save_momentum_calibration(
3✔
638
        self,
639
        filename: str = None,
640
        overwrite: bool = False,
641
    ):
642
        """Save the generated momentum calibration parameters to the folder config file.
643

644
        Args:
645
            filename (str, optional): Filename of the config dictionary to save to.
646
                Defaults to "sed_config.yaml" in the current folder.
647
            overwrite (bool, optional): Option to overwrite the present dictionary.
648
                Defaults to False.
649
        """
650
        if filename is None:
3✔
651
            filename = "sed_config.yaml"
3✔
652
        calibration = {}
3✔
653
        try:
3✔
654
            for key in [
3✔
655
                "kx_scale",
656
                "ky_scale",
657
                "x_center",
658
                "y_center",
659
                "rstart",
660
                "cstart",
661
                "rstep",
662
                "cstep",
663
            ]:
664
                calibration[key] = float(self.mc.calibration[key])
3✔
665
        except KeyError as exc:
×
666
            raise KeyError(
×
667
                "Momentum calibration parameters not found, need to generate parameters first!",
668
            ) from exc
669

670
        config = {"momentum": {"calibration": calibration}}
3✔
671
        save_config(config, filename, overwrite)
3✔
672

673
    # 2. Apply correction and calibration to the dataframe
674
    def apply_momentum_calibration(
3✔
675
        self,
676
        calibration: dict = None,
677
        preview: bool = False,
678
    ):
679
        """2. step of the momentum calibration work flow: Apply the momentum
680
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
681
        these are used.
682

683
        Args:
684
            calibration (dict, optional): Optional dictionary with calibration data to
685
                use. Defaults to None.
686
            preview (bool): Option to preview the first elements of the data frame.
687
        """
688
        if self._dataframe is not None:
3✔
689

690
            print("Adding kx/ky columns to dataframe:")
3✔
691
            self._dataframe, metadata = self.mc.append_k_axis(
3✔
692
                df=self._dataframe,
693
                calibration=calibration,
694
            )
695

696
            # Add Metadata
697
            self._attributes.add(
3✔
698
                metadata,
699
                "momentum_calibration",
700
                duplicate_policy="merge",
701
            )
702
            if preview:
3✔
703
                print(self._dataframe.head(10))
×
704
            else:
705
                print(self._dataframe)
3✔
706

707
    # Energy correction workflow
708
    # 1. Adjust the energy correction parameters
709
    def adjust_energy_correction(
3✔
710
        self,
711
        correction_type: str = None,
712
        amplitude: float = None,
713
        center: Tuple[float, float] = None,
714
        apply=False,
715
        **kwds,
716
    ):
717
        """1. step of the energy crrection workflow: Opens an interactive plot to
718
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
719
        they are not present yet.
720

721
        Args:
722
            correction_type (str, optional): Type of correction to apply to the TOF
723
                axis. Valid values are:
724

725
                - 'spherical'
726
                - 'Lorentzian'
727
                - 'Gaussian'
728
                - 'Lorentzian_asymmetric'
729

730
                Defaults to config["energy"]["correction_type"].
731
            amplitude (float, optional): Amplitude of the correction.
732
                Defaults to config["energy"]["correction"]["amplitude"].
733
            center (Tuple[float, float], optional): Center X/Y coordinates for the
734
                correction. Defaults to config["energy"]["correction"]["center"].
735
            apply (bool, optional): Option to directly apply the provided or default
736
                correction parameters. Defaults to False.
737
        """
738
        if self._pre_binned is None:
3✔
739
            print(
3✔
740
                "Pre-binned data not present, binning using defaults from config...",
741
            )
742
            self._pre_binned = self.pre_binning()
3✔
743

744
        self.ec.adjust_energy_correction(
3✔
745
            self._pre_binned,
746
            correction_type=correction_type,
747
            amplitude=amplitude,
748
            center=center,
749
            apply=apply,
750
            **kwds,
751
        )
752

753
    # 1a. Save energy correction parameters to config file.
754
    def save_energy_correction(
3✔
755
        self,
756
        filename: str = None,
757
        overwrite: bool = False,
758
    ):
759
        """Save the generated energy correction parameters to the folder config file.
760

761
        Args:
762
            filename (str, optional): Filename of the config dictionary to save to.
763
                Defaults to "sed_config.yaml" in the current folder.
764
            overwrite (bool, optional): Option to overwrite the present dictionary.
765
                Defaults to False.
766
        """
767
        if filename is None:
3✔
768
            filename = "sed_config.yaml"
3✔
769
        correction = {}
3✔
770
        try:
3✔
771
            for key, val in self.ec.correction.items():
3✔
772
                if key == "correction_type":
3✔
773
                    correction[key] = val
3✔
774
                elif key == "center":
3✔
775
                    correction[key] = [float(i) for i in val]
3✔
776
                else:
777
                    correction[key] = float(val)
3✔
778
        except AttributeError as exc:
×
779
            raise AttributeError(
×
780
                "Energy correction parameters not found, need to generate parameters first!",
781
            ) from exc
782

783
        config = {"energy": {"correction": correction}}
3✔
784
        save_config(config, filename, overwrite)
3✔
785

786
    # 2. Apply energy correction to dataframe
787
    def apply_energy_correction(
3✔
788
        self,
789
        correction: dict = None,
790
        preview: bool = False,
791
        **kwds,
792
    ):
793
        """2. step of the energy correction workflow: Apply the enery correction
794
        parameters stored in the class to the dataframe.
795

796
        Args:
797
            correction (dict, optional): Dictionary containing the correction
798
                parameters. Defaults to config["energy"]["calibration"].
799
            preview (bool): Option to preview the first elements of the data frame.
800
            **kwds:
801
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
802
            preview (bool): Option to preview the first elements of the data frame.
803
            **kwds:
804
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
805
        """
806
        if self._dataframe is not None:
3✔
807
            print("Applying energy correction to dataframe...")
3✔
808
            self._dataframe, metadata = self.ec.apply_energy_correction(
3✔
809
                df=self._dataframe,
810
                correction=correction,
811
                **kwds,
812
            )
813

814
            # Add Metadata
815
            self._attributes.add(
3✔
816
                metadata,
817
                "energy_correction",
818
            )
819
            if preview:
3✔
820
                print(self._dataframe.head(10))
×
821
            else:
822
                print(self._dataframe)
3✔
823

824
    # Energy calibrator workflow
825
    # 1. Load and normalize data
826
    def load_bias_series(
3✔
827
        self,
828
        data_files: List[str],
829
        axes: List[str] = None,
830
        bins: List = None,
831
        ranges: Sequence[Tuple[float, float]] = None,
832
        biases: np.ndarray = None,
833
        bias_key: str = None,
834
        normalize: bool = None,
835
        span: int = None,
836
        order: int = None,
837
    ):
838
        """1. step of the energy calibration workflow: Load and bin data from
839
        single-event files.
840

841
        Args:
842
            data_files (List[str]): list of file paths to bin
843
            axes (List[str], optional): bin axes.
844
                Defaults to config["dataframe"]["tof_column"].
845
            bins (List, optional): number of bins.
846
                Defaults to config["energy"]["bins"].
847
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
848
                Defaults to config["energy"]["ranges"].
849
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
850
                voltages are extracted from the data files.
851
            bias_key (str, optional): hdf5 path where bias values are stored.
852
                Defaults to config["energy"]["bias_key"].
853
            normalize (bool, optional): Option to normalize traces.
854
                Defaults to config["energy"]["normalize"].
855
            span (int, optional): span smoothing parameters of the LOESS method
856
                (see ``scipy.signal.savgol_filter()``).
857
                Defaults to config["energy"]["normalize_span"].
858
            order (int, optional): order smoothing parameters of the LOESS method
859
                (see ``scipy.signal.savgol_filter()``).
860
                Defaults to config["energy"]["normalize_order"].
861
        """
862
        self.ec.bin_data(
3✔
863
            data_files=cast(List[str], self.cpy(data_files)),
864
            axes=axes,
865
            bins=bins,
866
            ranges=ranges,
867
            biases=biases,
868
            bias_key=bias_key,
869
        )
870
        if (normalize is not None and normalize is True) or (
3✔
871
            normalize is None and self._config["energy"]["normalize"]
872
        ):
873
            if span is None:
3✔
874
                span = self._config["energy"]["normalize_span"]
3✔
875
            if order is None:
3✔
876
                order = self._config["energy"]["normalize_order"]
3✔
877
            self.ec.normalize(smooth=True, span=span, order=order)
3✔
878
        self.ec.view(
3✔
879
            traces=self.ec.traces_normed,
880
            xaxis=self.ec.tof,
881
            backend="bokeh",
882
        )
883

884
    # 2. extract ranges and get peak positions
885
    def find_bias_peaks(
3✔
886
        self,
887
        ranges: Union[List[Tuple], Tuple],
888
        ref_id: int = 0,
889
        infer_others: bool = True,
890
        mode: str = "replace",
891
        radius: int = None,
892
        peak_window: int = None,
893
        apply: bool = False,
894
    ):
895
        """2. step of the energy calibration workflow: Find a peak within a given range
896
        for the indicated reference trace, and tries to find the same peak for all
897
        other traces. Uses fast_dtw to align curves, which might not be too good if the
898
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
899
        middle of the set, and don't choose the range too narrow around the peak.
900
        Alternatively, a list of ranges for all traces can be provided.
901

902
        Args:
903
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
904
                Alternatively, a list of ranges for all traces can be given.
905
            refid (int, optional): The id of the trace the range refers to.
906
                Defaults to 0.
907
            infer_others (bool, optional): Whether to determine the range for the other
908
                traces. Defaults to True.
909
            mode (str, optional): Whether to "add" or "replace" existing ranges.
910
                Defaults to "replace".
911
            radius (int, optional): Radius parameter for fast_dtw.
912
                Defaults to config["energy"]["fastdtw_radius"].
913
            peak_window (int, optional): Peak_window parameter for the peak detection
914
                algorthm. amount of points that have to have to behave monotoneously
915
                around a peak. Defaults to config["energy"]["peak_window"].
916
            apply (bool, optional): Option to directly apply the provided parameters.
917
                Defaults to False.
918
        """
919
        if radius is None:
3✔
920
            radius = self._config["energy"]["fastdtw_radius"]
3✔
921
        if peak_window is None:
3✔
922
            peak_window = self._config["energy"]["peak_window"]
3✔
923
        if not infer_others:
3✔
924
            self.ec.add_ranges(
3✔
925
                ranges=ranges,
926
                ref_id=ref_id,
927
                infer_others=infer_others,
928
                mode=mode,
929
                radius=radius,
930
            )
931
            print(self.ec.featranges)
3✔
932
            try:
3✔
933
                self.ec.feature_extract(peak_window=peak_window)
3✔
934
                self.ec.view(
3✔
935
                    traces=self.ec.traces_normed,
936
                    segs=self.ec.featranges,
937
                    xaxis=self.ec.tof,
938
                    peaks=self.ec.peaks,
939
                    backend="bokeh",
940
                )
941
            except IndexError:
×
942
                print("Could not determine all peaks!")
×
943
                raise
×
944
        else:
945
            # New adjustment tool
946
            assert isinstance(ranges, tuple)
3✔
947
            self.ec.adjust_ranges(
3✔
948
                ranges=ranges,
949
                ref_id=ref_id,
950
                traces=self.ec.traces_normed,
951
                infer_others=infer_others,
952
                radius=radius,
953
                peak_window=peak_window,
954
                apply=apply,
955
            )
956

957
    # 3. Fit the energy calibration relation
958
    def calibrate_energy_axis(
3✔
959
        self,
960
        ref_id: int,
961
        ref_energy: float,
962
        method: str = None,
963
        energy_scale: str = None,
964
        **kwds,
965
    ):
966
        """3. Step of the energy calibration workflow: Calculate the calibration
967
        function for the energy axis, and apply it to the dataframe. Two
968
        approximations are implemented, a (normally 3rd order) polynomial
969
        approximation, and a d^2/(t-t0)^2 relation.
970

971
        Args:
972
            ref_id (int): id of the trace at the bias where the reference energy is
973
                given.
974
            ref_energy (float): Absolute energy of the detected feature at the bias
975
                of ref_id
976
            method (str, optional): Method for determining the energy calibration.
977

978
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
979
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
980

981
                Defaults to config["energy"]["calibration_method"]
982
            energy_scale (str, optional): Direction of increasing energy scale.
983

984
                - **'kinetic'**: increasing energy with decreasing TOF.
985
                - **'binding'**: increasing energy with increasing TOF.
986

987
                Defaults to config["energy"]["energy_scale"]
988
        """
989
        if method is None:
3✔
990
            method = self._config["energy"]["calibration_method"]
3✔
991

992
        if energy_scale is None:
3✔
993
            energy_scale = self._config["energy"]["energy_scale"]
3✔
994

995
        self.ec.calibrate(
3✔
996
            ref_id=ref_id,
997
            ref_energy=ref_energy,
998
            method=method,
999
            energy_scale=energy_scale,
1000
            **kwds,
1001
        )
1002
        print("Quality of Calibration:")
3✔
1003
        self.ec.view(
3✔
1004
            traces=self.ec.traces_normed,
1005
            xaxis=self.ec.calibration["axis"],
1006
            align=True,
1007
            energy_scale=energy_scale,
1008
            backend="bokeh",
1009
        )
1010
        print("E/TOF relationship:")
3✔
1011
        self.ec.view(
3✔
1012
            traces=self.ec.calibration["axis"][None, :],
1013
            xaxis=self.ec.tof,
1014
            backend="matplotlib",
1015
            show_legend=False,
1016
        )
1017
        if energy_scale == "kinetic":
3✔
1018
            plt.scatter(
3✔
1019
                self.ec.peaks[:, 0],
1020
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1021
                s=50,
1022
                c="k",
1023
            )
1024
        elif energy_scale == "binding":
3✔
1025
            plt.scatter(
3✔
1026
                self.ec.peaks[:, 0],
1027
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1028
                s=50,
1029
                c="k",
1030
            )
1031
        else:
1032
            raise ValueError(
×
1033
                'energy_scale needs to be either "binding" or "kinetic"',
1034
                f", got {energy_scale}.",
1035
            )
1036
        plt.xlabel("Time-of-flight", fontsize=15)
3✔
1037
        plt.ylabel("Energy (eV)", fontsize=15)
3✔
1038
        plt.show()
3✔
1039

1040
    # 3a. Save energy calibration parameters to config file.
1041
    def save_energy_calibration(
3✔
1042
        self,
1043
        filename: str = None,
1044
        overwrite: bool = False,
1045
    ):
1046
        """Save the generated energy calibration parameters to the folder config file.
1047

1048
        Args:
1049
            filename (str, optional): Filename of the config dictionary to save to.
1050
                Defaults to "sed_config.yaml" in the current folder.
1051
            overwrite (bool, optional): Option to overwrite the present dictionary.
1052
                Defaults to False.
1053
        """
1054
        if filename is None:
3✔
1055
            filename = "sed_config.yaml"
3✔
1056
        calibration = {}
3✔
1057
        try:
3✔
1058
            for (key, value) in self.ec.calibration.items():
3✔
1059
                if key in ["axis", "refid", "Tmat", "bvec"]:
3✔
1060
                    continue
3✔
1061
                if key == "energy_scale":
3✔
1062
                    calibration[key] = value
3✔
1063
                elif key == "coeffs":
3✔
1064
                    calibration[key] = [float(i) for i in value]
3✔
1065
                else:
1066
                    calibration[key] = float(value)
3✔
1067
        except AttributeError as exc:
×
1068
            raise AttributeError(
×
1069
                "Energy calibration parameters not found, need to generate parameters first!",
1070
            ) from exc
1071

1072
        config = {"energy": {"calibration": calibration}}
3✔
1073
        save_config(config, filename, overwrite)
3✔
1074

1075
    # 4. Apply energy calibration to the dataframe
1076
    def append_energy_axis(
3✔
1077
        self,
1078
        calibration: dict = None,
1079
        preview: bool = False,
1080
        **kwds,
1081
    ):
1082
        """4. step of the energy calibration workflow: Apply the calibration function
1083
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1084
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1085
        can be provided.
1086

1087
        Args:
1088
            calibration (dict, optional): Calibration dict containing calibration
1089
                parameters. Overrides calibration from class or config.
1090
                Defaults to None.
1091
            preview (bool): Option to preview the first elements of the data frame.
1092
            **kwds:
1093
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1094
        """
1095
        if self._dataframe is not None:
3✔
1096
            print("Adding energy column to dataframe:")
3✔
1097
            self._dataframe, metadata = self.ec.append_energy_axis(
3✔
1098
                df=self._dataframe,
1099
                calibration=calibration,
1100
                **kwds,
1101
            )
1102

1103
            # Add Metadata
1104
            self._attributes.add(
3✔
1105
                metadata,
1106
                "energy_calibration",
1107
                duplicate_policy="merge",
1108
            )
1109
            if preview:
3✔
1110
                print(self._dataframe.head(10))
3✔
1111
            else:
1112
                print(self._dataframe)
3✔
1113

1114
    # Delay calibration function
1115
    def calibrate_delay_axis(
3✔
1116
        self,
1117
        delay_range: Tuple[float, float] = None,
1118
        datafile: str = None,
1119
        preview: bool = False,
1120
        **kwds,
1121
    ):
1122
        """Append delay column to dataframe. Either provide delay ranges, or read
1123
        them from a file.
1124

1125
        Args:
1126
            delay_range (Tuple[float, float], optional): The scanned delay range in
1127
                picoseconds. Defaults to None.
1128
            datafile (str, optional): The file from which to read the delay ranges.
1129
                Defaults to None.
1130
            preview (bool): Option to preview the first elements of the data frame.
1131
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1132
        """
1133
        if self._dataframe is not None:
3✔
1134
            print("Adding delay column to dataframe:")
3✔
1135

1136
            if delay_range is not None:
3✔
1137
                self._dataframe, metadata = self.dc.append_delay_axis(
3✔
1138
                    self._dataframe,
1139
                    delay_range=delay_range,
1140
                    **kwds,
1141
                )
1142
            else:
1143
                if datafile is None:
3✔
1144
                    try:
3✔
1145
                        datafile = self._files[0]
3✔
1146
                    except IndexError:
×
1147
                        print(
×
1148
                            "No datafile available, specify eihter",
1149
                            " 'datafile' or 'delay_range'",
1150
                        )
1151
                        raise
×
1152

1153
                self._dataframe, metadata = self.dc.append_delay_axis(
3✔
1154
                    self._dataframe,
1155
                    datafile=datafile,
1156
                    **kwds,
1157
                )
1158

1159
            # Add Metadata
1160
            self._attributes.add(
3✔
1161
                metadata,
1162
                "delay_calibration",
1163
                duplicate_policy="merge",
1164
            )
1165
            if preview:
3✔
1166
                print(self._dataframe.head(10))
3✔
1167
            else:
1168
                print(self._dataframe)
3✔
1169

1170
    def add_jitter(self, cols: Sequence[str] = None):
3✔
1171
        """Add jitter to the selected dataframe columns.
1172

1173
        Args:
1174
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1175
                Defaults to config["dataframe"]["jitter_cols"].
1176
        """
1177
        if cols is None:
3✔
1178
            cols = self._config["dataframe"].get(
3✔
1179
                "jitter_cols",
1180
                self._dataframe.columns,
1181
            )  # jitter all columns
1182

1183
        self._dataframe = self._dataframe.map_partitions(
3✔
1184
            apply_jitter,
1185
            cols=cols,
1186
            cols_jittered=cols,
1187
        )
1188
        metadata = []
3✔
1189
        for col in cols:
3✔
1190
            metadata.append(col)
3✔
1191
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
3✔
1192

1193
    def pre_binning(
3✔
1194
        self,
1195
        df_partitions: int = 100,
1196
        axes: List[str] = None,
1197
        bins: List[int] = None,
1198
        ranges: Sequence[Tuple[float, float]] = None,
1199
        **kwds,
1200
    ) -> xr.DataArray:
1201
        """Function to do an initial binning of the dataframe loaded to the class.
1202

1203
        Args:
1204
            df_partitions (int, optional): Number of dataframe partitions to use for
1205
                the initial binning. Defaults to 100.
1206
            axes (List[str], optional): Axes to bin.
1207
                Defaults to config["momentum"]["axes"].
1208
            bins (List[int], optional): Bin numbers to use for binning.
1209
                Defaults to config["momentum"]["bins"].
1210
            ranges (List[Tuple], optional): Ranges to use for binning.
1211
                Defaults to config["momentum"]["ranges"].
1212
            **kwds: Keyword argument passed to ``compute``.
1213

1214
        Returns:
1215
            xr.DataArray: pre-binned data-array.
1216
        """
1217
        if axes is None:
3✔
1218
            axes = self._config["momentum"]["axes"]
3✔
1219
        for loc, axis in enumerate(axes):
3✔
1220
            if axis.startswith("@"):
3✔
1221
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
3✔
1222

1223
        if bins is None:
3✔
1224
            bins = self._config["momentum"]["bins"]
3✔
1225
        if ranges is None:
3✔
1226
            ranges_ = list(self._config["momentum"]["ranges"])
3✔
1227
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
3✔
1228
                self._config["dataframe"]["tof_binning"] - 1
1229
            )
1230
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
3✔
1231

1232
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1233

1234
        return self.compute(
3✔
1235
            bins=bins,
1236
            axes=axes,
1237
            ranges=ranges,
1238
            df_partitions=df_partitions,
1239
            **kwds,
1240
        )
1241

1242
    def compute(
3✔
1243
        self,
1244
        bins: Union[
1245
            int,
1246
            dict,
1247
            tuple,
1248
            List[int],
1249
            List[np.ndarray],
1250
            List[tuple],
1251
        ] = 100,
1252
        axes: Union[str, Sequence[str]] = None,
1253
        ranges: Sequence[Tuple[float, float]] = None,
1254
        **kwds,
1255
    ) -> xr.DataArray:
1256
        """Compute the histogram along the given dimensions.
1257

1258
        Args:
1259
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1260
                Definition of the bins. Can be any of the following cases:
1261

1262
                - an integer describing the number of bins in on all dimensions
1263
                - a tuple of 3 numbers describing start, end and step of the binning
1264
                  range
1265
                - a np.arrays defining the binning edges
1266
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1267
                - a dictionary made of the axes as keys and any of the above as values.
1268

1269
                This takes priority over the axes and range arguments. Defaults to 100.
1270
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1271
                on which to calculate the histogram. The order will be the order of the
1272
                dimensions in the resulting array. Defaults to None.
1273
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1274
                the start and end point of the binning range. Defaults to None.
1275
            **kwds: Keyword arguments:
1276

1277
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1278
                  ``bin_dataframe`` for details. Defaults to
1279
                  config["binning"]["hist_mode"].
1280
                - **mode**: Defines how the results from each partition are combined.
1281
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1282
                  Defaults to config["binning"]["mode"].
1283
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1284
                  config["binning"]["pbar"].
1285
                - **n_cores**: Number of CPU cores to use for parallelization.
1286
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1287
                - **threads_per_worker**: Limit the number of threads that
1288
                  multiprocessing can spawn per binning thread. Defaults to
1289
                  config["binning"]["threads_per_worker"].
1290
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1291
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1292
                  config["binning"]["threadpool_API"].
1293
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1294
                  partitions.
1295

1296
                Additional kwds are passed to ``bin_dataframe``.
1297

1298
        Raises:
1299
            AssertError: Rises when no dataframe has been loaded.
1300

1301
        Returns:
1302
            xr.DataArray: The result of the n-dimensional binning represented in an
1303
            xarray object, combining the data with the axes.
1304
        """
1305
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1306

1307
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
3✔
1308
        mode = kwds.pop("mode", self._config["binning"]["mode"])
3✔
1309
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
3✔
1310
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
3✔
1311
        threads_per_worker = kwds.pop(
3✔
1312
            "threads_per_worker",
1313
            self._config["binning"]["threads_per_worker"],
1314
        )
1315
        threadpool_api = kwds.pop(
3✔
1316
            "threadpool_API",
1317
            self._config["binning"]["threadpool_API"],
1318
        )
1319
        df_partitions = kwds.pop("df_partitions", None)
3✔
1320
        if df_partitions is not None:
3✔
1321
            dataframe = self._dataframe.partitions[
3✔
1322
                0 : min(df_partitions, self._dataframe.npartitions)
1323
            ]
1324
        else:
1325
            dataframe = self._dataframe
3✔
1326

1327
        self._binned = bin_dataframe(
3✔
1328
            df=dataframe,
1329
            bins=bins,
1330
            axes=axes,
1331
            ranges=ranges,
1332
            hist_mode=hist_mode,
1333
            mode=mode,
1334
            pbar=pbar,
1335
            n_cores=num_cores,
1336
            threads_per_worker=threads_per_worker,
1337
            threadpool_api=threadpool_api,
1338
            **kwds,
1339
        )
1340

1341
        for dim in self._binned.dims:
3✔
1342
            try:
3✔
1343
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
3✔
1344
            except KeyError:
3✔
1345
                pass
3✔
1346

1347
        self._binned.attrs["units"] = "counts"
3✔
1348
        self._binned.attrs["long_name"] = "photoelectron counts"
3✔
1349
        self._binned.attrs["metadata"] = self._attributes.metadata
3✔
1350

1351
        return self._binned
3✔
1352

1353
    def view_event_histogram(
3✔
1354
        self,
1355
        dfpid: int,
1356
        ncol: int = 2,
1357
        bins: Sequence[int] = None,
1358
        axes: Sequence[str] = None,
1359
        ranges: Sequence[Tuple[float, float]] = None,
1360
        backend: str = "bokeh",
1361
        legend: bool = True,
1362
        histkwds: dict = None,
1363
        legkwds: dict = None,
1364
        **kwds,
1365
    ):
1366
        """Plot individual histograms of specified dimensions (axes) from a substituent
1367
        dataframe partition.
1368

1369
        Args:
1370
            dfpid (int): Number of the data frame partition to look at.
1371
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1372
            bins (Sequence[int], optional): Number of bins to use for the speicified
1373
                axes. Defaults to config["histogram"]["bins"].
1374
            axes (Sequence[str], optional): Names of the axes to display.
1375
                Defaults to config["histogram"]["axes"].
1376
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1377
                specified axes. Defaults toconfig["histogram"]["ranges"].
1378
            backend (str, optional): Backend of the plotting library
1379
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1380
            legend (bool, optional): Option to include a legend in the histogram plots.
1381
                Defaults to True.
1382
            histkwds (dict, optional): Keyword arguments for histograms
1383
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1384
            legkwds (dict, optional): Keyword arguments for legend
1385
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1386
            **kwds: Extra keyword arguments passed to
1387
                ``sed.diagnostics.grid_histogram()``.
1388

1389
        Raises:
1390
            TypeError: Raises when the input values are not of the correct type.
1391
        """
1392
        if bins is None:
3✔
1393
            bins = self._config["histogram"]["bins"]
3✔
1394
        if axes is None:
3✔
1395
            axes = self._config["histogram"]["axes"]
3✔
1396
        axes = list(axes)
3✔
1397
        for loc, axis in enumerate(axes):
3✔
1398
            if axis.startswith("@"):
3✔
1399
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
3✔
1400
        if ranges is None:
3✔
1401
            ranges = list(self._config["histogram"]["ranges"])
3✔
1402
            for loc, axis in enumerate(axes):
3✔
1403
                if axis == self._config["dataframe"]["tof_column"]:
3✔
1404
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
3✔
1405
                        self._config["dataframe"]["tof_binning"] - 1
1406
                    )
1407
                elif axis == self._config["dataframe"]["adc_column"]:
3✔
1408
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1409
                        self._config["dataframe"]["adc_binning"] - 1
1410
                    )
1411

1412
        input_types = map(type, [axes, bins, ranges])
3✔
1413
        allowed_types = [list, tuple]
3✔
1414

1415
        df = self._dataframe
3✔
1416

1417
        if not set(input_types).issubset(allowed_types):
3✔
1418
            raise TypeError(
×
1419
                "Inputs of axes, bins, ranges need to be list or tuple!",
1420
            )
1421

1422
        # Read out the values for the specified groups
1423
        group_dict_dd = {}
3✔
1424
        dfpart = df.get_partition(dfpid)
3✔
1425
        cols = dfpart.columns
3✔
1426
        for ax in axes:
3✔
1427
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
3✔
1428
        group_dict = ddf.compute(group_dict_dd)[0]
3✔
1429

1430
        # Plot multiple histograms in a grid
1431
        grid_histogram(
3✔
1432
            group_dict,
1433
            ncol=ncol,
1434
            rvs=axes,
1435
            rvbins=bins,
1436
            rvranges=ranges,
1437
            backend=backend,
1438
            legend=legend,
1439
            histkwds=histkwds,
1440
            legkwds=legkwds,
1441
            **kwds,
1442
        )
1443

1444
    def save(
3✔
1445
        self,
1446
        faddr: str,
1447
        **kwds,
1448
    ):
1449
        """Saves the binned data to the provided path and filename.
1450

1451
        Args:
1452
            faddr (str): Path and name of the file to write. Its extension determines
1453
                the file type to write. Valid file types are:
1454

1455
                - "*.tiff", "*.tif": Saves a TIFF stack.
1456
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1457
                - "*.nxs", "*.nexus": Saves a NeXus file.
1458

1459
            **kwds: Keyword argumens, which are passed to the writer functions:
1460
                For TIFF writing:
1461

1462
                - **alias_dict**: Dictionary of dimension aliases to use.
1463

1464
                For HDF5 writing:
1465

1466
                - **mode**: hdf5 read/write mode. Defaults to "w".
1467

1468
                For NeXus:
1469

1470
                - **reader**: Name of the nexustools reader to use.
1471
                  Defaults to config["nexus"]["reader"]
1472
                - **definiton**: NeXus application definition to use for saving.
1473
                  Must be supported by the used ``reader``. Defaults to
1474
                  config["nexus"]["definition"]
1475
                - **input_files**: A list of input files to pass to the reader.
1476
                  Defaults to config["nexus"]["input_files"]
1477
        """
1478
        if self._binned is None:
3✔
1479
            raise NameError("Need to bin data first!")
3✔
1480

1481
        extension = pathlib.Path(faddr).suffix
3✔
1482

1483
        if extension in (".tif", ".tiff"):
3✔
1484
            to_tiff(
3✔
1485
                data=self._binned,
1486
                faddr=faddr,
1487
                **kwds,
1488
            )
1489
        elif extension in (".h5", ".hdf5"):
3✔
1490
            to_h5(
3✔
1491
                data=self._binned,
1492
                faddr=faddr,
1493
                **kwds,
1494
            )
1495
        elif extension in (".nxs", ".nexus"):
3✔
1496
            try:
3✔
1497
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
3✔
1498
                definition = kwds.pop(
3✔
1499
                    "definition",
1500
                    self._config["nexus"]["definition"],
1501
                )
1502
                input_files = kwds.pop(
3✔
1503
                    "input_files",
1504
                    self._config["nexus"]["input_files"],
1505
                )
1506
            except KeyError as exc:
×
1507
                raise ValueError(
×
1508
                    "The nexus reader, definition and input files need to be provide!",
1509
                ) from exc
1510

1511
            if isinstance(input_files, str):
3✔
1512
                input_files = [input_files]
3✔
1513

1514
            to_nexus(
3✔
1515
                data=self._binned,
1516
                faddr=faddr,
1517
                reader=reader,
1518
                definition=definition,
1519
                input_files=input_files,
1520
                **kwds,
1521
            )
1522

1523
        else:
1524
            raise NotImplementedError(
3✔
1525
                f"Unrecognized file format: {extension}.",
1526
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc