• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6191431816

14 Sep 2023 10:16PM UTC coverage: 73.785% (-0.3%) from 74.035%
6191431816

Pull #152

github

web-flow
Merge 6c6800d22 into 9b8cae0cf
Pull Request #152: Flash energy calibration

15 of 15 new or added lines in 1 file covered. (100.0%)

3051 of 4135 relevant lines covered (73.78%)

2.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

33.65
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
3✔
5
from typing import Any
3✔
6
from typing import cast
3✔
7
from typing import Dict
3✔
8
from typing import List
3✔
9
from typing import Sequence
3✔
10
from typing import Tuple
3✔
11
from typing import Union
3✔
12

13
import dask.dataframe as ddf
3✔
14
import matplotlib.pyplot as plt
3✔
15
import numpy as np
3✔
16
import pandas as pd
3✔
17
import psutil
3✔
18
import xarray as xr
3✔
19

20
from sed.binning import bin_dataframe
3✔
21
from sed.calibrator import DelayCalibrator
3✔
22
from sed.calibrator import EnergyCalibrator
3✔
23
from sed.calibrator import MomentumCorrector
3✔
24
from sed.core.config import parse_config
3✔
25
from sed.core.config import save_config
3✔
26
from sed.core.dfops import apply_jitter
3✔
27
from sed.core.metadata import MetaHandler
3✔
28
from sed.diagnostics import grid_histogram
3✔
29
from sed.io import to_h5
3✔
30
from sed.io import to_nexus
3✔
31
from sed.io import to_tiff
3✔
32
from sed.loader import CopyTool
3✔
33
from sed.loader import get_loader
3✔
34

35
N_CPU = psutil.cpu_count()
3✔
36

37

38
class SedProcessor:
3✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
3✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to the reader.
86
        """
87
        self._config = parse_config(config, **kwds)
3✔
88
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
3✔
89
        if num_cores >= N_CPU:
3✔
90
            num_cores = N_CPU - 1
3✔
91
        self._config["binning"]["num_cores"] = num_cores
3✔
92

93
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
3✔
94
        self._files: List[str] = []
3✔
95

96
        self._binned: xr.DataArray = None
3✔
97
        self._pre_binned: xr.DataArray = None
3✔
98

99
        self._dimensions: List[str] = []
3✔
100
        self._coordinates: Dict[Any, Any] = {}
3✔
101
        self.axis: Dict[Any, Any] = {}
3✔
102
        self._attributes = MetaHandler(meta=metadata)
3✔
103

104
        loader_name = self._config["core"]["loader"]
3✔
105
        self.loader = get_loader(
3✔
106
            loader_name=loader_name,
107
            config=self._config,
108
        )
109

110
        self.ec = EnergyCalibrator(
3✔
111
            loader=self.loader,
112
            config=self._config,
113
        )
114

115
        self.mc = MomentumCorrector(
3✔
116
            config=self._config,
117
        )
118

119
        self.dc = DelayCalibrator(
3✔
120
            config=self._config,
121
        )
122

123
        self.use_copy_tool = self._config.get("core", {}).get(
3✔
124
            "use_copy_tool",
125
            False,
126
        )
127
        if self.use_copy_tool:
3✔
128
            try:
×
129
                self.ct = CopyTool(
×
130
                    source=self._config["core"]["copy_tool_source"],
131
                    dest=self._config["core"]["copy_tool_dest"],
132
                    **self._config["core"].get("copy_tool_kwds", {}),
133
                )
134
            except KeyError:
×
135
                self.use_copy_tool = False
×
136

137
        # Load data if provided:
138
        if dataframe is not None or files is not None or folder is not None or runs is not None:
3✔
139
            self.load(
×
140
                dataframe=dataframe,
141
                metadata=metadata,
142
                files=files,
143
                folder=folder,
144
                runs=runs,
145
                collect_metadata=collect_metadata,
146
                **kwds,
147
            )
148

149
    def __repr__(self):
3✔
150
        if self._dataframe is None:
×
151
            df_str = "Data Frame: No Data loaded"
×
152
        else:
153
            df_str = self._dataframe.__repr__()
×
154
        coordinates_str = f"Coordinates: {self._coordinates}"
×
155
        dimensions_str = f"Dimensions: {self._dimensions}"
×
156
        pretty_str = df_str + "\n" + coordinates_str + "\n" + dimensions_str
×
157
        return pretty_str
×
158

159
    def __getitem__(self, val: str) -> pd.DataFrame:
3✔
160
        """Accessor to the underlying data structure.
161

162
        Args:
163
            val (str): Name of the dataframe column to retrieve.
164

165
        Returns:
166
            pd.DataFrame: Selected dataframe column.
167
        """
168
        return self._dataframe[val]
×
169

170
    @property
3✔
171
    def config(self) -> Dict[Any, Any]:
3✔
172
        """Getter attribute for the config dictionary
173

174
        Returns:
175
            Dict: The config dictionary.
176
        """
177
        return self._config
×
178

179
    @config.setter
3✔
180
    def config(self, config: Union[dict, str]):
3✔
181
        """Setter function for the config dictionary.
182

183
        Args:
184
            config (Union[dict, str]): Config dictionary or path of config file
185
                to load.
186
        """
187
        self._config = parse_config(config)
×
188
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
×
189
        if num_cores >= N_CPU:
×
190
            num_cores = N_CPU - 1
×
191
        self._config["binning"]["num_cores"] = num_cores
×
192

193
    @property
3✔
194
    def dimensions(self) -> list:
3✔
195
        """Getter attribute for the dimensions.
196

197
        Returns:
198
            list: List of dimensions.
199
        """
200
        return self._dimensions
×
201

202
    @dimensions.setter
3✔
203
    def dimensions(self, dims: list):
3✔
204
        """Setter function for the dimensions.
205

206
        Args:
207
            dims (list): List of dimensions to set.
208
        """
209
        assert isinstance(dims, list)
×
210
        self._dimensions = dims
×
211

212
    @property
3✔
213
    def coordinates(self) -> dict:
3✔
214
        """Getter attribute for the coordinates dict.
215

216
        Returns:
217
            dict: Dictionary of coordinates.
218
        """
219
        return self._coordinates
×
220

221
    @coordinates.setter
3✔
222
    def coordinates(self, coords: dict):
3✔
223
        """Setter function for the coordinates dict
224

225
        Args:
226
            coords (dict): Dictionary of coordinates.
227
        """
228
        assert isinstance(coords, dict)
×
229
        self._coordinates = {}
×
230
        for k, v in coords.items():
×
231
            self._coordinates[k] = xr.DataArray(v)
×
232

233
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
3✔
234
        """Function to mirror a list of files or a folder from a network drive to a
235
        local storage. Returns either the original or the copied path to the given
236
        path. The option to use this functionality is set by
237
        config["core"]["use_copy_tool"].
238

239
        Args:
240
            path (Union[str, List[str]]): Source path or path list.
241

242
        Returns:
243
            Union[str, List[str]]: Source or destination path or path list.
244
        """
245
        if self.use_copy_tool:
3✔
246
            if isinstance(path, list):
×
247
                path_out = []
×
248
                for file in path:
×
249
                    path_out.append(self.ct.copy(file))
×
250
                return path_out
×
251

252
            return self.ct.copy(path)
×
253

254
        if isinstance(path, list):
3✔
255
            return path
3✔
256

257
        return path
×
258

259
    def load(
3✔
260
        self,
261
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
262
        metadata: dict = None,
263
        files: List[str] = None,
264
        folder: str = None,
265
        runs: Sequence[str] = None,
266
        collect_metadata: bool = False,
267
        **kwds,
268
    ):
269
        """Load tabular data of single events into the dataframe object in the class.
270

271
        Args:
272
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
273
                format. Accepts anything which can be interpreted by pd.DataFrame as
274
                an input. Defaults to None.
275
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
276
            files (List[str], optional): List of file paths to pass to the loader.
277
                Defaults to None.
278
            runs (Sequence[str], optional): List of run identifiers to pass to the
279
                loader. Defaults to None.
280
            folder (str, optional): Folder path to pass to the loader.
281
                Defaults to None.
282

283
        Raises:
284
            ValueError: Raised if no valid input is provided.
285
        """
286
        if metadata is None:
3✔
287
            metadata = {}
3✔
288
        if dataframe is not None:
3✔
289
            self._dataframe = dataframe
×
290
        elif runs is not None:
3✔
291
            # If runs are provided, we only use the copy tool if also folder is provided.
292
            # In that case, we copy the whole provided base folder tree, and pass the copied
293
            # version to the loader as base folder to look for the runs.
294
            if folder is not None:
×
295
                dataframe, metadata = self.loader.read_dataframe(
×
296
                    folders=cast(str, self.cpy(folder)),
297
                    runs=runs,
298
                    metadata=metadata,
299
                    collect_metadata=collect_metadata,
300
                    **kwds,
301
                )
302
            else:
303
                dataframe, metadata = self.loader.read_dataframe(
×
304
                    runs=runs,
305
                    metadata=metadata,
306
                    collect_metadata=collect_metadata,
307
                    **kwds,
308
                )
309

310
        elif folder is not None:
3✔
311
            dataframe, metadata = self.loader.read_dataframe(
×
312
                folders=cast(str, self.cpy(folder)),
313
                metadata=metadata,
314
                collect_metadata=collect_metadata,
315
                **kwds,
316
            )
317

318
        elif files is not None:
3✔
319
            dataframe, metadata = self.loader.read_dataframe(
3✔
320
                files=cast(List[str], self.cpy(files)),
321
                metadata=metadata,
322
                collect_metadata=collect_metadata,
323
                **kwds,
324
            )
325

326
        else:
327
            raise ValueError(
×
328
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
329
            )
330

331
        self._dataframe = dataframe
3✔
332
        self._files = self.loader.files
3✔
333

334
        for key in metadata:
3✔
335
            self._attributes.add(
×
336
                entry=metadata[key],
337
                name=key,
338
                duplicate_policy="merge",
339
            )
340

341
    # Momentum calibration workflow
342
    # 1. Bin raw detector data for distortion correction
343
    def bin_and_load_momentum_calibration(
3✔
344
        self,
345
        df_partitions: int = 100,
346
        axes: List[str] = None,
347
        bins: List[int] = None,
348
        ranges: Sequence[Tuple[float, float]] = None,
349
        plane: int = 0,
350
        width: int = 5,
351
        apply: bool = False,
352
        **kwds,
353
    ):
354
        """1st step of momentum correction work flow. Function to do an initial binning
355
        of the dataframe loaded to the class, slice a plane from it using an
356
        interactive view, and load it into the momentum corrector class.
357

358
        Args:
359
            df_partitions (int, optional): Number of dataframe partitions to use for
360
                the initial binning. Defaults to 100.
361
            axes (List[str], optional): Axes to bin.
362
                Defaults to config["momentum"]["axes"].
363
            bins (List[int], optional): Bin numbers to use for binning.
364
                Defaults to config["momentum"]["bins"].
365
            ranges (List[Tuple], optional): Ranges to use for binning.
366
                Defaults to config["momentum"]["ranges"].
367
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
368
            width (int, optional): Initial value for the width slider. Defaults to 5.
369
            apply (bool, optional): Option to directly apply the values and select the
370
                slice. Defaults to False.
371
            **kwds: Keyword argument passed to the pre_binning function.
372
        """
373
        self._pre_binned = self.pre_binning(
3✔
374
            df_partitions=df_partitions,
375
            axes=axes,
376
            bins=bins,
377
            ranges=ranges,
378
            **kwds,
379
        )
380

381
        self.mc.load_data(data=self._pre_binned)
3✔
382
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
3✔
383

384
    # 2. Generate the spline warp correction from momentum features.
385
    # Either autoselect features, or input features from view above.
386
    def define_features(
3✔
387
        self,
388
        features: np.ndarray = None,
389
        rotation_symmetry: int = 6,
390
        auto_detect: bool = False,
391
        include_center: bool = True,
392
        apply: bool = False,
393
        **kwds,
394
    ):
395
        """2. Step of the distortion correction workflow: Define feature points in
396
        momentum space. They can be either manually selected using a GUI tool, be
397
        ptovided as list of feature points, or auto-generated using a
398
        feature-detection algorithm.
399

400
        Args:
401
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
402
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
403
                Defaults to 6.
404
            auto_detect (bool, optional): Whether to auto-detect the features.
405
                Defaults to False.
406
            include_center (bool, optional): Option to include a point at the center
407
                in the feature list. Defaults to True.
408
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
409
                MomentumCorrector.feature_select()
410
        """
411
        if auto_detect:  # automatic feature selection
×
412
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
413
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
414
            sigma_radius = kwds.pop(
×
415
                "sigma_radius",
416
                self._config["momentum"]["sigma_radius"],
417
            )
418
            self.mc.feature_extract(
×
419
                sigma=sigma,
420
                fwhm=fwhm,
421
                sigma_radius=sigma_radius,
422
                rotsym=rotation_symmetry,
423
                **kwds,
424
            )
425
            features = self.mc.peaks
×
426

427
        self.mc.feature_select(
×
428
            rotsym=rotation_symmetry,
429
            include_center=include_center,
430
            features=features,
431
            apply=apply,
432
            **kwds,
433
        )
434

435
    # 3. Generate the spline warp correction from momentum features.
436
    # If no features have been selected before, use class defaults.
437
    def generate_splinewarp(
3✔
438
        self,
439
        include_center: bool = True,
440
        **kwds,
441
    ):
442
        """3. Step of the distortion correction workflow: Generate the correction
443
        function restoring the symmetry in the image using a splinewarp algortihm.
444

445
        Args:
446
            include_center (bool, optional): Option to include the position of the
447
                center point in the correction. Defaults to True.
448
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
449
        """
450
        self.mc.spline_warp_estimate(include_center=include_center, **kwds)
×
451

452
        if self.mc.slice is not None:
×
453
            print("Original slice with reference features")
×
454
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
×
455

456
            print("Corrected slice with target features")
×
457
            self.mc.view(
×
458
                image=self.mc.slice_corrected,
459
                annotated=True,
460
                points={"feats": self.mc.ptargs},
461
                backend="bokeh",
462
                crosshair=True,
463
            )
464

465
            print("Original slice with target features")
×
466
            self.mc.view(
×
467
                image=self.mc.slice,
468
                points={"feats": self.mc.ptargs},
469
                annotated=True,
470
                backend="bokeh",
471
            )
472

473
    # 3a. Save spline-warp parameters to config file.
474
    def save_splinewarp(
3✔
475
        self,
476
        filename: str = None,
477
        overwrite: bool = False,
478
    ):
479
        """Save the generated spline-warp parameters to the folder config file.
480

481
        Args:
482
            filename (str, optional): Filename of the config dictionary to save to.
483
                Defaults to "sed_config.yaml" in the current folder.
484
            overwrite (bool, optional): Option to overwrite the present dictionary.
485
                Defaults to False.
486
        """
487
        if filename is None:
×
488
            filename = "sed_config.yaml"
×
489
        points = []
×
490
        try:
×
491
            for point in self.mc.pouter_ord:
×
492
                points.append([float(i) for i in point])
×
493
            if self.mc.pcent:
×
494
                points.append([float(i) for i in self.mc.pcent])
×
495
        except AttributeError as exc:
×
496
            raise AttributeError(
×
497
                "Momentum correction parameters not found, need to generate parameters first!",
498
            ) from exc
499
        config = {
×
500
            "momentum": {
501
                "correction": {"rotation_symmetry": self.mc.rotsym, "feature_points": points},
502
            },
503
        }
504
        save_config(config, filename, overwrite)
×
505

506
    # 4. Pose corrections. Provide interactive interface for correcting
507
    # scaling, shift and rotation
508
    def pose_adjustment(
3✔
509
        self,
510
        scale: float = 1,
511
        xtrans: float = 0,
512
        ytrans: float = 0,
513
        angle: float = 0,
514
        apply: bool = False,
515
        use_correction: bool = True,
516
    ):
517
        """3. step of the distortion correction workflow: Generate an interactive panel
518
        to adjust affine transformations that are applied to the image. Applies first
519
        a scaling, next an x/y translation, and last a rotation around the center of
520
        the image.
521

522
        Args:
523
            scale (float, optional): Initial value of the scaling slider.
524
                Defaults to 1.
525
            xtrans (float, optional): Initial value of the xtrans slider.
526
                Defaults to 0.
527
            ytrans (float, optional): Initial value of the ytrans slider.
528
                Defaults to 0.
529
            angle (float, optional): Initial value of the angle slider.
530
                Defaults to 0.
531
            apply (bool, optional): Option to directly apply the provided
532
                transformations. Defaults to False.
533
            use_correction (bool, option): Whether to use the spline warp correction
534
                or not. Defaults to True.
535
        """
536
        # Generate homomorphy as default if no distortion correction has been applied
537
        if self.mc.slice_corrected is None:
×
538
            if self.mc.slice is None:
×
539
                raise ValueError(
×
540
                    "No slice for corrections and transformations loaded!",
541
                )
542
            self.mc.slice_corrected = self.mc.slice
×
543

544
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
×
545
            # Generate default distortion correction
546
            self.mc.add_features()
×
547
            self.mc.spline_warp_estimate()
×
548

549
        if not use_correction:
×
550
            self.mc.reset_deformation()
×
551

552
        self.mc.pose_adjustment(
×
553
            scale=scale,
554
            xtrans=xtrans,
555
            ytrans=ytrans,
556
            angle=angle,
557
            apply=apply,
558
        )
559

560
    # 5. Apply the momentum correction to the dataframe
561
    def apply_momentum_correction(
3✔
562
        self,
563
        preview: bool = False,
564
    ):
565
        """Applies the distortion correction and pose adjustment (optional)
566
        to the dataframe.
567

568
        Args:
569
            rdeform_field (np.ndarray, optional): Row deformation field.
570
                Defaults to None.
571
            cdeform_field (np.ndarray, optional): Column deformation field.
572
                Defaults to None.
573
            inv_dfield (np.ndarray, optional): Inverse deformation field.
574
                Defaults to None.
575
            preview (bool): Option to preview the first elements of the data frame.
576
        """
577
        if self._dataframe is not None:
×
578
            print("Adding corrected X/Y columns to dataframe:")
×
579
            self._dataframe, metadata = self.mc.apply_corrections(
×
580
                df=self._dataframe,
581
            )
582
            # Add Metadata
583
            self._attributes.add(
×
584
                metadata,
585
                "momentum_correction",
586
                duplicate_policy="merge",
587
            )
588
            if preview:
×
589
                print(self._dataframe.head(10))
×
590
            else:
591
                print(self._dataframe)
×
592

593
    # Momentum calibration work flow
594
    # 1. Calculate momentum calibration
595
    def calibrate_momentum_axes(
3✔
596
        self,
597
        point_a: Union[np.ndarray, List[int]] = None,
598
        point_b: Union[np.ndarray, List[int]] = None,
599
        k_distance: float = None,
600
        k_coord_a: Union[np.ndarray, List[float]] = None,
601
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
602
        equiscale: bool = True,
603
        apply=False,
604
    ):
605
        """1. step of the momentum calibration workflow. Calibrate momentum
606
        axes using either provided pixel coordinates of a high-symmetry point and its
607
        distance to the BZ center, or the k-coordinates of two points in the BZ
608
        (depending on the equiscale option). Opens an interactive panel for selecting
609
        the points.
610

611
        Args:
612
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
613
                point used for momentum calibration.
614
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
615
                second point used for momentum calibration.
616
                Defaults to config["momentum"]["center_pixel"].
617
            k_distance (float, optional): Momentum distance between point a and b.
618
                Needs to be provided if no specific k-koordinates for the two points
619
                are given. Defaults to None.
620
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
621
                of the first point used for calibration. Used if equiscale is False.
622
                Defaults to None.
623
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
624
                of the second point used for calibration. Defaults to [0.0, 0.0].
625
            equiscale (bool, optional): Option to apply different scales to kx and ky.
626
                If True, the distance between points a and b, and the absolute
627
                position of point a are used for defining the scale. If False, the
628
                scale is calculated from the k-positions of both points a and b.
629
                Defaults to True.
630
            apply (bool, optional): Option to directly store the momentum calibration
631
                in the class. Defaults to False.
632
        """
633
        if point_b is None:
×
634
            point_b = self._config["momentum"]["center_pixel"]
×
635

636
        self.mc.select_k_range(
×
637
            point_a=point_a,
638
            point_b=point_b,
639
            k_distance=k_distance,
640
            k_coord_a=k_coord_a,
641
            k_coord_b=k_coord_b,
642
            equiscale=equiscale,
643
            apply=apply,
644
        )
645

646
    # 1a. Save momentum calibration parameters to config file.
647
    def save_momentum_calibration(
3✔
648
        self,
649
        filename: str = None,
650
        overwrite: bool = False,
651
    ):
652
        """Save the generated momentum calibration parameters to the folder config file.
653

654
        Args:
655
            filename (str, optional): Filename of the config dictionary to save to.
656
                Defaults to "sed_config.yaml" in the current folder.
657
            overwrite (bool, optional): Option to overwrite the present dictionary.
658
                Defaults to False.
659
        """
660
        if filename is None:
×
661
            filename = "sed_config.yaml"
×
662
        calibration = {}
×
663
        try:
×
664
            for key in [
×
665
                "kx_scale",
666
                "ky_scale",
667
                "x_center",
668
                "y_center",
669
                "rstart",
670
                "cstart",
671
                "rstep",
672
                "cstep",
673
            ]:
674
                calibration[key] = float(self.mc.calibration[key])
×
675
        except KeyError as exc:
×
676
            raise KeyError(
×
677
                "Momentum calibration parameters not found, need to generate parameters first!",
678
            ) from exc
679

680
        config = {"momentum": {"calibration": calibration}}
×
681
        save_config(config, filename, overwrite)
×
682

683
    # 2. Apply correction and calibration to the dataframe
684
    def apply_momentum_calibration(
3✔
685
        self,
686
        calibration: dict = None,
687
        preview: bool = False,
688
    ):
689
        """2. step of the momentum calibration work flow: Apply the momentum
690
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
691
        these are used.
692

693
        Args:
694
            calibration (dict, optional): Optional dictionary with calibration data to
695
                use. Defaults to None.
696
            preview (bool): Option to preview the first elements of the data frame.
697
        """
698
        if self._dataframe is not None:
×
699

700
            print("Adding kx/ky columns to dataframe:")
×
701
            self._dataframe, metadata = self.mc.append_k_axis(
×
702
                df=self._dataframe,
703
                calibration=calibration,
704
            )
705

706
            # Add Metadata
707
            self._attributes.add(
×
708
                metadata,
709
                "momentum_calibration",
710
                duplicate_policy="merge",
711
            )
712
            if preview:
×
713
                print(self._dataframe.head(10))
×
714
            else:
715
                print(self._dataframe)
×
716

717
    # Energy correction workflow
718
    # 1. Adjust the energy correction parameters
719
    def adjust_energy_correction(
3✔
720
        self,
721
        correction_type: str = None,
722
        amplitude: float = None,
723
        center: Tuple[float, float] = None,
724
        apply=False,
725
        **kwds,
726
    ):
727
        """1. step of the energy crrection workflow: Opens an interactive plot to
728
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
729
        they are not present yet.
730

731
        Args:
732
            correction_type (str, optional): Type of correction to apply to the TOF
733
                axis. Valid values are:
734

735
                - 'spherical'
736
                - 'Lorentzian'
737
                - 'Gaussian'
738
                - 'Lorentzian_asymmetric'
739

740
                Defaults to config["energy"]["correction_type"].
741
            amplitude (float, optional): Amplitude of the correction.
742
                Defaults to config["energy"]["correction"]["amplitude"].
743
            center (Tuple[float, float], optional): Center X/Y coordinates for the
744
                correction. Defaults to config["energy"]["correction"]["center"].
745
            apply (bool, optional): Option to directly apply the provided or default
746
                correction parameters. Defaults to False.
747
        """
748
        if self._pre_binned is None:
×
749
            print(
×
750
                "Pre-binned data not present, binning using defaults from config...",
751
            )
752
            self._pre_binned = self.pre_binning()
×
753

754
        self.ec.adjust_energy_correction(
×
755
            self._pre_binned,
756
            correction_type=correction_type,
757
            amplitude=amplitude,
758
            center=center,
759
            apply=apply,
760
            **kwds,
761
        )
762

763
    # 1a. Save energy correction parameters to config file.
764
    def save_energy_correction(
3✔
765
        self,
766
        filename: str = None,
767
        overwrite: bool = False,
768
    ):
769
        """Save the generated energy correction parameters to the folder config file.
770

771
        Args:
772
            filename (str, optional): Filename of the config dictionary to save to.
773
                Defaults to "sed_config.yaml" in the current folder.
774
            overwrite (bool, optional): Option to overwrite the present dictionary.
775
                Defaults to False.
776
        """
777
        if filename is None:
×
778
            filename = "sed_config.yaml"
×
779
        correction = {}
×
780
        try:
×
781
            for key in self.ec.correction.keys():
×
782
                if key == "correction_type":
×
783
                    correction[key] = self.ec.correction[key]
×
784
                elif key == "center":
×
785
                    correction[key] = [float(i) for i in self.ec.correction[key]]
×
786
                else:
787
                    correction[key] = float(self.ec.correction[key])
×
788
        except AttributeError as exc:
×
789
            raise AttributeError(
×
790
                "Energy correction parameters not found, need to generate parameters first!",
791
            ) from exc
792

793
        config = {"energy": {"correction": correction}}
×
794
        save_config(config, filename, overwrite)
×
795

796
    # 2. Apply energy correction to dataframe
797
    def apply_energy_correction(
3✔
798
        self,
799
        correction: dict = None,
800
        preview: bool = False,
801
        **kwds,
802
    ):
803
        """2. step of the energy correction workflow: Apply the enery correction
804
        parameters stored in the class to the dataframe.
805

806
        Args:
807
            correction (dict, optional): Dictionary containing the correction
808
                parameters. Defaults to config["energy"]["calibration"].
809
            preview (bool): Option to preview the first elements of the data frame.
810
            **kwds:
811
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
812
            preview (bool): Option to preview the first elements of the data frame.
813
            **kwds:
814
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
815
        """
816
        if self._dataframe is not None:
×
817
            print("Applying energy correction to dataframe...")
×
818
            self._dataframe, metadata = self.ec.apply_energy_correction(
×
819
                df=self._dataframe,
820
                correction=correction,
821
                **kwds,
822
            )
823

824
            # Add Metadata
825
            self._attributes.add(
×
826
                metadata,
827
                "energy_correction",
828
            )
829
            if preview:
×
830
                print(self._dataframe.head(10))
×
831
            else:
832
                print(self._dataframe)
×
833

834
    # Energy calibrator workflow
835
    # 1. Load and normalize data
836
    def load_bias_series(
3✔
837
        self,
838
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
839
        data_files: List[str] = None,
840
        axes: List[str] = None,
841
        bins: List = None,
842
        ranges: Sequence[Tuple[float, float]] = None,
843
        biases: np.ndarray = None,
844
        bias_key: str = None,
845
        normalize: bool = None,
846
        span: int = None,
847
        order: int = None,
848
    ):
849
        """1. step of the energy calibration workflow: Load and bin data from
850
        single-event files, or load binned bias/TOF traces.
851

852
        Args:
853
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
854
                Binned data If provided as DataArray, Needs to contain dimensions
855
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
856
                provided as tuple, needs to contain elements tof, biases, traces.
857
            data_files (List[str], optional): list of file paths to bin
858
            axes (List[str], optional): bin axes.
859
                Defaults to config["dataframe"]["tof_column"].
860
            bins (List, optional): number of bins.
861
                Defaults to config["energy"]["bins"].
862
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
863
                Defaults to config["energy"]["ranges"].
864
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
865
                voltages are extracted from the data files.
866
            bias_key (str, optional): hdf5 path where bias values are stored.
867
                Defaults to config["energy"]["bias_key"].
868
            normalize (bool, optional): Option to normalize traces.
869
                Defaults to config["energy"]["normalize"].
870
            span (int, optional): span smoothing parameters of the LOESS method
871
                (see ``scipy.signal.savgol_filter()``).
872
                Defaults to config["energy"]["normalize_span"].
873
            order (int, optional): order smoothing parameters of the LOESS method
874
                (see ``scipy.signal.savgol_filter()``).
875
                Defaults to config["energy"]["normalize_order"].
876
        """
877
        if binned_data is not None:
×
878
            if isinstance(binned_data, xr.DataArray):
×
879
                if (
×
880
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
881
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
882
                ):
883
                    raise ValueError(
×
884
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
885
                        f"'{self._config['dataframe']['tof_column']}' and "
886
                        f"'{self._config['dataframe']['bias_column']}'!.",
887
                    )
888
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
×
889
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
×
890
                traces = binned_data.values[:, :]
×
891
            else:
892
                try:
×
893
                    (tof, biases, traces) = binned_data
×
894
                except ValueError as exc:
×
895
                    raise ValueError(
×
896
                        "If binned_data is provided as tuple, it needs to contain "
897
                        "(tof, biases, traces)!",
898
                    ) from exc
899
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
×
900

901
        elif data_files is not None:
×
902

903
            self.ec.bin_data(
×
904
                data_files=cast(List[str], self.cpy(data_files)),
905
                axes=axes,
906
                bins=bins,
907
                ranges=ranges,
908
                biases=biases,
909
                bias_key=bias_key,
910
            )
911

912
        else:
913
            raise ValueError("Either binned_data or data_files needs to be provided!")
×
914

915
        if (normalize is not None and normalize is True) or (
×
916
            normalize is None and self._config["energy"]["normalize"]
917
        ):
918
            if span is None:
×
919
                span = self._config["energy"]["normalize_span"]
×
920
            if order is None:
×
921
                order = self._config["energy"]["normalize_order"]
×
922
            self.ec.normalize(smooth=True, span=span, order=order)
×
923
        self.ec.view(
×
924
            traces=self.ec.traces_normed,
925
            xaxis=self.ec.tof,
926
            backend="bokeh",
927
        )
928

929
    # 2. extract ranges and get peak positions
930
    def find_bias_peaks(
3✔
931
        self,
932
        ranges: Union[List[Tuple], Tuple],
933
        ref_id: int = 0,
934
        infer_others: bool = True,
935
        mode: str = "replace",
936
        radius: int = None,
937
        peak_window: int = None,
938
        apply: bool = False,
939
    ):
940
        """2. step of the energy calibration workflow: Find a peak within a given range
941
        for the indicated reference trace, and tries to find the same peak for all
942
        other traces. Uses fast_dtw to align curves, which might not be too good if the
943
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
944
        middle of the set, and don't choose the range too narrow around the peak.
945
        Alternatively, a list of ranges for all traces can be provided.
946

947
        Args:
948
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
949
                Alternatively, a list of ranges for all traces can be given.
950
            refid (int, optional): The id of the trace the range refers to.
951
                Defaults to 0.
952
            infer_others (bool, optional): Whether to determine the range for the other
953
                traces. Defaults to True.
954
            mode (str, optional): Whether to "add" or "replace" existing ranges.
955
                Defaults to "replace".
956
            radius (int, optional): Radius parameter for fast_dtw.
957
                Defaults to config["energy"]["fastdtw_radius"].
958
            peak_window (int, optional): Peak_window parameter for the peak detection
959
                algorthm. amount of points that have to have to behave monotoneously
960
                around a peak. Defaults to config["energy"]["peak_window"].
961
            apply (bool, optional): Option to directly apply the provided parameters.
962
                Defaults to False.
963
        """
964
        if radius is None:
×
965
            radius = self._config["energy"]["fastdtw_radius"]
×
966
        if peak_window is None:
×
967
            peak_window = self._config["energy"]["peak_window"]
×
968
        if not infer_others:
×
969
            self.ec.add_ranges(
×
970
                ranges=ranges,
971
                ref_id=ref_id,
972
                infer_others=infer_others,
973
                mode=mode,
974
                radius=radius,
975
            )
976
            print(self.ec.featranges)
×
977
            try:
×
978
                self.ec.feature_extract(peak_window=peak_window)
×
979
                self.ec.view(
×
980
                    traces=self.ec.traces_normed,
981
                    segs=self.ec.featranges,
982
                    xaxis=self.ec.tof,
983
                    peaks=self.ec.peaks,
984
                    backend="bokeh",
985
                )
986
            except IndexError:
×
987
                print("Could not determine all peaks!")
×
988
                raise
×
989
        else:
990
            # New adjustment tool
991
            assert isinstance(ranges, tuple)
×
992
            self.ec.adjust_ranges(
×
993
                ranges=ranges,
994
                ref_id=ref_id,
995
                traces=self.ec.traces_normed,
996
                infer_others=infer_others,
997
                radius=radius,
998
                peak_window=peak_window,
999
                apply=apply,
1000
            )
1001

1002
    # 3. Fit the energy calibration relation
1003
    def calibrate_energy_axis(
3✔
1004
        self,
1005
        ref_id: int,
1006
        ref_energy: float,
1007
        method: str = None,
1008
        energy_scale: str = None,
1009
        **kwds,
1010
    ):
1011
        """3. Step of the energy calibration workflow: Calculate the calibration
1012
        function for the energy axis, and apply it to the dataframe. Two
1013
        approximations are implemented, a (normally 3rd order) polynomial
1014
        approximation, and a d^2/(t-t0)^2 relation.
1015

1016
        Args:
1017
            ref_id (int): id of the trace at the bias where the reference energy is
1018
                given.
1019
            ref_energy (float): Absolute energy of the detected feature at the bias
1020
                of ref_id
1021
            method (str, optional): Method for determining the energy calibration.
1022

1023
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1024
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1025

1026
                Defaults to config["energy"]["calibration_method"]
1027
            energy_scale (str, optional): Direction of increasing energy scale.
1028

1029
                - **'kinetic'**: increasing energy with decreasing TOF.
1030
                - **'binding'**: increasing energy with increasing TOF.
1031

1032
                Defaults to config["energy"]["energy_scale"]
1033
        """
1034
        if method is None:
×
1035
            method = self._config["energy"]["calibration_method"]
×
1036

1037
        if energy_scale is None:
×
1038
            energy_scale = self._config["energy"]["energy_scale"]
×
1039

1040
        self.ec.calibrate(
×
1041
            ref_id=ref_id,
1042
            ref_energy=ref_energy,
1043
            method=method,
1044
            energy_scale=energy_scale,
1045
            **kwds,
1046
        )
1047
        print("Quality of Calibration:")
×
1048
        self.ec.view(
×
1049
            traces=self.ec.traces_normed,
1050
            xaxis=self.ec.calibration["axis"],
1051
            align=True,
1052
            energy_scale=energy_scale,
1053
            backend="bokeh",
1054
        )
1055
        print("E/TOF relationship:")
×
1056
        self.ec.view(
×
1057
            traces=self.ec.calibration["axis"][None, :],
1058
            xaxis=self.ec.tof,
1059
            backend="matplotlib",
1060
            show_legend=False,
1061
        )
1062
        if energy_scale == "kinetic":
×
1063
            plt.scatter(
×
1064
                self.ec.peaks[:, 0],
1065
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1066
                s=50,
1067
                c="k",
1068
            )
1069
        elif energy_scale == "binding":
×
1070
            plt.scatter(
×
1071
                self.ec.peaks[:, 0],
1072
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1073
                s=50,
1074
                c="k",
1075
            )
1076
        else:
1077
            raise ValueError(
×
1078
                'energy_scale needs to be either "binding" or "kinetic"',
1079
                f", got {energy_scale}.",
1080
            )
1081
        plt.xlabel("Time-of-flight", fontsize=15)
×
1082
        plt.ylabel("Energy (eV)", fontsize=15)
×
1083
        plt.show()
×
1084

1085
    # 3a. Save energy calibration parameters to config file.
1086
    def save_energy_calibration(
3✔
1087
        self,
1088
        filename: str = None,
1089
        overwrite: bool = False,
1090
    ):
1091
        """Save the generated energy calibration parameters to the folder config file.
1092

1093
        Args:
1094
            filename (str, optional): Filename of the config dictionary to save to.
1095
                Defaults to "sed_config.yaml" in the current folder.
1096
            overwrite (bool, optional): Option to overwrite the present dictionary.
1097
                Defaults to False.
1098
        """
1099
        if filename is None:
×
1100
            filename = "sed_config.yaml"
×
1101
        calibration = {}
×
1102
        try:
×
1103
            for (key, value) in self.ec.calibration.items():
×
1104
                if key in ["axis", "refid"]:
×
1105
                    continue
×
1106
                if key == "energy_scale":
×
1107
                    calibration[key] = value
×
1108
                else:
1109
                    calibration[key] = float(value)
×
1110
        except AttributeError as exc:
×
1111
            raise AttributeError(
×
1112
                "Energy calibration parameters not found, need to generate parameters first!",
1113
            ) from exc
1114

1115
        config = {"energy": {"calibration": calibration}}
×
1116
        save_config(config, filename, overwrite)
×
1117

1118
    # 4. Apply energy calibration to the dataframe
1119
    def append_energy_axis(
3✔
1120
        self,
1121
        calibration: dict = None,
1122
        preview: bool = False,
1123
        **kwds,
1124
    ):
1125
        """4. step of the energy calibration workflow: Apply the calibration function
1126
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1127
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1128
        can be provided.
1129

1130
        Args:
1131
            calibration (dict, optional): Calibration dict containing calibration
1132
                parameters. Overrides calibration from class or config.
1133
                Defaults to None.
1134
            preview (bool): Option to preview the first elements of the data frame.
1135
            **kwds:
1136
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1137
        """
1138
        if self._dataframe is not None:
×
1139
            print("Adding energy column to dataframe:")
×
1140
            self._dataframe, metadata = self.ec.append_energy_axis(
×
1141
                df=self._dataframe,
1142
                calibration=calibration,
1143
                **kwds,
1144
            )
1145

1146
            # Add Metadata
1147
            self._attributes.add(
×
1148
                metadata,
1149
                "energy_calibration",
1150
                duplicate_policy="merge",
1151
            )
1152
            if preview:
×
1153
                print(self._dataframe.head(10))
×
1154
            else:
1155
                print(self._dataframe)
×
1156

1157
    # Delay calibration function
1158
    def calibrate_delay_axis(
3✔
1159
        self,
1160
        delay_range: Tuple[float, float] = None,
1161
        datafile: str = None,
1162
        preview: bool = False,
1163
        **kwds,
1164
    ):
1165
        """Append delay column to dataframe. Either provide delay ranges, or read
1166
        them from a file.
1167

1168
        Args:
1169
            delay_range (Tuple[float, float], optional): The scanned delay range in
1170
                picoseconds. Defaults to None.
1171
            datafile (str, optional): The file from which to read the delay ranges.
1172
                Defaults to None.
1173
            preview (bool): Option to preview the first elements of the data frame.
1174
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1175
        """
1176
        if self._dataframe is not None:
×
1177
            print("Adding delay column to dataframe:")
×
1178

1179
            if delay_range is not None:
×
1180
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1181
                    self._dataframe,
1182
                    delay_range=delay_range,
1183
                    **kwds,
1184
                )
1185
            else:
1186
                if datafile is None:
×
1187
                    try:
×
1188
                        datafile = self._files[0]
×
1189
                    except IndexError:
×
1190
                        print(
×
1191
                            "No datafile available, specify either",
1192
                            " 'datafile' or 'delay_range'",
1193
                        )
1194
                        raise
×
1195

1196
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1197
                    self._dataframe,
1198
                    datafile=datafile,
1199
                    **kwds,
1200
                )
1201

1202
            # Add Metadata
1203
            self._attributes.add(
×
1204
                metadata,
1205
                "delay_calibration",
1206
                duplicate_policy="merge",
1207
            )
1208
            if preview:
×
1209
                print(self._dataframe.head(10))
×
1210
            else:
1211
                print(self._dataframe)
×
1212

1213
    def add_jitter(self, cols: Sequence[str] = None):
3✔
1214
        """Add jitter to the selected dataframe columns.
1215

1216
        Args:
1217
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1218
                Defaults to config["dataframe"]["jitter_cols"].
1219
        """
1220
        if cols is None:
×
1221
            cols = self._config["dataframe"].get(
×
1222
                "jitter_cols",
1223
                self._dataframe.columns,
1224
            )  # jitter all columns
1225

1226
        self._dataframe = self._dataframe.map_partitions(
×
1227
            apply_jitter,
1228
            cols=cols,
1229
            cols_jittered=cols,
1230
        )
1231
        metadata = []
×
1232
        for col in cols:
×
1233
            metadata.append(col)
×
1234
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
×
1235

1236
    def pre_binning(
3✔
1237
        self,
1238
        df_partitions: int = 100,
1239
        axes: List[str] = None,
1240
        bins: List[int] = None,
1241
        ranges: Sequence[Tuple[float, float]] = None,
1242
        **kwds,
1243
    ) -> xr.DataArray:
1244
        """Function to do an initial binning of the dataframe loaded to the class.
1245

1246
        Args:
1247
            df_partitions (int, optional): Number of dataframe partitions to use for
1248
                the initial binning. Defaults to 100.
1249
            axes (List[str], optional): Axes to bin.
1250
                Defaults to config["momentum"]["axes"].
1251
            bins (List[int], optional): Bin numbers to use for binning.
1252
                Defaults to config["momentum"]["bins"].
1253
            ranges (List[Tuple], optional): Ranges to use for binning.
1254
                Defaults to config["momentum"]["ranges"].
1255
            **kwds: Keyword argument passed to ``compute``.
1256

1257
        Returns:
1258
            xr.DataArray: pre-binned data-array.
1259
        """
1260
        if axes is None:
3✔
1261
            axes = self._config["momentum"]["axes"]
3✔
1262
        for loc, axis in enumerate(axes):
3✔
1263
            if axis.startswith("@"):
3✔
1264
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
3✔
1265

1266
        if bins is None:
3✔
1267
            bins = self._config["momentum"]["bins"]
3✔
1268
        if ranges is None:
3✔
1269
            ranges_ = list(self._config["momentum"]["ranges"])
3✔
1270
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
3✔
1271
                self._config["dataframe"]["tof_binning"] - 1
1272
            )
1273
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
3✔
1274

1275
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1276

1277
        return self.compute(
3✔
1278
            bins=bins,
1279
            axes=axes,
1280
            ranges=ranges,
1281
            df_partitions=df_partitions,
1282
            **kwds,
1283
        )
1284

1285
    def compute(
3✔
1286
        self,
1287
        bins: Union[
1288
            int,
1289
            dict,
1290
            tuple,
1291
            List[int],
1292
            List[np.ndarray],
1293
            List[tuple],
1294
        ] = 100,
1295
        axes: Union[str, Sequence[str]] = None,
1296
        ranges: Sequence[Tuple[float, float]] = None,
1297
        **kwds,
1298
    ) -> xr.DataArray:
1299
        """Compute the histogram along the given dimensions.
1300

1301
        Args:
1302
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1303
                Definition of the bins. Can be any of the following cases:
1304

1305
                - an integer describing the number of bins in on all dimensions
1306
                - a tuple of 3 numbers describing start, end and step of the binning
1307
                  range
1308
                - a np.arrays defining the binning edges
1309
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1310
                - a dictionary made of the axes as keys and any of the above as values.
1311

1312
                This takes priority over the axes and range arguments. Defaults to 100.
1313
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1314
                on which to calculate the histogram. The order will be the order of the
1315
                dimensions in the resulting array. Defaults to None.
1316
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1317
                the start and end point of the binning range. Defaults to None.
1318
            **kwds: Keyword arguments:
1319

1320
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1321
                  ``bin_dataframe`` for details. Defaults to
1322
                  config["binning"]["hist_mode"].
1323
                - **mode**: Defines how the results from each partition are combined.
1324
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1325
                  Defaults to config["binning"]["mode"].
1326
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1327
                  config["binning"]["pbar"].
1328
                - **n_cores**: Number of CPU cores to use for parallelization.
1329
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1330
                - **threads_per_worker**: Limit the number of threads that
1331
                  multiprocessing can spawn per binning thread. Defaults to
1332
                  config["binning"]["threads_per_worker"].
1333
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1334
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1335
                  config["binning"]["threadpool_API"].
1336
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1337
                  partitions.
1338

1339
                Additional kwds are passed to ``bin_dataframe``.
1340

1341
        Raises:
1342
            AssertError: Rises when no dataframe has been loaded.
1343

1344
        Returns:
1345
            xr.DataArray: The result of the n-dimensional binning represented in an
1346
            xarray object, combining the data with the axes.
1347
        """
1348
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
3✔
1349

1350
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
3✔
1351
        mode = kwds.pop("mode", self._config["binning"]["mode"])
3✔
1352
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
3✔
1353
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
3✔
1354
        threads_per_worker = kwds.pop(
3✔
1355
            "threads_per_worker",
1356
            self._config["binning"]["threads_per_worker"],
1357
        )
1358
        threadpool_api = kwds.pop(
3✔
1359
            "threadpool_API",
1360
            self._config["binning"]["threadpool_API"],
1361
        )
1362
        df_partitions = kwds.pop("df_partitions", None)
3✔
1363
        if df_partitions is not None:
3✔
1364
            dataframe = self._dataframe.partitions[
3✔
1365
                0 : min(df_partitions, self._dataframe.npartitions)
1366
            ]
1367
        else:
1368
            dataframe = self._dataframe
×
1369

1370
        self._binned = bin_dataframe(
3✔
1371
            df=dataframe,
1372
            bins=bins,
1373
            axes=axes,
1374
            ranges=ranges,
1375
            hist_mode=hist_mode,
1376
            mode=mode,
1377
            pbar=pbar,
1378
            n_cores=num_cores,
1379
            threads_per_worker=threads_per_worker,
1380
            threadpool_api=threadpool_api,
1381
            **kwds,
1382
        )
1383

1384
        for dim in self._binned.dims:
3✔
1385
            try:
3✔
1386
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
3✔
1387
            except KeyError:
×
1388
                pass
×
1389

1390
        self._binned.attrs["units"] = "counts"
3✔
1391
        self._binned.attrs["long_name"] = "photoelectron counts"
3✔
1392
        self._binned.attrs["metadata"] = self._attributes.metadata
3✔
1393

1394
        return self._binned
3✔
1395

1396
    def view_event_histogram(
3✔
1397
        self,
1398
        dfpid: int,
1399
        ncol: int = 2,
1400
        bins: Sequence[int] = None,
1401
        axes: Sequence[str] = None,
1402
        ranges: Sequence[Tuple[float, float]] = None,
1403
        backend: str = "bokeh",
1404
        legend: bool = True,
1405
        histkwds: dict = None,
1406
        legkwds: dict = None,
1407
        **kwds,
1408
    ):
1409
        """Plot individual histograms of specified dimensions (axes) from a substituent
1410
        dataframe partition.
1411

1412
        Args:
1413
            dfpid (int): Number of the data frame partition to look at.
1414
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1415
            bins (Sequence[int], optional): Number of bins to use for the speicified
1416
                axes. Defaults to config["histogram"]["bins"].
1417
            axes (Sequence[str], optional): Names of the axes to display.
1418
                Defaults to config["histogram"]["axes"].
1419
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1420
                specified axes. Defaults toconfig["histogram"]["ranges"].
1421
            backend (str, optional): Backend of the plotting library
1422
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1423
            legend (bool, optional): Option to include a legend in the histogram plots.
1424
                Defaults to True.
1425
            histkwds (dict, optional): Keyword arguments for histograms
1426
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1427
            legkwds (dict, optional): Keyword arguments for legend
1428
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1429
            **kwds: Extra keyword arguments passed to
1430
                ``sed.diagnostics.grid_histogram()``.
1431

1432
        Raises:
1433
            TypeError: Raises when the input values are not of the correct type.
1434
        """
1435
        if bins is None:
×
1436
            bins = self._config["histogram"]["bins"]
×
1437
        if axes is None:
×
1438
            axes = self._config["histogram"]["axes"]
×
1439
        axes = list(axes)
×
1440
        for loc, axis in enumerate(axes):
×
1441
            if axis.startswith("@"):
×
1442
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
×
1443
        if ranges is None:
×
1444
            ranges = list(self._config["histogram"]["ranges"])
×
1445
            ranges[2] = np.asarray(ranges[2]) / 2 ** (self._config["dataframe"]["tof_binning"] - 1)
×
1446
            ranges[3] = np.asarray(ranges[3]) / 2 ** (self._config["dataframe"]["adc_binning"] - 1)
×
1447

1448
        input_types = map(type, [axes, bins, ranges])
×
1449
        allowed_types = [list, tuple]
×
1450

1451
        df = self._dataframe
×
1452

1453
        if not set(input_types).issubset(allowed_types):
×
1454
            raise TypeError(
×
1455
                "Inputs of axes, bins, ranges need to be list or tuple!",
1456
            )
1457

1458
        # Read out the values for the specified groups
1459
        group_dict_dd = {}
×
1460
        dfpart = df.get_partition(dfpid)
×
1461
        cols = dfpart.columns
×
1462
        for ax in axes:
×
1463
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
×
1464
        group_dict = ddf.compute(group_dict_dd)[0]
×
1465

1466
        # Plot multiple histograms in a grid
1467
        grid_histogram(
×
1468
            group_dict,
1469
            ncol=ncol,
1470
            rvs=axes,
1471
            rvbins=bins,
1472
            rvranges=ranges,
1473
            backend=backend,
1474
            legend=legend,
1475
            histkwds=histkwds,
1476
            legkwds=legkwds,
1477
            **kwds,
1478
        )
1479

1480
    def save(
3✔
1481
        self,
1482
        faddr: str,
1483
        **kwds,
1484
    ):
1485
        """Saves the binned data to the provided path and filename.
1486

1487
        Args:
1488
            faddr (str): Path and name of the file to write. Its extension determines
1489
                the file type to write. Valid file types are:
1490

1491
                - "*.tiff", "*.tif": Saves a TIFF stack.
1492
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1493
                - "*.nxs", "*.nexus": Saves a NeXus file.
1494

1495
            **kwds: Keyword argumens, which are passed to the writer functions:
1496
                For TIFF writing:
1497

1498
                - **alias_dict**: Dictionary of dimension aliases to use.
1499

1500
                For HDF5 writing:
1501

1502
                - **mode**: hdf5 read/write mode. Defaults to "w".
1503

1504
                For NeXus:
1505

1506
                - **reader**: Name of the nexustools reader to use.
1507
                  Defaults to config["nexus"]["reader"]
1508
                - **definiton**: NeXus application definition to use for saving.
1509
                  Must be supported by the used ``reader``. Defaults to
1510
                  config["nexus"]["definition"]
1511
                - **input_files**: A list of input files to pass to the reader.
1512
                  Defaults to config["nexus"]["input_files"]
1513
        """
1514
        if self._binned is None:
×
1515
            raise NameError("Need to bin data first!")
×
1516

1517
        extension = pathlib.Path(faddr).suffix
×
1518

1519
        if extension in (".tif", ".tiff"):
×
1520
            to_tiff(
×
1521
                data=self._binned,
1522
                faddr=faddr,
1523
                **kwds,
1524
            )
1525
        elif extension in (".h5", ".hdf5"):
×
1526
            to_h5(
×
1527
                data=self._binned,
1528
                faddr=faddr,
1529
                **kwds,
1530
            )
1531
        elif extension in (".nxs", ".nexus"):
×
1532
            try:
×
1533
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
×
1534
                definition = kwds.pop(
×
1535
                    "definition",
1536
                    self._config["nexus"]["definition"],
1537
                )
1538
                input_files = kwds.pop(
×
1539
                    "input_files",
1540
                    self._config["nexus"]["input_files"],
1541
                )
1542
            except KeyError as exc:
×
1543
                raise ValueError(
×
1544
                    "The nexus reader, definition and input files need to be provide!",
1545
                ) from exc
1546

1547
            if isinstance(input_files, str):
×
1548
                input_files = [input_files]
×
1549

1550
            to_nexus(
×
1551
                data=self._binned,
1552
                faddr=faddr,
1553
                reader=reader,
1554
                definition=definition,
1555
                input_files=input_files,
1556
                **kwds,
1557
            )
1558

1559
        else:
1560
            raise NotImplementedError(
×
1561
                f"Unrecognized file format: {extension}.",
1562
            )
1563

1564
    def add_dimension(self, name: str, axis_range: Tuple):
3✔
1565
        """Add a dimension axis.
1566

1567
        Args:
1568
            name (str): name of the axis
1569
            axis_range (Tuple): range for the axis.
1570

1571
        Raises:
1572
            ValueError: Raised if an axis with that name already exists.
1573
        """
1574
        if name in self._coordinates:
×
1575
            raise ValueError(f"Axis {name} already exists")
×
1576

1577
        self.axis[name] = self.make_axis(axis_range)
×
1578

1579
    def make_axis(self, axis_range: Tuple) -> np.ndarray:
3✔
1580
        """Function to make an axis.
1581

1582
        Args:
1583
            axis_range (Tuple): range for the new axis.
1584
        """
1585

1586
        # TODO: What shall this function do?
1587
        return np.arange(*axis_range)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc