• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 5802086559

pending completion
5802086559

push

github

rettigl
Merge branch 'mpes-tweaks' into mpes-merged

2 of 2 new or added lines in 1 file covered. (100.0%)

3070 of 4146 relevant lines covered (74.05%)

0.74 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

34.91
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.calibrator import DelayCalibrator
1✔
22
from sed.calibrator import EnergyCalibrator
1✔
23
from sed.calibrator import MomentumCorrector
1✔
24
from sed.core.config import parse_config
1✔
25
from sed.core.config import save_config
1✔
26
from sed.core.dfops import apply_jitter
1✔
27
from sed.core.metadata import MetaHandler
1✔
28
from sed.diagnostics import grid_histogram
1✔
29
from sed.io import to_h5
1✔
30
from sed.io import to_nexus
1✔
31
from sed.io import to_tiff
1✔
32
from sed.loader import CopyTool
1✔
33
from sed.loader import get_loader
1✔
34

35
N_CPU = psutil.cpu_count()
1✔
36

37

38
class SedProcessor:
1✔
39
    """Processor class of sed. Contains wrapper functions defining a work flow for data
40
    correction, calibration and binning.
41

42
    Args:
43
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
44
        config (Union[dict, str], optional): Config dictionary or config file name.
45
            Defaults to None.
46
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
47
            into the class. Defaults to None.
48
        files (List[str], optional): List of files to pass to the loader defined in
49
            the config. Defaults to None.
50
        folder (str, optional): Folder containing files to pass to the loader
51
            defined in the config. Defaults to None.
52
        collect_metadata (bool): Option to collect metadata from files.
53
            Defaults to False.
54
        **kwds: Keyword arguments passed to the reader.
55
    """
56

57
    def __init__(
1✔
58
        self,
59
        metadata: dict = None,
60
        config: Union[dict, str] = None,
61
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
62
        files: List[str] = None,
63
        folder: str = None,
64
        runs: Sequence[str] = None,
65
        collect_metadata: bool = False,
66
        **kwds,
67
    ):
68
        """Processor class of sed. Contains wrapper functions defining a work flow
69
        for data correction, calibration, and binning.
70

71
        Args:
72
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
73
            config (Union[dict, str], optional): Config dictionary or config file name.
74
                Defaults to None.
75
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
76
                into the class. Defaults to None.
77
            files (List[str], optional): List of files to pass to the loader defined in
78
                the config. Defaults to None.
79
            folder (str, optional): Folder containing files to pass to the loader
80
                defined in the config. Defaults to None.
81
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
82
                defined in the config. Defaults to None.
83
            collect_metadata (bool): Option to collect metadata from files.
84
                Defaults to False.
85
            **kwds: Keyword arguments passed to the reader.
86
        """
87
        self._config = parse_config(config, **kwds)
1✔
88
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
89
        if num_cores >= N_CPU:
1✔
90
            num_cores = N_CPU - 1
1✔
91
        self._config["binning"]["num_cores"] = num_cores
1✔
92

93
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
94
        self._files: List[str] = []
1✔
95

96
        self._binned: xr.DataArray = None
1✔
97
        self._pre_binned: xr.DataArray = None
1✔
98

99
        self._dimensions: List[str] = []
1✔
100
        self._coordinates: Dict[Any, Any] = {}
1✔
101
        self.axis: Dict[Any, Any] = {}
1✔
102
        self._attributes = MetaHandler(meta=metadata)
1✔
103

104
        loader_name = self._config["core"]["loader"]
1✔
105
        self.loader = get_loader(
1✔
106
            loader_name=loader_name,
107
            config=self._config,
108
        )
109

110
        self.ec = EnergyCalibrator(
1✔
111
            loader=self.loader,
112
            config=self._config,
113
        )
114

115
        self.mc = MomentumCorrector(
1✔
116
            config=self._config,
117
        )
118

119
        self.dc = DelayCalibrator(
1✔
120
            config=self._config,
121
        )
122

123
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
124
            "use_copy_tool",
125
            False,
126
        )
127
        if self.use_copy_tool:
1✔
128
            try:
×
129
                self.ct = CopyTool(
×
130
                    source=self._config["core"]["copy_tool_source"],
131
                    dest=self._config["core"]["copy_tool_dest"],
132
                    **self._config["core"].get("copy_tool_kwds", {}),
133
                )
134
            except KeyError:
×
135
                self.use_copy_tool = False
×
136

137
        # Load data if provided:
138
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
139
            self.load(
×
140
                dataframe=dataframe,
141
                metadata=metadata,
142
                files=files,
143
                folder=folder,
144
                runs=runs,
145
                collect_metadata=collect_metadata,
146
                **kwds,
147
            )
148

149
    def __repr__(self):
1✔
150
        if self._dataframe is None:
×
151
            df_str = "Data Frame: No Data loaded"
×
152
        else:
153
            df_str = self._dataframe.__repr__()
×
154
        coordinates_str = f"Coordinates: {self._coordinates}"
×
155
        dimensions_str = f"Dimensions: {self._dimensions}"
×
156
        pretty_str = df_str + "\n" + coordinates_str + "\n" + dimensions_str
×
157
        return pretty_str
×
158

159
    def __getitem__(self, val: str) -> pd.DataFrame:
1✔
160
        """Accessor to the underlying data structure.
161

162
        Args:
163
            val (str): Name of the dataframe column to retrieve.
164

165
        Returns:
166
            pd.DataFrame: Selected dataframe column.
167
        """
168
        return self._dataframe[val]
×
169

170
    @property
1✔
171
    def config(self) -> Dict[Any, Any]:
1✔
172
        """Getter attribute for the config dictionary
173

174
        Returns:
175
            Dict: The config dictionary.
176
        """
177
        return self._config
×
178

179
    @config.setter
1✔
180
    def config(self, config: Union[dict, str]):
1✔
181
        """Setter function for the config dictionary.
182

183
        Args:
184
            config (Union[dict, str]): Config dictionary or path of config file
185
                to load.
186
        """
187
        self._config = parse_config(config)
×
188
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
×
189
        if num_cores >= N_CPU:
×
190
            num_cores = N_CPU - 1
×
191
        self._config["binning"]["num_cores"] = num_cores
×
192

193
    @property
1✔
194
    def dimensions(self) -> list:
1✔
195
        """Getter attribute for the dimensions.
196

197
        Returns:
198
            list: List of dimensions.
199
        """
200
        return self._dimensions
×
201

202
    @dimensions.setter
1✔
203
    def dimensions(self, dims: list):
1✔
204
        """Setter function for the dimensions.
205

206
        Args:
207
            dims (list): List of dimensions to set.
208
        """
209
        assert isinstance(dims, list)
×
210
        self._dimensions = dims
×
211

212
    @property
1✔
213
    def coordinates(self) -> dict:
1✔
214
        """Getter attribute for the coordinates dict.
215

216
        Returns:
217
            dict: Dictionary of coordinates.
218
        """
219
        return self._coordinates
×
220

221
    @coordinates.setter
1✔
222
    def coordinates(self, coords: dict):
1✔
223
        """Setter function for the coordinates dict
224

225
        Args:
226
            coords (dict): Dictionary of coordinates.
227
        """
228
        assert isinstance(coords, dict)
×
229
        self._coordinates = {}
×
230
        for k, v in coords.items():
×
231
            self._coordinates[k] = xr.DataArray(v)
×
232

233
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
234
        """Function to mirror a list of files or a folder from a network drive to a
235
        local storage. Returns either the original or the copied path to the given
236
        path. The option to use this functionality is set by
237
        config["core"]["use_copy_tool"].
238

239
        Args:
240
            path (Union[str, List[str]]): Source path or path list.
241

242
        Returns:
243
            Union[str, List[str]]: Source or destination path or path list.
244
        """
245
        if self.use_copy_tool:
1✔
246
            if isinstance(path, list):
×
247
                path_out = []
×
248
                for file in path:
×
249
                    path_out.append(self.ct.copy(file))
×
250
                return path_out
×
251

252
            return self.ct.copy(path)
×
253

254
        if isinstance(path, list):
1✔
255
            return path
1✔
256

257
        return path
×
258

259
    def load(
1✔
260
        self,
261
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
262
        metadata: dict = None,
263
        files: List[str] = None,
264
        folder: str = None,
265
        runs: Sequence[str] = None,
266
        collect_metadata: bool = False,
267
        **kwds,
268
    ):
269
        """Load tabular data of single events into the dataframe object in the class.
270

271
        Args:
272
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
273
                format. Accepts anything which can be interpreted by pd.DataFrame as
274
                an input. Defaults to None.
275
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
276
            files (List[str], optional): List of file paths to pass to the loader.
277
                Defaults to None.
278
            runs (Sequence[str], optional): List of run identifiers to pass to the
279
                loader. Defaults to None.
280
            folder (str, optional): Folder path to pass to the loader.
281
                Defaults to None.
282

283
        Raises:
284
            ValueError: Raised if no valid input is provided.
285
        """
286
        if metadata is None:
1✔
287
            metadata = {}
1✔
288
        if dataframe is not None:
1✔
289
            self._dataframe = dataframe
×
290
        elif runs is not None:
1✔
291
            # If runs are provided, we only use the copy tool if also folder is provided.
292
            # In that case, we copy the whole provided base folder tree, and pass the copied
293
            # version to the loader as base folder to look for the runs.
294
            if folder is not None:
×
295
                dataframe, metadata = self.loader.read_dataframe(
×
296
                    folders=cast(str, self.cpy(folder)),
297
                    runs=runs,
298
                    metadata=metadata,
299
                    collect_metadata=collect_metadata,
300
                    **kwds,
301
                )
302
            else:
303
                dataframe, metadata = self.loader.read_dataframe(
×
304
                    runs=runs,
305
                    metadata=metadata,
306
                    collect_metadata=collect_metadata,
307
                    **kwds,
308
                )
309

310
        elif folder is not None:
1✔
311
            dataframe, metadata = self.loader.read_dataframe(
×
312
                folders=cast(str, self.cpy(folder)),
313
                metadata=metadata,
314
                collect_metadata=collect_metadata,
315
                **kwds,
316
            )
317

318
        elif files is not None:
1✔
319
            dataframe, metadata = self.loader.read_dataframe(
1✔
320
                files=cast(List[str], self.cpy(files)),
321
                metadata=metadata,
322
                collect_metadata=collect_metadata,
323
                **kwds,
324
            )
325

326
        else:
327
            raise ValueError(
×
328
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
329
            )
330

331
        self._dataframe = dataframe
1✔
332
        self._files = self.loader.files
1✔
333

334
        for key in metadata:
1✔
335
            self._attributes.add(
×
336
                entry=metadata[key],
337
                name=key,
338
                duplicate_policy="merge",
339
            )
340

341
    # Momentum calibration workflow
342
    # 1. Bin raw detector data for distortion correction
343
    def bin_and_load_momentum_calibration(
1✔
344
        self,
345
        df_partitions: int = 100,
346
        axes: List[str] = None,
347
        bins: List[int] = None,
348
        ranges: Sequence[Tuple[float, float]] = None,
349
        plane: int = 0,
350
        width: int = 5,
351
        apply: bool = False,
352
        **kwds,
353
    ):
354
        """1st step of momentum correction work flow. Function to do an initial binning
355
        of the dataframe loaded to the class, slice a plane from it using an
356
        interactive view, and load it into the momentum corrector class.
357

358
        Args:
359
            df_partitions (int, optional): Number of dataframe partitions to use for
360
                the initial binning. Defaults to 100.
361
            axes (List[str], optional): Axes to bin.
362
                Defaults to config["momentum"]["axes"].
363
            bins (List[int], optional): Bin numbers to use for binning.
364
                Defaults to config["momentum"]["bins"].
365
            ranges (List[Tuple], optional): Ranges to use for binning.
366
                Defaults to config["momentum"]["ranges"].
367
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
368
            width (int, optional): Initial value for the width slider. Defaults to 5.
369
            apply (bool, optional): Option to directly apply the values and select the
370
                slice. Defaults to False.
371
            **kwds: Keyword argument passed to the pre_binning function.
372
        """
373
        self._pre_binned = self.pre_binning(
1✔
374
            df_partitions=df_partitions,
375
            axes=axes,
376
            bins=bins,
377
            ranges=ranges,
378
            **kwds,
379
        )
380

381
        self.mc.load_data(data=self._pre_binned)
1✔
382
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
383

384
    # 2. Generate the spline warp correction from momentum features.
385
    # Either autoselect features, or input features from view above.
386
    def define_features(
1✔
387
        self,
388
        features: np.ndarray = None,
389
        rotation_symmetry: int = 6,
390
        auto_detect: bool = False,
391
        include_center: bool = True,
392
        apply: bool = False,
393
        **kwds,
394
    ):
395
        """2. Step of the distortion correction workflow: Define feature points in
396
        momentum space. They can be either manually selected using a GUI tool, be
397
        ptovided as list of feature points, or auto-generated using a
398
        feature-detection algorithm.
399

400
        Args:
401
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
402
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
403
                Defaults to 6.
404
            auto_detect (bool, optional): Whether to auto-detect the features.
405
                Defaults to False.
406
            include_center (bool, optional): Option to include a point at the center
407
                in the feature list. Defaults to True.
408
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
409
                MomentumCorrector.feature_select()
410
        """
411
        if auto_detect:  # automatic feature selection
×
412
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
413
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
414
            sigma_radius = kwds.pop(
×
415
                "sigma_radius",
416
                self._config["momentum"]["sigma_radius"],
417
            )
418
            self.mc.feature_extract(
×
419
                sigma=sigma,
420
                fwhm=fwhm,
421
                sigma_radius=sigma_radius,
422
                rotsym=rotation_symmetry,
423
                **kwds,
424
            )
425
            features = self.mc.peaks
×
426

427
        self.mc.feature_select(
×
428
            rotsym=rotation_symmetry,
429
            include_center=include_center,
430
            features=features,
431
            apply=apply,
432
            **kwds,
433
        )
434

435
    # 3. Generate the spline warp correction from momentum features.
436
    # If no features have been selected before, use class defaults.
437
    def generate_splinewarp(
1✔
438
        self,
439
        include_center: bool = True,
440
        **kwds,
441
    ):
442
        """3. Step of the distortion correction workflow: Generate the correction
443
        function restoring the symmetry in the image using a splinewarp algortihm.
444

445
        Args:
446
            include_center (bool, optional): Option to include the position of the
447
                center point in the correction. Defaults to True.
448
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
449
        """
450
        self.mc.spline_warp_estimate(include_center=include_center, **kwds)
×
451

452
        if self.mc.slice is not None:
×
453
            print("Original slice with reference features")
×
454
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
×
455

456
            print("Corrected slice with target features")
×
457
            self.mc.view(
×
458
                image=self.mc.slice_corrected,
459
                annotated=True,
460
                points={"feats": self.mc.ptargs},
461
                backend="bokeh",
462
                crosshair=True,
463
            )
464

465
            print("Original slice with target features")
×
466
            self.mc.view(
×
467
                image=self.mc.slice,
468
                points={"feats": self.mc.ptargs},
469
                annotated=True,
470
                backend="bokeh",
471
            )
472

473
    # 3a. Save spline-warp parameters to config file.
474
    def save_splinewarp(
1✔
475
        self,
476
        filename: str = None,
477
        overwrite: bool = False,
478
    ):
479
        """Save the generated spline-warp parameters to the folder config file.
480

481
        Args:
482
            filename (str, optional): Filename of the config dictionary to save to.
483
                Defaults to "sed_config.yaml" in the current folder.
484
            overwrite (bool, optional): Option to overwrite the present dictionary.
485
                Defaults to False.
486
        """
487
        if filename is None:
×
488
            filename = "sed_config.yaml"
×
489
        points = []
×
490
        try:
×
491
            for point in self.mc.pouter_ord:
×
492
                points.append([float(i) for i in point])
×
493
            if self.mc.pcent:
×
494
                points.append([float(i) for i in self.mc.pcent])
×
495
        except AttributeError as exc:
×
496
            raise AttributeError(
×
497
                "Momentum correction parameters not found, need to generate parameters first!",
498
            ) from exc
499
        config = {
×
500
            "momentum": {
501
                "correction": {"rotation_symmetry": self.mc.rotsym, "feature_points": points},
502
            },
503
        }
504
        save_config(config, filename, overwrite)
×
505

506
    # 4. Pose corrections. Provide interactive interface for correcting
507
    # scaling, shift and rotation
508
    def pose_adjustment(
1✔
509
        self,
510
        scale: float = 1,
511
        xtrans: float = 0,
512
        ytrans: float = 0,
513
        angle: float = 0,
514
        apply: bool = False,
515
        use_correction: bool = True,
516
        reset: bool = True,
517
    ):
518
        """3. step of the distortion correction workflow: Generate an interactive panel
519
        to adjust affine transformations that are applied to the image. Applies first
520
        a scaling, next an x/y translation, and last a rotation around the center of
521
        the image.
522

523
        Args:
524
            scale (float, optional): Initial value of the scaling slider.
525
                Defaults to 1.
526
            xtrans (float, optional): Initial value of the xtrans slider.
527
                Defaults to 0.
528
            ytrans (float, optional): Initial value of the ytrans slider.
529
                Defaults to 0.
530
            angle (float, optional): Initial value of the angle slider.
531
                Defaults to 0.
532
            apply (bool, optional): Option to directly apply the provided
533
                transformations. Defaults to False.
534
            use_correction (bool, option): Whether to use the spline warp correction
535
                or not. Defaults to True.
536
            reset (bool, optional):
537
                Option to reset the correction before transformation. Defaults to True.
538
        """
539
        # Generate homomorphy as default if no distortion correction has been applied
540
        if self.mc.slice_corrected is None:
×
541
            if self.mc.slice is None:
×
542
                raise ValueError(
×
543
                    "No slice for corrections and transformations loaded!",
544
                )
545
            self.mc.slice_corrected = self.mc.slice
×
546

547
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
×
548
            # Generate distortion correction from config values
549
            self.mc.add_features()
×
550
            self.mc.spline_warp_estimate()
×
551

552
        if not use_correction:
×
553
            self.mc.reset_deformation()
×
554

555
        self.mc.pose_adjustment(
×
556
            scale=scale,
557
            xtrans=xtrans,
558
            ytrans=ytrans,
559
            angle=angle,
560
            apply=apply,
561
            reset=reset,
562
        )
563

564
    # 5. Apply the momentum correction to the dataframe
565
    def apply_momentum_correction(
1✔
566
        self,
567
        preview: bool = False,
568
    ):
569
        """Applies the distortion correction and pose adjustment (optional)
570
        to the dataframe.
571

572
        Args:
573
            rdeform_field (np.ndarray, optional): Row deformation field.
574
                Defaults to None.
575
            cdeform_field (np.ndarray, optional): Column deformation field.
576
                Defaults to None.
577
            inv_dfield (np.ndarray, optional): Inverse deformation field.
578
                Defaults to None.
579
            preview (bool): Option to preview the first elements of the data frame.
580
        """
581
        if self._dataframe is not None:
×
582
            print("Adding corrected X/Y columns to dataframe:")
×
583
            self._dataframe, metadata = self.mc.apply_corrections(
×
584
                df=self._dataframe,
585
            )
586
            # Add Metadata
587
            self._attributes.add(
×
588
                metadata,
589
                "momentum_correction",
590
                duplicate_policy="merge",
591
            )
592
            if preview:
×
593
                print(self._dataframe.head(10))
×
594
            else:
595
                print(self._dataframe)
×
596

597
    # Momentum calibration work flow
598
    # 1. Calculate momentum calibration
599
    def calibrate_momentum_axes(
1✔
600
        self,
601
        point_a: Union[np.ndarray, List[int]] = None,
602
        point_b: Union[np.ndarray, List[int]] = None,
603
        k_distance: float = None,
604
        k_coord_a: Union[np.ndarray, List[float]] = None,
605
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
606
        equiscale: bool = True,
607
        apply=False,
608
    ):
609
        """1. step of the momentum calibration workflow. Calibrate momentum
610
        axes using either provided pixel coordinates of a high-symmetry point and its
611
        distance to the BZ center, or the k-coordinates of two points in the BZ
612
        (depending on the equiscale option). Opens an interactive panel for selecting
613
        the points.
614

615
        Args:
616
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
617
                point used for momentum calibration.
618
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
619
                second point used for momentum calibration.
620
                Defaults to config["momentum"]["center_pixel"].
621
            k_distance (float, optional): Momentum distance between point a and b.
622
                Needs to be provided if no specific k-koordinates for the two points
623
                are given. Defaults to None.
624
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
625
                of the first point used for calibration. Used if equiscale is False.
626
                Defaults to None.
627
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
628
                of the second point used for calibration. Defaults to [0.0, 0.0].
629
            equiscale (bool, optional): Option to apply different scales to kx and ky.
630
                If True, the distance between points a and b, and the absolute
631
                position of point a are used for defining the scale. If False, the
632
                scale is calculated from the k-positions of both points a and b.
633
                Defaults to True.
634
            apply (bool, optional): Option to directly store the momentum calibration
635
                in the class. Defaults to False.
636
        """
637
        if point_b is None:
×
638
            point_b = self._config["momentum"]["center_pixel"]
×
639

640
        self.mc.select_k_range(
×
641
            point_a=point_a,
642
            point_b=point_b,
643
            k_distance=k_distance,
644
            k_coord_a=k_coord_a,
645
            k_coord_b=k_coord_b,
646
            equiscale=equiscale,
647
            apply=apply,
648
        )
649

650
    # 1a. Save momentum calibration parameters to config file.
651
    def save_momentum_calibration(
1✔
652
        self,
653
        filename: str = None,
654
        overwrite: bool = False,
655
    ):
656
        """Save the generated momentum calibration parameters to the folder config file.
657

658
        Args:
659
            filename (str, optional): Filename of the config dictionary to save to.
660
                Defaults to "sed_config.yaml" in the current folder.
661
            overwrite (bool, optional): Option to overwrite the present dictionary.
662
                Defaults to False.
663
        """
664
        if filename is None:
×
665
            filename = "sed_config.yaml"
×
666
        calibration = {}
×
667
        try:
×
668
            for key in [
×
669
                "kx_scale",
670
                "ky_scale",
671
                "x_center",
672
                "y_center",
673
                "rstart",
674
                "cstart",
675
                "rstep",
676
                "cstep",
677
            ]:
678
                calibration[key] = float(self.mc.calibration[key])
×
679
        except KeyError as exc:
×
680
            raise KeyError(
×
681
                "Momentum calibration parameters not found, need to generate parameters first!",
682
            ) from exc
683

684
        config = {"momentum": {"calibration": calibration}}
×
685
        save_config(config, filename, overwrite)
×
686

687
    # 2. Apply correction and calibration to the dataframe
688
    def apply_momentum_calibration(
1✔
689
        self,
690
        calibration: dict = None,
691
        preview: bool = False,
692
    ):
693
        """2. step of the momentum calibration work flow: Apply the momentum
694
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
695
        these are used.
696

697
        Args:
698
            calibration (dict, optional): Optional dictionary with calibration data to
699
                use. Defaults to None.
700
            preview (bool): Option to preview the first elements of the data frame.
701
        """
702
        if self._dataframe is not None:
×
703

704
            print("Adding kx/ky columns to dataframe:")
×
705
            self._dataframe, metadata = self.mc.append_k_axis(
×
706
                df=self._dataframe,
707
                calibration=calibration,
708
            )
709

710
            # Add Metadata
711
            self._attributes.add(
×
712
                metadata,
713
                "momentum_calibration",
714
                duplicate_policy="merge",
715
            )
716
            if preview:
×
717
                print(self._dataframe.head(10))
×
718
            else:
719
                print(self._dataframe)
×
720

721
    # Energy correction workflow
722
    # 1. Adjust the energy correction parameters
723
    def adjust_energy_correction(
1✔
724
        self,
725
        correction_type: str = None,
726
        amplitude: float = None,
727
        center: Tuple[float, float] = None,
728
        apply=False,
729
        **kwds,
730
    ):
731
        """1. step of the energy crrection workflow: Opens an interactive plot to
732
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
733
        they are not present yet.
734

735
        Args:
736
            correction_type (str, optional): Type of correction to apply to the TOF
737
                axis. Valid values are:
738

739
                - 'spherical'
740
                - 'Lorentzian'
741
                - 'Gaussian'
742
                - 'Lorentzian_asymmetric'
743

744
                Defaults to config["energy"]["correction_type"].
745
            amplitude (float, optional): Amplitude of the correction.
746
                Defaults to config["energy"]["correction"]["amplitude"].
747
            center (Tuple[float, float], optional): Center X/Y coordinates for the
748
                correction. Defaults to config["energy"]["correction"]["center"].
749
            apply (bool, optional): Option to directly apply the provided or default
750
                correction parameters. Defaults to False.
751
        """
752
        if self._pre_binned is None:
×
753
            print(
×
754
                "Pre-binned data not present, binning using defaults from config...",
755
            )
756
            self._pre_binned = self.pre_binning()
×
757

758
        self.ec.adjust_energy_correction(
×
759
            self._pre_binned,
760
            correction_type=correction_type,
761
            amplitude=amplitude,
762
            center=center,
763
            apply=apply,
764
            **kwds,
765
        )
766

767
    # 1a. Save energy correction parameters to config file.
768
    def save_energy_correction(
1✔
769
        self,
770
        filename: str = None,
771
        overwrite: bool = False,
772
    ):
773
        """Save the generated energy correction parameters to the folder config file.
774

775
        Args:
776
            filename (str, optional): Filename of the config dictionary to save to.
777
                Defaults to "sed_config.yaml" in the current folder.
778
            overwrite (bool, optional): Option to overwrite the present dictionary.
779
                Defaults to False.
780
        """
781
        if filename is None:
×
782
            filename = "sed_config.yaml"
×
783
        correction = {}
×
784
        try:
×
785
            for key in self.ec.correction.keys():
×
786
                if key == "correction_type":
×
787
                    correction[key] = self.ec.correction[key]
×
788
                elif key == "center":
×
789
                    correction[key] = [float(i) for i in self.ec.correction[key]]
×
790
                else:
791
                    correction[key] = float(self.ec.correction[key])
×
792
        except AttributeError as exc:
×
793
            raise AttributeError(
×
794
                "Energy correction parameters not found, need to generate parameters first!",
795
            ) from exc
796

797
        config = {"energy": {"correction": correction}}
×
798
        save_config(config, filename, overwrite)
×
799

800
    # 2. Apply energy correction to dataframe
801
    def apply_energy_correction(
1✔
802
        self,
803
        correction: dict = None,
804
        preview: bool = False,
805
        **kwds,
806
    ):
807
        """2. step of the energy correction workflow: Apply the enery correction
808
        parameters stored in the class to the dataframe.
809

810
        Args:
811
            correction (dict, optional): Dictionary containing the correction
812
                parameters. Defaults to config["energy"]["calibration"].
813
            preview (bool): Option to preview the first elements of the data frame.
814
            **kwds:
815
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
816
            preview (bool): Option to preview the first elements of the data frame.
817
            **kwds:
818
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
819
        """
820
        if self._dataframe is not None:
×
821
            print("Applying energy correction to dataframe...")
×
822
            self._dataframe, metadata = self.ec.apply_energy_correction(
×
823
                df=self._dataframe,
824
                correction=correction,
825
                **kwds,
826
            )
827

828
            # Add Metadata
829
            self._attributes.add(
×
830
                metadata,
831
                "energy_correction",
832
            )
833
            if preview:
×
834
                print(self._dataframe.head(10))
×
835
            else:
836
                print(self._dataframe)
×
837

838
    # Energy calibrator workflow
839
    # 1. Load and normalize data
840
    def load_bias_series(
1✔
841
        self,
842
        data_files: List[str],
843
        axes: List[str] = None,
844
        bins: List = None,
845
        ranges: Sequence[Tuple[float, float]] = None,
846
        biases: np.ndarray = None,
847
        bias_key: str = None,
848
        normalize: bool = None,
849
        span: int = None,
850
        order: int = None,
851
    ):
852
        """1. step of the energy calibration workflow: Load and bin data from
853
        single-event files.
854

855
        Args:
856
            data_files (List[str]): list of file paths to bin
857
            axes (List[str], optional): bin axes.
858
                Defaults to config["dataframe"]["tof_column"].
859
            bins (List, optional): number of bins.
860
                Defaults to config["energy"]["bins"].
861
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
862
                Defaults to config["energy"]["ranges"].
863
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
864
                voltages are extracted from the data files.
865
            bias_key (str, optional): hdf5 path where bias values are stored.
866
                Defaults to config["energy"]["bias_key"].
867
            normalize (bool, optional): Option to normalize traces.
868
                Defaults to config["energy"]["normalize"].
869
            span (int, optional): span smoothing parameters of the LOESS method
870
                (see ``scipy.signal.savgol_filter()``).
871
                Defaults to config["energy"]["normalize_span"].
872
            order (int, optional): order smoothing parameters of the LOESS method
873
                (see ``scipy.signal.savgol_filter()``).
874
                Defaults to config["energy"]["normalize_order"].
875
        """
876
        self.ec.bin_data(
×
877
            data_files=cast(List[str], self.cpy(data_files)),
878
            axes=axes,
879
            bins=bins,
880
            ranges=ranges,
881
            biases=biases,
882
            bias_key=bias_key,
883
        )
884
        if (normalize is not None and normalize is True) or (
×
885
            normalize is None and self._config["energy"]["normalize"]
886
        ):
887
            if span is None:
×
888
                span = self._config["energy"]["normalize_span"]
×
889
            if order is None:
×
890
                order = self._config["energy"]["normalize_order"]
×
891
            self.ec.normalize(smooth=True, span=span, order=order)
×
892
        self.ec.view(
×
893
            traces=self.ec.traces_normed,
894
            xaxis=self.ec.tof,
895
            backend="bokeh",
896
        )
897

898
    # 2. extract ranges and get peak positions
899
    def find_bias_peaks(
1✔
900
        self,
901
        ranges: Union[List[Tuple], Tuple],
902
        ref_id: int = 0,
903
        infer_others: bool = True,
904
        mode: str = "replace",
905
        radius: int = None,
906
        peak_window: int = None,
907
        apply: bool = False,
908
    ):
909
        """2. step of the energy calibration workflow: Find a peak within a given range
910
        for the indicated reference trace, and tries to find the same peak for all
911
        other traces. Uses fast_dtw to align curves, which might not be too good if the
912
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
913
        middle of the set, and don't choose the range too narrow around the peak.
914
        Alternatively, a list of ranges for all traces can be provided.
915

916
        Args:
917
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
918
                Alternatively, a list of ranges for all traces can be given.
919
            refid (int, optional): The id of the trace the range refers to.
920
                Defaults to 0.
921
            infer_others (bool, optional): Whether to determine the range for the other
922
                traces. Defaults to True.
923
            mode (str, optional): Whether to "add" or "replace" existing ranges.
924
                Defaults to "replace".
925
            radius (int, optional): Radius parameter for fast_dtw.
926
                Defaults to config["energy"]["fastdtw_radius"].
927
            peak_window (int, optional): Peak_window parameter for the peak detection
928
                algorthm. amount of points that have to have to behave monotoneously
929
                around a peak. Defaults to config["energy"]["peak_window"].
930
            apply (bool, optional): Option to directly apply the provided parameters.
931
                Defaults to False.
932
        """
933
        if radius is None:
×
934
            radius = self._config["energy"]["fastdtw_radius"]
×
935
        if peak_window is None:
×
936
            peak_window = self._config["energy"]["fastdtw_radius"]
×
937
        if not infer_others:
×
938
            self.ec.add_ranges(
×
939
                ranges=ranges,
940
                ref_id=ref_id,
941
                infer_others=infer_others,
942
                mode=mode,
943
                radius=radius,
944
            )
945
            print(self.ec.featranges)
×
946
            try:
×
947
                self.ec.feature_extract(peak_window=peak_window)
×
948
                self.ec.view(
×
949
                    traces=self.ec.traces_normed,
950
                    segs=self.ec.featranges,
951
                    xaxis=self.ec.tof,
952
                    peaks=self.ec.peaks,
953
                    backend="bokeh",
954
                )
955
            except IndexError:
×
956
                print("Could not determine all peaks!")
×
957
                raise
×
958
        else:
959
            # New adjustment tool
960
            assert isinstance(ranges, tuple)
×
961
            self.ec.adjust_ranges(
×
962
                ranges=ranges,
963
                ref_id=ref_id,
964
                traces=self.ec.traces_normed,
965
                infer_others=infer_others,
966
                radius=radius,
967
                peak_window=peak_window,
968
                apply=apply,
969
            )
970

971
    # 3. Fit the energy calibration relation
972
    def calibrate_energy_axis(
1✔
973
        self,
974
        ref_id: int,
975
        ref_energy: float,
976
        method: str = None,
977
        energy_scale: str = None,
978
        **kwds,
979
    ):
980
        """3. Step of the energy calibration workflow: Calculate the calibration
981
        function for the energy axis, and apply it to the dataframe. Two
982
        approximations are implemented, a (normally 3rd order) polynomial
983
        approximation, and a d^2/(t-t0)^2 relation.
984

985
        Args:
986
            ref_id (int): id of the trace at the bias where the reference energy is
987
                given.
988
            ref_energy (float): Absolute energy of the detected feature at the bias
989
                of ref_id
990
            method (str, optional): Method for determining the energy calibration.
991

992
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
993
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
994

995
                Defaults to config["energy"]["calibration_method"]
996
            energy_scale (str, optional): Direction of increasing energy scale.
997

998
                - **'kinetic'**: increasing energy with decreasing TOF.
999
                - **'binding'**: increasing energy with increasing TOF.
1000

1001
                Defaults to config["energy"]["energy_scale"]
1002
        """
1003
        if method is None:
×
1004
            method = self._config["energy"]["calibration_method"]
×
1005

1006
        if energy_scale is None:
×
1007
            energy_scale = self._config["energy"]["energy_scale"]
×
1008

1009
        self.ec.calibrate(
×
1010
            ref_id=ref_id,
1011
            ref_energy=ref_energy,
1012
            method=method,
1013
            energy_scale=energy_scale,
1014
            **kwds,
1015
        )
1016
        print("Quality of Calibration:")
×
1017
        self.ec.view(
×
1018
            traces=self.ec.traces_normed,
1019
            xaxis=self.ec.calibration["axis"],
1020
            align=True,
1021
            energy_scale=energy_scale,
1022
            backend="bokeh",
1023
        )
1024
        print("E/TOF relationship:")
×
1025
        self.ec.view(
×
1026
            traces=self.ec.calibration["axis"][None, :],
1027
            xaxis=self.ec.tof,
1028
            backend="matplotlib",
1029
            show_legend=False,
1030
        )
1031
        if energy_scale == "kinetic":
×
1032
            plt.scatter(
×
1033
                self.ec.peaks[:, 0],
1034
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1035
                s=50,
1036
                c="k",
1037
            )
1038
        elif energy_scale == "binding":
×
1039
            plt.scatter(
×
1040
                self.ec.peaks[:, 0],
1041
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1042
                s=50,
1043
                c="k",
1044
            )
1045
        else:
1046
            raise ValueError(
×
1047
                'energy_scale needs to be either "binding" or "kinetic"',
1048
                f", got {energy_scale}.",
1049
            )
1050
        plt.xlabel("Time-of-flight", fontsize=15)
×
1051
        plt.ylabel("Energy (eV)", fontsize=15)
×
1052
        plt.show()
×
1053

1054
    # 3a. Save energy calibration parameters to config file.
1055
    def save_energy_calibration(
1✔
1056
        self,
1057
        filename: str = None,
1058
        overwrite: bool = False,
1059
    ):
1060
        """Save the generated energy calibration parameters to the folder config file.
1061

1062
        Args:
1063
            filename (str, optional): Filename of the config dictionary to save to.
1064
                Defaults to "sed_config.yaml" in the current folder.
1065
            overwrite (bool, optional): Option to overwrite the present dictionary.
1066
                Defaults to False.
1067
        """
1068
        if filename is None:
×
1069
            filename = "sed_config.yaml"
×
1070
        calibration = {}
×
1071
        try:
×
1072
            for (key, value) in self.ec.calibration.items():
×
1073
                if key in ["axis", "refid"]:
×
1074
                    continue
×
1075
                if key == "energy_scale":
×
1076
                    calibration[key] = value
×
1077
                else:
1078
                    calibration[key] = float(value)
×
1079
        except AttributeError as exc:
×
1080
            raise AttributeError(
×
1081
                "Energy calibration parameters not found, need to generate parameters first!",
1082
            ) from exc
1083

1084
        config = {"energy": {"calibration": calibration}}
×
1085
        save_config(config, filename, overwrite)
×
1086

1087
    # 4. Apply energy calibration to the dataframe
1088
    def append_energy_axis(
1✔
1089
        self,
1090
        calibration: dict = None,
1091
        preview: bool = False,
1092
        **kwds,
1093
    ):
1094
        """4. step of the energy calibration workflow: Apply the calibration function
1095
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1096
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1097
        can be provided.
1098

1099
        Args:
1100
            calibration (dict, optional): Calibration dict containing calibration
1101
                parameters. Overrides calibration from class or config.
1102
                Defaults to None.
1103
            preview (bool): Option to preview the first elements of the data frame.
1104
            **kwds:
1105
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1106
        """
1107
        if self._dataframe is not None:
×
1108
            print("Adding energy column to dataframe:")
×
1109
            self._dataframe, metadata = self.ec.append_energy_axis(
×
1110
                df=self._dataframe,
1111
                calibration=calibration,
1112
                **kwds,
1113
            )
1114

1115
            # Add Metadata
1116
            self._attributes.add(
×
1117
                metadata,
1118
                "energy_calibration",
1119
                duplicate_policy="merge",
1120
            )
1121
            if preview:
×
1122
                print(self._dataframe.head(10))
×
1123
            else:
1124
                print(self._dataframe)
×
1125

1126
    # Delay calibration function
1127
    def calibrate_delay_axis(
1✔
1128
        self,
1129
        delay_range: Tuple[float, float] = None,
1130
        datafile: str = None,
1131
        preview: bool = False,
1132
        **kwds,
1133
    ):
1134
        """Append delay column to dataframe. Either provide delay ranges, or read
1135
        them from a file.
1136

1137
        Args:
1138
            delay_range (Tuple[float, float], optional): The scanned delay range in
1139
                picoseconds. Defaults to None.
1140
            datafile (str, optional): The file from which to read the delay ranges.
1141
                Defaults to None.
1142
            preview (bool): Option to preview the first elements of the data frame.
1143
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1144
        """
1145
        if self._dataframe is not None:
×
1146
            print("Adding delay column to dataframe:")
×
1147

1148
            if delay_range is not None:
×
1149
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1150
                    self._dataframe,
1151
                    delay_range=delay_range,
1152
                    **kwds,
1153
                )
1154
            else:
1155
                if datafile is None:
×
1156
                    try:
×
1157
                        datafile = self._files[0]
×
1158
                    except IndexError:
×
1159
                        print(
×
1160
                            "No datafile available, specify eihter",
1161
                            " 'datafile' or 'delay_range'",
1162
                        )
1163
                        raise
×
1164

1165
                self._dataframe, metadata = self.dc.append_delay_axis(
×
1166
                    self._dataframe,
1167
                    datafile=datafile,
1168
                    **kwds,
1169
                )
1170

1171
            # Add Metadata
1172
            self._attributes.add(
×
1173
                metadata,
1174
                "delay_calibration",
1175
                duplicate_policy="merge",
1176
            )
1177
            if preview:
×
1178
                print(self._dataframe.head(10))
×
1179
            else:
1180
                print(self._dataframe)
×
1181

1182
    def add_jitter(self, cols: Sequence[str] = None):
1✔
1183
        """Add jitter to the selected dataframe columns.
1184

1185
        Args:
1186
            cols (Sequence[str], optional): The colums onto which to apply jitter.
1187
                Defaults to config["dataframe"]["jitter_cols"].
1188
        """
1189
        if cols is None:
×
1190
            cols = self._config["dataframe"].get(
×
1191
                "jitter_cols",
1192
                self._dataframe.columns,
1193
            )  # jitter all columns
1194

1195
        self._dataframe = self._dataframe.map_partitions(
×
1196
            apply_jitter,
1197
            cols=cols,
1198
            cols_jittered=cols,
1199
        )
1200
        metadata = []
×
1201
        for col in cols:
×
1202
            metadata.append(col)
×
1203
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
×
1204

1205
    def pre_binning(
1✔
1206
        self,
1207
        df_partitions: int = 100,
1208
        axes: List[str] = None,
1209
        bins: List[int] = None,
1210
        ranges: Sequence[Tuple[float, float]] = None,
1211
        **kwds,
1212
    ) -> xr.DataArray:
1213
        """Function to do an initial binning of the dataframe loaded to the class.
1214

1215
        Args:
1216
            df_partitions (int, optional): Number of dataframe partitions to use for
1217
                the initial binning. Defaults to 100.
1218
            axes (List[str], optional): Axes to bin.
1219
                Defaults to config["momentum"]["axes"].
1220
            bins (List[int], optional): Bin numbers to use for binning.
1221
                Defaults to config["momentum"]["bins"].
1222
            ranges (List[Tuple], optional): Ranges to use for binning.
1223
                Defaults to config["momentum"]["ranges"].
1224
            **kwds: Keyword argument passed to ``compute``.
1225

1226
        Returns:
1227
            xr.DataArray: pre-binned data-array.
1228
        """
1229
        if axes is None:
1✔
1230
            axes = self._config["momentum"]["axes"]
1✔
1231
        for loc, axis in enumerate(axes):
1✔
1232
            if axis.startswith("@"):
1✔
1233
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1234

1235
        if bins is None:
1✔
1236
            bins = self._config["momentum"]["bins"]
1✔
1237
        if ranges is None:
1✔
1238
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1239
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1240
                self._config["dataframe"]["tof_binning"] - 1
1241
            )
1242
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1243

1244
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1245

1246
        return self.compute(
1✔
1247
            bins=bins,
1248
            axes=axes,
1249
            ranges=ranges,
1250
            df_partitions=df_partitions,
1251
            **kwds,
1252
        )
1253

1254
    def compute(
1✔
1255
        self,
1256
        bins: Union[
1257
            int,
1258
            dict,
1259
            tuple,
1260
            List[int],
1261
            List[np.ndarray],
1262
            List[tuple],
1263
        ] = 100,
1264
        axes: Union[str, Sequence[str]] = None,
1265
        ranges: Sequence[Tuple[float, float]] = None,
1266
        **kwds,
1267
    ) -> xr.DataArray:
1268
        """Compute the histogram along the given dimensions.
1269

1270
        Args:
1271
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1272
                Definition of the bins. Can be any of the following cases:
1273

1274
                - an integer describing the number of bins in on all dimensions
1275
                - a tuple of 3 numbers describing start, end and step of the binning
1276
                  range
1277
                - a np.arrays defining the binning edges
1278
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1279
                - a dictionary made of the axes as keys and any of the above as values.
1280

1281
                This takes priority over the axes and range arguments. Defaults to 100.
1282
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1283
                on which to calculate the histogram. The order will be the order of the
1284
                dimensions in the resulting array. Defaults to None.
1285
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1286
                the start and end point of the binning range. Defaults to None.
1287
            **kwds: Keyword arguments:
1288

1289
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1290
                  ``bin_dataframe`` for details. Defaults to
1291
                  config["binning"]["hist_mode"].
1292
                - **mode**: Defines how the results from each partition are combined.
1293
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1294
                  Defaults to config["binning"]["mode"].
1295
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1296
                  config["binning"]["pbar"].
1297
                - **n_cores**: Number of CPU cores to use for parallelization.
1298
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1299
                - **threads_per_worker**: Limit the number of threads that
1300
                  multiprocessing can spawn per binning thread. Defaults to
1301
                  config["binning"]["threads_per_worker"].
1302
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1303
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1304
                  config["binning"]["threadpool_API"].
1305
                - **df_partitions**: A list of dataframe partitions. Defaults to all
1306
                  partitions.
1307

1308
                Additional kwds are passed to ``bin_dataframe``.
1309

1310
        Raises:
1311
            AssertError: Rises when no dataframe has been loaded.
1312

1313
        Returns:
1314
            xr.DataArray: The result of the n-dimensional binning represented in an
1315
            xarray object, combining the data with the axes.
1316
        """
1317
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1318

1319
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1320
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1321
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1322
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1323
        threads_per_worker = kwds.pop(
1✔
1324
            "threads_per_worker",
1325
            self._config["binning"]["threads_per_worker"],
1326
        )
1327
        threadpool_api = kwds.pop(
1✔
1328
            "threadpool_API",
1329
            self._config["binning"]["threadpool_API"],
1330
        )
1331
        df_partitions = kwds.pop("df_partitions", None)
1✔
1332
        if df_partitions is not None:
1✔
1333
            dataframe = self._dataframe.partitions[
1✔
1334
                0 : min(df_partitions, self._dataframe.npartitions)
1335
            ]
1336
        else:
1337
            dataframe = self._dataframe
×
1338

1339
        self._binned = bin_dataframe(
1✔
1340
            df=dataframe,
1341
            bins=bins,
1342
            axes=axes,
1343
            ranges=ranges,
1344
            hist_mode=hist_mode,
1345
            mode=mode,
1346
            pbar=pbar,
1347
            n_cores=num_cores,
1348
            threads_per_worker=threads_per_worker,
1349
            threadpool_api=threadpool_api,
1350
            **kwds,
1351
        )
1352

1353
        for dim in self._binned.dims:
1✔
1354
            try:
1✔
1355
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1356
            except KeyError:
×
1357
                pass
×
1358

1359
        self._binned.attrs["units"] = "counts"
1✔
1360
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1361
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1362

1363
        return self._binned
1✔
1364

1365
    def view_event_histogram(
1✔
1366
        self,
1367
        dfpid: int,
1368
        ncol: int = 2,
1369
        bins: Sequence[int] = None,
1370
        axes: Sequence[str] = None,
1371
        ranges: Sequence[Tuple[float, float]] = None,
1372
        backend: str = "bokeh",
1373
        legend: bool = True,
1374
        histkwds: dict = None,
1375
        legkwds: dict = None,
1376
        **kwds,
1377
    ):
1378
        """Plot individual histograms of specified dimensions (axes) from a substituent
1379
        dataframe partition.
1380

1381
        Args:
1382
            dfpid (int): Number of the data frame partition to look at.
1383
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1384
            bins (Sequence[int], optional): Number of bins to use for the speicified
1385
                axes. Defaults to config["histogram"]["bins"].
1386
            axes (Sequence[str], optional): Names of the axes to display.
1387
                Defaults to config["histogram"]["axes"].
1388
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1389
                specified axes. Defaults toconfig["histogram"]["ranges"].
1390
            backend (str, optional): Backend of the plotting library
1391
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1392
            legend (bool, optional): Option to include a legend in the histogram plots.
1393
                Defaults to True.
1394
            histkwds (dict, optional): Keyword arguments for histograms
1395
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1396
            legkwds (dict, optional): Keyword arguments for legend
1397
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1398
            **kwds: Extra keyword arguments passed to
1399
                ``sed.diagnostics.grid_histogram()``.
1400

1401
        Raises:
1402
            TypeError: Raises when the input values are not of the correct type.
1403
        """
1404
        if bins is None:
×
1405
            bins = self._config["histogram"]["bins"]
×
1406
        if axes is None:
×
1407
            axes = self._config["histogram"]["axes"]
×
1408
        axes = list(axes)
×
1409
        for loc, axis in enumerate(axes):
×
1410
            if axis.startswith("@"):
×
1411
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
×
1412
        if ranges is None:
×
1413
            ranges = list(self._config["histogram"]["ranges"])
×
1414
            ranges[2] = np.asarray(ranges[2]) / 2 ** (
×
1415
                self._config["dataframe"]["tof_binning"] - 1
1416
            )
1417
            ranges[3] = np.asarray(ranges[3]) / 2 ** (
×
1418
                self._config["dataframe"]["adc_binning"] - 1
1419
            )
1420

1421
        input_types = map(type, [axes, bins, ranges])
×
1422
        allowed_types = [list, tuple]
×
1423

1424
        df = self._dataframe
×
1425

1426
        if not set(input_types).issubset(allowed_types):
×
1427
            raise TypeError(
×
1428
                "Inputs of axes, bins, ranges need to be list or tuple!",
1429
            )
1430

1431
        # Read out the values for the specified groups
1432
        group_dict_dd = {}
×
1433
        dfpart = df.get_partition(dfpid)
×
1434
        cols = dfpart.columns
×
1435
        for ax in axes:
×
1436
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
×
1437
        group_dict = ddf.compute(group_dict_dd)[0]
×
1438

1439
        # Plot multiple histograms in a grid
1440
        grid_histogram(
×
1441
            group_dict,
1442
            ncol=ncol,
1443
            rvs=axes,
1444
            rvbins=bins,
1445
            rvranges=ranges,
1446
            backend=backend,
1447
            legend=legend,
1448
            histkwds=histkwds,
1449
            legkwds=legkwds,
1450
            **kwds,
1451
        )
1452

1453
    def save(
1✔
1454
        self,
1455
        faddr: str,
1456
        **kwds,
1457
    ):
1458
        """Saves the binned data to the provided path and filename.
1459

1460
        Args:
1461
            faddr (str): Path and name of the file to write. Its extension determines
1462
                the file type to write. Valid file types are:
1463

1464
                - "*.tiff", "*.tif": Saves a TIFF stack.
1465
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1466
                - "*.nxs", "*.nexus": Saves a NeXus file.
1467

1468
            **kwds: Keyword argumens, which are passed to the writer functions:
1469
                For TIFF writing:
1470

1471
                - **alias_dict**: Dictionary of dimension aliases to use.
1472

1473
                For HDF5 writing:
1474

1475
                - **mode**: hdf5 read/write mode. Defaults to "w".
1476

1477
                For NeXus:
1478

1479
                - **reader**: Name of the nexustools reader to use.
1480
                  Defaults to config["nexus"]["reader"]
1481
                - **definiton**: NeXus application definition to use for saving.
1482
                  Must be supported by the used ``reader``. Defaults to
1483
                  config["nexus"]["definition"]
1484
                - **input_files**: A list of input files to pass to the reader.
1485
                  Defaults to config["nexus"]["input_files"]
1486
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
1487
                  to add to the list of files to pass to the reader.
1488
        """
1489
        if self._binned is None:
×
1490
            raise NameError("Need to bin data first!")
×
1491

1492
        extension = pathlib.Path(faddr).suffix
×
1493

1494
        if extension in (".tif", ".tiff"):
×
1495
            to_tiff(
×
1496
                data=self._binned,
1497
                faddr=faddr,
1498
                **kwds,
1499
            )
1500
        elif extension in (".h5", ".hdf5"):
×
1501
            to_h5(
×
1502
                data=self._binned,
1503
                faddr=faddr,
1504
                **kwds,
1505
            )
1506
        elif extension in (".nxs", ".nexus"):
×
1507
            reader = kwds.pop("reader", self._config["nexus"]["reader"])
×
1508
            definition = kwds.pop(
×
1509
                "definition",
1510
                self._config["nexus"]["definition"],
1511
            )
1512
            input_files = kwds.pop(
×
1513
                "input_files",
1514
                self._config["nexus"]["input_files"],
1515
            )
1516
            if isinstance(input_files, str):
×
1517
                input_files = [input_files]
×
1518

1519
            if "eln_data" in kwds:
×
1520
                input_files.append(kwds.pop("eln_data"))
×
1521

1522
            to_nexus(
×
1523
                data=self._binned,
1524
                faddr=faddr,
1525
                reader=reader,
1526
                definition=definition,
1527
                input_files=input_files,
1528
                **kwds,
1529
            )
1530

1531
        else:
1532
            raise NotImplementedError(
×
1533
                f"Unrecognized file format: {extension}.",
1534
            )
1535

1536
    def add_dimension(self, name: str, axis_range: Tuple):
1✔
1537
        """Add a dimension axis.
1538

1539
        Args:
1540
            name (str): name of the axis
1541
            axis_range (Tuple): range for the axis.
1542

1543
        Raises:
1544
            ValueError: Raised if an axis with that name already exists.
1545
        """
1546
        if name in self._coordinates:
×
1547
            raise ValueError(f"Axis {name} already exists")
×
1548

1549
        self.axis[name] = self.make_axis(axis_range)
×
1550

1551
    def make_axis(self, axis_range: Tuple) -> np.ndarray:
1✔
1552
        """Function to make an axis.
1553

1554
        Args:
1555
            axis_range (Tuple): range for the new axis.
1556
        """
1557

1558
        # TODO: What shall this function do?
1559
        return np.arange(*axis_range)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc