• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6665624252

27 Oct 2023 09:16AM UTC coverage: 87.404%. First build
6665624252

push

github

steinnymir
Merge branch 'hextof_workflow_steps' into hist_testing

415 of 415 new or added lines in 11 files covered. (100.0%)

4573 of 5232 relevant lines covered (87.4%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.06
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Literal
1✔
10
from typing import Sequence
1✔
11
from typing import Tuple
1✔
12
from typing import Union
1✔
13

14
import dask.dataframe as ddf
1✔
15
import matplotlib.pyplot as plt
1✔
16
import numpy as np
1✔
17
import pandas as pd
1✔
18
import psutil
1✔
19
import xarray as xr
1✔
20

21
from sed.binning import bin_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
23
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
24
from sed.calibrator import DelayCalibrator
1✔
25
from sed.calibrator import EnergyCalibrator
1✔
26
from sed.calibrator import MomentumCorrector
1✔
27
from sed.core.config import parse_config
1✔
28
from sed.core.config import save_config
1✔
29
from sed.core.dfops import apply_jitter
1✔
30
from sed.core.dfops import rolling_average_on_acquisition_time
1✔
31
from sed.core.metadata import MetaHandler
1✔
32
from sed.diagnostics import grid_histogram
1✔
33
from sed.io import to_h5
1✔
34
from sed.io import to_nexus
1✔
35
from sed.io import to_tiff
1✔
36
from sed.loader import CopyTool
1✔
37
from sed.loader import get_loader
1✔
38

39
N_CPU = psutil.cpu_count()
1✔
40

41

42
class SedProcessor:
1✔
43
    """Processor class of sed. Contains wrapper functions defining a work flow for data
44
    correction, calibration and binning.
45

46
    Args:
47
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
48
        config (Union[dict, str], optional): Config dictionary or config file name.
49
            Defaults to None.
50
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
51
            into the class. Defaults to None.
52
        files (List[str], optional): List of files to pass to the loader defined in
53
            the config. Defaults to None.
54
        folder (str, optional): Folder containing files to pass to the loader
55
            defined in the config. Defaults to None.
56
        collect_metadata (bool): Option to collect metadata from files.
57
            Defaults to False.
58
        **kwds: Keyword arguments passed to the reader.
59
    """
60

61
    def __init__(
1✔
62
        self,
63
        metadata: dict = None,
64
        config: Union[dict, str] = None,
65
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
66
        files: List[str] = None,
67
        folder: str = None,
68
        runs: Sequence[str] = None,
69
        collect_metadata: bool = False,
70
        **kwds,
71
    ):
72
        """Processor class of sed. Contains wrapper functions defining a work flow
73
        for data correction, calibration, and binning.
74

75
        Args:
76
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
77
            config (Union[dict, str], optional): Config dictionary or config file name.
78
                Defaults to None.
79
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
80
                into the class. Defaults to None.
81
            files (List[str], optional): List of files to pass to the loader defined in
82
                the config. Defaults to None.
83
            folder (str, optional): Folder containing files to pass to the loader
84
                defined in the config. Defaults to None.
85
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
86
                defined in the config. Defaults to None.
87
            collect_metadata (bool): Option to collect metadata from files.
88
                Defaults to False.
89
            **kwds: Keyword arguments passed to parse_config and to the reader.
90
        """
91
        config_kwds = {
1✔
92
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
93
        }
94
        for key in config_kwds.keys():
1✔
95
            del kwds[key]
1✔
96
        self._config = parse_config(config, **config_kwds)
1✔
97
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
98
        if num_cores >= N_CPU:
1✔
99
            num_cores = N_CPU - 1
1✔
100
        self._config["binning"]["num_cores"] = num_cores
1✔
101

102
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
103
        self._timed_dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
104
        self._files: List[str] = []
1✔
105

106
        self._binned: xr.DataArray = None
1✔
107
        self._pre_binned: xr.DataArray = None
1✔
108
        self._normalization_histogram: xr.DataArray = None
1✔
109
        self._normalized: xr.DataArray = None
1✔
110

111
        self._attributes = MetaHandler(meta=metadata)
1✔
112

113
        loader_name = self._config["core"]["loader"]
1✔
114
        self.loader = get_loader(
1✔
115
            loader_name=loader_name,
116
            config=self._config,
117
        )
118

119
        self.ec = EnergyCalibrator(
1✔
120
            loader=self.loader,
121
            config=self._config,
122
        )
123

124
        self.mc = MomentumCorrector(
1✔
125
            config=self._config,
126
        )
127

128
        self.dc = DelayCalibrator(
1✔
129
            config=self._config,
130
        )
131

132
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
133
            "use_copy_tool",
134
            False,
135
        )
136
        if self.use_copy_tool:
1✔
137
            try:
1✔
138
                self.ct = CopyTool(
1✔
139
                    source=self._config["core"]["copy_tool_source"],
140
                    dest=self._config["core"]["copy_tool_dest"],
141
                    **self._config["core"].get("copy_tool_kwds", {}),
142
                )
143
            except KeyError:
1✔
144
                self.use_copy_tool = False
1✔
145

146
        # Load data if provided:
147
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
148
            self.load(
1✔
149
                dataframe=dataframe,
150
                metadata=metadata,
151
                files=files,
152
                folder=folder,
153
                runs=runs,
154
                collect_metadata=collect_metadata,
155
                **kwds,
156
            )
157

158
    def __repr__(self):
1✔
159
        if self._dataframe is None:
1✔
160
            df_str = "Data Frame: No Data loaded"
1✔
161
        else:
162
            df_str = self._dataframe.__repr__()
1✔
163
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
164
        pretty_str = df_str + "\n" + attributes_str
1✔
165
        return pretty_str
1✔
166

167
    @property
1✔
168
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
169
        """Accessor to the underlying dataframe.
170

171
        Returns:
172
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
173
        """
174
        return self._dataframe
1✔
175

176
    @dataframe.setter
1✔
177
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
178
        """Setter for the underlying dataframe.
179

180
        Args:
181
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
182
        """
183
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
184
            dataframe,
185
            self._dataframe.__class__,
186
        ):
187
            raise ValueError(
1✔
188
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
189
                "as the dataframe loaded into the SedProcessor!.\n"
190
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
191
            )
192
        self._dataframe = dataframe
1✔
193

194
    @property
1✔
195
    def timed_dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
196
        """Accessor to the underlying timed_dataframe.
197

198
        Returns:
199
            Union[pd.DataFrame, ddf.DataFrame]: Timed Dataframe object.
200
        """
201
        return self._timed_dataframe
×
202

203
    @timed_dataframe.setter
1✔
204
    def timed_dataframe(self, timed_dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
205
        """Setter for the underlying timed dataframe.
206

207
        Args:
208
            timed_dataframe (Union[pd.DataFrame, ddf.DataFrame]): The timed dataframe object to set
209
        """
210
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
211
            timed_dataframe,
212
            self._timed_dataframe.__class__,
213
        ):
214
            raise ValueError(
×
215
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
216
                "kind as the dataframe loaded into the SedProcessor!.\n"
217
                f"Loaded type: {self._timed_dataframe.__class__}, "
218
                f"provided type: {timed_dataframe}.",
219
            )
220
        self._timed_dataframe = timed_dataframe
×
221

222
    @property
1✔
223
    def attributes(self) -> dict:
1✔
224
        """Accessor to the metadata dict.
225

226
        Returns:
227
            dict: The metadata dict.
228
        """
229
        return self._attributes.metadata
1✔
230

231
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
232
        """Function to add element to the attributes dict.
233

234
        Args:
235
            attributes (dict): The attributes dictionary object to add.
236
            name (str): Key under which to add the dictionary to the attributes.
237
        """
238
        self._attributes.add(
1✔
239
            entry=attributes,
240
            name=name,
241
            **kwds,
242
        )
243

244
    @property
1✔
245
    def config(self) -> Dict[Any, Any]:
1✔
246
        """Getter attribute for the config dictionary
247

248
        Returns:
249
            Dict: The config dictionary.
250
        """
251
        return self._config
1✔
252

253
    @property
1✔
254
    def files(self) -> List[str]:
1✔
255
        """Getter attribute for the list of files
256

257
        Returns:
258
            List[str]: The list of loaded files
259
        """
260
        return self._files
1✔
261

262
    @property
1✔
263
    def binned(self) -> xr.DataArray:
1✔
264
        """Getter attribute for the binned data array
265

266
        Returns:
267
            xr.DataArray: The binned data array
268
        """
269
        if self._binned is None:
1✔
270
            raise ValueError("No binned data available, need to compute histogram first!")
×
271
        return self._binned
1✔
272

273
    @property
1✔
274
    def normalized(self) -> xr.DataArray:
1✔
275
        """Getter attribute for the normalized data array
276

277
        Returns:
278
            xr.DataArray: The normalized data array
279
        """
280
        if self._normalized is None:
1✔
281
            raise ValueError(
×
282
                "No normalized data available, compute data with normalization enabled!",
283
            )
284
        return self._normalized
1✔
285

286
    @property
1✔
287
    def normalization_histogram(self) -> xr.DataArray:
1✔
288
        """Getter attribute for the normalization histogram
289

290
        Returns:
291
            xr.DataArray: The normalizazion histogram
292
        """
293
        if self._normalization_histogram is None:
1✔
294
            raise ValueError("No normalization histogram available, generate histogram first!")
×
295
        return self._normalization_histogram
1✔
296

297
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
298
        """Function to mirror a list of files or a folder from a network drive to a
299
        local storage. Returns either the original or the copied path to the given
300
        path. The option to use this functionality is set by
301
        config["core"]["use_copy_tool"].
302

303
        Args:
304
            path (Union[str, List[str]]): Source path or path list.
305

306
        Returns:
307
            Union[str, List[str]]: Source or destination path or path list.
308
        """
309
        if self.use_copy_tool:
1✔
310
            if isinstance(path, list):
1✔
311
                path_out = []
1✔
312
                for file in path:
1✔
313
                    path_out.append(self.ct.copy(file))
1✔
314
                return path_out
1✔
315

316
            return self.ct.copy(path)
×
317

318
        if isinstance(path, list):
1✔
319
            return path
1✔
320

321
        return path
1✔
322

323
    def load(
1✔
324
        self,
325
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
326
        metadata: dict = None,
327
        files: List[str] = None,
328
        folder: str = None,
329
        runs: Sequence[str] = None,
330
        collect_metadata: bool = False,
331
        **kwds,
332
    ):
333
        """Load tabular data of single events into the dataframe object in the class.
334

335
        Args:
336
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
337
                format. Accepts anything which can be interpreted by pd.DataFrame as
338
                an input. Defaults to None.
339
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
340
            files (List[str], optional): List of file paths to pass to the loader.
341
                Defaults to None.
342
            runs (Sequence[str], optional): List of run identifiers to pass to the
343
                loader. Defaults to None.
344
            folder (str, optional): Folder path to pass to the loader.
345
                Defaults to None.
346

347
        Raises:
348
            ValueError: Raised if no valid input is provided.
349
        """
350
        if metadata is None:
1✔
351
            metadata = {}
1✔
352
        if dataframe is not None:
1✔
353
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
354
        elif runs is not None:
1✔
355
            # If runs are provided, we only use the copy tool if also folder is provided.
356
            # In that case, we copy the whole provided base folder tree, and pass the copied
357
            # version to the loader as base folder to look for the runs.
358
            if folder is not None:
1✔
359
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
360
                    folders=cast(str, self.cpy(folder)),
361
                    runs=runs,
362
                    metadata=metadata,
363
                    collect_metadata=collect_metadata,
364
                    **kwds,
365
                )
366
            else:
367
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
368
                    runs=runs,
369
                    metadata=metadata,
370
                    collect_metadata=collect_metadata,
371
                    **kwds,
372
                )
373

374
        elif folder is not None:
1✔
375
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
376
                folders=cast(str, self.cpy(folder)),
377
                metadata=metadata,
378
                collect_metadata=collect_metadata,
379
                **kwds,
380
            )
381
        elif files is not None:
1✔
382
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
383
                files=cast(List[str], self.cpy(files)),
384
                metadata=metadata,
385
                collect_metadata=collect_metadata,
386
                **kwds,
387
            )
388
        else:
389
            raise ValueError(
1✔
390
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
391
            )
392

393
        self._dataframe = dataframe
1✔
394
        self._timed_dataframe = timed_dataframe
1✔
395
        self._files = self.loader.files
1✔
396

397
        for key in metadata:
1✔
398
            self._attributes.add(
1✔
399
                entry=metadata[key],
400
                name=key,
401
                duplicate_policy="merge",
402
            )
403

404
    # Momentum calibration workflow
405
    # 1. Bin raw detector data for distortion correction
406
    def bin_and_load_momentum_calibration(
1✔
407
        self,
408
        df_partitions: int = 100,
409
        axes: List[str] = None,
410
        bins: List[int] = None,
411
        ranges: Sequence[Tuple[float, float]] = None,
412
        plane: int = 0,
413
        width: int = 5,
414
        apply: bool = False,
415
        **kwds,
416
    ):
417
        """1st step of momentum correction work flow. Function to do an initial binning
418
        of the dataframe loaded to the class, slice a plane from it using an
419
        interactive view, and load it into the momentum corrector class.
420

421
        Args:
422
            df_partitions (int, optional): Number of dataframe partitions to use for
423
                the initial binning. Defaults to 100.
424
            axes (List[str], optional): Axes to bin.
425
                Defaults to config["momentum"]["axes"].
426
            bins (List[int], optional): Bin numbers to use for binning.
427
                Defaults to config["momentum"]["bins"].
428
            ranges (List[Tuple], optional): Ranges to use for binning.
429
                Defaults to config["momentum"]["ranges"].
430
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
431
            width (int, optional): Initial value for the width slider. Defaults to 5.
432
            apply (bool, optional): Option to directly apply the values and select the
433
                slice. Defaults to False.
434
            **kwds: Keyword argument passed to the pre_binning function.
435
        """
436
        self._pre_binned = self.pre_binning(
1✔
437
            df_partitions=df_partitions,
438
            axes=axes,
439
            bins=bins,
440
            ranges=ranges,
441
            **kwds,
442
        )
443

444
        self.mc.load_data(data=self._pre_binned)
1✔
445
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
446

447
    # 2. Generate the spline warp correction from momentum features.
448
    # Either autoselect features, or input features from view above.
449
    def define_features(
1✔
450
        self,
451
        features: np.ndarray = None,
452
        rotation_symmetry: int = 6,
453
        auto_detect: bool = False,
454
        include_center: bool = True,
455
        apply: bool = False,
456
        **kwds,
457
    ):
458
        """2. Step of the distortion correction workflow: Define feature points in
459
        momentum space. They can be either manually selected using a GUI tool, be
460
        ptovided as list of feature points, or auto-generated using a
461
        feature-detection algorithm.
462

463
        Args:
464
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
465
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
466
                Defaults to 6.
467
            auto_detect (bool, optional): Whether to auto-detect the features.
468
                Defaults to False.
469
            include_center (bool, optional): Option to include a point at the center
470
                in the feature list. Defaults to True.
471
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
472
                MomentumCorrector.feature_select()
473
        """
474
        if auto_detect:  # automatic feature selection
1✔
475
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
476
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
477
            sigma_radius = kwds.pop(
×
478
                "sigma_radius",
479
                self._config["momentum"]["sigma_radius"],
480
            )
481
            self.mc.feature_extract(
×
482
                sigma=sigma,
483
                fwhm=fwhm,
484
                sigma_radius=sigma_radius,
485
                rotsym=rotation_symmetry,
486
                **kwds,
487
            )
488
            features = self.mc.peaks
×
489

490
        self.mc.feature_select(
1✔
491
            rotsym=rotation_symmetry,
492
            include_center=include_center,
493
            features=features,
494
            apply=apply,
495
            **kwds,
496
        )
497

498
    # 3. Generate the spline warp correction from momentum features.
499
    # If no features have been selected before, use class defaults.
500
    def generate_splinewarp(
1✔
501
        self,
502
        use_center: bool = None,
503
        **kwds,
504
    ):
505
        """3. Step of the distortion correction workflow: Generate the correction
506
        function restoring the symmetry in the image using a splinewarp algortihm.
507

508
        Args:
509
            use_center (bool, optional): Option to use the position of the
510
                center point in the correction. Default is read from config, or set to True.
511
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
512
        """
513
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
514

515
        if self.mc.slice is not None:
1✔
516
            print("Original slice with reference features")
1✔
517
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
518

519
            print("Corrected slice with target features")
1✔
520
            self.mc.view(
1✔
521
                image=self.mc.slice_corrected,
522
                annotated=True,
523
                points={"feats": self.mc.ptargs},
524
                backend="bokeh",
525
                crosshair=True,
526
            )
527

528
            print("Original slice with target features")
1✔
529
            self.mc.view(
1✔
530
                image=self.mc.slice,
531
                points={"feats": self.mc.ptargs},
532
                annotated=True,
533
                backend="bokeh",
534
            )
535

536
    # 3a. Save spline-warp parameters to config file.
537
    def save_splinewarp(
1✔
538
        self,
539
        filename: str = None,
540
        overwrite: bool = False,
541
    ):
542
        """Save the generated spline-warp parameters to the folder config file.
543

544
        Args:
545
            filename (str, optional): Filename of the config dictionary to save to.
546
                Defaults to "sed_config.yaml" in the current folder.
547
            overwrite (bool, optional): Option to overwrite the present dictionary.
548
                Defaults to False.
549
        """
550
        if filename is None:
1✔
551
            filename = "sed_config.yaml"
×
552
        points = []
1✔
553
        try:
1✔
554
            for point in self.mc.pouter_ord:
1✔
555
                points.append([float(i) for i in point])
1✔
556
            if self.mc.include_center:
1✔
557
                points.append([float(i) for i in self.mc.pcent])
1✔
558
        except AttributeError as exc:
×
559
            raise AttributeError(
×
560
                "Momentum correction parameters not found, need to generate parameters first!",
561
            ) from exc
562
        config = {
1✔
563
            "momentum": {
564
                "correction": {
565
                    "rotation_symmetry": self.mc.rotsym,
566
                    "feature_points": points,
567
                    "include_center": self.mc.include_center,
568
                    "use_center": self.mc.use_center,
569
                },
570
            },
571
        }
572
        save_config(config, filename, overwrite)
1✔
573

574
    # 4. Pose corrections. Provide interactive interface for correcting
575
    # scaling, shift and rotation
576
    def pose_adjustment(
1✔
577
        self,
578
        scale: float = 1,
579
        xtrans: float = 0,
580
        ytrans: float = 0,
581
        angle: float = 0,
582
        apply: bool = False,
583
        use_correction: bool = True,
584
        reset: bool = True,
585
    ):
586
        """3. step of the distortion correction workflow: Generate an interactive panel
587
        to adjust affine transformations that are applied to the image. Applies first
588
        a scaling, next an x/y translation, and last a rotation around the center of
589
        the image.
590

591
        Args:
592
            scale (float, optional): Initial value of the scaling slider.
593
                Defaults to 1.
594
            xtrans (float, optional): Initial value of the xtrans slider.
595
                Defaults to 0.
596
            ytrans (float, optional): Initial value of the ytrans slider.
597
                Defaults to 0.
598
            angle (float, optional): Initial value of the angle slider.
599
                Defaults to 0.
600
            apply (bool, optional): Option to directly apply the provided
601
                transformations. Defaults to False.
602
            use_correction (bool, option): Whether to use the spline warp correction
603
                or not. Defaults to True.
604
            reset (bool, optional):
605
                Option to reset the correction before transformation. Defaults to True.
606
        """
607
        # Generate homomorphy as default if no distortion correction has been applied
608
        if self.mc.slice_corrected is None:
1✔
609
            if self.mc.slice is None:
1✔
610
                raise ValueError(
1✔
611
                    "No slice for corrections and transformations loaded!",
612
                )
613
            self.mc.slice_corrected = self.mc.slice
×
614

615
        if not use_correction:
1✔
616
            self.mc.reset_deformation()
1✔
617

618
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
619
            # Generate distortion correction from config values
620
            self.mc.add_features()
×
621
            self.mc.spline_warp_estimate()
×
622

623
        self.mc.pose_adjustment(
1✔
624
            scale=scale,
625
            xtrans=xtrans,
626
            ytrans=ytrans,
627
            angle=angle,
628
            apply=apply,
629
            reset=reset,
630
        )
631

632
    # 5. Apply the momentum correction to the dataframe
633
    def apply_momentum_correction(
1✔
634
        self,
635
        preview: bool = False,
636
    ):
637
        """Applies the distortion correction and pose adjustment (optional)
638
        to the dataframe.
639

640
        Args:
641
            rdeform_field (np.ndarray, optional): Row deformation field.
642
                Defaults to None.
643
            cdeform_field (np.ndarray, optional): Column deformation field.
644
                Defaults to None.
645
            inv_dfield (np.ndarray, optional): Inverse deformation field.
646
                Defaults to None.
647
            preview (bool): Option to preview the first elements of the data frame.
648
        """
649
        if self._dataframe is not None:
1✔
650
            print("Adding corrected X/Y columns to dataframe:")
1✔
651
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
652
                df=self._dataframe,
653
            )
654
            if self._timed_dataframe is not None:
1✔
655
                self._timed_dataframe, _ = self.mc.apply_corrections(
×
656
                    self._timed_dataframe,
657
                )
658
            # Add Metadata
659
            self._attributes.add(
1✔
660
                metadata,
661
                "momentum_correction",
662
                duplicate_policy="merge",
663
            )
664
            if preview:
1✔
665
                print(self._dataframe.head(10))
×
666
            else:
667
                print(self._dataframe)
1✔
668

669
    # Momentum calibration work flow
670
    # 1. Calculate momentum calibration
671
    def calibrate_momentum_axes(
1✔
672
        self,
673
        point_a: Union[np.ndarray, List[int]] = None,
674
        point_b: Union[np.ndarray, List[int]] = None,
675
        k_distance: float = None,
676
        k_coord_a: Union[np.ndarray, List[float]] = None,
677
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
678
        equiscale: bool = True,
679
        apply=False,
680
    ):
681
        """1. step of the momentum calibration workflow. Calibrate momentum
682
        axes using either provided pixel coordinates of a high-symmetry point and its
683
        distance to the BZ center, or the k-coordinates of two points in the BZ
684
        (depending on the equiscale option). Opens an interactive panel for selecting
685
        the points.
686

687
        Args:
688
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
689
                point used for momentum calibration.
690
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
691
                second point used for momentum calibration.
692
                Defaults to config["momentum"]["center_pixel"].
693
            k_distance (float, optional): Momentum distance between point a and b.
694
                Needs to be provided if no specific k-koordinates for the two points
695
                are given. Defaults to None.
696
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
697
                of the first point used for calibration. Used if equiscale is False.
698
                Defaults to None.
699
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
700
                of the second point used for calibration. Defaults to [0.0, 0.0].
701
            equiscale (bool, optional): Option to apply different scales to kx and ky.
702
                If True, the distance between points a and b, and the absolute
703
                position of point a are used for defining the scale. If False, the
704
                scale is calculated from the k-positions of both points a and b.
705
                Defaults to True.
706
            apply (bool, optional): Option to directly store the momentum calibration
707
                in the class. Defaults to False.
708
        """
709
        if point_b is None:
1✔
710
            point_b = self._config["momentum"]["center_pixel"]
1✔
711

712
        self.mc.select_k_range(
1✔
713
            point_a=point_a,
714
            point_b=point_b,
715
            k_distance=k_distance,
716
            k_coord_a=k_coord_a,
717
            k_coord_b=k_coord_b,
718
            equiscale=equiscale,
719
            apply=apply,
720
        )
721

722
    # 1a. Save momentum calibration parameters to config file.
723
    def save_momentum_calibration(
1✔
724
        self,
725
        filename: str = None,
726
        overwrite: bool = False,
727
    ):
728
        """Save the generated momentum calibration parameters to the folder config file.
729

730
        Args:
731
            filename (str, optional): Filename of the config dictionary to save to.
732
                Defaults to "sed_config.yaml" in the current folder.
733
            overwrite (bool, optional): Option to overwrite the present dictionary.
734
                Defaults to False.
735
        """
736
        if filename is None:
1✔
737
            filename = "sed_config.yaml"
×
738
        calibration = {}
1✔
739
        try:
1✔
740
            for key in [
1✔
741
                "kx_scale",
742
                "ky_scale",
743
                "x_center",
744
                "y_center",
745
                "rstart",
746
                "cstart",
747
                "rstep",
748
                "cstep",
749
            ]:
750
                calibration[key] = float(self.mc.calibration[key])
1✔
751
        except KeyError as exc:
×
752
            raise KeyError(
×
753
                "Momentum calibration parameters not found, need to generate parameters first!",
754
            ) from exc
755

756
        config = {"momentum": {"calibration": calibration}}
1✔
757
        save_config(config, filename, overwrite)
1✔
758

759
    # 2. Apply correction and calibration to the dataframe
760
    def apply_momentum_calibration(
1✔
761
        self,
762
        calibration: dict = None,
763
        preview: bool = False,
764
    ):
765
        """2. step of the momentum calibration work flow: Apply the momentum
766
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
767
        these are used.
768

769
        Args:
770
            calibration (dict, optional): Optional dictionary with calibration data to
771
                use. Defaults to None.
772
            preview (bool): Option to preview the first elements of the data frame.
773
        """
774
        if self._dataframe is not None:
1✔
775

776
            print("Adding kx/ky columns to dataframe:")
1✔
777
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
778
                df=self._dataframe,
779
                calibration=calibration,
780
            )
781
            if self._timed_dataframe is not None:
1✔
782
                self._timed_dataframe, _ = self.mc.append_k_axis(
1✔
783
                    df=self._timed_dataframe,
784
                    calibration=calibration,
785
                )
786

787
            # Add Metadata
788
            self._attributes.add(
1✔
789
                metadata,
790
                "momentum_calibration",
791
                duplicate_policy="merge",
792
            )
793
            if preview:
1✔
794
                print(self._dataframe.head(10))
×
795
            else:
796
                print(self._dataframe)
1✔
797

798
    # Energy correction workflow
799
    # 1. Adjust the energy correction parameters
800
    def adjust_energy_correction(
1✔
801
        self,
802
        correction_type: str = None,
803
        amplitude: float = None,
804
        center: Tuple[float, float] = None,
805
        apply=False,
806
        **kwds,
807
    ):
808
        """1. step of the energy crrection workflow: Opens an interactive plot to
809
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
810
        they are not present yet.
811

812
        Args:
813
            correction_type (str, optional): Type of correction to apply to the TOF
814
                axis. Valid values are:
815

816
                - 'spherical'
817
                - 'Lorentzian'
818
                - 'Gaussian'
819
                - 'Lorentzian_asymmetric'
820

821
                Defaults to config["energy"]["correction_type"].
822
            amplitude (float, optional): Amplitude of the correction.
823
                Defaults to config["energy"]["correction"]["amplitude"].
824
            center (Tuple[float, float], optional): Center X/Y coordinates for the
825
                correction. Defaults to config["energy"]["correction"]["center"].
826
            apply (bool, optional): Option to directly apply the provided or default
827
                correction parameters. Defaults to False.
828
        """
829
        if self._pre_binned is None:
1✔
830
            print(
1✔
831
                "Pre-binned data not present, binning using defaults from config...",
832
            )
833
            self._pre_binned = self.pre_binning()
1✔
834

835
        self.ec.adjust_energy_correction(
1✔
836
            self._pre_binned,
837
            correction_type=correction_type,
838
            amplitude=amplitude,
839
            center=center,
840
            apply=apply,
841
            **kwds,
842
        )
843

844
    # 1a. Save energy correction parameters to config file.
845
    def save_energy_correction(
1✔
846
        self,
847
        filename: str = None,
848
        overwrite: bool = False,
849
    ):
850
        """Save the generated energy correction parameters to the folder config file.
851

852
        Args:
853
            filename (str, optional): Filename of the config dictionary to save to.
854
                Defaults to "sed_config.yaml" in the current folder.
855
            overwrite (bool, optional): Option to overwrite the present dictionary.
856
                Defaults to False.
857
        """
858
        if filename is None:
1✔
859
            filename = "sed_config.yaml"
1✔
860
        correction = {}
1✔
861
        try:
1✔
862
            for key, val in self.ec.correction.items():
1✔
863
                if key == "correction_type":
1✔
864
                    correction[key] = val
1✔
865
                elif key == "center":
1✔
866
                    correction[key] = [float(i) for i in val]
1✔
867
                else:
868
                    correction[key] = float(val)
1✔
869
        except AttributeError as exc:
×
870
            raise AttributeError(
×
871
                "Energy correction parameters not found, need to generate parameters first!",
872
            ) from exc
873

874
        config = {"energy": {"correction": correction}}
1✔
875
        save_config(config, filename, overwrite)
1✔
876

877
    # 2. Apply energy correction to dataframe
878
    def apply_energy_correction(
1✔
879
        self,
880
        correction: dict = None,
881
        preview: bool = False,
882
        **kwds,
883
    ):
884
        """2. step of the energy correction workflow: Apply the enery correction
885
        parameters stored in the class to the dataframe.
886

887
        Args:
888
            correction (dict, optional): Dictionary containing the correction
889
                parameters. Defaults to config["energy"]["calibration"].
890
            preview (bool): Option to preview the first elements of the data frame.
891
            **kwds:
892
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
893
            preview (bool): Option to preview the first elements of the data frame.
894
            **kwds:
895
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
896
        """
897
        if self._dataframe is not None:
1✔
898
            print("Applying energy correction to dataframe...")
1✔
899
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
900
                df=self._dataframe,
901
                correction=correction,
902
                **kwds,
903
            )
904
            if self._timed_dataframe is not None:
1✔
905
                self._timed_dataframe, _ = self.ec.apply_energy_correction(
×
906
                    df=self._timed_dataframe,
907
                    correction=correction,
908
                    **kwds,
909
                )
910

911
            # Add Metadata
912
            self._attributes.add(
1✔
913
                metadata,
914
                "energy_correction",
915
            )
916
            if preview:
1✔
917
                print(self._dataframe.head(10))
×
918
            else:
919
                print(self._dataframe)
1✔
920

921
    # Energy calibrator workflow
922
    # 1. Load and normalize data
923
    def load_bias_series(
1✔
924
        self,
925
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
926
        data_files: List[str] = None,
927
        axes: List[str] = None,
928
        bins: List = None,
929
        ranges: Sequence[Tuple[float, float]] = None,
930
        biases: np.ndarray = None,
931
        bias_key: str = None,
932
        normalize: bool = None,
933
        span: int = None,
934
        order: int = None,
935
    ):
936
        """1. step of the energy calibration workflow: Load and bin data from
937
        single-event files, or load binned bias/TOF traces.
938

939
        Args:
940
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
941
                Binned data If provided as DataArray, Needs to contain dimensions
942
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
943
                provided as tuple, needs to contain elements tof, biases, traces.
944
            data_files (List[str], optional): list of file paths to bin
945
            axes (List[str], optional): bin axes.
946
                Defaults to config["dataframe"]["tof_column"].
947
            bins (List, optional): number of bins.
948
                Defaults to config["energy"]["bins"].
949
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
950
                Defaults to config["energy"]["ranges"].
951
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
952
                voltages are extracted from the data files.
953
            bias_key (str, optional): hdf5 path where bias values are stored.
954
                Defaults to config["energy"]["bias_key"].
955
            normalize (bool, optional): Option to normalize traces.
956
                Defaults to config["energy"]["normalize"].
957
            span (int, optional): span smoothing parameters of the LOESS method
958
                (see ``scipy.signal.savgol_filter()``).
959
                Defaults to config["energy"]["normalize_span"].
960
            order (int, optional): order smoothing parameters of the LOESS method
961
                (see ``scipy.signal.savgol_filter()``).
962
                Defaults to config["energy"]["normalize_order"].
963
        """
964
        if binned_data is not None:
1✔
965
            if isinstance(binned_data, xr.DataArray):
1✔
966
                if (
1✔
967
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
968
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
969
                ):
970
                    raise ValueError(
1✔
971
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
972
                        f"'{self._config['dataframe']['tof_column']}' and "
973
                        f"'{self._config['dataframe']['bias_column']}'!.",
974
                    )
975
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
976
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
977
                traces = binned_data.values[:, :]
1✔
978
            else:
979
                try:
1✔
980
                    (tof, biases, traces) = binned_data
1✔
981
                except ValueError as exc:
1✔
982
                    raise ValueError(
1✔
983
                        "If binned_data is provided as tuple, it needs to contain "
984
                        "(tof, biases, traces)!",
985
                    ) from exc
986
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
987

988
        elif data_files is not None:
1✔
989

990
            self.ec.bin_data(
1✔
991
                data_files=cast(List[str], self.cpy(data_files)),
992
                axes=axes,
993
                bins=bins,
994
                ranges=ranges,
995
                biases=biases,
996
                bias_key=bias_key,
997
            )
998

999
        else:
1000
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1001

1002
        if (normalize is not None and normalize is True) or (
1✔
1003
            normalize is None and self._config["energy"]["normalize"]
1004
        ):
1005
            if span is None:
1✔
1006
                span = self._config["energy"]["normalize_span"]
1✔
1007
            if order is None:
1✔
1008
                order = self._config["energy"]["normalize_order"]
1✔
1009
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1010
        self.ec.view(
1✔
1011
            traces=self.ec.traces_normed,
1012
            xaxis=self.ec.tof,
1013
            backend="bokeh",
1014
        )
1015

1016
    # 2. extract ranges and get peak positions
1017
    def find_bias_peaks(
1✔
1018
        self,
1019
        ranges: Union[List[Tuple], Tuple],
1020
        ref_id: int = 0,
1021
        infer_others: bool = True,
1022
        mode: str = "replace",
1023
        radius: int = None,
1024
        peak_window: int = None,
1025
        apply: bool = False,
1026
    ):
1027
        """2. step of the energy calibration workflow: Find a peak within a given range
1028
        for the indicated reference trace, and tries to find the same peak for all
1029
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1030
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1031
        middle of the set, and don't choose the range too narrow around the peak.
1032
        Alternatively, a list of ranges for all traces can be provided.
1033

1034
        Args:
1035
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
1036
                Alternatively, a list of ranges for all traces can be given.
1037
            refid (int, optional): The id of the trace the range refers to.
1038
                Defaults to 0.
1039
            infer_others (bool, optional): Whether to determine the range for the other
1040
                traces. Defaults to True.
1041
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1042
                Defaults to "replace".
1043
            radius (int, optional): Radius parameter for fast_dtw.
1044
                Defaults to config["energy"]["fastdtw_radius"].
1045
            peak_window (int, optional): Peak_window parameter for the peak detection
1046
                algorthm. amount of points that have to have to behave monotoneously
1047
                around a peak. Defaults to config["energy"]["peak_window"].
1048
            apply (bool, optional): Option to directly apply the provided parameters.
1049
                Defaults to False.
1050
        """
1051
        if radius is None:
1✔
1052
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1053
        if peak_window is None:
1✔
1054
            peak_window = self._config["energy"]["peak_window"]
1✔
1055
        if not infer_others:
1✔
1056
            self.ec.add_ranges(
1✔
1057
                ranges=ranges,
1058
                ref_id=ref_id,
1059
                infer_others=infer_others,
1060
                mode=mode,
1061
                radius=radius,
1062
            )
1063
            print(self.ec.featranges)
1✔
1064
            try:
1✔
1065
                self.ec.feature_extract(peak_window=peak_window)
1✔
1066
                self.ec.view(
1✔
1067
                    traces=self.ec.traces_normed,
1068
                    segs=self.ec.featranges,
1069
                    xaxis=self.ec.tof,
1070
                    peaks=self.ec.peaks,
1071
                    backend="bokeh",
1072
                )
1073
            except IndexError:
×
1074
                print("Could not determine all peaks!")
×
1075
                raise
×
1076
        else:
1077
            # New adjustment tool
1078
            assert isinstance(ranges, tuple)
1✔
1079
            self.ec.adjust_ranges(
1✔
1080
                ranges=ranges,
1081
                ref_id=ref_id,
1082
                traces=self.ec.traces_normed,
1083
                infer_others=infer_others,
1084
                radius=radius,
1085
                peak_window=peak_window,
1086
                apply=apply,
1087
            )
1088

1089
    # 3. Fit the energy calibration relation
1090
    def calibrate_energy_axis(
1✔
1091
        self,
1092
        ref_id: int,
1093
        ref_energy: float,
1094
        method: str = None,
1095
        energy_scale: str = None,
1096
        **kwds,
1097
    ):
1098
        """3. Step of the energy calibration workflow: Calculate the calibration
1099
        function for the energy axis, and apply it to the dataframe. Two
1100
        approximations are implemented, a (normally 3rd order) polynomial
1101
        approximation, and a d^2/(t-t0)^2 relation.
1102

1103
        Args:
1104
            ref_id (int): id of the trace at the bias where the reference energy is
1105
                given.
1106
            ref_energy (float): Absolute energy of the detected feature at the bias
1107
                of ref_id
1108
            method (str, optional): Method for determining the energy calibration.
1109

1110
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1111
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1112

1113
                Defaults to config["energy"]["calibration_method"]
1114
            energy_scale (str, optional): Direction of increasing energy scale.
1115

1116
                - **'kinetic'**: increasing energy with decreasing TOF.
1117
                - **'binding'**: increasing energy with increasing TOF.
1118

1119
                Defaults to config["energy"]["energy_scale"]
1120
        """
1121
        if method is None:
1✔
1122
            method = self._config["energy"]["calibration_method"]
1✔
1123

1124
        if energy_scale is None:
1✔
1125
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1126

1127
        self.ec.calibrate(
1✔
1128
            ref_id=ref_id,
1129
            ref_energy=ref_energy,
1130
            method=method,
1131
            energy_scale=energy_scale,
1132
            **kwds,
1133
        )
1134
        print("Quality of Calibration:")
1✔
1135
        self.ec.view(
1✔
1136
            traces=self.ec.traces_normed,
1137
            xaxis=self.ec.calibration["axis"],
1138
            align=True,
1139
            energy_scale=energy_scale,
1140
            backend="bokeh",
1141
        )
1142
        print("E/TOF relationship:")
1✔
1143
        self.ec.view(
1✔
1144
            traces=self.ec.calibration["axis"][None, :],
1145
            xaxis=self.ec.tof,
1146
            backend="matplotlib",
1147
            show_legend=False,
1148
        )
1149
        if energy_scale == "kinetic":
1✔
1150
            plt.scatter(
1✔
1151
                self.ec.peaks[:, 0],
1152
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1153
                s=50,
1154
                c="k",
1155
            )
1156
        elif energy_scale == "binding":
1✔
1157
            plt.scatter(
1✔
1158
                self.ec.peaks[:, 0],
1159
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1160
                s=50,
1161
                c="k",
1162
            )
1163
        else:
1164
            raise ValueError(
×
1165
                'energy_scale needs to be either "binding" or "kinetic"',
1166
                f", got {energy_scale}.",
1167
            )
1168
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1169
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1170
        plt.show()
1✔
1171

1172
    # 3a. Save energy calibration parameters to config file.
1173
    def save_energy_calibration(
1✔
1174
        self,
1175
        filename: str = None,
1176
        overwrite: bool = False,
1177
    ):
1178
        """Save the generated energy calibration parameters to the folder config file.
1179

1180
        Args:
1181
            filename (str, optional): Filename of the config dictionary to save to.
1182
                Defaults to "sed_config.yaml" in the current folder.
1183
            overwrite (bool, optional): Option to overwrite the present dictionary.
1184
                Defaults to False.
1185
        """
1186
        if filename is None:
1✔
1187
            filename = "sed_config.yaml"
×
1188
        calibration = {}
1✔
1189
        try:
1✔
1190
            for (key, value) in self.ec.calibration.items():
1✔
1191
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1192
                    continue
1✔
1193
                if key == "energy_scale":
1✔
1194
                    calibration[key] = value
1✔
1195
                elif key == "coeffs":
1✔
1196
                    calibration[key] = [float(i) for i in value]
1✔
1197
                else:
1198
                    calibration[key] = float(value)
1✔
1199
        except AttributeError as exc:
×
1200
            raise AttributeError(
×
1201
                "Energy calibration parameters not found, need to generate parameters first!",
1202
            ) from exc
1203

1204
        config = {"energy": {"calibration": calibration}}
1✔
1205
        save_config(config, filename, overwrite)
1✔
1206

1207
    # 4. Apply energy calibration to the dataframe
1208
    def append_energy_axis(
1✔
1209
        self,
1210
        calibration: dict = None,
1211
        preview: bool = False,
1212
        **kwds,
1213
    ):
1214
        """4. step of the energy calibration workflow: Apply the calibration function
1215
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1216
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1217
        can be provided.
1218

1219
        Args:
1220
            calibration (dict, optional): Calibration dict containing calibration
1221
                parameters. Overrides calibration from class or config.
1222
                Defaults to None.
1223
            preview (bool): Option to preview the first elements of the data frame.
1224
            **kwds:
1225
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1226
        """
1227
        if self._dataframe is not None:
1✔
1228
            print("Adding energy column to dataframe:")
1✔
1229
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1230
                df=self._dataframe,
1231
                calibration=calibration,
1232
                **kwds,
1233
            )
1234
            if self._timed_dataframe is not None:
1✔
1235
                self._timed_dataframe, _ = self.ec.append_energy_axis(
1✔
1236
                    df=self._timed_dataframe,
1237
                    calibration=calibration,
1238
                    **kwds,
1239
                )
1240

1241
            # Add Metadata
1242
            self._attributes.add(
1✔
1243
                metadata,
1244
                "energy_calibration",
1245
                duplicate_policy="merge",
1246
            )
1247
            if preview:
1✔
1248
                print(self._dataframe.head(10))
1✔
1249
            else:
1250
                print(self._dataframe)
1✔
1251

1252
    def apply_energy_offset(
1✔
1253
        self,
1254
        constant: float = None,
1255
        columns: Union[str, Sequence[str]] = None,
1256
        signs: Union[int, Sequence[int]] = None,
1257
        reductions: Union[str, Sequence[str]] = None,
1258
        subtract_mean: Union[bool, Sequence[bool]] = None,
1259
    ) -> None:
1260
        """Shift the energy axis of the dataframe by a given amount.
1261

1262
        Args:
1263
            constant (float, optional): The constant to shift the energy axis by.
1264
            columns (Union[str, Sequence[str]]): The columns to shift.
1265
            signs (Union[int, Sequence[int]]): The sign of the shift.
1266
            reductions (str): The reduction to apply to the column. If "rolled" it searches for
1267
                columns with suffix "_rolled", e.g. "sampleBias_rolled", as those generated by the
1268
                ``SedProcessor.smooth_columns()`` function. Otherwise should be an available method
1269
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1270
                to the column to generate a single value for the whole dataset. If None, the shift
1271
                is applied per-dataframe-row. Defaults to None.
1272
            subtract_mean (bool): Whether to subtract the mean of the column before applying the
1273
                shift. Defaults to False.
1274
        Raises:
1275
            ValueError: If the energy column is not in the dataframe.
1276
        """
1277
        energy_column = self._config["dataframe"]["energy_column"]
×
1278
        if energy_column not in self._dataframe.columns:
×
1279
            raise ValueError(
×
1280
                f"Energy column {energy_column} not found in dataframe! "
1281
                "Run energy calibration first",
1282
            )
1283
        metadata = {}
×
1284
        self._dataframe, metadata = self.ec.apply_energy_offset(
×
1285
            df=self._dataframe,
1286
            constant=constant,
1287
            columns=columns,
1288
            energy_column=energy_column,
1289
            signs=signs,
1290
            reductions=reductions,
1291
            subtract_mean=subtract_mean,
1292
        )
1293
        if len(metadata) > 0:
×
1294
            self._attributes.add(
×
1295
                metadata,
1296
                "apply_energy_offset",
1297
                # TODO: allow only appending when no offset along this column(s) was applied
1298
                duplicate_policy="append",
1299
            )
1300

1301
    def append_tof_ns_axis(
1✔
1302
        self,
1303
        **kwargs,
1304
    ):
1305
        """Convert time-of-flight channel steps to nanoseconds.
1306

1307
        Args:
1308
            tof_ns_column (str, optional): Name of the generated column containing the
1309
                time-of-flight in nanosecond.
1310
                Defaults to config["dataframe"]["tof_ns_column"].
1311
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1312

1313
        """
1314
        if self._dataframe is not None:
×
1315
            print("Adding time-of-flight column in nanoseconds to dataframe:")
×
1316
            # TODO assert order of execution through metadata
1317

1318
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
×
1319
                df=self._dataframe,
1320
                **kwargs,
1321
            )
1322
            self._attributes.add(
×
1323
                metadata,
1324
                "tof_ns_conversion",
1325
                duplicate_policy="append",
1326
            )
1327

1328
    def align_dld_sectors(self, **kwargs):
1✔
1329
        """Align the 8s sectors of the HEXTOF endstation."""
1330
        if self._dataframe is not None:
×
1331
            print("Aligning 8s sectors of dataframe")
×
1332
            # TODO assert order of execution through metadata
1333
            self._dataframe, metadata = self.ec.align_dld_sectors(df=self._dataframe, **kwargs)
×
1334
            self._attributes.add(
×
1335
                metadata,
1336
                "dld_sector_alignment",
1337
                duplicate_policy="raise",
1338
            )
1339

1340
    # Delay calibration function
1341
    def calibrate_delay_axis(
1✔
1342
        self,
1343
        delay_range: Tuple[float, float] = None,
1344
        datafile: str = None,
1345
        preview: bool = False,
1346
        **kwds,
1347
    ):
1348
        """Append delay column to dataframe. Either provide delay ranges, or read
1349
        them from a file.
1350

1351
        Args:
1352
            delay_range (Tuple[float, float], optional): The scanned delay range in
1353
                picoseconds. Defaults to None.
1354
            datafile (str, optional): The file from which to read the delay ranges.
1355
                Defaults to None.
1356
            preview (bool): Option to preview the first elements of the data frame.
1357
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1358
        """
1359
        if self._dataframe is not None:
1✔
1360
            print("Adding delay column to dataframe:")
1✔
1361

1362
            if delay_range is not None:
1✔
1363
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1364
                    self._dataframe,
1365
                    delay_range=delay_range,
1366
                    **kwds,
1367
                )
1368
                if self._timed_dataframe is not None:
1✔
1369
                    self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1370
                        self._timed_dataframe,
1371
                        delay_range=delay_range,
1372
                        **kwds,
1373
                    )
1374
            else:
1375
                if datafile is None:
1✔
1376
                    try:
1✔
1377
                        datafile = self._files[0]
1✔
1378
                    except IndexError:
×
1379
                        print(
×
1380
                            "No datafile available, specify either",
1381
                            " 'datafile' or 'delay_range'",
1382
                        )
1383
                        raise
×
1384

1385
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1386
                    self._dataframe,
1387
                    datafile=datafile,
1388
                    **kwds,
1389
                )
1390
                if self._timed_dataframe is not None:
1✔
1391
                    self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1392
                        self._timed_dataframe,
1393
                        datafile=datafile,
1394
                        **kwds,
1395
                    )
1396

1397
            # Add Metadata
1398
            self._attributes.add(
1✔
1399
                metadata,
1400
                "delay_calibration",
1401
                duplicate_policy="merge",
1402
            )
1403
            if preview:
1✔
1404
                print(self._dataframe.head(10))
1✔
1405
            else:
1406
                print(self._dataframe)
1✔
1407

1408
    def add_jitter(
1✔
1409
        self,
1410
        cols: List[str] = None,
1411
        amps: Union[float, Sequence[float]] = None,
1412
        **kwds,
1413
    ):
1414
        """Add jitter to the selected dataframe columns.
1415

1416
        Args:
1417
            cols (List[str], optional): The colums onto which to apply jitter.
1418
                Defaults to config["dataframe"]["jitter_cols"].
1419
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1420
                jittering noise. If one number is given, the same is used for all axes.
1421
                For uniform noise (default) it will cover the interval [-amp, +amp].
1422
                Defaults to config["dataframe"]["jitter_amps"].
1423
            **kwds: additional keyword arguments passed to apply_jitter
1424
        """
1425
        if cols is None:
1✔
1426
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1427
        for loc, col in enumerate(cols):
1✔
1428
            if col.startswith("@"):
1✔
1429
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1430

1431
        if amps is None:
1✔
1432
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1433

1434
        self._dataframe = self._dataframe.map_partitions(
1✔
1435
            apply_jitter,
1436
            cols=cols,
1437
            cols_jittered=cols,
1438
            amps=amps,
1439
            **kwds,
1440
        )
1441
        if self._timed_dataframe is not None:
1✔
1442
            self._timed_dataframe = self._timed_dataframe.map_partitions(
×
1443
                apply_jitter,
1444
                cols=cols,
1445
                cols_jittered=cols,
1446
            )
1447
        metadata = []
1✔
1448
        for col in cols:
1✔
1449
            metadata.append(col)
1✔
1450
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1451

1452
    def smooth_columns(
1✔
1453
        self,
1454
        columns: Union[str, Sequence[str]] = None,
1455
        method: Literal["rolling"] = "rolling",
1456
        **kwargs,
1457
    ) -> None:
1458
        """Apply a filter along one or more columns of the dataframe.
1459

1460
        Currently only supports rolling average on acquisition time.
1461

1462
        Args:
1463
            columns (Union[str,Sequence[str]]): The colums onto which to apply the filter.
1464
            method (Literal['rolling'], optional): The filter method. Defaults to 'rolling'.
1465
            **kwargs: Keyword arguments passed to the filter method.
1466
        """
1467
        if isinstance(columns, str):
×
1468
            columns = [columns]
×
1469
        for column in columns:
×
1470
            if column not in self._dataframe.columns:
×
1471
                raise ValueError(f"Cannot smooth {column}. Column not in dataframe!")
×
1472
        kwargs = {**self._config["smooth"], **kwargs}
×
1473
        if method == "rolling":
×
1474
            self._dataframe = rolling_average_on_acquisition_time(
×
1475
                df=self._dataframe,
1476
                rolling_group_channel=kwargs.get("rolling_group_channel", None),
1477
                columns=columns or kwargs.get("columns", None),
1478
                window=kwargs.get("window", None),
1479
                sigma=kwargs.get("sigma", None),
1480
            )
1481
        else:
1482
            raise ValueError(f"Method {method} not supported!")
×
1483
        self._attributes.add(
×
1484
            columns,
1485
            "smooth",
1486
            duplicate_policy="append",
1487
        )
1488

1489
    def pre_binning(
1✔
1490
        self,
1491
        df_partitions: int = 100,
1492
        axes: List[str] = None,
1493
        bins: List[int] = None,
1494
        ranges: Sequence[Tuple[float, float]] = None,
1495
        **kwds,
1496
    ) -> xr.DataArray:
1497
        """Function to do an initial binning of the dataframe loaded to the class.
1498

1499
        Args:
1500
            df_partitions (int, optional): Number of dataframe partitions to use for
1501
                the initial binning. Defaults to 100.
1502
            axes (List[str], optional): Axes to bin.
1503
                Defaults to config["momentum"]["axes"].
1504
            bins (List[int], optional): Bin numbers to use for binning.
1505
                Defaults to config["momentum"]["bins"].
1506
            ranges (List[Tuple], optional): Ranges to use for binning.
1507
                Defaults to config["momentum"]["ranges"].
1508
            **kwds: Keyword argument passed to ``compute``.
1509

1510
        Returns:
1511
            xr.DataArray: pre-binned data-array.
1512
        """
1513
        if axes is None:
1✔
1514
            axes = self._config["momentum"]["axes"]
1✔
1515
        for loc, axis in enumerate(axes):
1✔
1516
            if axis.startswith("@"):
1✔
1517
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1518

1519
        if bins is None:
1✔
1520
            bins = self._config["momentum"]["bins"]
1✔
1521
        if ranges is None:
1✔
1522
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1523
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1524
                self._config["dataframe"]["tof_binning"] - 1
1525
            )
1526
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1527

1528
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1529

1530
        return self.compute(
1✔
1531
            bins=bins,
1532
            axes=axes,
1533
            ranges=ranges,
1534
            df_partitions=df_partitions,
1535
            **kwds,
1536
        )
1537

1538
    def compute(
1✔
1539
        self,
1540
        bins: Union[
1541
            int,
1542
            dict,
1543
            tuple,
1544
            List[int],
1545
            List[np.ndarray],
1546
            List[tuple],
1547
        ] = 100,
1548
        axes: Union[str, Sequence[str]] = None,
1549
        ranges: Sequence[Tuple[float, float]] = None,
1550
        normalize_to_acquisition_time: Union[bool, str] = False,
1551
        **kwds,
1552
    ) -> xr.DataArray:
1553
        """Compute the histogram along the given dimensions.
1554

1555
        Args:
1556
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1557
                Definition of the bins. Can be any of the following cases:
1558

1559
                - an integer describing the number of bins in on all dimensions
1560
                - a tuple of 3 numbers describing start, end and step of the binning
1561
                  range
1562
                - a np.arrays defining the binning edges
1563
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1564
                - a dictionary made of the axes as keys and any of the above as values.
1565

1566
                This takes priority over the axes and range arguments. Defaults to 100.
1567
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1568
                on which to calculate the histogram. The order will be the order of the
1569
                dimensions in the resulting array. Defaults to None.
1570
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1571
                the start and end point of the binning range. Defaults to None.
1572
            normalize_to_acquisition_time (Union[bool, str]): Option to normalize the
1573
                result to the acquistion time. If a "slow" axis was scanned, providing
1574
                the name of the scanned axis will compute and apply the corresponding
1575
                normalization histogram. Defaults to False.
1576
            **kwds: Keyword arguments:
1577

1578
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1579
                  ``bin_dataframe`` for details. Defaults to
1580
                  config["binning"]["hist_mode"].
1581
                - **mode**: Defines how the results from each partition are combined.
1582
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1583
                  Defaults to config["binning"]["mode"].
1584
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1585
                  config["binning"]["pbar"].
1586
                - **n_cores**: Number of CPU cores to use for parallelization.
1587
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1588
                - **threads_per_worker**: Limit the number of threads that
1589
                  multiprocessing can spawn per binning thread. Defaults to
1590
                  config["binning"]["threads_per_worker"].
1591
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1592
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1593
                  config["binning"]["threadpool_API"].
1594
                - **df_partitions**: A range or list of dataframe partitions, or the
1595
                  number of the dataframe partitions to use. Defaults to all partitions.
1596

1597
                Additional kwds are passed to ``bin_dataframe``.
1598

1599
        Raises:
1600
            AssertError: Rises when no dataframe has been loaded.
1601

1602
        Returns:
1603
            xr.DataArray: The result of the n-dimensional binning represented in an
1604
            xarray object, combining the data with the axes.
1605
        """
1606
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1607

1608
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1609
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1610
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1611
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1612
        threads_per_worker = kwds.pop(
1✔
1613
            "threads_per_worker",
1614
            self._config["binning"]["threads_per_worker"],
1615
        )
1616
        threadpool_api = kwds.pop(
1✔
1617
            "threadpool_API",
1618
            self._config["binning"]["threadpool_API"],
1619
        )
1620
        df_partitions = kwds.pop("df_partitions", None)
1✔
1621
        if isinstance(df_partitions, int):
1✔
1622
            df_partitions = slice(
1✔
1623
                0,
1624
                min(df_partitions, self._dataframe.npartitions),
1625
            )
1626
        if df_partitions is not None:
1✔
1627
            dataframe = self._dataframe.partitions[df_partitions]
1✔
1628
        else:
1629
            dataframe = self._dataframe
1✔
1630

1631
        self._binned = bin_dataframe(
1✔
1632
            df=dataframe,
1633
            bins=bins,
1634
            axes=axes,
1635
            ranges=ranges,
1636
            hist_mode=hist_mode,
1637
            mode=mode,
1638
            pbar=pbar,
1639
            n_cores=num_cores,
1640
            threads_per_worker=threads_per_worker,
1641
            threadpool_api=threadpool_api,
1642
            **kwds,
1643
        )
1644

1645
        for dim in self._binned.dims:
1✔
1646
            try:
1✔
1647
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1648
            except KeyError:
1✔
1649
                pass
1✔
1650

1651
        self._binned.attrs["units"] = "counts"
1✔
1652
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1653
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1654

1655
        if normalize_to_acquisition_time:
1✔
1656
            if isinstance(normalize_to_acquisition_time, str):
1✔
1657
                axis = normalize_to_acquisition_time
1✔
1658
                print(
1✔
1659
                    f"Calculate normalization histogram for axis '{axis}'...",
1660
                )
1661
                self._normalization_histogram = self.get_normalization_histogram(
1✔
1662
                    axis=axis,
1663
                    df_partitions=df_partitions,
1664
                )
1665
                # if the axes are named correctly, xarray figures out the normalization correctly
1666
                self._normalized = self._binned / self._normalization_histogram
1✔
1667
                self._attributes.add(
1✔
1668
                    self._normalization_histogram.values,
1669
                    name="normalization_histogram",
1670
                    duplicate_policy="overwrite",
1671
                )
1672
            else:
1673
                acquisition_time = self.loader.get_elapsed_time(
×
1674
                    fids=df_partitions,
1675
                )
1676
                if acquisition_time > 0:
×
1677
                    self._normalized = self._binned / acquisition_time
×
1678
                self._attributes.add(
×
1679
                    acquisition_time,
1680
                    name="normalization_histogram",
1681
                    duplicate_policy="overwrite",
1682
                )
1683

1684
            self._normalized.attrs["units"] = "counts/second"
1✔
1685
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
1686
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
1687

1688
            return self._normalized
1✔
1689

1690
        return self._binned
1✔
1691

1692
    def get_normalization_histogram(
1✔
1693
        self,
1694
        axis: str = "delay",
1695
        use_time_stamps: bool = False,
1696
        **kwds,
1697
    ) -> xr.DataArray:
1698
        """Generates a normalization histogram from the timed dataframe. Optionally,
1699
        use the TimeStamps column instead.
1700

1701
        Args:
1702
            axis (str, optional): The axis for which to compute histogram.
1703
                Defaults to "delay".
1704
            use_time_stamps (bool, optional): Use the TimeStamps column of the
1705
                dataframe, rather than the timed dataframe. Defaults to False.
1706
            **kwds: Keyword arguments:
1707

1708
                -df_partitions (int, optional): Number of dataframe partitions to use.
1709
                  Defaults to all.
1710

1711
        Raises:
1712
            ValueError: Raised if no data are binned.
1713
            ValueError: Raised if 'axis' not in binned coordinates.
1714
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
1715
                in Dataframe.
1716

1717
        Returns:
1718
            xr.DataArray: The computed normalization histogram (in TimeStamp units
1719
            per bin).
1720
        """
1721

1722
        if self._binned is None:
1✔
1723
            raise ValueError("Need to bin data first!")
1✔
1724
        if axis not in self._binned.coords:
1✔
1725
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
1726

1727
        df_partitions: Union[int, slice] = kwds.pop("df_partitions", None)
1✔
1728
        if isinstance(df_partitions, int):
1✔
1729
            df_partitions = slice(
1✔
1730
                0,
1731
                min(df_partitions, self._dataframe.npartitions),
1732
            )
1733

1734
        if use_time_stamps or self._timed_dataframe is None:
1✔
1735
            if df_partitions is not None:
1✔
1736
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
1737
                    self._dataframe.partitions[df_partitions],
1738
                    axis,
1739
                    self._binned.coords[axis].values,
1740
                    self._config["dataframe"]["time_stamp_alias"],
1741
                )
1742
            else:
1743
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
1744
                    self._dataframe,
1745
                    axis,
1746
                    self._binned.coords[axis].values,
1747
                    self._config["dataframe"]["time_stamp_alias"],
1748
                )
1749
        else:
1750
            if df_partitions is not None:
1✔
1751
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
1752
                    self._timed_dataframe.partitions[df_partitions],
1753
                    axis,
1754
                    self._binned.coords[axis].values,
1755
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1756
                )
1757
            else:
1758
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
1759
                    self._timed_dataframe,
1760
                    axis,
1761
                    self._binned.coords[axis].values,
1762
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1763
                )
1764

1765
        return self._normalization_histogram
1✔
1766

1767
    def view_event_histogram(
1✔
1768
        self,
1769
        dfpid: int,
1770
        ncol: int = 2,
1771
        bins: Sequence[int] = None,
1772
        axes: Sequence[str] = None,
1773
        ranges: Sequence[Tuple[float, float]] = None,
1774
        backend: str = "bokeh",
1775
        legend: bool = True,
1776
        histkwds: dict = None,
1777
        legkwds: dict = None,
1778
        **kwds,
1779
    ):
1780
        """Plot individual histograms of specified dimensions (axes) from a substituent
1781
        dataframe partition.
1782

1783
        Args:
1784
            dfpid (int): Number of the data frame partition to look at.
1785
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1786
            bins (Sequence[int], optional): Number of bins to use for the speicified
1787
                axes. Defaults to config["histogram"]["bins"].
1788
            axes (Sequence[str], optional): Names of the axes to display.
1789
                Defaults to config["histogram"]["axes"].
1790
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1791
                specified axes. Defaults toconfig["histogram"]["ranges"].
1792
            backend (str, optional): Backend of the plotting library
1793
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1794
            legend (bool, optional): Option to include a legend in the histogram plots.
1795
                Defaults to True.
1796
            histkwds (dict, optional): Keyword arguments for histograms
1797
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1798
            legkwds (dict, optional): Keyword arguments for legend
1799
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1800
            **kwds: Extra keyword arguments passed to
1801
                ``sed.diagnostics.grid_histogram()``.
1802

1803
        Raises:
1804
            TypeError: Raises when the input values are not of the correct type.
1805
        """
1806
        if bins is None:
1✔
1807
            bins = self._config["histogram"]["bins"]
1✔
1808
        if axes is None:
1✔
1809
            axes = self._config["histogram"]["axes"]
1✔
1810
        axes = list(axes)
1✔
1811
        for loc, axis in enumerate(axes):
1✔
1812
            if axis.startswith("@"):
1✔
1813
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1814
        if ranges is None:
1✔
1815
            ranges = list(self._config["histogram"]["ranges"])
1✔
1816
            for loc, axis in enumerate(axes):
1✔
1817
                if axis == self._config["dataframe"]["tof_column"]:
1✔
1818
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
1819
                        self._config["dataframe"]["tof_binning"] - 1
1820
                    )
1821
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
1822
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1823
                        self._config["dataframe"]["adc_binning"] - 1
1824
                    )
1825

1826
        input_types = map(type, [axes, bins, ranges])
1✔
1827
        allowed_types = [list, tuple]
1✔
1828

1829
        df = self._dataframe
1✔
1830

1831
        if not set(input_types).issubset(allowed_types):
1✔
1832
            raise TypeError(
×
1833
                "Inputs of axes, bins, ranges need to be list or tuple!",
1834
            )
1835

1836
        # Read out the values for the specified groups
1837
        group_dict_dd = {}
1✔
1838
        dfpart = df.get_partition(dfpid)
1✔
1839
        cols = dfpart.columns
1✔
1840
        for ax in axes:
1✔
1841
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
1842
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
1843

1844
        # Plot multiple histograms in a grid
1845
        grid_histogram(
1✔
1846
            group_dict,
1847
            ncol=ncol,
1848
            rvs=axes,
1849
            rvbins=bins,
1850
            rvranges=ranges,
1851
            backend=backend,
1852
            legend=legend,
1853
            histkwds=histkwds,
1854
            legkwds=legkwds,
1855
            **kwds,
1856
        )
1857

1858
    def save(
1✔
1859
        self,
1860
        faddr: str,
1861
        **kwds,
1862
    ):
1863
        """Saves the binned data to the provided path and filename.
1864

1865
        Args:
1866
            faddr (str): Path and name of the file to write. Its extension determines
1867
                the file type to write. Valid file types are:
1868

1869
                - "*.tiff", "*.tif": Saves a TIFF stack.
1870
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1871
                - "*.nxs", "*.nexus": Saves a NeXus file.
1872

1873
            **kwds: Keyword argumens, which are passed to the writer functions:
1874
                For TIFF writing:
1875

1876
                - **alias_dict**: Dictionary of dimension aliases to use.
1877

1878
                For HDF5 writing:
1879

1880
                - **mode**: hdf5 read/write mode. Defaults to "w".
1881

1882
                For NeXus:
1883

1884
                - **reader**: Name of the nexustools reader to use.
1885
                  Defaults to config["nexus"]["reader"]
1886
                - **definiton**: NeXus application definition to use for saving.
1887
                  Must be supported by the used ``reader``. Defaults to
1888
                  config["nexus"]["definition"]
1889
                - **input_files**: A list of input files to pass to the reader.
1890
                  Defaults to config["nexus"]["input_files"]
1891
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
1892
                  to add to the list of files to pass to the reader.
1893
        """
1894
        if self._binned is None:
1✔
1895
            raise NameError("Need to bin data first!")
1✔
1896

1897
        if self._normalized is not None:
1✔
1898
            data = self._normalized
×
1899
        else:
1900
            data = self._binned
1✔
1901

1902
        extension = pathlib.Path(faddr).suffix
1✔
1903

1904
        if extension in (".tif", ".tiff"):
1✔
1905
            to_tiff(
1✔
1906
                data=data,
1907
                faddr=faddr,
1908
                **kwds,
1909
            )
1910
        elif extension in (".h5", ".hdf5"):
1✔
1911
            to_h5(
1✔
1912
                data=data,
1913
                faddr=faddr,
1914
                **kwds,
1915
            )
1916
        elif extension in (".nxs", ".nexus"):
1✔
1917
            try:
1✔
1918
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
1919
                definition = kwds.pop(
1✔
1920
                    "definition",
1921
                    self._config["nexus"]["definition"],
1922
                )
1923
                input_files = kwds.pop(
1✔
1924
                    "input_files",
1925
                    self._config["nexus"]["input_files"],
1926
                )
1927
            except KeyError as exc:
×
1928
                raise ValueError(
×
1929
                    "The nexus reader, definition and input files need to be provide!",
1930
                ) from exc
1931

1932
            if isinstance(input_files, str):
1✔
1933
                input_files = [input_files]
1✔
1934

1935
            if "eln_data" in kwds:
1✔
1936
                input_files.append(kwds.pop("eln_data"))
×
1937

1938
            to_nexus(
1✔
1939
                data=data,
1940
                faddr=faddr,
1941
                reader=reader,
1942
                definition=definition,
1943
                input_files=input_files,
1944
                **kwds,
1945
            )
1946

1947
        else:
1948
            raise NotImplementedError(
1✔
1949
                f"Unrecognized file format: {extension}.",
1950
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc