• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6942383825

21 Nov 2023 10:22AM UTC coverage: 89.993% (-0.6%) from 90.586%
6942383825

push

github

zain-sohail
workflow runs when it is updated

5009 of 5566 relevant lines covered (89.99%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.88
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
23
from sed.calibrator import DelayCalibrator
1✔
24
from sed.calibrator import EnergyCalibrator
1✔
25
from sed.calibrator import MomentumCorrector
1✔
26
from sed.core.config import parse_config
1✔
27
from sed.core.config import save_config
1✔
28
from sed.core.dfops import apply_jitter
1✔
29
from sed.core.metadata import MetaHandler
1✔
30
from sed.diagnostics import grid_histogram
1✔
31
from sed.io import to_h5
1✔
32
from sed.io import to_nexus
1✔
33
from sed.io import to_tiff
1✔
34
from sed.loader import CopyTool
1✔
35
from sed.loader import get_loader
1✔
36

37
N_CPU = psutil.cpu_count()
1✔
38

39

40
class SedProcessor:
1✔
41
    """Processor class of sed. Contains wrapper functions defining a work flow for data
42
    correction, calibration and binning.
43

44
    Args:
45
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
46
        config (Union[dict, str], optional): Config dictionary or config file name.
47
            Defaults to None.
48
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
49
            into the class. Defaults to None.
50
        files (List[str], optional): List of files to pass to the loader defined in
51
            the config. Defaults to None.
52
        folder (str, optional): Folder containing files to pass to the loader
53
            defined in the config. Defaults to None.
54
        collect_metadata (bool): Option to collect metadata from files.
55
            Defaults to False.
56
        **kwds: Keyword arguments passed to the reader.
57
    """
58

59
    def __init__(
1✔
60
        self,
61
        metadata: dict = None,
62
        config: Union[dict, str] = None,
63
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
64
        files: List[str] = None,
65
        folder: str = None,
66
        runs: Sequence[str] = None,
67
        collect_metadata: bool = False,
68
        verbose: bool = False,
69
        **kwds,
70
    ):
71
        """Processor class of sed. Contains wrapper functions defining a work flow
72
        for data correction, calibration, and binning.
73

74
        Args:
75
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
76
            config (Union[dict, str], optional): Config dictionary or config file name.
77
                Defaults to None.
78
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
79
                into the class. Defaults to None.
80
            files (List[str], optional): List of files to pass to the loader defined in
81
                the config. Defaults to None.
82
            folder (str, optional): Folder containing files to pass to the loader
83
                defined in the config. Defaults to None.
84
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
85
                defined in the config. Defaults to None.
86
            collect_metadata (bool): Option to collect metadata from files.
87
                Defaults to False.
88
            **kwds: Keyword arguments passed to parse_config and to the reader.
89
        """
90
        config_kwds = {
1✔
91
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
92
        }
93
        for key in config_kwds.keys():
1✔
94
            del kwds[key]
1✔
95
        self._config = parse_config(config, **config_kwds)
1✔
96
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
97
        if num_cores >= N_CPU:
1✔
98
            num_cores = N_CPU - 1
1✔
99
        self._config["binning"]["num_cores"] = num_cores
1✔
100

101
        self.verbose = verbose
1✔
102

103
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
104
        self._timed_dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
105
        self._files: List[str] = []
1✔
106

107
        self._binned: xr.DataArray = None
1✔
108
        self._pre_binned: xr.DataArray = None
1✔
109
        self._normalization_histogram: xr.DataArray = None
1✔
110
        self._normalized: xr.DataArray = None
1✔
111

112
        self._attributes = MetaHandler(meta=metadata)
1✔
113

114
        loader_name = self._config["core"]["loader"]
1✔
115
        self.loader = get_loader(
1✔
116
            loader_name=loader_name,
117
            config=self._config,
118
        )
119

120
        self.ec = EnergyCalibrator(
1✔
121
            loader=self.loader,
122
            config=self._config,
123
        )
124

125
        self.mc = MomentumCorrector(
1✔
126
            config=self._config,
127
        )
128

129
        self.dc = DelayCalibrator(
1✔
130
            config=self._config,
131
        )
132

133
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
134
            "use_copy_tool",
135
            False,
136
        )
137
        if self.use_copy_tool:
1✔
138
            try:
1✔
139
                self.ct = CopyTool(
1✔
140
                    source=self._config["core"]["copy_tool_source"],
141
                    dest=self._config["core"]["copy_tool_dest"],
142
                    **self._config["core"].get("copy_tool_kwds", {}),
143
                )
144
            except KeyError:
1✔
145
                self.use_copy_tool = False
1✔
146

147
        # Load data if provided:
148
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
149
            self.load(
1✔
150
                dataframe=dataframe,
151
                metadata=metadata,
152
                files=files,
153
                folder=folder,
154
                runs=runs,
155
                collect_metadata=collect_metadata,
156
                **kwds,
157
            )
158

159
    def __repr__(self):
1✔
160
        if self._dataframe is None:
1✔
161
            df_str = "Data Frame: No Data loaded"
1✔
162
        else:
163
            df_str = self._dataframe.__repr__()
1✔
164
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
165
        pretty_str = df_str + "\n" + attributes_str
1✔
166
        return pretty_str
1✔
167

168
    @property
1✔
169
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
170
        """Accessor to the underlying dataframe.
171

172
        Returns:
173
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
174
        """
175
        return self._dataframe
1✔
176

177
    @dataframe.setter
1✔
178
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
179
        """Setter for the underlying dataframe.
180

181
        Args:
182
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
183
        """
184
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
185
            dataframe,
186
            self._dataframe.__class__,
187
        ):
188
            raise ValueError(
1✔
189
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
190
                "as the dataframe loaded into the SedProcessor!.\n"
191
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
192
            )
193
        self._dataframe = dataframe
1✔
194

195
    @property
1✔
196
    def timed_dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
197
        """Accessor to the underlying timed_dataframe.
198

199
        Returns:
200
            Union[pd.DataFrame, ddf.DataFrame]: Timed Dataframe object.
201
        """
202
        return self._timed_dataframe
1✔
203

204
    @timed_dataframe.setter
1✔
205
    def timed_dataframe(self, timed_dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
206
        """Setter for the underlying timed dataframe.
207

208
        Args:
209
            timed_dataframe (Union[pd.DataFrame, ddf.DataFrame]): The timed dataframe object to set
210
        """
211
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
212
            timed_dataframe,
213
            self._timed_dataframe.__class__,
214
        ):
215
            raise ValueError(
×
216
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
217
                "kind as the dataframe loaded into the SedProcessor!.\n"
218
                f"Loaded type: {self._timed_dataframe.__class__}, "
219
                f"provided type: {timed_dataframe}.",
220
            )
221
        self._timed_dataframe = timed_dataframe
×
222

223
    @property
1✔
224
    def attributes(self) -> dict:
1✔
225
        """Accessor to the metadata dict.
226

227
        Returns:
228
            dict: The metadata dict.
229
        """
230
        return self._attributes.metadata
1✔
231

232
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
233
        """Function to add element to the attributes dict.
234

235
        Args:
236
            attributes (dict): The attributes dictionary object to add.
237
            name (str): Key under which to add the dictionary to the attributes.
238
        """
239
        self._attributes.add(
1✔
240
            entry=attributes,
241
            name=name,
242
            **kwds,
243
        )
244

245
    @property
1✔
246
    def config(self) -> Dict[Any, Any]:
1✔
247
        """Getter attribute for the config dictionary
248

249
        Returns:
250
            Dict: The config dictionary.
251
        """
252
        return self._config
1✔
253

254
    @property
1✔
255
    def files(self) -> List[str]:
1✔
256
        """Getter attribute for the list of files
257

258
        Returns:
259
            List[str]: The list of loaded files
260
        """
261
        return self._files
1✔
262

263
    @property
1✔
264
    def binned(self) -> xr.DataArray:
1✔
265
        """Getter attribute for the binned data array
266

267
        Returns:
268
            xr.DataArray: The binned data array
269
        """
270
        if self._binned is None:
1✔
271
            raise ValueError("No binned data available, need to compute histogram first!")
×
272
        return self._binned
1✔
273

274
    @property
1✔
275
    def normalized(self) -> xr.DataArray:
1✔
276
        """Getter attribute for the normalized data array
277

278
        Returns:
279
            xr.DataArray: The normalized data array
280
        """
281
        if self._normalized is None:
1✔
282
            raise ValueError(
×
283
                "No normalized data available, compute data with normalization enabled!",
284
            )
285
        return self._normalized
1✔
286

287
    @property
1✔
288
    def normalization_histogram(self) -> xr.DataArray:
1✔
289
        """Getter attribute for the normalization histogram
290

291
        Returns:
292
            xr.DataArray: The normalizazion histogram
293
        """
294
        if self._normalization_histogram is None:
1✔
295
            raise ValueError("No normalization histogram available, generate histogram first!")
×
296
        return self._normalization_histogram
1✔
297

298
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
299
        """Function to mirror a list of files or a folder from a network drive to a
300
        local storage. Returns either the original or the copied path to the given
301
        path. The option to use this functionality is set by
302
        config["core"]["use_copy_tool"].
303

304
        Args:
305
            path (Union[str, List[str]]): Source path or path list.
306

307
        Returns:
308
            Union[str, List[str]]: Source or destination path or path list.
309
        """
310
        if self.use_copy_tool:
1✔
311
            if isinstance(path, list):
1✔
312
                path_out = []
1✔
313
                for file in path:
1✔
314
                    path_out.append(self.ct.copy(file))
1✔
315
                return path_out
1✔
316

317
            return self.ct.copy(path)
×
318

319
        if isinstance(path, list):
1✔
320
            return path
1✔
321

322
        return path
1✔
323

324
    def load(
1✔
325
        self,
326
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
327
        metadata: dict = None,
328
        files: List[str] = None,
329
        folder: str = None,
330
        runs: Sequence[str] = None,
331
        collect_metadata: bool = False,
332
        **kwds,
333
    ):
334
        """Load tabular data of single events into the dataframe object in the class.
335

336
        Args:
337
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
338
                format. Accepts anything which can be interpreted by pd.DataFrame as
339
                an input. Defaults to None.
340
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
341
            files (List[str], optional): List of file paths to pass to the loader.
342
                Defaults to None.
343
            runs (Sequence[str], optional): List of run identifiers to pass to the
344
                loader. Defaults to None.
345
            folder (str, optional): Folder path to pass to the loader.
346
                Defaults to None.
347

348
        Raises:
349
            ValueError: Raised if no valid input is provided.
350
        """
351
        if metadata is None:
1✔
352
            metadata = {}
1✔
353
        if dataframe is not None:
1✔
354
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
355
        elif runs is not None:
1✔
356
            # If runs are provided, we only use the copy tool if also folder is provided.
357
            # In that case, we copy the whole provided base folder tree, and pass the copied
358
            # version to the loader as base folder to look for the runs.
359
            if folder is not None:
1✔
360
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
361
                    folders=cast(str, self.cpy(folder)),
362
                    runs=runs,
363
                    metadata=metadata,
364
                    collect_metadata=collect_metadata,
365
                    **kwds,
366
                )
367
            else:
368
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
369
                    runs=runs,
370
                    metadata=metadata,
371
                    collect_metadata=collect_metadata,
372
                    **kwds,
373
                )
374

375
        elif folder is not None:
1✔
376
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
377
                folders=cast(str, self.cpy(folder)),
378
                metadata=metadata,
379
                collect_metadata=collect_metadata,
380
                **kwds,
381
            )
382
        elif files is not None:
1✔
383
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
384
                files=cast(List[str], self.cpy(files)),
385
                metadata=metadata,
386
                collect_metadata=collect_metadata,
387
                **kwds,
388
            )
389
        else:
390
            raise ValueError(
1✔
391
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
392
            )
393

394
        self._dataframe = dataframe
1✔
395
        self._timed_dataframe = timed_dataframe
1✔
396
        self._files = self.loader.files
1✔
397

398
        for key in metadata:
1✔
399
            self._attributes.add(
1✔
400
                entry=metadata[key],
401
                name=key,
402
                duplicate_policy="merge",
403
            )
404

405
    # Momentum calibration workflow
406
    # 1. Bin raw detector data for distortion correction
407
    def bin_and_load_momentum_calibration(
1✔
408
        self,
409
        df_partitions: int = 100,
410
        axes: List[str] = None,
411
        bins: List[int] = None,
412
        ranges: Sequence[Tuple[float, float]] = None,
413
        plane: int = 0,
414
        width: int = 5,
415
        apply: bool = False,
416
        **kwds,
417
    ):
418
        """1st step of momentum correction work flow. Function to do an initial binning
419
        of the dataframe loaded to the class, slice a plane from it using an
420
        interactive view, and load it into the momentum corrector class.
421

422
        Args:
423
            df_partitions (int, optional): Number of dataframe partitions to use for
424
                the initial binning. Defaults to 100.
425
            axes (List[str], optional): Axes to bin.
426
                Defaults to config["momentum"]["axes"].
427
            bins (List[int], optional): Bin numbers to use for binning.
428
                Defaults to config["momentum"]["bins"].
429
            ranges (List[Tuple], optional): Ranges to use for binning.
430
                Defaults to config["momentum"]["ranges"].
431
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
432
            width (int, optional): Initial value for the width slider. Defaults to 5.
433
            apply (bool, optional): Option to directly apply the values and select the
434
                slice. Defaults to False.
435
            **kwds: Keyword argument passed to the pre_binning function.
436
        """
437
        self._pre_binned = self.pre_binning(
1✔
438
            df_partitions=df_partitions,
439
            axes=axes,
440
            bins=bins,
441
            ranges=ranges,
442
            **kwds,
443
        )
444

445
        self.mc.load_data(data=self._pre_binned)
1✔
446
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
447

448
    # 2. Generate the spline warp correction from momentum features.
449
    # Either autoselect features, or input features from view above.
450
    def define_features(
1✔
451
        self,
452
        features: np.ndarray = None,
453
        rotation_symmetry: int = 6,
454
        auto_detect: bool = False,
455
        include_center: bool = True,
456
        apply: bool = False,
457
        **kwds,
458
    ):
459
        """2. Step of the distortion correction workflow: Define feature points in
460
        momentum space. They can be either manually selected using a GUI tool, be
461
        ptovided as list of feature points, or auto-generated using a
462
        feature-detection algorithm.
463

464
        Args:
465
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
466
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
467
                Defaults to 6.
468
            auto_detect (bool, optional): Whether to auto-detect the features.
469
                Defaults to False.
470
            include_center (bool, optional): Option to include a point at the center
471
                in the feature list. Defaults to True.
472
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
473
                MomentumCorrector.feature_select()
474
        """
475
        if auto_detect:  # automatic feature selection
1✔
476
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
477
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
478
            sigma_radius = kwds.pop(
×
479
                "sigma_radius",
480
                self._config["momentum"]["sigma_radius"],
481
            )
482
            self.mc.feature_extract(
×
483
                sigma=sigma,
484
                fwhm=fwhm,
485
                sigma_radius=sigma_radius,
486
                rotsym=rotation_symmetry,
487
                **kwds,
488
            )
489
            features = self.mc.peaks
×
490

491
        self.mc.feature_select(
1✔
492
            rotsym=rotation_symmetry,
493
            include_center=include_center,
494
            features=features,
495
            apply=apply,
496
            **kwds,
497
        )
498

499
    # 3. Generate the spline warp correction from momentum features.
500
    # If no features have been selected before, use class defaults.
501
    def generate_splinewarp(
1✔
502
        self,
503
        use_center: bool = None,
504
        **kwds,
505
    ):
506
        """3. Step of the distortion correction workflow: Generate the correction
507
        function restoring the symmetry in the image using a splinewarp algortihm.
508

509
        Args:
510
            use_center (bool, optional): Option to use the position of the
511
                center point in the correction. Default is read from config, or set to True.
512
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
513
        """
514
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
515

516
        if self.mc.slice is not None:
1✔
517
            print("Original slice with reference features")
1✔
518
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
519

520
            print("Corrected slice with target features")
1✔
521
            self.mc.view(
1✔
522
                image=self.mc.slice_corrected,
523
                annotated=True,
524
                points={"feats": self.mc.ptargs},
525
                backend="bokeh",
526
                crosshair=True,
527
            )
528

529
            print("Original slice with target features")
1✔
530
            self.mc.view(
1✔
531
                image=self.mc.slice,
532
                points={"feats": self.mc.ptargs},
533
                annotated=True,
534
                backend="bokeh",
535
            )
536

537
    # 3a. Save spline-warp parameters to config file.
538
    def save_splinewarp(
1✔
539
        self,
540
        filename: str = None,
541
        overwrite: bool = False,
542
    ):
543
        """Save the generated spline-warp parameters to the folder config file.
544

545
        Args:
546
            filename (str, optional): Filename of the config dictionary to save to.
547
                Defaults to "sed_config.yaml" in the current folder.
548
            overwrite (bool, optional): Option to overwrite the present dictionary.
549
                Defaults to False.
550
        """
551
        if filename is None:
1✔
552
            filename = "sed_config.yaml"
×
553
        points = []
1✔
554
        if self.mc.pouter_ord is not None:  # if there is any calibration info
1✔
555
            try:
1✔
556
                for point in self.mc.pouter_ord:
1✔
557
                    points.append([float(i) for i in point])
1✔
558
                if self.mc.include_center:
1✔
559
                    points.append([float(i) for i in self.mc.pcent])
1✔
560
            except AttributeError as exc:
×
561
                raise AttributeError(
×
562
                    "Momentum correction parameters not found, need to generate parameters first!",
563
                ) from exc
564
            config = {
1✔
565
                "momentum": {
566
                    "correction": {
567
                        "rotation_symmetry": self.mc.rotsym,
568
                        "feature_points": points,
569
                        "include_center": self.mc.include_center,
570
                        "use_center": self.mc.use_center,
571
                    },
572
                },
573
            }
574
            save_config(config, filename, overwrite)
1✔
575

576
    # 4. Pose corrections. Provide interactive interface for correcting
577
    # scaling, shift and rotation
578
    def pose_adjustment(
1✔
579
        self,
580
        scale: float = 1,
581
        xtrans: float = 0,
582
        ytrans: float = 0,
583
        angle: float = 0,
584
        apply: bool = False,
585
        use_correction: bool = True,
586
        reset: bool = True,
587
    ):
588
        """3. step of the distortion correction workflow: Generate an interactive panel
589
        to adjust affine transformations that are applied to the image. Applies first
590
        a scaling, next an x/y translation, and last a rotation around the center of
591
        the image.
592

593
        Args:
594
            scale (float, optional): Initial value of the scaling slider.
595
                Defaults to 1.
596
            xtrans (float, optional): Initial value of the xtrans slider.
597
                Defaults to 0.
598
            ytrans (float, optional): Initial value of the ytrans slider.
599
                Defaults to 0.
600
            angle (float, optional): Initial value of the angle slider.
601
                Defaults to 0.
602
            apply (bool, optional): Option to directly apply the provided
603
                transformations. Defaults to False.
604
            use_correction (bool, option): Whether to use the spline warp correction
605
                or not. Defaults to True.
606
            reset (bool, optional):
607
                Option to reset the correction before transformation. Defaults to True.
608
        """
609
        # Generate homomorphy as default if no distortion correction has been applied
610
        if self.mc.slice_corrected is None:
1✔
611
            if self.mc.slice is None:
1✔
612
                raise ValueError(
1✔
613
                    "No slice for corrections and transformations loaded!",
614
                )
615
            self.mc.slice_corrected = self.mc.slice
×
616

617
        if not use_correction:
1✔
618
            self.mc.reset_deformation()
1✔
619

620
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
621
            # Generate distortion correction from config values
622
            self.mc.add_features()
×
623
            self.mc.spline_warp_estimate()
×
624

625
        self.mc.pose_adjustment(
1✔
626
            scale=scale,
627
            xtrans=xtrans,
628
            ytrans=ytrans,
629
            angle=angle,
630
            apply=apply,
631
            reset=reset,
632
        )
633

634
    # 5. Apply the momentum correction to the dataframe
635
    def apply_momentum_correction(
1✔
636
        self,
637
        preview: bool = False,
638
    ):
639
        """Applies the distortion correction and pose adjustment (optional)
640
        to the dataframe.
641

642
        Args:
643
            rdeform_field (np.ndarray, optional): Row deformation field.
644
                Defaults to None.
645
            cdeform_field (np.ndarray, optional): Column deformation field.
646
                Defaults to None.
647
            inv_dfield (np.ndarray, optional): Inverse deformation field.
648
                Defaults to None.
649
            preview (bool): Option to preview the first elements of the data frame.
650
        """
651
        if self._dataframe is not None:
1✔
652
            print("Adding corrected X/Y columns to dataframe:")
1✔
653
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
654
                df=self._dataframe,
655
            )
656
            if self._timed_dataframe is not None:
1✔
657
                if (
1✔
658
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
659
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
660
                ):
661
                    self._timed_dataframe, _ = self.mc.apply_corrections(
1✔
662
                        self._timed_dataframe,
663
                    )
664
            # Add Metadata
665
            self._attributes.add(
1✔
666
                metadata,
667
                "momentum_correction",
668
                duplicate_policy="merge",
669
            )
670
            if preview:
1✔
671
                print(self._dataframe.head(10))
×
672
            else:
673
                if self.verbose:
1✔
674
                    print(self._dataframe)
×
675

676
    # Momentum calibration work flow
677
    # 1. Calculate momentum calibration
678
    def calibrate_momentum_axes(
1✔
679
        self,
680
        point_a: Union[np.ndarray, List[int]] = None,
681
        point_b: Union[np.ndarray, List[int]] = None,
682
        k_distance: float = None,
683
        k_coord_a: Union[np.ndarray, List[float]] = None,
684
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
685
        equiscale: bool = True,
686
        apply=False,
687
    ):
688
        """1. step of the momentum calibration workflow. Calibrate momentum
689
        axes using either provided pixel coordinates of a high-symmetry point and its
690
        distance to the BZ center, or the k-coordinates of two points in the BZ
691
        (depending on the equiscale option). Opens an interactive panel for selecting
692
        the points.
693

694
        Args:
695
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
696
                point used for momentum calibration.
697
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
698
                second point used for momentum calibration.
699
                Defaults to config["momentum"]["center_pixel"].
700
            k_distance (float, optional): Momentum distance between point a and b.
701
                Needs to be provided if no specific k-koordinates for the two points
702
                are given. Defaults to None.
703
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
704
                of the first point used for calibration. Used if equiscale is False.
705
                Defaults to None.
706
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
707
                of the second point used for calibration. Defaults to [0.0, 0.0].
708
            equiscale (bool, optional): Option to apply different scales to kx and ky.
709
                If True, the distance between points a and b, and the absolute
710
                position of point a are used for defining the scale. If False, the
711
                scale is calculated from the k-positions of both points a and b.
712
                Defaults to True.
713
            apply (bool, optional): Option to directly store the momentum calibration
714
                in the class. Defaults to False.
715
        """
716
        if point_b is None:
1✔
717
            point_b = self._config["momentum"]["center_pixel"]
1✔
718

719
        self.mc.select_k_range(
1✔
720
            point_a=point_a,
721
            point_b=point_b,
722
            k_distance=k_distance,
723
            k_coord_a=k_coord_a,
724
            k_coord_b=k_coord_b,
725
            equiscale=equiscale,
726
            apply=apply,
727
        )
728

729
    # 1a. Save momentum calibration parameters to config file.
730
    def save_momentum_calibration(
1✔
731
        self,
732
        filename: str = None,
733
        overwrite: bool = False,
734
    ):
735
        """Save the generated momentum calibration parameters to the folder config file.
736

737
        Args:
738
            filename (str, optional): Filename of the config dictionary to save to.
739
                Defaults to "sed_config.yaml" in the current folder.
740
            overwrite (bool, optional): Option to overwrite the present dictionary.
741
                Defaults to False.
742
        """
743
        if filename is None:
1✔
744
            filename = "sed_config.yaml"
×
745
        calibration = {}
1✔
746
        try:
1✔
747
            for key in [
1✔
748
                "kx_scale",
749
                "ky_scale",
750
                "x_center",
751
                "y_center",
752
                "rstart",
753
                "cstart",
754
                "rstep",
755
                "cstep",
756
            ]:
757
                calibration[key] = float(self.mc.calibration[key])
1✔
758
        except KeyError as exc:
×
759
            raise KeyError(
×
760
                "Momentum calibration parameters not found, need to generate parameters first!",
761
            ) from exc
762

763
        config = {"momentum": {"calibration": calibration}}
1✔
764
        save_config(config, filename, overwrite)
1✔
765
        print(f"Saved momentum calibration parameters to {filename}")
1✔
766

767
    # 2. Apply correction and calibration to the dataframe
768
    def apply_momentum_calibration(
1✔
769
        self,
770
        calibration: dict = None,
771
        preview: bool = False,
772
    ):
773
        """2. step of the momentum calibration work flow: Apply the momentum
774
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
775
        these are used.
776

777
        Args:
778
            calibration (dict, optional): Optional dictionary with calibration data to
779
                use. Defaults to None.
780
            preview (bool): Option to preview the first elements of the data frame.
781
        """
782
        if self._dataframe is not None:
1✔
783

784
            print("Adding kx/ky columns to dataframe:")
1✔
785
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
786
                df=self._dataframe,
787
                calibration=calibration,
788
            )
789
            if self._timed_dataframe is not None:
1✔
790
                if (
1✔
791
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
792
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
793
                ):
794
                    self._timed_dataframe, _ = self.mc.append_k_axis(
1✔
795
                        df=self._timed_dataframe,
796
                        calibration=calibration,
797
                    )
798

799
            # Add Metadata
800
            self._attributes.add(
1✔
801
                metadata,
802
                "momentum_calibration",
803
                duplicate_policy="merge",
804
            )
805
            if preview:
1✔
806
                print(self._dataframe.head(10))
×
807
            else:
808
                if self.verbose:
1✔
809
                    print(self._dataframe)
×
810

811
    # Energy correction workflow
812
    # 1. Adjust the energy correction parameters
813
    def adjust_energy_correction(
1✔
814
        self,
815
        correction_type: str = None,
816
        amplitude: float = None,
817
        center: Tuple[float, float] = None,
818
        apply=False,
819
        **kwds,
820
    ):
821
        """1. step of the energy crrection workflow: Opens an interactive plot to
822
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
823
        they are not present yet.
824

825
        Args:
826
            correction_type (str, optional): Type of correction to apply to the TOF
827
                axis. Valid values are:
828

829
                - 'spherical'
830
                - 'Lorentzian'
831
                - 'Gaussian'
832
                - 'Lorentzian_asymmetric'
833

834
                Defaults to config["energy"]["correction_type"].
835
            amplitude (float, optional): Amplitude of the correction.
836
                Defaults to config["energy"]["correction"]["amplitude"].
837
            center (Tuple[float, float], optional): Center X/Y coordinates for the
838
                correction. Defaults to config["energy"]["correction"]["center"].
839
            apply (bool, optional): Option to directly apply the provided or default
840
                correction parameters. Defaults to False.
841
        """
842
        if self._pre_binned is None:
1✔
843
            print(
1✔
844
                "Pre-binned data not present, binning using defaults from config...",
845
            )
846
            self._pre_binned = self.pre_binning()
1✔
847

848
        self.ec.adjust_energy_correction(
1✔
849
            self._pre_binned,
850
            correction_type=correction_type,
851
            amplitude=amplitude,
852
            center=center,
853
            apply=apply,
854
            **kwds,
855
        )
856

857
    # 1a. Save energy correction parameters to config file.
858
    def save_energy_correction(
1✔
859
        self,
860
        filename: str = None,
861
        overwrite: bool = False,
862
    ):
863
        """Save the generated energy correction parameters to the folder config file.
864

865
        Args:
866
            filename (str, optional): Filename of the config dictionary to save to.
867
                Defaults to "sed_config.yaml" in the current folder.
868
            overwrite (bool, optional): Option to overwrite the present dictionary.
869
                Defaults to False.
870
        """
871
        if filename is None:
1✔
872
            filename = "sed_config.yaml"
1✔
873
        correction = {}
1✔
874
        try:
1✔
875
            for key, val in self.ec.correction.items():
1✔
876
                if key == "correction_type":
1✔
877
                    correction[key] = val
1✔
878
                elif key == "center":
1✔
879
                    correction[key] = [float(i) for i in val]
1✔
880
                else:
881
                    correction[key] = float(val)
1✔
882
        except AttributeError as exc:
×
883
            raise AttributeError(
×
884
                "Energy correction parameters not found, need to generate parameters first!",
885
            ) from exc
886

887
        config = {"energy": {"correction": correction}}
1✔
888
        save_config(config, filename, overwrite)
1✔
889
        print(f"Saved energy correction parameters to {filename}")
1✔
890

891
    # 2. Apply energy correction to dataframe
892
    def apply_energy_correction(
1✔
893
        self,
894
        correction: dict = None,
895
        preview: bool = False,
896
        **kwds,
897
    ):
898
        """2. step of the energy correction workflow: Apply the enery correction
899
        parameters stored in the class to the dataframe.
900

901
        Args:
902
            correction (dict, optional): Dictionary containing the correction
903
                parameters. Defaults to config["energy"]["calibration"].
904
            preview (bool): Option to preview the first elements of the data frame.
905
            **kwds:
906
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
907
            preview (bool): Option to preview the first elements of the data frame.
908
            **kwds:
909
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
910
        """
911
        if self._dataframe is not None:
1✔
912
            print("Applying energy correction to dataframe...")
1✔
913
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
914
                df=self._dataframe,
915
                correction=correction,
916
                **kwds,
917
            )
918
            if self._timed_dataframe is not None:
1✔
919
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
920
                    self._timed_dataframe, _ = self.ec.apply_energy_correction(
1✔
921
                        df=self._timed_dataframe,
922
                        correction=correction,
923
                        **kwds,
924
                    )
925

926
            # Add Metadata
927
            self._attributes.add(
1✔
928
                metadata,
929
                "energy_correction",
930
            )
931
            if preview:
1✔
932
                print(self._dataframe.head(10))
×
933
            else:
934
                if self.verbose:
1✔
935
                    print(self._dataframe)
×
936

937
    # Energy calibrator workflow
938
    # 1. Load and normalize data
939
    def load_bias_series(
1✔
940
        self,
941
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
942
        data_files: List[str] = None,
943
        axes: List[str] = None,
944
        bins: List = None,
945
        ranges: Sequence[Tuple[float, float]] = None,
946
        biases: np.ndarray = None,
947
        bias_key: str = None,
948
        normalize: bool = None,
949
        span: int = None,
950
        order: int = None,
951
    ):
952
        """1. step of the energy calibration workflow: Load and bin data from
953
        single-event files, or load binned bias/TOF traces.
954

955
        Args:
956
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
957
                Binned data If provided as DataArray, Needs to contain dimensions
958
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
959
                provided as tuple, needs to contain elements tof, biases, traces.
960
            data_files (List[str], optional): list of file paths to bin
961
            axes (List[str], optional): bin axes.
962
                Defaults to config["dataframe"]["tof_column"].
963
            bins (List, optional): number of bins.
964
                Defaults to config["energy"]["bins"].
965
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
966
                Defaults to config["energy"]["ranges"].
967
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
968
                voltages are extracted from the data files.
969
            bias_key (str, optional): hdf5 path where bias values are stored.
970
                Defaults to config["energy"]["bias_key"].
971
            normalize (bool, optional): Option to normalize traces.
972
                Defaults to config["energy"]["normalize"].
973
            span (int, optional): span smoothing parameters of the LOESS method
974
                (see ``scipy.signal.savgol_filter()``).
975
                Defaults to config["energy"]["normalize_span"].
976
            order (int, optional): order smoothing parameters of the LOESS method
977
                (see ``scipy.signal.savgol_filter()``).
978
                Defaults to config["energy"]["normalize_order"].
979
        """
980
        if binned_data is not None:
1✔
981
            if isinstance(binned_data, xr.DataArray):
1✔
982
                if (
1✔
983
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
984
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
985
                ):
986
                    raise ValueError(
1✔
987
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
988
                        f"'{self._config['dataframe']['tof_column']}' and "
989
                        f"'{self._config['dataframe']['bias_column']}'!.",
990
                    )
991
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
992
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
993
                traces = binned_data.values[:, :]
1✔
994
            else:
995
                try:
1✔
996
                    (tof, biases, traces) = binned_data
1✔
997
                except ValueError as exc:
1✔
998
                    raise ValueError(
1✔
999
                        "If binned_data is provided as tuple, it needs to contain "
1000
                        "(tof, biases, traces)!",
1001
                    ) from exc
1002
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1003

1004
        elif data_files is not None:
1✔
1005

1006
            self.ec.bin_data(
1✔
1007
                data_files=cast(List[str], self.cpy(data_files)),
1008
                axes=axes,
1009
                bins=bins,
1010
                ranges=ranges,
1011
                biases=biases,
1012
                bias_key=bias_key,
1013
            )
1014

1015
        else:
1016
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1017

1018
        if (normalize is not None and normalize is True) or (
1✔
1019
            normalize is None and self._config["energy"]["normalize"]
1020
        ):
1021
            if span is None:
1✔
1022
                span = self._config["energy"]["normalize_span"]
1✔
1023
            if order is None:
1✔
1024
                order = self._config["energy"]["normalize_order"]
1✔
1025
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1026
        self.ec.view(
1✔
1027
            traces=self.ec.traces_normed,
1028
            xaxis=self.ec.tof,
1029
            backend="bokeh",
1030
        )
1031

1032
    # 2. extract ranges and get peak positions
1033
    def find_bias_peaks(
1✔
1034
        self,
1035
        ranges: Union[List[Tuple], Tuple],
1036
        ref_id: int = 0,
1037
        infer_others: bool = True,
1038
        mode: str = "replace",
1039
        radius: int = None,
1040
        peak_window: int = None,
1041
        apply: bool = False,
1042
    ):
1043
        """2. step of the energy calibration workflow: Find a peak within a given range
1044
        for the indicated reference trace, and tries to find the same peak for all
1045
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1046
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1047
        middle of the set, and don't choose the range too narrow around the peak.
1048
        Alternatively, a list of ranges for all traces can be provided.
1049

1050
        Args:
1051
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
1052
                Alternatively, a list of ranges for all traces can be given.
1053
            refid (int, optional): The id of the trace the range refers to.
1054
                Defaults to 0.
1055
            infer_others (bool, optional): Whether to determine the range for the other
1056
                traces. Defaults to True.
1057
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1058
                Defaults to "replace".
1059
            radius (int, optional): Radius parameter for fast_dtw.
1060
                Defaults to config["energy"]["fastdtw_radius"].
1061
            peak_window (int, optional): Peak_window parameter for the peak detection
1062
                algorthm. amount of points that have to have to behave monotoneously
1063
                around a peak. Defaults to config["energy"]["peak_window"].
1064
            apply (bool, optional): Option to directly apply the provided parameters.
1065
                Defaults to False.
1066
        """
1067
        if radius is None:
1✔
1068
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1069
        if peak_window is None:
1✔
1070
            peak_window = self._config["energy"]["peak_window"]
1✔
1071
        if not infer_others:
1✔
1072
            self.ec.add_ranges(
1✔
1073
                ranges=ranges,
1074
                ref_id=ref_id,
1075
                infer_others=infer_others,
1076
                mode=mode,
1077
                radius=radius,
1078
            )
1079
            print(self.ec.featranges)
1✔
1080
            try:
1✔
1081
                self.ec.feature_extract(peak_window=peak_window)
1✔
1082
                self.ec.view(
1✔
1083
                    traces=self.ec.traces_normed,
1084
                    segs=self.ec.featranges,
1085
                    xaxis=self.ec.tof,
1086
                    peaks=self.ec.peaks,
1087
                    backend="bokeh",
1088
                )
1089
            except IndexError:
×
1090
                print("Could not determine all peaks!")
×
1091
                raise
×
1092
        else:
1093
            # New adjustment tool
1094
            assert isinstance(ranges, tuple)
1✔
1095
            self.ec.adjust_ranges(
1✔
1096
                ranges=ranges,
1097
                ref_id=ref_id,
1098
                traces=self.ec.traces_normed,
1099
                infer_others=infer_others,
1100
                radius=radius,
1101
                peak_window=peak_window,
1102
                apply=apply,
1103
            )
1104

1105
    # 3. Fit the energy calibration relation
1106
    def calibrate_energy_axis(
1✔
1107
        self,
1108
        ref_id: int,
1109
        ref_energy: float,
1110
        method: str = None,
1111
        energy_scale: str = None,
1112
        **kwds,
1113
    ):
1114
        """3. Step of the energy calibration workflow: Calculate the calibration
1115
        function for the energy axis, and apply it to the dataframe. Two
1116
        approximations are implemented, a (normally 3rd order) polynomial
1117
        approximation, and a d^2/(t-t0)^2 relation.
1118

1119
        Args:
1120
            ref_id (int): id of the trace at the bias where the reference energy is
1121
                given.
1122
            ref_energy (float): Absolute energy of the detected feature at the bias
1123
                of ref_id
1124
            method (str, optional): Method for determining the energy calibration.
1125

1126
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1127
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1128

1129
                Defaults to config["energy"]["calibration_method"]
1130
            energy_scale (str, optional): Direction of increasing energy scale.
1131

1132
                - **'kinetic'**: increasing energy with decreasing TOF.
1133
                - **'binding'**: increasing energy with increasing TOF.
1134

1135
                Defaults to config["energy"]["energy_scale"]
1136
        """
1137
        if method is None:
1✔
1138
            method = self._config["energy"]["calibration_method"]
1✔
1139

1140
        if energy_scale is None:
1✔
1141
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1142

1143
        self.ec.calibrate(
1✔
1144
            ref_id=ref_id,
1145
            ref_energy=ref_energy,
1146
            method=method,
1147
            energy_scale=energy_scale,
1148
            **kwds,
1149
        )
1150
        print("Quality of Calibration:")
1✔
1151
        self.ec.view(
1✔
1152
            traces=self.ec.traces_normed,
1153
            xaxis=self.ec.calibration["axis"],
1154
            align=True,
1155
            energy_scale=energy_scale,
1156
            backend="bokeh",
1157
        )
1158
        print("E/TOF relationship:")
1✔
1159
        self.ec.view(
1✔
1160
            traces=self.ec.calibration["axis"][None, :],
1161
            xaxis=self.ec.tof,
1162
            backend="matplotlib",
1163
            show_legend=False,
1164
        )
1165
        if energy_scale == "kinetic":
1✔
1166
            plt.scatter(
1✔
1167
                self.ec.peaks[:, 0],
1168
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1169
                s=50,
1170
                c="k",
1171
            )
1172
        elif energy_scale == "binding":
1✔
1173
            plt.scatter(
1✔
1174
                self.ec.peaks[:, 0],
1175
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1176
                s=50,
1177
                c="k",
1178
            )
1179
        else:
1180
            raise ValueError(
×
1181
                'energy_scale needs to be either "binding" or "kinetic"',
1182
                f", got {energy_scale}.",
1183
            )
1184
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1185
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1186
        plt.show()
1✔
1187

1188
    # 3a. Save energy calibration parameters to config file.
1189
    def save_energy_calibration(
1✔
1190
        self,
1191
        filename: str = None,
1192
        overwrite: bool = False,
1193
    ):
1194
        """Save the generated energy calibration parameters to the folder config file.
1195

1196
        Args:
1197
            filename (str, optional): Filename of the config dictionary to save to.
1198
                Defaults to "sed_config.yaml" in the current folder.
1199
            overwrite (bool, optional): Option to overwrite the present dictionary.
1200
                Defaults to False.
1201
        """
1202
        if filename is None:
1✔
1203
            filename = "sed_config.yaml"
×
1204
        calibration = {}
1✔
1205
        try:
1✔
1206
            for (key, value) in self.ec.calibration.items():
1✔
1207
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1208
                    continue
1✔
1209
                if key == "energy_scale":
1✔
1210
                    calibration[key] = value
1✔
1211
                elif key == "coeffs":
1✔
1212
                    calibration[key] = [float(i) for i in value]
1✔
1213
                else:
1214
                    calibration[key] = float(value)
1✔
1215
        except AttributeError as exc:
×
1216
            raise AttributeError(
×
1217
                "Energy calibration parameters not found, need to generate parameters first!",
1218
            ) from exc
1219
        config = {"energy": {"calibration": calibration}}
1✔
1220
        save_config(config, filename, overwrite)
1✔
1221
        print(f'Saved energy calibration parameters to "{filename}".')
1✔
1222

1223
    # 4. Apply energy calibration to the dataframe
1224
    def append_energy_axis(
1✔
1225
        self,
1226
        calibration: dict = None,
1227
        preview: bool = False,
1228
        **kwds,
1229
    ):
1230
        """4. step of the energy calibration workflow: Apply the calibration function
1231
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1232
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1233
        can be provided.
1234

1235
        Args:
1236
            calibration (dict, optional): Calibration dict containing calibration
1237
                parameters. Overrides calibration from class or config.
1238
                Defaults to None.
1239
            preview (bool): Option to preview the first elements of the data frame.
1240
            **kwds:
1241
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1242
        """
1243
        if self._dataframe is not None:
1✔
1244
            print("Adding energy column to dataframe:")
1✔
1245
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1246
                df=self._dataframe,
1247
                calibration=calibration,
1248
                **kwds,
1249
            )
1250
            if self._timed_dataframe is not None:
1✔
1251
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1252
                    self._timed_dataframe, _ = self.ec.append_energy_axis(
1✔
1253
                        df=self._timed_dataframe,
1254
                        calibration=calibration,
1255
                        **kwds,
1256
                    )
1257

1258
            # Add Metadata
1259
            self._attributes.add(
1✔
1260
                metadata,
1261
                "energy_calibration",
1262
                duplicate_policy="merge",
1263
            )
1264
            if preview:
1✔
1265
                print(self._dataframe.head(10))
1✔
1266
            else:
1267
                if self.verbose:
1✔
1268
                    print(self._dataframe)
×
1269

1270
    def add_energy_offset(
1✔
1271
        self,
1272
        constant: float = None,
1273
        columns: Union[str, Sequence[str]] = None,
1274
        weights: Union[float, Sequence[float]] = None,
1275
        reductions: Union[str, Sequence[str]] = None,
1276
        preserve_mean: Union[bool, Sequence[bool]] = None,
1277
    ) -> None:
1278
        """Shift the energy axis of the dataframe by a given amount.
1279

1280
        Args:
1281
            constant (float, optional): The constant to shift the energy axis by.
1282
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1283
            weights (Union[float, Sequence[float]]): weights to apply to the columns.
1284
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1285
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1286
                shift. Defaults to False.
1287
            reductions (str): The reduction to apply to the column. Should be an available method
1288
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1289
                to the column to generate a single value for the whole dataset. If None, the shift
1290
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1291

1292
        Raises:
1293
            ValueError: If the energy column is not in the dataframe.
1294
        """
1295
        print("Adding energy offset to dataframe:")
1✔
1296
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1297
        if self.dataframe is not None:
1✔
1298
            if energy_column not in self._dataframe.columns:
1✔
1299
                raise ValueError(
1✔
1300
                    f"Energy column {energy_column} not found in dataframe! "
1301
                    "Run `append energy axis` first.",
1302
                )
1303
            df, metadata = self.ec.add_offsets(
1✔
1304
                df=self._dataframe,
1305
                constant=constant,
1306
                columns=columns,
1307
                energy_column=energy_column,
1308
                weights=weights,
1309
                reductions=reductions,
1310
                preserve_mean=preserve_mean,
1311
            )
1312
            if self._timed_dataframe is not None:
1✔
1313
                if energy_column in self._timed_dataframe.columns:
1✔
1314
                    self._timed_dataframe, _ = self.ec.add_offsets(
1✔
1315
                        df=self._timed_dataframe,
1316
                        constant=constant,
1317
                        columns=columns,
1318
                        energy_column=energy_column,
1319
                        weights=weights,
1320
                        reductions=reductions,
1321
                        preserve_mean=preserve_mean,
1322
                    )
1323
            self._attributes.add(
1✔
1324
                metadata,
1325
                "add_energy_offset",
1326
                # TODO: allow only appending when no offset along this column(s) was applied
1327
                # TODO: clear memory of modifications if the energy axis is recalculated
1328
                duplicate_policy="append",
1329
            )
1330
            self._dataframe = df
1✔
1331
        else:
1332
            raise ValueError("No dataframe loaded!")
×
1333

1334
    def save_energy_offset(
1✔
1335
        self,
1336
        filename: str = None,
1337
        overwrite: bool = False,
1338
    ):
1339
        """Save the generated energy calibration parameters to the folder config file.
1340

1341
        Args:
1342
            filename (str, optional): Filename of the config dictionary to save to.
1343
                Defaults to "sed_config.yaml" in the current folder.
1344
            overwrite (bool, optional): Option to overwrite the present dictionary.
1345
                Defaults to False.
1346
        """
1347
        if filename is None:
×
1348
            filename = "sed_config.yaml"
×
1349
        if len(self.ec.offsets) == 0:
×
1350
            raise ValueError("No energy offset parameters to save!")
×
1351
        config = {"energy": {"offsets": self.ec.offsets}}
×
1352
        save_config(config, filename, overwrite)
×
1353
        print(f'Saved energy offset parameters to "{filename}".')
×
1354

1355
    def append_tof_ns_axis(
1✔
1356
        self,
1357
        **kwargs,
1358
    ):
1359
        """Convert time-of-flight channel steps to nanoseconds.
1360

1361
        Args:
1362
            tof_ns_column (str, optional): Name of the generated column containing the
1363
                time-of-flight in nanosecond.
1364
                Defaults to config["dataframe"]["tof_ns_column"].
1365
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1366

1367
        """
1368
        if self._dataframe is not None:
1✔
1369
            print("Adding time-of-flight column in nanoseconds to dataframe:")
1✔
1370
            # TODO assert order of execution through metadata
1371

1372
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
1✔
1373
                df=self._dataframe,
1374
                **kwargs,
1375
            )
1376
            if self._timed_dataframe is not None:
1✔
1377
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1378
                    self._timed_dataframe, _ = self.ec.append_tof_ns_axis(
1✔
1379
                        df=self._timed_dataframe,
1380
                        **kwargs,
1381
                    )
1382
            self._attributes.add(
1✔
1383
                metadata,
1384
                "tof_ns_conversion",
1385
                duplicate_policy="append",
1386
            )
1387

1388
    def align_dld_sectors(self, sector_delays: np.ndarray = None, **kwargs):
1✔
1389
        """Align the 8s sectors of the HEXTOF endstation.
1390

1391
        Args:
1392
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1393
                config["dataframe"]["sector_delays"].
1394
        """
1395
        if self._dataframe is not None:
1✔
1396
            print("Aligning 8s sectors of dataframe")
1✔
1397
            # TODO assert order of execution through metadata
1398
            self._dataframe, metadata = self.ec.align_dld_sectors(
1✔
1399
                df=self._dataframe,
1400
                sector_delays=sector_delays,
1401
                **kwargs,
1402
            )
1403
            if self._timed_dataframe is not None:
1✔
1404
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1405
                    self._timed_dataframe, _ = self.ec.align_dld_sectors(
×
1406
                        df=self._timed_dataframe,
1407
                        sector_delays=sector_delays,
1408
                        **kwargs,
1409
                    )
1410
            self._attributes.add(
1✔
1411
                metadata,
1412
                "dld_sector_alignment",
1413
                duplicate_policy="raise",
1414
            )
1415

1416
    # Delay calibration function
1417
    def calibrate_delay_axis(
1✔
1418
        self,
1419
        delay_range: Tuple[float, float] = None,
1420
        datafile: str = None,
1421
        preview: bool = False,
1422
        **kwds,
1423
    ):
1424
        """Append delay column to dataframe. Either provide delay ranges, or read
1425
        them from a file.
1426

1427
        Args:
1428
            delay_range (Tuple[float, float], optional): The scanned delay range in
1429
                picoseconds. Defaults to None.
1430
            datafile (str, optional): The file from which to read the delay ranges.
1431
                Defaults to None.
1432
            preview (bool): Option to preview the first elements of the data frame.
1433
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1434
        """
1435
        if self._dataframe is not None:
1✔
1436
            print("Adding delay column to dataframe:")
1✔
1437

1438
            if delay_range is not None:
1✔
1439
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1440
                    self._dataframe,
1441
                    delay_range=delay_range,
1442
                    **kwds,
1443
                )
1444
                if self._timed_dataframe is not None:
1✔
1445
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1446
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1447
                            self._timed_dataframe,
1448
                            delay_range=delay_range,
1449
                            **kwds,
1450
                        )
1451
            else:
1452
                if datafile is None:
1✔
1453
                    try:
1✔
1454
                        datafile = self._files[0]
1✔
1455
                    except IndexError:
×
1456
                        print(
×
1457
                            "No datafile available, specify either",
1458
                            " 'datafile' or 'delay_range'",
1459
                        )
1460
                        raise
×
1461

1462
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1463
                    self._dataframe,
1464
                    datafile=datafile,
1465
                    **kwds,
1466
                )
1467
                if self._timed_dataframe is not None:
1✔
1468
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1469
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1470
                            self._timed_dataframe,
1471
                            datafile=datafile,
1472
                            **kwds,
1473
                        )
1474

1475
            # Add Metadata
1476
            self._attributes.add(
1✔
1477
                metadata,
1478
                "delay_calibration",
1479
                duplicate_policy="merge",
1480
            )
1481
            if preview:
1✔
1482
                print(self._dataframe.head(10))
1✔
1483
            else:
1484
                if self.verbose:
1✔
1485
                    print(self._dataframe)
×
1486

1487
    def save_delay_calibration(
1✔
1488
        self,
1489
        filename: str = None,
1490
        overwrite: bool = False,
1491
    ) -> None:
1492
        """Save the generated delay calibration parameters to the folder config file.
1493

1494
        Args:
1495
            filename (str, optional): Filename of the config dictionary to save to.
1496
                Defaults to "sed_config.yaml" in the current folder.
1497
            overwrite (bool, optional): Option to overwrite the present dictionary.
1498
                Defaults to False.
1499
        """
1500
        if filename is None:
×
1501
            filename = "sed_config.yaml"
×
1502

1503
        config = {
×
1504
            "delay": {
1505
                "calibration": self.dc.calibration,
1506
            },
1507
        }
1508
        save_config(config, filename, overwrite)
×
1509

1510
    def add_delay_offset(
1✔
1511
        self,
1512
        constant: float = None,
1513
        flip_delay_axis: bool = None,
1514
        columns: Union[str, Sequence[str]] = None,
1515
        weights: Union[float, Sequence[float]] = None,
1516
        reductions: Union[str, Sequence[str]] = None,
1517
        preserve_mean: Union[bool, Sequence[bool]] = None,
1518
    ) -> None:
1519
        """Shift the delay axis of the dataframe by a constant or other columns.
1520

1521
        Args:
1522
            constant (float, optional): The constant to shift the delay axis by.
1523
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1524
            weights (Union[float, Sequence[float]]): weights to apply to the columns.
1525
                Can also be used to flip the sign (e.g. -1). Defaults to 1.
1526
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1527
                shift. Defaults to False.
1528
            reductions (str): The reduction to apply to the column. Should be an available method
1529
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1530
                to the column to generate a single value for the whole dataset. If None, the shift
1531
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1532

1533
        Returns:
1534
            None
1535
        """
1536
        print("Adding delay offset to dataframe:")
×
1537
        delay_column = self._config["dataframe"]["delay_column"]
×
1538
        if delay_column not in self._dataframe.columns:
×
1539
            raise ValueError(f"Delay column {delay_column} not found in dataframe! ")
×
1540

1541
        if self.dataframe is not None:
×
1542
            df, metadata = self.dc.add_offsets(
×
1543
                df=self._dataframe,
1544
                constant=constant,
1545
                flip_delay_axis=flip_delay_axis,
1546
                columns=columns,
1547
                delay_column=delay_column,
1548
                weights=weights,
1549
                reductions=reductions,
1550
                preserve_mean=preserve_mean,
1551
            )
1552
        if self._timed_dataframe is not None:
×
1553
            if delay_column in self._timed_dataframe.columns:
×
1554
                tdf, _ = self.dc.add_offsets(
×
1555
                    df=self._timed_dataframe,
1556
                    constant=constant,
1557
                    flip_delay_axis=flip_delay_axis,
1558
                    columns=columns,
1559
                    delay_column=delay_column,
1560
                    weights=weights,
1561
                    reductions=reductions,
1562
                    preserve_mean=preserve_mean,
1563
                )
1564
            self._attributes.add(
×
1565
                metadata,
1566
                "add_delay_offset",
1567
                duplicate_policy="append",
1568
            )
1569
            self._dataframe = df
×
1570
            if self._timed_dataframe is not None and delay_column in self._timed_dataframe.columns:
×
1571
                self._timed_dataframe = tdf
×
1572
        else:
1573
            raise ValueError("No dataframe loaded!")
×
1574

1575
    def save_delay_offsets(
1✔
1576
        self,
1577
        filename: str = None,
1578
        overwrite: bool = False,
1579
    ) -> None:
1580
        """Save the generated delay calibration parameters to the folder config file.
1581

1582
        Args:
1583
            filename (str, optional): Filename of the config dictionary to save to.
1584
                Defaults to "sed_config.yaml" in the current folder.
1585
            overwrite (bool, optional): Option to overwrite the present dictionary.
1586
                Defaults to False.
1587
        """
1588
        if filename is None:
×
1589
            filename = "sed_config.yaml"
×
1590
        if len(self.dc.offsets) == 0:
×
1591
            raise ValueError("No delay offset parameters to save!")
×
1592
        config = {
×
1593
            "delay": {
1594
                "offsets": self.dc.offsets,
1595
            },
1596
        }
1597
        save_config(config, filename, overwrite)
×
1598
        print(f'Saved delay offset parameters to "{filename}".')
×
1599

1600
    def save_workflow_params(
1✔
1601
        self,
1602
        filename: str = None,
1603
        overwrite: bool = False,
1604
    ) -> None:
1605
        """run all save calibration parameter methods
1606

1607
        Args:
1608
            filename (str, optional): Filename of the config dictionary to save to.
1609
                Defaults to "sed_config.yaml" in the current folder.
1610
            overwrite (bool, optional): Option to overwrite the present dictionary.
1611
                Defaults to False.
1612
        """
1613
        for method in [
×
1614
            self.save_momentum_calibration,
1615
            self.save_splinewarp,
1616
            self.save_energy_correction,
1617
            self.save_energy_calibration,
1618
            self.save_energy_offset,
1619
            self.save_delay_calibration,
1620
            self.save_delay_offsets,
1621
        ]:
1622
            try:
×
1623
                method(filename, overwrite)
×
1624
            except (ValueError, AttributeError, KeyError):
×
1625
                pass
×
1626

1627
    def add_jitter(
1✔
1628
        self,
1629
        cols: List[str] = None,
1630
        amps: Union[float, Sequence[float]] = None,
1631
        **kwds,
1632
    ):
1633
        """Add jitter to the selected dataframe columns.
1634

1635
        Args:
1636
            cols (List[str], optional): The colums onto which to apply jitter.
1637
                Defaults to config["dataframe"]["jitter_cols"].
1638
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1639
                jittering noise. If one number is given, the same is used for all axes.
1640
                For uniform noise (default) it will cover the interval [-amp, +amp].
1641
                Defaults to config["dataframe"]["jitter_amps"].
1642
            **kwds: additional keyword arguments passed to apply_jitter
1643
        """
1644
        if cols is None:
1✔
1645
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1646
        for loc, col in enumerate(cols):
1✔
1647
            if col.startswith("@"):
1✔
1648
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1649

1650
        if amps is None:
1✔
1651
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1652

1653
        self._dataframe = self._dataframe.map_partitions(
1✔
1654
            apply_jitter,
1655
            cols=cols,
1656
            cols_jittered=cols,
1657
            amps=amps,
1658
            **kwds,
1659
        )
1660
        if self._timed_dataframe is not None:
1✔
1661
            cols_timed = cols.copy()
1✔
1662
            for col in cols:
1✔
1663
                if col not in self._timed_dataframe.columns:
1✔
1664
                    cols_timed.remove(col)
×
1665

1666
            if cols_timed:
1✔
1667
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1668
                    apply_jitter,
1669
                    cols=cols_timed,
1670
                    cols_jittered=cols_timed,
1671
                )
1672
        metadata = []
1✔
1673
        for col in cols:
1✔
1674
            metadata.append(col)
1✔
1675
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1676

1677
    def pre_binning(
1✔
1678
        self,
1679
        df_partitions: int = 100,
1680
        axes: List[str] = None,
1681
        bins: List[int] = None,
1682
        ranges: Sequence[Tuple[float, float]] = None,
1683
        **kwds,
1684
    ) -> xr.DataArray:
1685
        """Function to do an initial binning of the dataframe loaded to the class.
1686

1687
        Args:
1688
            df_partitions (int, optional): Number of dataframe partitions to use for
1689
                the initial binning. Defaults to 100.
1690
            axes (List[str], optional): Axes to bin.
1691
                Defaults to config["momentum"]["axes"].
1692
            bins (List[int], optional): Bin numbers to use for binning.
1693
                Defaults to config["momentum"]["bins"].
1694
            ranges (List[Tuple], optional): Ranges to use for binning.
1695
                Defaults to config["momentum"]["ranges"].
1696
            **kwds: Keyword argument passed to ``compute``.
1697

1698
        Returns:
1699
            xr.DataArray: pre-binned data-array.
1700
        """
1701
        if axes is None:
1✔
1702
            axes = self._config["momentum"]["axes"]
1✔
1703
        for loc, axis in enumerate(axes):
1✔
1704
            if axis.startswith("@"):
1✔
1705
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1706

1707
        if bins is None:
1✔
1708
            bins = self._config["momentum"]["bins"]
1✔
1709
        if ranges is None:
1✔
1710
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1711
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1712
                self._config["dataframe"]["tof_binning"] - 1
1713
            )
1714
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1715

1716
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1717

1718
        return self.compute(
1✔
1719
            bins=bins,
1720
            axes=axes,
1721
            ranges=ranges,
1722
            df_partitions=df_partitions,
1723
            **kwds,
1724
        )
1725

1726
    def compute(
1✔
1727
        self,
1728
        bins: Union[
1729
            int,
1730
            dict,
1731
            tuple,
1732
            List[int],
1733
            List[np.ndarray],
1734
            List[tuple],
1735
        ] = 100,
1736
        axes: Union[str, Sequence[str]] = None,
1737
        ranges: Sequence[Tuple[float, float]] = None,
1738
        normalize_to_acquisition_time: Union[bool, str] = False,
1739
        **kwds,
1740
    ) -> xr.DataArray:
1741
        """Compute the histogram along the given dimensions.
1742

1743
        Args:
1744
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1745
                Definition of the bins. Can be any of the following cases:
1746

1747
                - an integer describing the number of bins in on all dimensions
1748
                - a tuple of 3 numbers describing start, end and step of the binning
1749
                  range
1750
                - a np.arrays defining the binning edges
1751
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1752
                - a dictionary made of the axes as keys and any of the above as values.
1753

1754
                This takes priority over the axes and range arguments. Defaults to 100.
1755
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1756
                on which to calculate the histogram. The order will be the order of the
1757
                dimensions in the resulting array. Defaults to None.
1758
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1759
                the start and end point of the binning range. Defaults to None.
1760
            normalize_to_acquisition_time (Union[bool, str]): Option to normalize the
1761
                result to the acquistion time. If a "slow" axis was scanned, providing
1762
                the name of the scanned axis will compute and apply the corresponding
1763
                normalization histogram. Defaults to False.
1764
            **kwds: Keyword arguments:
1765

1766
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1767
                  ``bin_dataframe`` for details. Defaults to
1768
                  config["binning"]["hist_mode"].
1769
                - **mode**: Defines how the results from each partition are combined.
1770
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1771
                  Defaults to config["binning"]["mode"].
1772
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1773
                  config["binning"]["pbar"].
1774
                - **n_cores**: Number of CPU cores to use for parallelization.
1775
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1776
                - **threads_per_worker**: Limit the number of threads that
1777
                  multiprocessing can spawn per binning thread. Defaults to
1778
                  config["binning"]["threads_per_worker"].
1779
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1780
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1781
                  config["binning"]["threadpool_API"].
1782
                - **df_partitions**: A range or list of dataframe partitions, or the
1783
                  number of the dataframe partitions to use. Defaults to all partitions.
1784

1785
                Additional kwds are passed to ``bin_dataframe``.
1786

1787
        Raises:
1788
            AssertError: Rises when no dataframe has been loaded.
1789

1790
        Returns:
1791
            xr.DataArray: The result of the n-dimensional binning represented in an
1792
            xarray object, combining the data with the axes.
1793
        """
1794
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1795

1796
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1797
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1798
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1799
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1800
        threads_per_worker = kwds.pop(
1✔
1801
            "threads_per_worker",
1802
            self._config["binning"]["threads_per_worker"],
1803
        )
1804
        threadpool_api = kwds.pop(
1✔
1805
            "threadpool_API",
1806
            self._config["binning"]["threadpool_API"],
1807
        )
1808
        df_partitions = kwds.pop("df_partitions", None)
1✔
1809
        if isinstance(df_partitions, int):
1✔
1810
            df_partitions = slice(
1✔
1811
                0,
1812
                min(df_partitions, self._dataframe.npartitions),
1813
            )
1814
        if df_partitions is not None:
1✔
1815
            dataframe = self._dataframe.partitions[df_partitions]
1✔
1816
        else:
1817
            dataframe = self._dataframe
1✔
1818

1819
        self._binned = bin_dataframe(
1✔
1820
            df=dataframe,
1821
            bins=bins,
1822
            axes=axes,
1823
            ranges=ranges,
1824
            hist_mode=hist_mode,
1825
            mode=mode,
1826
            pbar=pbar,
1827
            n_cores=num_cores,
1828
            threads_per_worker=threads_per_worker,
1829
            threadpool_api=threadpool_api,
1830
            **kwds,
1831
        )
1832

1833
        for dim in self._binned.dims:
1✔
1834
            try:
1✔
1835
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1836
            except KeyError:
1✔
1837
                pass
1✔
1838

1839
        self._binned.attrs["units"] = "counts"
1✔
1840
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1841
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1842

1843
        if normalize_to_acquisition_time:
1✔
1844
            if isinstance(normalize_to_acquisition_time, str):
1✔
1845
                axis = normalize_to_acquisition_time
1✔
1846
                print(
1✔
1847
                    f"Calculate normalization histogram for axis '{axis}'...",
1848
                )
1849
                self._normalization_histogram = self.get_normalization_histogram(
1✔
1850
                    axis=axis,
1851
                    df_partitions=df_partitions,
1852
                )
1853
                # if the axes are named correctly, xarray figures out the normalization correctly
1854
                self._normalized = self._binned / self._normalization_histogram
1✔
1855
                self._attributes.add(
1✔
1856
                    self._normalization_histogram.values,
1857
                    name="normalization_histogram",
1858
                    duplicate_policy="overwrite",
1859
                )
1860
            else:
1861
                acquisition_time = self.loader.get_elapsed_time(
×
1862
                    fids=df_partitions,
1863
                )
1864
                if acquisition_time > 0:
×
1865
                    self._normalized = self._binned / acquisition_time
×
1866
                self._attributes.add(
×
1867
                    acquisition_time,
1868
                    name="normalization_histogram",
1869
                    duplicate_policy="overwrite",
1870
                )
1871

1872
            self._normalized.attrs["units"] = "counts/second"
1✔
1873
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
1874
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
1875

1876
            return self._normalized
1✔
1877

1878
        return self._binned
1✔
1879

1880
    def get_normalization_histogram(
1✔
1881
        self,
1882
        axis: str = "delay",
1883
        use_time_stamps: bool = False,
1884
        **kwds,
1885
    ) -> xr.DataArray:
1886
        """Generates a normalization histogram from the timed dataframe. Optionally,
1887
        use the TimeStamps column instead.
1888

1889
        Args:
1890
            axis (str, optional): The axis for which to compute histogram.
1891
                Defaults to "delay".
1892
            use_time_stamps (bool, optional): Use the TimeStamps column of the
1893
                dataframe, rather than the timed dataframe. Defaults to False.
1894
            **kwds: Keyword arguments:
1895

1896
                -df_partitions (int, optional): Number of dataframe partitions to use.
1897
                  Defaults to all.
1898

1899
        Raises:
1900
            ValueError: Raised if no data are binned.
1901
            ValueError: Raised if 'axis' not in binned coordinates.
1902
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
1903
                in Dataframe.
1904

1905
        Returns:
1906
            xr.DataArray: The computed normalization histogram (in TimeStamp units
1907
            per bin).
1908
        """
1909

1910
        if self._binned is None:
1✔
1911
            raise ValueError("Need to bin data first!")
1✔
1912
        if axis not in self._binned.coords:
1✔
1913
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
1914

1915
        df_partitions: Union[int, slice] = kwds.pop("df_partitions", None)
1✔
1916
        if isinstance(df_partitions, int):
1✔
1917
            df_partitions = slice(
1✔
1918
                0,
1919
                min(df_partitions, self._dataframe.npartitions),
1920
            )
1921

1922
        if use_time_stamps or self._timed_dataframe is None:
1✔
1923
            if df_partitions is not None:
1✔
1924
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
1925
                    self._dataframe.partitions[df_partitions],
1926
                    axis,
1927
                    self._binned.coords[axis].values,
1928
                    self._config["dataframe"]["time_stamp_alias"],
1929
                )
1930
            else:
1931
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
1932
                    self._dataframe,
1933
                    axis,
1934
                    self._binned.coords[axis].values,
1935
                    self._config["dataframe"]["time_stamp_alias"],
1936
                )
1937
        else:
1938
            if df_partitions is not None:
1✔
1939
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
1940
                    self._timed_dataframe.partitions[df_partitions],
1941
                    axis,
1942
                    self._binned.coords[axis].values,
1943
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1944
                )
1945
            else:
1946
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
1947
                    self._timed_dataframe,
1948
                    axis,
1949
                    self._binned.coords[axis].values,
1950
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1951
                )
1952

1953
        return self._normalization_histogram
1✔
1954

1955
    def view_event_histogram(
1✔
1956
        self,
1957
        dfpid: int,
1958
        ncol: int = 2,
1959
        bins: Sequence[int] = None,
1960
        axes: Sequence[str] = None,
1961
        ranges: Sequence[Tuple[float, float]] = None,
1962
        backend: str = "bokeh",
1963
        legend: bool = True,
1964
        histkwds: dict = None,
1965
        legkwds: dict = None,
1966
        **kwds,
1967
    ):
1968
        """Plot individual histograms of specified dimensions (axes) from a substituent
1969
        dataframe partition.
1970

1971
        Args:
1972
            dfpid (int): Number of the data frame partition to look at.
1973
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1974
            bins (Sequence[int], optional): Number of bins to use for the speicified
1975
                axes. Defaults to config["histogram"]["bins"].
1976
            axes (Sequence[str], optional): Names of the axes to display.
1977
                Defaults to config["histogram"]["axes"].
1978
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1979
                specified axes. Defaults toconfig["histogram"]["ranges"].
1980
            backend (str, optional): Backend of the plotting library
1981
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1982
            legend (bool, optional): Option to include a legend in the histogram plots.
1983
                Defaults to True.
1984
            histkwds (dict, optional): Keyword arguments for histograms
1985
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1986
            legkwds (dict, optional): Keyword arguments for legend
1987
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1988
            **kwds: Extra keyword arguments passed to
1989
                ``sed.diagnostics.grid_histogram()``.
1990

1991
        Raises:
1992
            TypeError: Raises when the input values are not of the correct type.
1993
        """
1994
        if bins is None:
1✔
1995
            bins = self._config["histogram"]["bins"]
1✔
1996
        if axes is None:
1✔
1997
            axes = self._config["histogram"]["axes"]
1✔
1998
        axes = list(axes)
1✔
1999
        for loc, axis in enumerate(axes):
1✔
2000
            if axis.startswith("@"):
1✔
2001
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
2002
        if ranges is None:
1✔
2003
            ranges = list(self._config["histogram"]["ranges"])
1✔
2004
            for loc, axis in enumerate(axes):
1✔
2005
                if axis == self._config["dataframe"]["tof_column"]:
1✔
2006
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
2007
                        self._config["dataframe"]["tof_binning"] - 1
2008
                    )
2009
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
2010
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
2011
                        self._config["dataframe"]["adc_binning"] - 1
2012
                    )
2013

2014
        input_types = map(type, [axes, bins, ranges])
1✔
2015
        allowed_types = [list, tuple]
1✔
2016

2017
        df = self._dataframe
1✔
2018

2019
        if not set(input_types).issubset(allowed_types):
1✔
2020
            raise TypeError(
×
2021
                "Inputs of axes, bins, ranges need to be list or tuple!",
2022
            )
2023

2024
        # Read out the values for the specified groups
2025
        group_dict_dd = {}
1✔
2026
        dfpart = df.get_partition(dfpid)
1✔
2027
        cols = dfpart.columns
1✔
2028
        for ax in axes:
1✔
2029
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
2030
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
2031

2032
        # Plot multiple histograms in a grid
2033
        grid_histogram(
1✔
2034
            group_dict,
2035
            ncol=ncol,
2036
            rvs=axes,
2037
            rvbins=bins,
2038
            rvranges=ranges,
2039
            backend=backend,
2040
            legend=legend,
2041
            histkwds=histkwds,
2042
            legkwds=legkwds,
2043
            **kwds,
2044
        )
2045

2046
    def save(
1✔
2047
        self,
2048
        faddr: str,
2049
        **kwds,
2050
    ):
2051
        """Saves the binned data to the provided path and filename.
2052

2053
        Args:
2054
            faddr (str): Path and name of the file to write. Its extension determines
2055
                the file type to write. Valid file types are:
2056

2057
                - "*.tiff", "*.tif": Saves a TIFF stack.
2058
                - "*.h5", "*.hdf5": Saves an HDF5 file.
2059
                - "*.nxs", "*.nexus": Saves a NeXus file.
2060

2061
            **kwds: Keyword argumens, which are passed to the writer functions:
2062
                For TIFF writing:
2063

2064
                - **alias_dict**: Dictionary of dimension aliases to use.
2065

2066
                For HDF5 writing:
2067

2068
                - **mode**: hdf5 read/write mode. Defaults to "w".
2069

2070
                For NeXus:
2071

2072
                - **reader**: Name of the nexustools reader to use.
2073
                  Defaults to config["nexus"]["reader"]
2074
                - **definiton**: NeXus application definition to use for saving.
2075
                  Must be supported by the used ``reader``. Defaults to
2076
                  config["nexus"]["definition"]
2077
                - **input_files**: A list of input files to pass to the reader.
2078
                  Defaults to config["nexus"]["input_files"]
2079
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
2080
                  to add to the list of files to pass to the reader.
2081
        """
2082
        if self._binned is None:
1✔
2083
            raise NameError("Need to bin data first!")
1✔
2084

2085
        if self._normalized is not None:
1✔
2086
            data = self._normalized
×
2087
        else:
2088
            data = self._binned
1✔
2089

2090
        extension = pathlib.Path(faddr).suffix
1✔
2091

2092
        if extension in (".tif", ".tiff"):
1✔
2093
            to_tiff(
1✔
2094
                data=data,
2095
                faddr=faddr,
2096
                **kwds,
2097
            )
2098
        elif extension in (".h5", ".hdf5"):
1✔
2099
            to_h5(
1✔
2100
                data=data,
2101
                faddr=faddr,
2102
                **kwds,
2103
            )
2104
        elif extension in (".nxs", ".nexus"):
1✔
2105
            try:
1✔
2106
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
2107
                definition = kwds.pop(
1✔
2108
                    "definition",
2109
                    self._config["nexus"]["definition"],
2110
                )
2111
                input_files = kwds.pop(
1✔
2112
                    "input_files",
2113
                    self._config["nexus"]["input_files"],
2114
                )
2115
            except KeyError as exc:
×
2116
                raise ValueError(
×
2117
                    "The nexus reader, definition and input files need to be provide!",
2118
                ) from exc
2119

2120
            if isinstance(input_files, str):
1✔
2121
                input_files = [input_files]
1✔
2122

2123
            if "eln_data" in kwds:
1✔
2124
                input_files.append(kwds.pop("eln_data"))
×
2125

2126
            to_nexus(
1✔
2127
                data=data,
2128
                faddr=faddr,
2129
                reader=reader,
2130
                definition=definition,
2131
                input_files=input_files,
2132
                **kwds,
2133
            )
2134

2135
        else:
2136
            raise NotImplementedError(
1✔
2137
                f"Unrecognized file format: {extension}.",
2138
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc