• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6679473184

28 Oct 2023 10:09PM UTC coverage: 87.502% (+0.04%) from 87.467%
6679473184

push

github

rettigl
Merge branch 'histograms_from_timed_dataframe' into hist_testing

21 of 21 new or added lines in 1 file covered. (100.0%)

4593 of 5249 relevant lines covered (87.5%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.93
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Literal
1✔
10
from typing import Sequence
1✔
11
from typing import Tuple
1✔
12
from typing import Union
1✔
13

14
import dask.dataframe as ddf
1✔
15
import matplotlib.pyplot as plt
1✔
16
import numpy as np
1✔
17
import pandas as pd
1✔
18
import psutil
1✔
19
import xarray as xr
1✔
20

21
from sed.binning import bin_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
23
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
24
from sed.calibrator import DelayCalibrator
1✔
25
from sed.calibrator import EnergyCalibrator
1✔
26
from sed.calibrator import MomentumCorrector
1✔
27
from sed.core.config import parse_config
1✔
28
from sed.core.config import save_config
1✔
29
from sed.core.dfops import apply_jitter
1✔
30
from sed.core.dfops import rolling_average_on_acquisition_time
1✔
31
from sed.core.metadata import MetaHandler
1✔
32
from sed.diagnostics import grid_histogram
1✔
33
from sed.io import to_h5
1✔
34
from sed.io import to_nexus
1✔
35
from sed.io import to_tiff
1✔
36
from sed.loader import CopyTool
1✔
37
from sed.loader import get_loader
1✔
38

39
N_CPU = psutil.cpu_count()
1✔
40

41

42
class SedProcessor:
1✔
43
    """Processor class of sed. Contains wrapper functions defining a work flow for data
44
    correction, calibration and binning.
45

46
    Args:
47
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
48
        config (Union[dict, str], optional): Config dictionary or config file name.
49
            Defaults to None.
50
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
51
            into the class. Defaults to None.
52
        files (List[str], optional): List of files to pass to the loader defined in
53
            the config. Defaults to None.
54
        folder (str, optional): Folder containing files to pass to the loader
55
            defined in the config. Defaults to None.
56
        collect_metadata (bool): Option to collect metadata from files.
57
            Defaults to False.
58
        **kwds: Keyword arguments passed to the reader.
59
    """
60

61
    def __init__(
1✔
62
        self,
63
        metadata: dict = None,
64
        config: Union[dict, str] = None,
65
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
66
        files: List[str] = None,
67
        folder: str = None,
68
        runs: Sequence[str] = None,
69
        collect_metadata: bool = False,
70
        **kwds,
71
    ):
72
        """Processor class of sed. Contains wrapper functions defining a work flow
73
        for data correction, calibration, and binning.
74

75
        Args:
76
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
77
            config (Union[dict, str], optional): Config dictionary or config file name.
78
                Defaults to None.
79
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
80
                into the class. Defaults to None.
81
            files (List[str], optional): List of files to pass to the loader defined in
82
                the config. Defaults to None.
83
            folder (str, optional): Folder containing files to pass to the loader
84
                defined in the config. Defaults to None.
85
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
86
                defined in the config. Defaults to None.
87
            collect_metadata (bool): Option to collect metadata from files.
88
                Defaults to False.
89
            **kwds: Keyword arguments passed to parse_config and to the reader.
90
        """
91
        config_kwds = {
1✔
92
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
93
        }
94
        for key in config_kwds.keys():
1✔
95
            del kwds[key]
1✔
96
        self._config = parse_config(config, **config_kwds)
1✔
97
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
98
        if num_cores >= N_CPU:
1✔
99
            num_cores = N_CPU - 1
1✔
100
        self._config["binning"]["num_cores"] = num_cores
1✔
101

102
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
103
        self._timed_dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
104
        self._files: List[str] = []
1✔
105

106
        self._binned: xr.DataArray = None
1✔
107
        self._pre_binned: xr.DataArray = None
1✔
108
        self._normalization_histogram: xr.DataArray = None
1✔
109
        self._normalized: xr.DataArray = None
1✔
110

111
        self._attributes = MetaHandler(meta=metadata)
1✔
112

113
        loader_name = self._config["core"]["loader"]
1✔
114
        self.loader = get_loader(
1✔
115
            loader_name=loader_name,
116
            config=self._config,
117
        )
118

119
        self.ec = EnergyCalibrator(
1✔
120
            loader=self.loader,
121
            config=self._config,
122
        )
123

124
        self.mc = MomentumCorrector(
1✔
125
            config=self._config,
126
        )
127

128
        self.dc = DelayCalibrator(
1✔
129
            config=self._config,
130
        )
131

132
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
133
            "use_copy_tool",
134
            False,
135
        )
136
        if self.use_copy_tool:
1✔
137
            try:
1✔
138
                self.ct = CopyTool(
1✔
139
                    source=self._config["core"]["copy_tool_source"],
140
                    dest=self._config["core"]["copy_tool_dest"],
141
                    **self._config["core"].get("copy_tool_kwds", {}),
142
                )
143
            except KeyError:
1✔
144
                self.use_copy_tool = False
1✔
145

146
        # Load data if provided:
147
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
148
            self.load(
1✔
149
                dataframe=dataframe,
150
                metadata=metadata,
151
                files=files,
152
                folder=folder,
153
                runs=runs,
154
                collect_metadata=collect_metadata,
155
                **kwds,
156
            )
157

158
    def __repr__(self):
1✔
159
        if self._dataframe is None:
1✔
160
            df_str = "Data Frame: No Data loaded"
1✔
161
        else:
162
            df_str = self._dataframe.__repr__()
1✔
163
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
164
        pretty_str = df_str + "\n" + attributes_str
1✔
165
        return pretty_str
1✔
166

167
    @property
1✔
168
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
169
        """Accessor to the underlying dataframe.
170

171
        Returns:
172
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
173
        """
174
        return self._dataframe
1✔
175

176
    @dataframe.setter
1✔
177
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
178
        """Setter for the underlying dataframe.
179

180
        Args:
181
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
182
        """
183
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
184
            dataframe,
185
            self._dataframe.__class__,
186
        ):
187
            raise ValueError(
1✔
188
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
189
                "as the dataframe loaded into the SedProcessor!.\n"
190
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
191
            )
192
        self._dataframe = dataframe
1✔
193

194
    @property
1✔
195
    def timed_dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
196
        """Accessor to the underlying timed_dataframe.
197

198
        Returns:
199
            Union[pd.DataFrame, ddf.DataFrame]: Timed Dataframe object.
200
        """
201
        return self._timed_dataframe
1✔
202

203
    @timed_dataframe.setter
1✔
204
    def timed_dataframe(self, timed_dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
205
        """Setter for the underlying timed dataframe.
206

207
        Args:
208
            timed_dataframe (Union[pd.DataFrame, ddf.DataFrame]): The timed dataframe object to set
209
        """
210
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
211
            timed_dataframe,
212
            self._timed_dataframe.__class__,
213
        ):
214
            raise ValueError(
×
215
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
216
                "kind as the dataframe loaded into the SedProcessor!.\n"
217
                f"Loaded type: {self._timed_dataframe.__class__}, "
218
                f"provided type: {timed_dataframe}.",
219
            )
220
        self._timed_dataframe = timed_dataframe
×
221

222
    @property
1✔
223
    def attributes(self) -> dict:
1✔
224
        """Accessor to the metadata dict.
225

226
        Returns:
227
            dict: The metadata dict.
228
        """
229
        return self._attributes.metadata
1✔
230

231
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
232
        """Function to add element to the attributes dict.
233

234
        Args:
235
            attributes (dict): The attributes dictionary object to add.
236
            name (str): Key under which to add the dictionary to the attributes.
237
        """
238
        self._attributes.add(
1✔
239
            entry=attributes,
240
            name=name,
241
            **kwds,
242
        )
243

244
    @property
1✔
245
    def config(self) -> Dict[Any, Any]:
1✔
246
        """Getter attribute for the config dictionary
247

248
        Returns:
249
            Dict: The config dictionary.
250
        """
251
        return self._config
1✔
252

253
    @property
1✔
254
    def files(self) -> List[str]:
1✔
255
        """Getter attribute for the list of files
256

257
        Returns:
258
            List[str]: The list of loaded files
259
        """
260
        return self._files
1✔
261

262
    @property
1✔
263
    def binned(self) -> xr.DataArray:
1✔
264
        """Getter attribute for the binned data array
265

266
        Returns:
267
            xr.DataArray: The binned data array
268
        """
269
        if self._binned is None:
1✔
270
            raise ValueError("No binned data available, need to compute histogram first!")
×
271
        return self._binned
1✔
272

273
    @property
1✔
274
    def normalized(self) -> xr.DataArray:
1✔
275
        """Getter attribute for the normalized data array
276

277
        Returns:
278
            xr.DataArray: The normalized data array
279
        """
280
        if self._normalized is None:
1✔
281
            raise ValueError(
×
282
                "No normalized data available, compute data with normalization enabled!",
283
            )
284
        return self._normalized
1✔
285

286
    @property
1✔
287
    def normalization_histogram(self) -> xr.DataArray:
1✔
288
        """Getter attribute for the normalization histogram
289

290
        Returns:
291
            xr.DataArray: The normalizazion histogram
292
        """
293
        if self._normalization_histogram is None:
1✔
294
            raise ValueError("No normalization histogram available, generate histogram first!")
×
295
        return self._normalization_histogram
1✔
296

297
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
298
        """Function to mirror a list of files or a folder from a network drive to a
299
        local storage. Returns either the original or the copied path to the given
300
        path. The option to use this functionality is set by
301
        config["core"]["use_copy_tool"].
302

303
        Args:
304
            path (Union[str, List[str]]): Source path or path list.
305

306
        Returns:
307
            Union[str, List[str]]: Source or destination path or path list.
308
        """
309
        if self.use_copy_tool:
1✔
310
            if isinstance(path, list):
1✔
311
                path_out = []
1✔
312
                for file in path:
1✔
313
                    path_out.append(self.ct.copy(file))
1✔
314
                return path_out
1✔
315

316
            return self.ct.copy(path)
×
317

318
        if isinstance(path, list):
1✔
319
            return path
1✔
320

321
        return path
1✔
322

323
    def load(
1✔
324
        self,
325
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
326
        metadata: dict = None,
327
        files: List[str] = None,
328
        folder: str = None,
329
        runs: Sequence[str] = None,
330
        collect_metadata: bool = False,
331
        **kwds,
332
    ):
333
        """Load tabular data of single events into the dataframe object in the class.
334

335
        Args:
336
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
337
                format. Accepts anything which can be interpreted by pd.DataFrame as
338
                an input. Defaults to None.
339
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
340
            files (List[str], optional): List of file paths to pass to the loader.
341
                Defaults to None.
342
            runs (Sequence[str], optional): List of run identifiers to pass to the
343
                loader. Defaults to None.
344
            folder (str, optional): Folder path to pass to the loader.
345
                Defaults to None.
346

347
        Raises:
348
            ValueError: Raised if no valid input is provided.
349
        """
350
        if metadata is None:
1✔
351
            metadata = {}
1✔
352
        if dataframe is not None:
1✔
353
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
354
        elif runs is not None:
1✔
355
            # If runs are provided, we only use the copy tool if also folder is provided.
356
            # In that case, we copy the whole provided base folder tree, and pass the copied
357
            # version to the loader as base folder to look for the runs.
358
            if folder is not None:
1✔
359
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
360
                    folders=cast(str, self.cpy(folder)),
361
                    runs=runs,
362
                    metadata=metadata,
363
                    collect_metadata=collect_metadata,
364
                    **kwds,
365
                )
366
            else:
367
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
368
                    runs=runs,
369
                    metadata=metadata,
370
                    collect_metadata=collect_metadata,
371
                    **kwds,
372
                )
373

374
        elif folder is not None:
1✔
375
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
376
                folders=cast(str, self.cpy(folder)),
377
                metadata=metadata,
378
                collect_metadata=collect_metadata,
379
                **kwds,
380
            )
381
        elif files is not None:
1✔
382
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
383
                files=cast(List[str], self.cpy(files)),
384
                metadata=metadata,
385
                collect_metadata=collect_metadata,
386
                **kwds,
387
            )
388
        else:
389
            raise ValueError(
1✔
390
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
391
            )
392

393
        self._dataframe = dataframe
1✔
394
        self._timed_dataframe = timed_dataframe
1✔
395
        self._files = self.loader.files
1✔
396

397
        for key in metadata:
1✔
398
            self._attributes.add(
1✔
399
                entry=metadata[key],
400
                name=key,
401
                duplicate_policy="merge",
402
            )
403

404
    # Momentum calibration workflow
405
    # 1. Bin raw detector data for distortion correction
406
    def bin_and_load_momentum_calibration(
1✔
407
        self,
408
        df_partitions: int = 100,
409
        axes: List[str] = None,
410
        bins: List[int] = None,
411
        ranges: Sequence[Tuple[float, float]] = None,
412
        plane: int = 0,
413
        width: int = 5,
414
        apply: bool = False,
415
        **kwds,
416
    ):
417
        """1st step of momentum correction work flow. Function to do an initial binning
418
        of the dataframe loaded to the class, slice a plane from it using an
419
        interactive view, and load it into the momentum corrector class.
420

421
        Args:
422
            df_partitions (int, optional): Number of dataframe partitions to use for
423
                the initial binning. Defaults to 100.
424
            axes (List[str], optional): Axes to bin.
425
                Defaults to config["momentum"]["axes"].
426
            bins (List[int], optional): Bin numbers to use for binning.
427
                Defaults to config["momentum"]["bins"].
428
            ranges (List[Tuple], optional): Ranges to use for binning.
429
                Defaults to config["momentum"]["ranges"].
430
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
431
            width (int, optional): Initial value for the width slider. Defaults to 5.
432
            apply (bool, optional): Option to directly apply the values and select the
433
                slice. Defaults to False.
434
            **kwds: Keyword argument passed to the pre_binning function.
435
        """
436
        self._pre_binned = self.pre_binning(
1✔
437
            df_partitions=df_partitions,
438
            axes=axes,
439
            bins=bins,
440
            ranges=ranges,
441
            **kwds,
442
        )
443

444
        self.mc.load_data(data=self._pre_binned)
1✔
445
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
446

447
    # 2. Generate the spline warp correction from momentum features.
448
    # Either autoselect features, or input features from view above.
449
    def define_features(
1✔
450
        self,
451
        features: np.ndarray = None,
452
        rotation_symmetry: int = 6,
453
        auto_detect: bool = False,
454
        include_center: bool = True,
455
        apply: bool = False,
456
        **kwds,
457
    ):
458
        """2. Step of the distortion correction workflow: Define feature points in
459
        momentum space. They can be either manually selected using a GUI tool, be
460
        ptovided as list of feature points, or auto-generated using a
461
        feature-detection algorithm.
462

463
        Args:
464
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
465
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
466
                Defaults to 6.
467
            auto_detect (bool, optional): Whether to auto-detect the features.
468
                Defaults to False.
469
            include_center (bool, optional): Option to include a point at the center
470
                in the feature list. Defaults to True.
471
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
472
                MomentumCorrector.feature_select()
473
        """
474
        if auto_detect:  # automatic feature selection
1✔
475
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
476
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
477
            sigma_radius = kwds.pop(
×
478
                "sigma_radius",
479
                self._config["momentum"]["sigma_radius"],
480
            )
481
            self.mc.feature_extract(
×
482
                sigma=sigma,
483
                fwhm=fwhm,
484
                sigma_radius=sigma_radius,
485
                rotsym=rotation_symmetry,
486
                **kwds,
487
            )
488
            features = self.mc.peaks
×
489

490
        self.mc.feature_select(
1✔
491
            rotsym=rotation_symmetry,
492
            include_center=include_center,
493
            features=features,
494
            apply=apply,
495
            **kwds,
496
        )
497

498
    # 3. Generate the spline warp correction from momentum features.
499
    # If no features have been selected before, use class defaults.
500
    def generate_splinewarp(
1✔
501
        self,
502
        use_center: bool = None,
503
        **kwds,
504
    ):
505
        """3. Step of the distortion correction workflow: Generate the correction
506
        function restoring the symmetry in the image using a splinewarp algortihm.
507

508
        Args:
509
            use_center (bool, optional): Option to use the position of the
510
                center point in the correction. Default is read from config, or set to True.
511
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
512
        """
513
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
514

515
        if self.mc.slice is not None:
1✔
516
            print("Original slice with reference features")
1✔
517
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
518

519
            print("Corrected slice with target features")
1✔
520
            self.mc.view(
1✔
521
                image=self.mc.slice_corrected,
522
                annotated=True,
523
                points={"feats": self.mc.ptargs},
524
                backend="bokeh",
525
                crosshair=True,
526
            )
527

528
            print("Original slice with target features")
1✔
529
            self.mc.view(
1✔
530
                image=self.mc.slice,
531
                points={"feats": self.mc.ptargs},
532
                annotated=True,
533
                backend="bokeh",
534
            )
535

536
    # 3a. Save spline-warp parameters to config file.
537
    def save_splinewarp(
1✔
538
        self,
539
        filename: str = None,
540
        overwrite: bool = False,
541
    ):
542
        """Save the generated spline-warp parameters to the folder config file.
543

544
        Args:
545
            filename (str, optional): Filename of the config dictionary to save to.
546
                Defaults to "sed_config.yaml" in the current folder.
547
            overwrite (bool, optional): Option to overwrite the present dictionary.
548
                Defaults to False.
549
        """
550
        if filename is None:
1✔
551
            filename = "sed_config.yaml"
×
552
        points = []
1✔
553
        try:
1✔
554
            for point in self.mc.pouter_ord:
1✔
555
                points.append([float(i) for i in point])
1✔
556
            if self.mc.include_center:
1✔
557
                points.append([float(i) for i in self.mc.pcent])
1✔
558
        except AttributeError as exc:
×
559
            raise AttributeError(
×
560
                "Momentum correction parameters not found, need to generate parameters first!",
561
            ) from exc
562
        config = {
1✔
563
            "momentum": {
564
                "correction": {
565
                    "rotation_symmetry": self.mc.rotsym,
566
                    "feature_points": points,
567
                    "include_center": self.mc.include_center,
568
                    "use_center": self.mc.use_center,
569
                },
570
            },
571
        }
572
        save_config(config, filename, overwrite)
1✔
573

574
    # 4. Pose corrections. Provide interactive interface for correcting
575
    # scaling, shift and rotation
576
    def pose_adjustment(
1✔
577
        self,
578
        scale: float = 1,
579
        xtrans: float = 0,
580
        ytrans: float = 0,
581
        angle: float = 0,
582
        apply: bool = False,
583
        use_correction: bool = True,
584
        reset: bool = True,
585
    ):
586
        """3. step of the distortion correction workflow: Generate an interactive panel
587
        to adjust affine transformations that are applied to the image. Applies first
588
        a scaling, next an x/y translation, and last a rotation around the center of
589
        the image.
590

591
        Args:
592
            scale (float, optional): Initial value of the scaling slider.
593
                Defaults to 1.
594
            xtrans (float, optional): Initial value of the xtrans slider.
595
                Defaults to 0.
596
            ytrans (float, optional): Initial value of the ytrans slider.
597
                Defaults to 0.
598
            angle (float, optional): Initial value of the angle slider.
599
                Defaults to 0.
600
            apply (bool, optional): Option to directly apply the provided
601
                transformations. Defaults to False.
602
            use_correction (bool, option): Whether to use the spline warp correction
603
                or not. Defaults to True.
604
            reset (bool, optional):
605
                Option to reset the correction before transformation. Defaults to True.
606
        """
607
        # Generate homomorphy as default if no distortion correction has been applied
608
        if self.mc.slice_corrected is None:
1✔
609
            if self.mc.slice is None:
1✔
610
                raise ValueError(
1✔
611
                    "No slice for corrections and transformations loaded!",
612
                )
613
            self.mc.slice_corrected = self.mc.slice
×
614

615
        if not use_correction:
1✔
616
            self.mc.reset_deformation()
1✔
617

618
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
619
            # Generate distortion correction from config values
620
            self.mc.add_features()
×
621
            self.mc.spline_warp_estimate()
×
622

623
        self.mc.pose_adjustment(
1✔
624
            scale=scale,
625
            xtrans=xtrans,
626
            ytrans=ytrans,
627
            angle=angle,
628
            apply=apply,
629
            reset=reset,
630
        )
631

632
    # 5. Apply the momentum correction to the dataframe
633
    def apply_momentum_correction(
1✔
634
        self,
635
        preview: bool = False,
636
    ):
637
        """Applies the distortion correction and pose adjustment (optional)
638
        to the dataframe.
639

640
        Args:
641
            rdeform_field (np.ndarray, optional): Row deformation field.
642
                Defaults to None.
643
            cdeform_field (np.ndarray, optional): Column deformation field.
644
                Defaults to None.
645
            inv_dfield (np.ndarray, optional): Inverse deformation field.
646
                Defaults to None.
647
            preview (bool): Option to preview the first elements of the data frame.
648
        """
649
        if self._dataframe is not None:
1✔
650
            print("Adding corrected X/Y columns to dataframe:")
1✔
651
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
652
                df=self._dataframe,
653
            )
654
            if self._timed_dataframe is not None:
1✔
655
                if (
1✔
656
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
657
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
658
                ):
659
                    self._timed_dataframe, _ = self.mc.apply_corrections(
1✔
660
                        self._timed_dataframe,
661
                    )
662
            # Add Metadata
663
            self._attributes.add(
1✔
664
                metadata,
665
                "momentum_correction",
666
                duplicate_policy="merge",
667
            )
668
            if preview:
1✔
669
                print(self._dataframe.head(10))
×
670
            else:
671
                print(self._dataframe)
1✔
672

673
    # Momentum calibration work flow
674
    # 1. Calculate momentum calibration
675
    def calibrate_momentum_axes(
1✔
676
        self,
677
        point_a: Union[np.ndarray, List[int]] = None,
678
        point_b: Union[np.ndarray, List[int]] = None,
679
        k_distance: float = None,
680
        k_coord_a: Union[np.ndarray, List[float]] = None,
681
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
682
        equiscale: bool = True,
683
        apply=False,
684
    ):
685
        """1. step of the momentum calibration workflow. Calibrate momentum
686
        axes using either provided pixel coordinates of a high-symmetry point and its
687
        distance to the BZ center, or the k-coordinates of two points in the BZ
688
        (depending on the equiscale option). Opens an interactive panel for selecting
689
        the points.
690

691
        Args:
692
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
693
                point used for momentum calibration.
694
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
695
                second point used for momentum calibration.
696
                Defaults to config["momentum"]["center_pixel"].
697
            k_distance (float, optional): Momentum distance between point a and b.
698
                Needs to be provided if no specific k-koordinates for the two points
699
                are given. Defaults to None.
700
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
701
                of the first point used for calibration. Used if equiscale is False.
702
                Defaults to None.
703
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
704
                of the second point used for calibration. Defaults to [0.0, 0.0].
705
            equiscale (bool, optional): Option to apply different scales to kx and ky.
706
                If True, the distance between points a and b, and the absolute
707
                position of point a are used for defining the scale. If False, the
708
                scale is calculated from the k-positions of both points a and b.
709
                Defaults to True.
710
            apply (bool, optional): Option to directly store the momentum calibration
711
                in the class. Defaults to False.
712
        """
713
        if point_b is None:
1✔
714
            point_b = self._config["momentum"]["center_pixel"]
1✔
715

716
        self.mc.select_k_range(
1✔
717
            point_a=point_a,
718
            point_b=point_b,
719
            k_distance=k_distance,
720
            k_coord_a=k_coord_a,
721
            k_coord_b=k_coord_b,
722
            equiscale=equiscale,
723
            apply=apply,
724
        )
725

726
    # 1a. Save momentum calibration parameters to config file.
727
    def save_momentum_calibration(
1✔
728
        self,
729
        filename: str = None,
730
        overwrite: bool = False,
731
    ):
732
        """Save the generated momentum calibration parameters to the folder config file.
733

734
        Args:
735
            filename (str, optional): Filename of the config dictionary to save to.
736
                Defaults to "sed_config.yaml" in the current folder.
737
            overwrite (bool, optional): Option to overwrite the present dictionary.
738
                Defaults to False.
739
        """
740
        if filename is None:
1✔
741
            filename = "sed_config.yaml"
×
742
        calibration = {}
1✔
743
        try:
1✔
744
            for key in [
1✔
745
                "kx_scale",
746
                "ky_scale",
747
                "x_center",
748
                "y_center",
749
                "rstart",
750
                "cstart",
751
                "rstep",
752
                "cstep",
753
            ]:
754
                calibration[key] = float(self.mc.calibration[key])
1✔
755
        except KeyError as exc:
×
756
            raise KeyError(
×
757
                "Momentum calibration parameters not found, need to generate parameters first!",
758
            ) from exc
759

760
        config = {"momentum": {"calibration": calibration}}
1✔
761
        save_config(config, filename, overwrite)
1✔
762

763
    # 2. Apply correction and calibration to the dataframe
764
    def apply_momentum_calibration(
1✔
765
        self,
766
        calibration: dict = None,
767
        preview: bool = False,
768
    ):
769
        """2. step of the momentum calibration work flow: Apply the momentum
770
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
771
        these are used.
772

773
        Args:
774
            calibration (dict, optional): Optional dictionary with calibration data to
775
                use. Defaults to None.
776
            preview (bool): Option to preview the first elements of the data frame.
777
        """
778
        if self._dataframe is not None:
1✔
779

780
            print("Adding kx/ky columns to dataframe:")
1✔
781
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
782
                df=self._dataframe,
783
                calibration=calibration,
784
            )
785
            if self._timed_dataframe is not None:
1✔
786
                if (
1✔
787
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
788
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
789
                ):
790
                    self._timed_dataframe, _ = self.mc.append_k_axis(
1✔
791
                        df=self._timed_dataframe,
792
                        calibration=calibration,
793
                    )
794

795
            # Add Metadata
796
            self._attributes.add(
1✔
797
                metadata,
798
                "momentum_calibration",
799
                duplicate_policy="merge",
800
            )
801
            if preview:
1✔
802
                print(self._dataframe.head(10))
×
803
            else:
804
                print(self._dataframe)
1✔
805

806
    # Energy correction workflow
807
    # 1. Adjust the energy correction parameters
808
    def adjust_energy_correction(
1✔
809
        self,
810
        correction_type: str = None,
811
        amplitude: float = None,
812
        center: Tuple[float, float] = None,
813
        apply=False,
814
        **kwds,
815
    ):
816
        """1. step of the energy crrection workflow: Opens an interactive plot to
817
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
818
        they are not present yet.
819

820
        Args:
821
            correction_type (str, optional): Type of correction to apply to the TOF
822
                axis. Valid values are:
823

824
                - 'spherical'
825
                - 'Lorentzian'
826
                - 'Gaussian'
827
                - 'Lorentzian_asymmetric'
828

829
                Defaults to config["energy"]["correction_type"].
830
            amplitude (float, optional): Amplitude of the correction.
831
                Defaults to config["energy"]["correction"]["amplitude"].
832
            center (Tuple[float, float], optional): Center X/Y coordinates for the
833
                correction. Defaults to config["energy"]["correction"]["center"].
834
            apply (bool, optional): Option to directly apply the provided or default
835
                correction parameters. Defaults to False.
836
        """
837
        if self._pre_binned is None:
1✔
838
            print(
1✔
839
                "Pre-binned data not present, binning using defaults from config...",
840
            )
841
            self._pre_binned = self.pre_binning()
1✔
842

843
        self.ec.adjust_energy_correction(
1✔
844
            self._pre_binned,
845
            correction_type=correction_type,
846
            amplitude=amplitude,
847
            center=center,
848
            apply=apply,
849
            **kwds,
850
        )
851

852
    # 1a. Save energy correction parameters to config file.
853
    def save_energy_correction(
1✔
854
        self,
855
        filename: str = None,
856
        overwrite: bool = False,
857
    ):
858
        """Save the generated energy correction parameters to the folder config file.
859

860
        Args:
861
            filename (str, optional): Filename of the config dictionary to save to.
862
                Defaults to "sed_config.yaml" in the current folder.
863
            overwrite (bool, optional): Option to overwrite the present dictionary.
864
                Defaults to False.
865
        """
866
        if filename is None:
1✔
867
            filename = "sed_config.yaml"
1✔
868
        correction = {}
1✔
869
        try:
1✔
870
            for key, val in self.ec.correction.items():
1✔
871
                if key == "correction_type":
1✔
872
                    correction[key] = val
1✔
873
                elif key == "center":
1✔
874
                    correction[key] = [float(i) for i in val]
1✔
875
                else:
876
                    correction[key] = float(val)
1✔
877
        except AttributeError as exc:
×
878
            raise AttributeError(
×
879
                "Energy correction parameters not found, need to generate parameters first!",
880
            ) from exc
881

882
        config = {"energy": {"correction": correction}}
1✔
883
        save_config(config, filename, overwrite)
1✔
884

885
    # 2. Apply energy correction to dataframe
886
    def apply_energy_correction(
1✔
887
        self,
888
        correction: dict = None,
889
        preview: bool = False,
890
        **kwds,
891
    ):
892
        """2. step of the energy correction workflow: Apply the enery correction
893
        parameters stored in the class to the dataframe.
894

895
        Args:
896
            correction (dict, optional): Dictionary containing the correction
897
                parameters. Defaults to config["energy"]["calibration"].
898
            preview (bool): Option to preview the first elements of the data frame.
899
            **kwds:
900
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
901
            preview (bool): Option to preview the first elements of the data frame.
902
            **kwds:
903
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
904
        """
905
        if self._dataframe is not None:
1✔
906
            print("Applying energy correction to dataframe...")
1✔
907
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
908
                df=self._dataframe,
909
                correction=correction,
910
                **kwds,
911
            )
912
            if self._timed_dataframe is not None:
1✔
913
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
914
                    self._timed_dataframe, _ = self.ec.apply_energy_correction(
1✔
915
                        df=self._timed_dataframe,
916
                        correction=correction,
917
                        **kwds,
918
                    )
919

920
            # Add Metadata
921
            self._attributes.add(
1✔
922
                metadata,
923
                "energy_correction",
924
            )
925
            if preview:
1✔
926
                print(self._dataframe.head(10))
×
927
            else:
928
                print(self._dataframe)
1✔
929

930
    # Energy calibrator workflow
931
    # 1. Load and normalize data
932
    def load_bias_series(
1✔
933
        self,
934
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
935
        data_files: List[str] = None,
936
        axes: List[str] = None,
937
        bins: List = None,
938
        ranges: Sequence[Tuple[float, float]] = None,
939
        biases: np.ndarray = None,
940
        bias_key: str = None,
941
        normalize: bool = None,
942
        span: int = None,
943
        order: int = None,
944
    ):
945
        """1. step of the energy calibration workflow: Load and bin data from
946
        single-event files, or load binned bias/TOF traces.
947

948
        Args:
949
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
950
                Binned data If provided as DataArray, Needs to contain dimensions
951
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
952
                provided as tuple, needs to contain elements tof, biases, traces.
953
            data_files (List[str], optional): list of file paths to bin
954
            axes (List[str], optional): bin axes.
955
                Defaults to config["dataframe"]["tof_column"].
956
            bins (List, optional): number of bins.
957
                Defaults to config["energy"]["bins"].
958
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
959
                Defaults to config["energy"]["ranges"].
960
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
961
                voltages are extracted from the data files.
962
            bias_key (str, optional): hdf5 path where bias values are stored.
963
                Defaults to config["energy"]["bias_key"].
964
            normalize (bool, optional): Option to normalize traces.
965
                Defaults to config["energy"]["normalize"].
966
            span (int, optional): span smoothing parameters of the LOESS method
967
                (see ``scipy.signal.savgol_filter()``).
968
                Defaults to config["energy"]["normalize_span"].
969
            order (int, optional): order smoothing parameters of the LOESS method
970
                (see ``scipy.signal.savgol_filter()``).
971
                Defaults to config["energy"]["normalize_order"].
972
        """
973
        if binned_data is not None:
1✔
974
            if isinstance(binned_data, xr.DataArray):
1✔
975
                if (
1✔
976
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
977
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
978
                ):
979
                    raise ValueError(
1✔
980
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
981
                        f"'{self._config['dataframe']['tof_column']}' and "
982
                        f"'{self._config['dataframe']['bias_column']}'!.",
983
                    )
984
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
985
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
986
                traces = binned_data.values[:, :]
1✔
987
            else:
988
                try:
1✔
989
                    (tof, biases, traces) = binned_data
1✔
990
                except ValueError as exc:
1✔
991
                    raise ValueError(
1✔
992
                        "If binned_data is provided as tuple, it needs to contain "
993
                        "(tof, biases, traces)!",
994
                    ) from exc
995
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
996

997
        elif data_files is not None:
1✔
998

999
            self.ec.bin_data(
1✔
1000
                data_files=cast(List[str], self.cpy(data_files)),
1001
                axes=axes,
1002
                bins=bins,
1003
                ranges=ranges,
1004
                biases=biases,
1005
                bias_key=bias_key,
1006
            )
1007

1008
        else:
1009
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1010

1011
        if (normalize is not None and normalize is True) or (
1✔
1012
            normalize is None and self._config["energy"]["normalize"]
1013
        ):
1014
            if span is None:
1✔
1015
                span = self._config["energy"]["normalize_span"]
1✔
1016
            if order is None:
1✔
1017
                order = self._config["energy"]["normalize_order"]
1✔
1018
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1019
        self.ec.view(
1✔
1020
            traces=self.ec.traces_normed,
1021
            xaxis=self.ec.tof,
1022
            backend="bokeh",
1023
        )
1024

1025
    # 2. extract ranges and get peak positions
1026
    def find_bias_peaks(
1✔
1027
        self,
1028
        ranges: Union[List[Tuple], Tuple],
1029
        ref_id: int = 0,
1030
        infer_others: bool = True,
1031
        mode: str = "replace",
1032
        radius: int = None,
1033
        peak_window: int = None,
1034
        apply: bool = False,
1035
    ):
1036
        """2. step of the energy calibration workflow: Find a peak within a given range
1037
        for the indicated reference trace, and tries to find the same peak for all
1038
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1039
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1040
        middle of the set, and don't choose the range too narrow around the peak.
1041
        Alternatively, a list of ranges for all traces can be provided.
1042

1043
        Args:
1044
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
1045
                Alternatively, a list of ranges for all traces can be given.
1046
            refid (int, optional): The id of the trace the range refers to.
1047
                Defaults to 0.
1048
            infer_others (bool, optional): Whether to determine the range for the other
1049
                traces. Defaults to True.
1050
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1051
                Defaults to "replace".
1052
            radius (int, optional): Radius parameter for fast_dtw.
1053
                Defaults to config["energy"]["fastdtw_radius"].
1054
            peak_window (int, optional): Peak_window parameter for the peak detection
1055
                algorthm. amount of points that have to have to behave monotoneously
1056
                around a peak. Defaults to config["energy"]["peak_window"].
1057
            apply (bool, optional): Option to directly apply the provided parameters.
1058
                Defaults to False.
1059
        """
1060
        if radius is None:
1✔
1061
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1062
        if peak_window is None:
1✔
1063
            peak_window = self._config["energy"]["peak_window"]
1✔
1064
        if not infer_others:
1✔
1065
            self.ec.add_ranges(
1✔
1066
                ranges=ranges,
1067
                ref_id=ref_id,
1068
                infer_others=infer_others,
1069
                mode=mode,
1070
                radius=radius,
1071
            )
1072
            print(self.ec.featranges)
1✔
1073
            try:
1✔
1074
                self.ec.feature_extract(peak_window=peak_window)
1✔
1075
                self.ec.view(
1✔
1076
                    traces=self.ec.traces_normed,
1077
                    segs=self.ec.featranges,
1078
                    xaxis=self.ec.tof,
1079
                    peaks=self.ec.peaks,
1080
                    backend="bokeh",
1081
                )
1082
            except IndexError:
×
1083
                print("Could not determine all peaks!")
×
1084
                raise
×
1085
        else:
1086
            # New adjustment tool
1087
            assert isinstance(ranges, tuple)
1✔
1088
            self.ec.adjust_ranges(
1✔
1089
                ranges=ranges,
1090
                ref_id=ref_id,
1091
                traces=self.ec.traces_normed,
1092
                infer_others=infer_others,
1093
                radius=radius,
1094
                peak_window=peak_window,
1095
                apply=apply,
1096
            )
1097

1098
    # 3. Fit the energy calibration relation
1099
    def calibrate_energy_axis(
1✔
1100
        self,
1101
        ref_id: int,
1102
        ref_energy: float,
1103
        method: str = None,
1104
        energy_scale: str = None,
1105
        **kwds,
1106
    ):
1107
        """3. Step of the energy calibration workflow: Calculate the calibration
1108
        function for the energy axis, and apply it to the dataframe. Two
1109
        approximations are implemented, a (normally 3rd order) polynomial
1110
        approximation, and a d^2/(t-t0)^2 relation.
1111

1112
        Args:
1113
            ref_id (int): id of the trace at the bias where the reference energy is
1114
                given.
1115
            ref_energy (float): Absolute energy of the detected feature at the bias
1116
                of ref_id
1117
            method (str, optional): Method for determining the energy calibration.
1118

1119
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1120
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1121

1122
                Defaults to config["energy"]["calibration_method"]
1123
            energy_scale (str, optional): Direction of increasing energy scale.
1124

1125
                - **'kinetic'**: increasing energy with decreasing TOF.
1126
                - **'binding'**: increasing energy with increasing TOF.
1127

1128
                Defaults to config["energy"]["energy_scale"]
1129
        """
1130
        if method is None:
1✔
1131
            method = self._config["energy"]["calibration_method"]
1✔
1132

1133
        if energy_scale is None:
1✔
1134
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1135

1136
        self.ec.calibrate(
1✔
1137
            ref_id=ref_id,
1138
            ref_energy=ref_energy,
1139
            method=method,
1140
            energy_scale=energy_scale,
1141
            **kwds,
1142
        )
1143
        print("Quality of Calibration:")
1✔
1144
        self.ec.view(
1✔
1145
            traces=self.ec.traces_normed,
1146
            xaxis=self.ec.calibration["axis"],
1147
            align=True,
1148
            energy_scale=energy_scale,
1149
            backend="bokeh",
1150
        )
1151
        print("E/TOF relationship:")
1✔
1152
        self.ec.view(
1✔
1153
            traces=self.ec.calibration["axis"][None, :],
1154
            xaxis=self.ec.tof,
1155
            backend="matplotlib",
1156
            show_legend=False,
1157
        )
1158
        if energy_scale == "kinetic":
1✔
1159
            plt.scatter(
1✔
1160
                self.ec.peaks[:, 0],
1161
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1162
                s=50,
1163
                c="k",
1164
            )
1165
        elif energy_scale == "binding":
1✔
1166
            plt.scatter(
1✔
1167
                self.ec.peaks[:, 0],
1168
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1169
                s=50,
1170
                c="k",
1171
            )
1172
        else:
1173
            raise ValueError(
×
1174
                'energy_scale needs to be either "binding" or "kinetic"',
1175
                f", got {energy_scale}.",
1176
            )
1177
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1178
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1179
        plt.show()
1✔
1180

1181
    # 3a. Save energy calibration parameters to config file.
1182
    def save_energy_calibration(
1✔
1183
        self,
1184
        filename: str = None,
1185
        overwrite: bool = False,
1186
    ):
1187
        """Save the generated energy calibration parameters to the folder config file.
1188

1189
        Args:
1190
            filename (str, optional): Filename of the config dictionary to save to.
1191
                Defaults to "sed_config.yaml" in the current folder.
1192
            overwrite (bool, optional): Option to overwrite the present dictionary.
1193
                Defaults to False.
1194
        """
1195
        if filename is None:
1✔
1196
            filename = "sed_config.yaml"
×
1197
        calibration = {}
1✔
1198
        try:
1✔
1199
            for (key, value) in self.ec.calibration.items():
1✔
1200
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1201
                    continue
1✔
1202
                if key == "energy_scale":
1✔
1203
                    calibration[key] = value
1✔
1204
                elif key == "coeffs":
1✔
1205
                    calibration[key] = [float(i) for i in value]
1✔
1206
                else:
1207
                    calibration[key] = float(value)
1✔
1208
        except AttributeError as exc:
×
1209
            raise AttributeError(
×
1210
                "Energy calibration parameters not found, need to generate parameters first!",
1211
            ) from exc
1212

1213
        config = {"energy": {"calibration": calibration}}
1✔
1214
        save_config(config, filename, overwrite)
1✔
1215

1216
    # 4. Apply energy calibration to the dataframe
1217
    def append_energy_axis(
1✔
1218
        self,
1219
        calibration: dict = None,
1220
        preview: bool = False,
1221
        **kwds,
1222
    ):
1223
        """4. step of the energy calibration workflow: Apply the calibration function
1224
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1225
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1226
        can be provided.
1227

1228
        Args:
1229
            calibration (dict, optional): Calibration dict containing calibration
1230
                parameters. Overrides calibration from class or config.
1231
                Defaults to None.
1232
            preview (bool): Option to preview the first elements of the data frame.
1233
            **kwds:
1234
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1235
        """
1236
        if self._dataframe is not None:
1✔
1237
            print("Adding energy column to dataframe:")
1✔
1238
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1239
                df=self._dataframe,
1240
                calibration=calibration,
1241
                **kwds,
1242
            )
1243
            if self._timed_dataframe is not None:
1✔
1244
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1245
                    self._timed_dataframe, _ = self.ec.append_energy_axis(
1✔
1246
                        df=self._timed_dataframe,
1247
                        calibration=calibration,
1248
                        **kwds,
1249
                    )
1250

1251
            # Add Metadata
1252
            self._attributes.add(
1✔
1253
                metadata,
1254
                "energy_calibration",
1255
                duplicate_policy="merge",
1256
            )
1257
            if preview:
1✔
1258
                print(self._dataframe.head(10))
1✔
1259
            else:
1260
                print(self._dataframe)
1✔
1261

1262
    def apply_energy_offset(
1✔
1263
        self,
1264
        constant: float = None,
1265
        columns: Union[str, Sequence[str]] = None,
1266
        signs: Union[int, Sequence[int]] = None,
1267
        reductions: Union[str, Sequence[str]] = None,
1268
        subtract_mean: Union[bool, Sequence[bool]] = None,
1269
    ) -> None:
1270
        """Shift the energy axis of the dataframe by a given amount.
1271

1272
        Args:
1273
            constant (float, optional): The constant to shift the energy axis by.
1274
            columns (Union[str, Sequence[str]]): The columns to shift.
1275
            signs (Union[int, Sequence[int]]): The sign of the shift.
1276
            reductions (str): The reduction to apply to the column. If "rolled" it searches for
1277
                columns with suffix "_rolled", e.g. "sampleBias_rolled", as those generated by the
1278
                ``SedProcessor.smooth_columns()`` function. Otherwise should be an available method
1279
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1280
                to the column to generate a single value for the whole dataset. If None, the shift
1281
                is applied per-dataframe-row. Defaults to None.
1282
            subtract_mean (bool): Whether to subtract the mean of the column before applying the
1283
                shift. Defaults to False.
1284
        Raises:
1285
            ValueError: If the energy column is not in the dataframe.
1286
        """
1287
        energy_column = self._config["dataframe"]["energy_column"]
×
1288
        if energy_column not in self._dataframe.columns:
×
1289
            raise ValueError(
×
1290
                f"Energy column {energy_column} not found in dataframe! "
1291
                "Run energy calibration first",
1292
            )
1293
        metadata = {}
×
1294
        self._dataframe, metadata = self.ec.apply_energy_offset(
×
1295
            df=self._dataframe,
1296
            constant=constant,
1297
            columns=columns,
1298
            energy_column=energy_column,
1299
            signs=signs,
1300
            reductions=reductions,
1301
            subtract_mean=subtract_mean,
1302
        )
1303
        if len(metadata) > 0:
×
1304
            self._attributes.add(
×
1305
                metadata,
1306
                "apply_energy_offset",
1307
                # TODO: allow only appending when no offset along this column(s) was applied
1308
                duplicate_policy="append",
1309
            )
1310

1311
    def append_tof_ns_axis(
1✔
1312
        self,
1313
        **kwargs,
1314
    ):
1315
        """Convert time-of-flight channel steps to nanoseconds.
1316

1317
        Args:
1318
            tof_ns_column (str, optional): Name of the generated column containing the
1319
                time-of-flight in nanosecond.
1320
                Defaults to config["dataframe"]["tof_ns_column"].
1321
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1322

1323
        """
1324
        if self._dataframe is not None:
×
1325
            print("Adding time-of-flight column in nanoseconds to dataframe:")
×
1326
            # TODO assert order of execution through metadata
1327

1328
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
×
1329
                df=self._dataframe,
1330
                **kwargs,
1331
            )
1332
            self._attributes.add(
×
1333
                metadata,
1334
                "tof_ns_conversion",
1335
                duplicate_policy="append",
1336
            )
1337

1338
    def align_dld_sectors(self, **kwargs):
1✔
1339
        """Align the 8s sectors of the HEXTOF endstation."""
1340
        if self._dataframe is not None:
×
1341
            print("Aligning 8s sectors of dataframe")
×
1342
            # TODO assert order of execution through metadata
1343
            self._dataframe, metadata = self.ec.align_dld_sectors(df=self._dataframe, **kwargs)
×
1344
            self._attributes.add(
×
1345
                metadata,
1346
                "dld_sector_alignment",
1347
                duplicate_policy="raise",
1348
            )
1349

1350
    # Delay calibration function
1351
    def calibrate_delay_axis(
1✔
1352
        self,
1353
        delay_range: Tuple[float, float] = None,
1354
        datafile: str = None,
1355
        preview: bool = False,
1356
        **kwds,
1357
    ):
1358
        """Append delay column to dataframe. Either provide delay ranges, or read
1359
        them from a file.
1360

1361
        Args:
1362
            delay_range (Tuple[float, float], optional): The scanned delay range in
1363
                picoseconds. Defaults to None.
1364
            datafile (str, optional): The file from which to read the delay ranges.
1365
                Defaults to None.
1366
            preview (bool): Option to preview the first elements of the data frame.
1367
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1368
        """
1369
        if self._dataframe is not None:
1✔
1370
            print("Adding delay column to dataframe:")
1✔
1371

1372
            if delay_range is not None:
1✔
1373
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1374
                    self._dataframe,
1375
                    delay_range=delay_range,
1376
                    **kwds,
1377
                )
1378
                if self._timed_dataframe is not None:
1✔
1379
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1380
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1381
                            self._timed_dataframe,
1382
                            delay_range=delay_range,
1383
                            **kwds,
1384
                        )
1385
            else:
1386
                if datafile is None:
1✔
1387
                    try:
1✔
1388
                        datafile = self._files[0]
1✔
1389
                    except IndexError:
×
1390
                        print(
×
1391
                            "No datafile available, specify either",
1392
                            " 'datafile' or 'delay_range'",
1393
                        )
1394
                        raise
×
1395

1396
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1397
                    self._dataframe,
1398
                    datafile=datafile,
1399
                    **kwds,
1400
                )
1401
                if self._timed_dataframe is not None:
1✔
1402
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1403
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1404
                            self._timed_dataframe,
1405
                            datafile=datafile,
1406
                            **kwds,
1407
                        )
1408

1409
            # Add Metadata
1410
            self._attributes.add(
1✔
1411
                metadata,
1412
                "delay_calibration",
1413
                duplicate_policy="merge",
1414
            )
1415
            if preview:
1✔
1416
                print(self._dataframe.head(10))
1✔
1417
            else:
1418
                print(self._dataframe)
1✔
1419

1420
    def add_jitter(
1✔
1421
        self,
1422
        cols: List[str] = None,
1423
        amps: Union[float, Sequence[float]] = None,
1424
        **kwds,
1425
    ):
1426
        """Add jitter to the selected dataframe columns.
1427

1428
        Args:
1429
            cols (List[str], optional): The colums onto which to apply jitter.
1430
                Defaults to config["dataframe"]["jitter_cols"].
1431
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1432
                jittering noise. If one number is given, the same is used for all axes.
1433
                For uniform noise (default) it will cover the interval [-amp, +amp].
1434
                Defaults to config["dataframe"]["jitter_amps"].
1435
            **kwds: additional keyword arguments passed to apply_jitter
1436
        """
1437
        if cols is None:
1✔
1438
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1439
        for loc, col in enumerate(cols):
1✔
1440
            if col.startswith("@"):
1✔
1441
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1442

1443
        if amps is None:
1✔
1444
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1445

1446
        self._dataframe = self._dataframe.map_partitions(
1✔
1447
            apply_jitter,
1448
            cols=cols,
1449
            cols_jittered=cols,
1450
            amps=amps,
1451
            **kwds,
1452
        )
1453
        if self._timed_dataframe is not None:
1✔
1454
            cols_timed = cols.copy()
1✔
1455
            for col in cols:
1✔
1456
                if col not in self._timed_dataframe.columns:
1✔
1457
                    cols_timed.remove(col)
×
1458

1459
            if cols_timed:
1✔
1460
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1461
                    apply_jitter,
1462
                    cols=cols_timed,
1463
                    cols_jittered=cols_timed,
1464
                )
1465
        metadata = []
1✔
1466
        for col in cols:
1✔
1467
            metadata.append(col)
1✔
1468
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1469

1470
    def smooth_columns(
1✔
1471
        self,
1472
        columns: Union[str, Sequence[str]] = None,
1473
        method: Literal["rolling"] = "rolling",
1474
        **kwargs,
1475
    ) -> None:
1476
        """Apply a filter along one or more columns of the dataframe.
1477

1478
        Currently only supports rolling average on acquisition time.
1479

1480
        Args:
1481
            columns (Union[str,Sequence[str]]): The colums onto which to apply the filter.
1482
            method (Literal['rolling'], optional): The filter method. Defaults to 'rolling'.
1483
            **kwargs: Keyword arguments passed to the filter method.
1484
        """
1485
        if isinstance(columns, str):
×
1486
            columns = [columns]
×
1487
        for column in columns:
×
1488
            if column not in self._dataframe.columns:
×
1489
                raise ValueError(f"Cannot smooth {column}. Column not in dataframe!")
×
1490
        kwargs = {**self._config["smooth"], **kwargs}
×
1491
        if method == "rolling":
×
1492
            self._dataframe = rolling_average_on_acquisition_time(
×
1493
                df=self._dataframe,
1494
                rolling_group_channel=kwargs.get("rolling_group_channel", None),
1495
                columns=columns or kwargs.get("columns", None),
1496
                window=kwargs.get("window", None),
1497
                sigma=kwargs.get("sigma", None),
1498
            )
1499
        else:
1500
            raise ValueError(f"Method {method} not supported!")
×
1501
        self._attributes.add(
×
1502
            columns,
1503
            "smooth",
1504
            duplicate_policy="append",
1505
        )
1506

1507
    def pre_binning(
1✔
1508
        self,
1509
        df_partitions: int = 100,
1510
        axes: List[str] = None,
1511
        bins: List[int] = None,
1512
        ranges: Sequence[Tuple[float, float]] = None,
1513
        **kwds,
1514
    ) -> xr.DataArray:
1515
        """Function to do an initial binning of the dataframe loaded to the class.
1516

1517
        Args:
1518
            df_partitions (int, optional): Number of dataframe partitions to use for
1519
                the initial binning. Defaults to 100.
1520
            axes (List[str], optional): Axes to bin.
1521
                Defaults to config["momentum"]["axes"].
1522
            bins (List[int], optional): Bin numbers to use for binning.
1523
                Defaults to config["momentum"]["bins"].
1524
            ranges (List[Tuple], optional): Ranges to use for binning.
1525
                Defaults to config["momentum"]["ranges"].
1526
            **kwds: Keyword argument passed to ``compute``.
1527

1528
        Returns:
1529
            xr.DataArray: pre-binned data-array.
1530
        """
1531
        if axes is None:
1✔
1532
            axes = self._config["momentum"]["axes"]
1✔
1533
        for loc, axis in enumerate(axes):
1✔
1534
            if axis.startswith("@"):
1✔
1535
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1536

1537
        if bins is None:
1✔
1538
            bins = self._config["momentum"]["bins"]
1✔
1539
        if ranges is None:
1✔
1540
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1541
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1542
                self._config["dataframe"]["tof_binning"] - 1
1543
            )
1544
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1545

1546
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1547

1548
        return self.compute(
1✔
1549
            bins=bins,
1550
            axes=axes,
1551
            ranges=ranges,
1552
            df_partitions=df_partitions,
1553
            **kwds,
1554
        )
1555

1556
    def compute(
1✔
1557
        self,
1558
        bins: Union[
1559
            int,
1560
            dict,
1561
            tuple,
1562
            List[int],
1563
            List[np.ndarray],
1564
            List[tuple],
1565
        ] = 100,
1566
        axes: Union[str, Sequence[str]] = None,
1567
        ranges: Sequence[Tuple[float, float]] = None,
1568
        normalize_to_acquisition_time: Union[bool, str] = False,
1569
        **kwds,
1570
    ) -> xr.DataArray:
1571
        """Compute the histogram along the given dimensions.
1572

1573
        Args:
1574
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1575
                Definition of the bins. Can be any of the following cases:
1576

1577
                - an integer describing the number of bins in on all dimensions
1578
                - a tuple of 3 numbers describing start, end and step of the binning
1579
                  range
1580
                - a np.arrays defining the binning edges
1581
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1582
                - a dictionary made of the axes as keys and any of the above as values.
1583

1584
                This takes priority over the axes and range arguments. Defaults to 100.
1585
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1586
                on which to calculate the histogram. The order will be the order of the
1587
                dimensions in the resulting array. Defaults to None.
1588
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1589
                the start and end point of the binning range. Defaults to None.
1590
            normalize_to_acquisition_time (Union[bool, str]): Option to normalize the
1591
                result to the acquistion time. If a "slow" axis was scanned, providing
1592
                the name of the scanned axis will compute and apply the corresponding
1593
                normalization histogram. Defaults to False.
1594
            **kwds: Keyword arguments:
1595

1596
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1597
                  ``bin_dataframe`` for details. Defaults to
1598
                  config["binning"]["hist_mode"].
1599
                - **mode**: Defines how the results from each partition are combined.
1600
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1601
                  Defaults to config["binning"]["mode"].
1602
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1603
                  config["binning"]["pbar"].
1604
                - **n_cores**: Number of CPU cores to use for parallelization.
1605
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1606
                - **threads_per_worker**: Limit the number of threads that
1607
                  multiprocessing can spawn per binning thread. Defaults to
1608
                  config["binning"]["threads_per_worker"].
1609
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1610
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1611
                  config["binning"]["threadpool_API"].
1612
                - **df_partitions**: A range or list of dataframe partitions, or the
1613
                  number of the dataframe partitions to use. Defaults to all partitions.
1614

1615
                Additional kwds are passed to ``bin_dataframe``.
1616

1617
        Raises:
1618
            AssertError: Rises when no dataframe has been loaded.
1619

1620
        Returns:
1621
            xr.DataArray: The result of the n-dimensional binning represented in an
1622
            xarray object, combining the data with the axes.
1623
        """
1624
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1625

1626
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1627
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1628
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1629
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1630
        threads_per_worker = kwds.pop(
1✔
1631
            "threads_per_worker",
1632
            self._config["binning"]["threads_per_worker"],
1633
        )
1634
        threadpool_api = kwds.pop(
1✔
1635
            "threadpool_API",
1636
            self._config["binning"]["threadpool_API"],
1637
        )
1638
        df_partitions = kwds.pop("df_partitions", None)
1✔
1639
        if isinstance(df_partitions, int):
1✔
1640
            df_partitions = slice(
1✔
1641
                0,
1642
                min(df_partitions, self._dataframe.npartitions),
1643
            )
1644
        if df_partitions is not None:
1✔
1645
            dataframe = self._dataframe.partitions[df_partitions]
1✔
1646
        else:
1647
            dataframe = self._dataframe
1✔
1648

1649
        self._binned = bin_dataframe(
1✔
1650
            df=dataframe,
1651
            bins=bins,
1652
            axes=axes,
1653
            ranges=ranges,
1654
            hist_mode=hist_mode,
1655
            mode=mode,
1656
            pbar=pbar,
1657
            n_cores=num_cores,
1658
            threads_per_worker=threads_per_worker,
1659
            threadpool_api=threadpool_api,
1660
            **kwds,
1661
        )
1662

1663
        for dim in self._binned.dims:
1✔
1664
            try:
1✔
1665
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1666
            except KeyError:
1✔
1667
                pass
1✔
1668

1669
        self._binned.attrs["units"] = "counts"
1✔
1670
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1671
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1672

1673
        if normalize_to_acquisition_time:
1✔
1674
            if isinstance(normalize_to_acquisition_time, str):
1✔
1675
                axis = normalize_to_acquisition_time
1✔
1676
                print(
1✔
1677
                    f"Calculate normalization histogram for axis '{axis}'...",
1678
                )
1679
                self._normalization_histogram = self.get_normalization_histogram(
1✔
1680
                    axis=axis,
1681
                    df_partitions=df_partitions,
1682
                )
1683
                # if the axes are named correctly, xarray figures out the normalization correctly
1684
                self._normalized = self._binned / self._normalization_histogram
1✔
1685
                self._attributes.add(
1✔
1686
                    self._normalization_histogram.values,
1687
                    name="normalization_histogram",
1688
                    duplicate_policy="overwrite",
1689
                )
1690
            else:
1691
                acquisition_time = self.loader.get_elapsed_time(
×
1692
                    fids=df_partitions,
1693
                )
1694
                if acquisition_time > 0:
×
1695
                    self._normalized = self._binned / acquisition_time
×
1696
                self._attributes.add(
×
1697
                    acquisition_time,
1698
                    name="normalization_histogram",
1699
                    duplicate_policy="overwrite",
1700
                )
1701

1702
            self._normalized.attrs["units"] = "counts/second"
1✔
1703
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
1704
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
1705

1706
            return self._normalized
1✔
1707

1708
        return self._binned
1✔
1709

1710
    def get_normalization_histogram(
1✔
1711
        self,
1712
        axis: str = "delay",
1713
        use_time_stamps: bool = False,
1714
        **kwds,
1715
    ) -> xr.DataArray:
1716
        """Generates a normalization histogram from the timed dataframe. Optionally,
1717
        use the TimeStamps column instead.
1718

1719
        Args:
1720
            axis (str, optional): The axis for which to compute histogram.
1721
                Defaults to "delay".
1722
            use_time_stamps (bool, optional): Use the TimeStamps column of the
1723
                dataframe, rather than the timed dataframe. Defaults to False.
1724
            **kwds: Keyword arguments:
1725

1726
                -df_partitions (int, optional): Number of dataframe partitions to use.
1727
                  Defaults to all.
1728

1729
        Raises:
1730
            ValueError: Raised if no data are binned.
1731
            ValueError: Raised if 'axis' not in binned coordinates.
1732
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
1733
                in Dataframe.
1734

1735
        Returns:
1736
            xr.DataArray: The computed normalization histogram (in TimeStamp units
1737
            per bin).
1738
        """
1739

1740
        if self._binned is None:
1✔
1741
            raise ValueError("Need to bin data first!")
1✔
1742
        if axis not in self._binned.coords:
1✔
1743
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
1744

1745
        df_partitions: Union[int, slice] = kwds.pop("df_partitions", None)
1✔
1746
        if isinstance(df_partitions, int):
1✔
1747
            df_partitions = slice(
1✔
1748
                0,
1749
                min(df_partitions, self._dataframe.npartitions),
1750
            )
1751

1752
        if use_time_stamps or self._timed_dataframe is None:
1✔
1753
            if df_partitions is not None:
1✔
1754
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
1755
                    self._dataframe.partitions[df_partitions],
1756
                    axis,
1757
                    self._binned.coords[axis].values,
1758
                    self._config["dataframe"]["time_stamp_alias"],
1759
                )
1760
            else:
1761
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
1762
                    self._dataframe,
1763
                    axis,
1764
                    self._binned.coords[axis].values,
1765
                    self._config["dataframe"]["time_stamp_alias"],
1766
                )
1767
        else:
1768
            if df_partitions is not None:
1✔
1769
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
1770
                    self._timed_dataframe.partitions[df_partitions],
1771
                    axis,
1772
                    self._binned.coords[axis].values,
1773
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1774
                )
1775
            else:
1776
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
1777
                    self._timed_dataframe,
1778
                    axis,
1779
                    self._binned.coords[axis].values,
1780
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1781
                )
1782

1783
        return self._normalization_histogram
1✔
1784

1785
    def view_event_histogram(
1✔
1786
        self,
1787
        dfpid: int,
1788
        ncol: int = 2,
1789
        bins: Sequence[int] = None,
1790
        axes: Sequence[str] = None,
1791
        ranges: Sequence[Tuple[float, float]] = None,
1792
        backend: str = "bokeh",
1793
        legend: bool = True,
1794
        histkwds: dict = None,
1795
        legkwds: dict = None,
1796
        **kwds,
1797
    ):
1798
        """Plot individual histograms of specified dimensions (axes) from a substituent
1799
        dataframe partition.
1800

1801
        Args:
1802
            dfpid (int): Number of the data frame partition to look at.
1803
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1804
            bins (Sequence[int], optional): Number of bins to use for the speicified
1805
                axes. Defaults to config["histogram"]["bins"].
1806
            axes (Sequence[str], optional): Names of the axes to display.
1807
                Defaults to config["histogram"]["axes"].
1808
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1809
                specified axes. Defaults toconfig["histogram"]["ranges"].
1810
            backend (str, optional): Backend of the plotting library
1811
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1812
            legend (bool, optional): Option to include a legend in the histogram plots.
1813
                Defaults to True.
1814
            histkwds (dict, optional): Keyword arguments for histograms
1815
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1816
            legkwds (dict, optional): Keyword arguments for legend
1817
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1818
            **kwds: Extra keyword arguments passed to
1819
                ``sed.diagnostics.grid_histogram()``.
1820

1821
        Raises:
1822
            TypeError: Raises when the input values are not of the correct type.
1823
        """
1824
        if bins is None:
1✔
1825
            bins = self._config["histogram"]["bins"]
1✔
1826
        if axes is None:
1✔
1827
            axes = self._config["histogram"]["axes"]
1✔
1828
        axes = list(axes)
1✔
1829
        for loc, axis in enumerate(axes):
1✔
1830
            if axis.startswith("@"):
1✔
1831
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1832
        if ranges is None:
1✔
1833
            ranges = list(self._config["histogram"]["ranges"])
1✔
1834
            for loc, axis in enumerate(axes):
1✔
1835
                if axis == self._config["dataframe"]["tof_column"]:
1✔
1836
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
1837
                        self._config["dataframe"]["tof_binning"] - 1
1838
                    )
1839
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
1840
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1841
                        self._config["dataframe"]["adc_binning"] - 1
1842
                    )
1843

1844
        input_types = map(type, [axes, bins, ranges])
1✔
1845
        allowed_types = [list, tuple]
1✔
1846

1847
        df = self._dataframe
1✔
1848

1849
        if not set(input_types).issubset(allowed_types):
1✔
1850
            raise TypeError(
×
1851
                "Inputs of axes, bins, ranges need to be list or tuple!",
1852
            )
1853

1854
        # Read out the values for the specified groups
1855
        group_dict_dd = {}
1✔
1856
        dfpart = df.get_partition(dfpid)
1✔
1857
        cols = dfpart.columns
1✔
1858
        for ax in axes:
1✔
1859
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
1860
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
1861

1862
        # Plot multiple histograms in a grid
1863
        grid_histogram(
1✔
1864
            group_dict,
1865
            ncol=ncol,
1866
            rvs=axes,
1867
            rvbins=bins,
1868
            rvranges=ranges,
1869
            backend=backend,
1870
            legend=legend,
1871
            histkwds=histkwds,
1872
            legkwds=legkwds,
1873
            **kwds,
1874
        )
1875

1876
    def save(
1✔
1877
        self,
1878
        faddr: str,
1879
        **kwds,
1880
    ):
1881
        """Saves the binned data to the provided path and filename.
1882

1883
        Args:
1884
            faddr (str): Path and name of the file to write. Its extension determines
1885
                the file type to write. Valid file types are:
1886

1887
                - "*.tiff", "*.tif": Saves a TIFF stack.
1888
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1889
                - "*.nxs", "*.nexus": Saves a NeXus file.
1890

1891
            **kwds: Keyword argumens, which are passed to the writer functions:
1892
                For TIFF writing:
1893

1894
                - **alias_dict**: Dictionary of dimension aliases to use.
1895

1896
                For HDF5 writing:
1897

1898
                - **mode**: hdf5 read/write mode. Defaults to "w".
1899

1900
                For NeXus:
1901

1902
                - **reader**: Name of the nexustools reader to use.
1903
                  Defaults to config["nexus"]["reader"]
1904
                - **definiton**: NeXus application definition to use for saving.
1905
                  Must be supported by the used ``reader``. Defaults to
1906
                  config["nexus"]["definition"]
1907
                - **input_files**: A list of input files to pass to the reader.
1908
                  Defaults to config["nexus"]["input_files"]
1909
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
1910
                  to add to the list of files to pass to the reader.
1911
        """
1912
        if self._binned is None:
1✔
1913
            raise NameError("Need to bin data first!")
1✔
1914

1915
        if self._normalized is not None:
1✔
1916
            data = self._normalized
×
1917
        else:
1918
            data = self._binned
1✔
1919

1920
        extension = pathlib.Path(faddr).suffix
1✔
1921

1922
        if extension in (".tif", ".tiff"):
1✔
1923
            to_tiff(
1✔
1924
                data=data,
1925
                faddr=faddr,
1926
                **kwds,
1927
            )
1928
        elif extension in (".h5", ".hdf5"):
1✔
1929
            to_h5(
1✔
1930
                data=data,
1931
                faddr=faddr,
1932
                **kwds,
1933
            )
1934
        elif extension in (".nxs", ".nexus"):
1✔
1935
            try:
1✔
1936
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
1937
                definition = kwds.pop(
1✔
1938
                    "definition",
1939
                    self._config["nexus"]["definition"],
1940
                )
1941
                input_files = kwds.pop(
1✔
1942
                    "input_files",
1943
                    self._config["nexus"]["input_files"],
1944
                )
1945
            except KeyError as exc:
×
1946
                raise ValueError(
×
1947
                    "The nexus reader, definition and input files need to be provide!",
1948
                ) from exc
1949

1950
            if isinstance(input_files, str):
1✔
1951
                input_files = [input_files]
1✔
1952

1953
            if "eln_data" in kwds:
1✔
1954
                input_files.append(kwds.pop("eln_data"))
×
1955

1956
            to_nexus(
1✔
1957
                data=data,
1958
                faddr=faddr,
1959
                reader=reader,
1960
                definition=definition,
1961
                input_files=input_files,
1962
                **kwds,
1963
            )
1964

1965
        else:
1966
            raise NotImplementedError(
1✔
1967
                f"Unrecognized file format: {extension}.",
1968
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc