• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 6811966001

09 Nov 2023 12:45PM UTC coverage: 90.355% (-0.3%) from 90.677%
6811966001

push

github

web-flow
Merge pull request #250 from OpenCOMPES/fix_save_energy_offset

fix saving energy offset and add global save

22 of 42 new or added lines in 2 files covered. (52.38%)

1 existing line in 1 file now uncovered.

4918 of 5443 relevant lines covered (90.35%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.52
/sed/core/processor.py
1
"""This module contains the core class for the sed package
2

3
"""
4
import pathlib
1✔
5
from typing import Any
1✔
6
from typing import cast
1✔
7
from typing import Dict
1✔
8
from typing import List
1✔
9
from typing import Sequence
1✔
10
from typing import Tuple
1✔
11
from typing import Union
1✔
12

13
import dask.dataframe as ddf
1✔
14
import matplotlib.pyplot as plt
1✔
15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import psutil
1✔
18
import xarray as xr
1✔
19

20
from sed.binning import bin_dataframe
1✔
21
from sed.binning.binning import normalization_histogram_from_timed_dataframe
1✔
22
from sed.binning.binning import normalization_histogram_from_timestamps
1✔
23
from sed.calibrator import DelayCalibrator
1✔
24
from sed.calibrator import EnergyCalibrator
1✔
25
from sed.calibrator import MomentumCorrector
1✔
26
from sed.core.config import parse_config
1✔
27
from sed.core.config import save_config
1✔
28
from sed.core.dfops import apply_jitter
1✔
29
from sed.core.metadata import MetaHandler
1✔
30
from sed.diagnostics import grid_histogram
1✔
31
from sed.io import to_h5
1✔
32
from sed.io import to_nexus
1✔
33
from sed.io import to_tiff
1✔
34
from sed.loader import CopyTool
1✔
35
from sed.loader import get_loader
1✔
36

37
N_CPU = psutil.cpu_count()
1✔
38

39

40
class SedProcessor:
1✔
41
    """Processor class of sed. Contains wrapper functions defining a work flow for data
42
    correction, calibration and binning.
43

44
    Args:
45
        metadata (dict, optional): Dict of external Metadata. Defaults to None.
46
        config (Union[dict, str], optional): Config dictionary or config file name.
47
            Defaults to None.
48
        dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
49
            into the class. Defaults to None.
50
        files (List[str], optional): List of files to pass to the loader defined in
51
            the config. Defaults to None.
52
        folder (str, optional): Folder containing files to pass to the loader
53
            defined in the config. Defaults to None.
54
        collect_metadata (bool): Option to collect metadata from files.
55
            Defaults to False.
56
        **kwds: Keyword arguments passed to the reader.
57
    """
58

59
    def __init__(
1✔
60
        self,
61
        metadata: dict = None,
62
        config: Union[dict, str] = None,
63
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
64
        files: List[str] = None,
65
        folder: str = None,
66
        runs: Sequence[str] = None,
67
        collect_metadata: bool = False,
68
        verbose: bool = False,
69
        **kwds,
70
    ):
71
        """Processor class of sed. Contains wrapper functions defining a work flow
72
        for data correction, calibration, and binning.
73

74
        Args:
75
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
76
            config (Union[dict, str], optional): Config dictionary or config file name.
77
                Defaults to None.
78
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): dataframe to load
79
                into the class. Defaults to None.
80
            files (List[str], optional): List of files to pass to the loader defined in
81
                the config. Defaults to None.
82
            folder (str, optional): Folder containing files to pass to the loader
83
                defined in the config. Defaults to None.
84
            runs (Sequence[str], optional): List of run identifiers to pass to the loader
85
                defined in the config. Defaults to None.
86
            collect_metadata (bool): Option to collect metadata from files.
87
                Defaults to False.
88
            **kwds: Keyword arguments passed to parse_config and to the reader.
89
        """
90
        config_kwds = {
1✔
91
            key: value for key, value in kwds.items() if key in parse_config.__code__.co_varnames
92
        }
93
        for key in config_kwds.keys():
1✔
94
            del kwds[key]
1✔
95
        self._config = parse_config(config, **config_kwds)
1✔
96
        num_cores = self._config.get("binning", {}).get("num_cores", N_CPU - 1)
1✔
97
        if num_cores >= N_CPU:
1✔
98
            num_cores = N_CPU - 1
1✔
99
        self._config["binning"]["num_cores"] = num_cores
1✔
100

101
        self.verbose = verbose
1✔
102

103
        self._dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
104
        self._timed_dataframe: Union[pd.DataFrame, ddf.DataFrame] = None
1✔
105
        self._files: List[str] = []
1✔
106

107
        self._binned: xr.DataArray = None
1✔
108
        self._pre_binned: xr.DataArray = None
1✔
109
        self._normalization_histogram: xr.DataArray = None
1✔
110
        self._normalized: xr.DataArray = None
1✔
111

112
        self._attributes = MetaHandler(meta=metadata)
1✔
113

114
        loader_name = self._config["core"]["loader"]
1✔
115
        self.loader = get_loader(
1✔
116
            loader_name=loader_name,
117
            config=self._config,
118
        )
119

120
        self.ec = EnergyCalibrator(
1✔
121
            loader=self.loader,
122
            config=self._config,
123
        )
124

125
        self.mc = MomentumCorrector(
1✔
126
            config=self._config,
127
        )
128

129
        self.dc = DelayCalibrator(
1✔
130
            config=self._config,
131
        )
132

133
        self.use_copy_tool = self._config.get("core", {}).get(
1✔
134
            "use_copy_tool",
135
            False,
136
        )
137
        if self.use_copy_tool:
1✔
138
            try:
1✔
139
                self.ct = CopyTool(
1✔
140
                    source=self._config["core"]["copy_tool_source"],
141
                    dest=self._config["core"]["copy_tool_dest"],
142
                    **self._config["core"].get("copy_tool_kwds", {}),
143
                )
144
            except KeyError:
1✔
145
                self.use_copy_tool = False
1✔
146

147
        # Load data if provided:
148
        if dataframe is not None or files is not None or folder is not None or runs is not None:
1✔
149
            self.load(
1✔
150
                dataframe=dataframe,
151
                metadata=metadata,
152
                files=files,
153
                folder=folder,
154
                runs=runs,
155
                collect_metadata=collect_metadata,
156
                **kwds,
157
            )
158

159
    def __repr__(self):
1✔
160
        if self._dataframe is None:
1✔
161
            df_str = "Data Frame: No Data loaded"
1✔
162
        else:
163
            df_str = self._dataframe.__repr__()
1✔
164
        attributes_str = f"Metadata: {self._attributes.metadata}"
1✔
165
        pretty_str = df_str + "\n" + attributes_str
1✔
166
        return pretty_str
1✔
167

168
    @property
1✔
169
    def dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
170
        """Accessor to the underlying dataframe.
171

172
        Returns:
173
            Union[pd.DataFrame, ddf.DataFrame]: Dataframe object.
174
        """
175
        return self._dataframe
1✔
176

177
    @dataframe.setter
1✔
178
    def dataframe(self, dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
179
        """Setter for the underlying dataframe.
180

181
        Args:
182
            dataframe (Union[pd.DataFrame, ddf.DataFrame]): The dataframe object to set.
183
        """
184
        if not isinstance(dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
1✔
185
            dataframe,
186
            self._dataframe.__class__,
187
        ):
188
            raise ValueError(
1✔
189
                "'dataframe' has to be a Pandas or Dask dataframe and has to be of the same kind "
190
                "as the dataframe loaded into the SedProcessor!.\n"
191
                f"Loaded type: {self._dataframe.__class__}, provided type: {dataframe}.",
192
            )
193
        self._dataframe = dataframe
1✔
194

195
    @property
1✔
196
    def timed_dataframe(self) -> Union[pd.DataFrame, ddf.DataFrame]:
1✔
197
        """Accessor to the underlying timed_dataframe.
198

199
        Returns:
200
            Union[pd.DataFrame, ddf.DataFrame]: Timed Dataframe object.
201
        """
202
        return self._timed_dataframe
1✔
203

204
    @timed_dataframe.setter
1✔
205
    def timed_dataframe(self, timed_dataframe: Union[pd.DataFrame, ddf.DataFrame]):
1✔
206
        """Setter for the underlying timed dataframe.
207

208
        Args:
209
            timed_dataframe (Union[pd.DataFrame, ddf.DataFrame]): The timed dataframe object to set
210
        """
211
        if not isinstance(timed_dataframe, (pd.DataFrame, ddf.DataFrame)) or not isinstance(
×
212
            timed_dataframe,
213
            self._timed_dataframe.__class__,
214
        ):
215
            raise ValueError(
×
216
                "'timed_dataframe' has to be a Pandas or Dask dataframe and has to be of the same "
217
                "kind as the dataframe loaded into the SedProcessor!.\n"
218
                f"Loaded type: {self._timed_dataframe.__class__}, "
219
                f"provided type: {timed_dataframe}.",
220
            )
221
        self._timed_dataframe = timed_dataframe
×
222

223
    @property
1✔
224
    def attributes(self) -> dict:
1✔
225
        """Accessor to the metadata dict.
226

227
        Returns:
228
            dict: The metadata dict.
229
        """
230
        return self._attributes.metadata
1✔
231

232
    def add_attribute(self, attributes: dict, name: str, **kwds):
1✔
233
        """Function to add element to the attributes dict.
234

235
        Args:
236
            attributes (dict): The attributes dictionary object to add.
237
            name (str): Key under which to add the dictionary to the attributes.
238
        """
239
        self._attributes.add(
1✔
240
            entry=attributes,
241
            name=name,
242
            **kwds,
243
        )
244

245
    @property
1✔
246
    def config(self) -> Dict[Any, Any]:
1✔
247
        """Getter attribute for the config dictionary
248

249
        Returns:
250
            Dict: The config dictionary.
251
        """
252
        return self._config
1✔
253

254
    @property
1✔
255
    def files(self) -> List[str]:
1✔
256
        """Getter attribute for the list of files
257

258
        Returns:
259
            List[str]: The list of loaded files
260
        """
261
        return self._files
1✔
262

263
    @property
1✔
264
    def binned(self) -> xr.DataArray:
1✔
265
        """Getter attribute for the binned data array
266

267
        Returns:
268
            xr.DataArray: The binned data array
269
        """
270
        if self._binned is None:
1✔
271
            raise ValueError("No binned data available, need to compute histogram first!")
×
272
        return self._binned
1✔
273

274
    @property
1✔
275
    def normalized(self) -> xr.DataArray:
1✔
276
        """Getter attribute for the normalized data array
277

278
        Returns:
279
            xr.DataArray: The normalized data array
280
        """
281
        if self._normalized is None:
1✔
282
            raise ValueError(
×
283
                "No normalized data available, compute data with normalization enabled!",
284
            )
285
        return self._normalized
1✔
286

287
    @property
1✔
288
    def normalization_histogram(self) -> xr.DataArray:
1✔
289
        """Getter attribute for the normalization histogram
290

291
        Returns:
292
            xr.DataArray: The normalizazion histogram
293
        """
294
        if self._normalization_histogram is None:
1✔
295
            raise ValueError("No normalization histogram available, generate histogram first!")
×
296
        return self._normalization_histogram
1✔
297

298
    def cpy(self, path: Union[str, List[str]]) -> Union[str, List[str]]:
1✔
299
        """Function to mirror a list of files or a folder from a network drive to a
300
        local storage. Returns either the original or the copied path to the given
301
        path. The option to use this functionality is set by
302
        config["core"]["use_copy_tool"].
303

304
        Args:
305
            path (Union[str, List[str]]): Source path or path list.
306

307
        Returns:
308
            Union[str, List[str]]: Source or destination path or path list.
309
        """
310
        if self.use_copy_tool:
1✔
311
            if isinstance(path, list):
1✔
312
                path_out = []
1✔
313
                for file in path:
1✔
314
                    path_out.append(self.ct.copy(file))
1✔
315
                return path_out
1✔
316

317
            return self.ct.copy(path)
×
318

319
        if isinstance(path, list):
1✔
320
            return path
1✔
321

322
        return path
1✔
323

324
    def load(
1✔
325
        self,
326
        dataframe: Union[pd.DataFrame, ddf.DataFrame] = None,
327
        metadata: dict = None,
328
        files: List[str] = None,
329
        folder: str = None,
330
        runs: Sequence[str] = None,
331
        collect_metadata: bool = False,
332
        **kwds,
333
    ):
334
        """Load tabular data of single events into the dataframe object in the class.
335

336
        Args:
337
            dataframe (Union[pd.DataFrame, ddf.DataFrame], optional): data in tabular
338
                format. Accepts anything which can be interpreted by pd.DataFrame as
339
                an input. Defaults to None.
340
            metadata (dict, optional): Dict of external Metadata. Defaults to None.
341
            files (List[str], optional): List of file paths to pass to the loader.
342
                Defaults to None.
343
            runs (Sequence[str], optional): List of run identifiers to pass to the
344
                loader. Defaults to None.
345
            folder (str, optional): Folder path to pass to the loader.
346
                Defaults to None.
347

348
        Raises:
349
            ValueError: Raised if no valid input is provided.
350
        """
351
        if metadata is None:
1✔
352
            metadata = {}
1✔
353
        if dataframe is not None:
1✔
354
            timed_dataframe = kwds.pop("timed_dataframe", None)
1✔
355
        elif runs is not None:
1✔
356
            # If runs are provided, we only use the copy tool if also folder is provided.
357
            # In that case, we copy the whole provided base folder tree, and pass the copied
358
            # version to the loader as base folder to look for the runs.
359
            if folder is not None:
1✔
360
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
361
                    folders=cast(str, self.cpy(folder)),
362
                    runs=runs,
363
                    metadata=metadata,
364
                    collect_metadata=collect_metadata,
365
                    **kwds,
366
                )
367
            else:
368
                dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
×
369
                    runs=runs,
370
                    metadata=metadata,
371
                    collect_metadata=collect_metadata,
372
                    **kwds,
373
                )
374

375
        elif folder is not None:
1✔
376
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
377
                folders=cast(str, self.cpy(folder)),
378
                metadata=metadata,
379
                collect_metadata=collect_metadata,
380
                **kwds,
381
            )
382
        elif files is not None:
1✔
383
            dataframe, timed_dataframe, metadata = self.loader.read_dataframe(
1✔
384
                files=cast(List[str], self.cpy(files)),
385
                metadata=metadata,
386
                collect_metadata=collect_metadata,
387
                **kwds,
388
            )
389
        else:
390
            raise ValueError(
1✔
391
                "Either 'dataframe', 'files', 'folder', or 'runs' needs to be provided!",
392
            )
393

394
        self._dataframe = dataframe
1✔
395
        self._timed_dataframe = timed_dataframe
1✔
396
        self._files = self.loader.files
1✔
397

398
        for key in metadata:
1✔
399
            self._attributes.add(
1✔
400
                entry=metadata[key],
401
                name=key,
402
                duplicate_policy="merge",
403
            )
404

405
    # Momentum calibration workflow
406
    # 1. Bin raw detector data for distortion correction
407
    def bin_and_load_momentum_calibration(
1✔
408
        self,
409
        df_partitions: int = 100,
410
        axes: List[str] = None,
411
        bins: List[int] = None,
412
        ranges: Sequence[Tuple[float, float]] = None,
413
        plane: int = 0,
414
        width: int = 5,
415
        apply: bool = False,
416
        **kwds,
417
    ):
418
        """1st step of momentum correction work flow. Function to do an initial binning
419
        of the dataframe loaded to the class, slice a plane from it using an
420
        interactive view, and load it into the momentum corrector class.
421

422
        Args:
423
            df_partitions (int, optional): Number of dataframe partitions to use for
424
                the initial binning. Defaults to 100.
425
            axes (List[str], optional): Axes to bin.
426
                Defaults to config["momentum"]["axes"].
427
            bins (List[int], optional): Bin numbers to use for binning.
428
                Defaults to config["momentum"]["bins"].
429
            ranges (List[Tuple], optional): Ranges to use for binning.
430
                Defaults to config["momentum"]["ranges"].
431
            plane (int, optional): Initial value for the plane slider. Defaults to 0.
432
            width (int, optional): Initial value for the width slider. Defaults to 5.
433
            apply (bool, optional): Option to directly apply the values and select the
434
                slice. Defaults to False.
435
            **kwds: Keyword argument passed to the pre_binning function.
436
        """
437
        self._pre_binned = self.pre_binning(
1✔
438
            df_partitions=df_partitions,
439
            axes=axes,
440
            bins=bins,
441
            ranges=ranges,
442
            **kwds,
443
        )
444

445
        self.mc.load_data(data=self._pre_binned)
1✔
446
        self.mc.select_slicer(plane=plane, width=width, apply=apply)
1✔
447

448
    # 2. Generate the spline warp correction from momentum features.
449
    # Either autoselect features, or input features from view above.
450
    def define_features(
1✔
451
        self,
452
        features: np.ndarray = None,
453
        rotation_symmetry: int = 6,
454
        auto_detect: bool = False,
455
        include_center: bool = True,
456
        apply: bool = False,
457
        **kwds,
458
    ):
459
        """2. Step of the distortion correction workflow: Define feature points in
460
        momentum space. They can be either manually selected using a GUI tool, be
461
        ptovided as list of feature points, or auto-generated using a
462
        feature-detection algorithm.
463

464
        Args:
465
            features (np.ndarray, optional): np.ndarray of features. Defaults to None.
466
            rotation_symmetry (int, optional): Number of rotational symmetry axes.
467
                Defaults to 6.
468
            auto_detect (bool, optional): Whether to auto-detect the features.
469
                Defaults to False.
470
            include_center (bool, optional): Option to include a point at the center
471
                in the feature list. Defaults to True.
472
            ***kwds: Keyword arguments for MomentumCorrector.feature_extract() and
473
                MomentumCorrector.feature_select()
474
        """
475
        if auto_detect:  # automatic feature selection
1✔
476
            sigma = kwds.pop("sigma", self._config["momentum"]["sigma"])
×
477
            fwhm = kwds.pop("fwhm", self._config["momentum"]["fwhm"])
×
478
            sigma_radius = kwds.pop(
×
479
                "sigma_radius",
480
                self._config["momentum"]["sigma_radius"],
481
            )
482
            self.mc.feature_extract(
×
483
                sigma=sigma,
484
                fwhm=fwhm,
485
                sigma_radius=sigma_radius,
486
                rotsym=rotation_symmetry,
487
                **kwds,
488
            )
489
            features = self.mc.peaks
×
490

491
        self.mc.feature_select(
1✔
492
            rotsym=rotation_symmetry,
493
            include_center=include_center,
494
            features=features,
495
            apply=apply,
496
            **kwds,
497
        )
498

499
    # 3. Generate the spline warp correction from momentum features.
500
    # If no features have been selected before, use class defaults.
501
    def generate_splinewarp(
1✔
502
        self,
503
        use_center: bool = None,
504
        **kwds,
505
    ):
506
        """3. Step of the distortion correction workflow: Generate the correction
507
        function restoring the symmetry in the image using a splinewarp algortihm.
508

509
        Args:
510
            use_center (bool, optional): Option to use the position of the
511
                center point in the correction. Default is read from config, or set to True.
512
            **kwds: Keyword arguments for MomentumCorrector.spline_warp_estimate().
513
        """
514
        self.mc.spline_warp_estimate(use_center=use_center, **kwds)
1✔
515

516
        if self.mc.slice is not None:
1✔
517
            print("Original slice with reference features")
1✔
518
            self.mc.view(annotated=True, backend="bokeh", crosshair=True)
1✔
519

520
            print("Corrected slice with target features")
1✔
521
            self.mc.view(
1✔
522
                image=self.mc.slice_corrected,
523
                annotated=True,
524
                points={"feats": self.mc.ptargs},
525
                backend="bokeh",
526
                crosshair=True,
527
            )
528

529
            print("Original slice with target features")
1✔
530
            self.mc.view(
1✔
531
                image=self.mc.slice,
532
                points={"feats": self.mc.ptargs},
533
                annotated=True,
534
                backend="bokeh",
535
            )
536

537
    # 3a. Save spline-warp parameters to config file.
538
    def save_splinewarp(
1✔
539
        self,
540
        filename: str = None,
541
        overwrite: bool = False,
542
    ):
543
        """Save the generated spline-warp parameters to the folder config file.
544

545
        Args:
546
            filename (str, optional): Filename of the config dictionary to save to.
547
                Defaults to "sed_config.yaml" in the current folder.
548
            overwrite (bool, optional): Option to overwrite the present dictionary.
549
                Defaults to False.
550
        """
551
        if filename is None:
1✔
552
            filename = "sed_config.yaml"
×
553
        points = []
1✔
554
        try:
1✔
555
            for point in self.mc.pouter_ord:
1✔
556
                points.append([float(i) for i in point])
1✔
557
            if self.mc.include_center:
1✔
558
                points.append([float(i) for i in self.mc.pcent])
1✔
559
        except AttributeError as exc:
×
560
            raise AttributeError(
×
561
                "Momentum correction parameters not found, need to generate parameters first!",
562
            ) from exc
563
        config = {
1✔
564
            "momentum": {
565
                "correction": {
566
                    "rotation_symmetry": self.mc.rotsym,
567
                    "feature_points": points,
568
                    "include_center": self.mc.include_center,
569
                    "use_center": self.mc.use_center,
570
                },
571
            },
572
        }
573
        save_config(config, filename, overwrite)
1✔
574

575
    # 4. Pose corrections. Provide interactive interface for correcting
576
    # scaling, shift and rotation
577
    def pose_adjustment(
1✔
578
        self,
579
        scale: float = 1,
580
        xtrans: float = 0,
581
        ytrans: float = 0,
582
        angle: float = 0,
583
        apply: bool = False,
584
        use_correction: bool = True,
585
        reset: bool = True,
586
    ):
587
        """3. step of the distortion correction workflow: Generate an interactive panel
588
        to adjust affine transformations that are applied to the image. Applies first
589
        a scaling, next an x/y translation, and last a rotation around the center of
590
        the image.
591

592
        Args:
593
            scale (float, optional): Initial value of the scaling slider.
594
                Defaults to 1.
595
            xtrans (float, optional): Initial value of the xtrans slider.
596
                Defaults to 0.
597
            ytrans (float, optional): Initial value of the ytrans slider.
598
                Defaults to 0.
599
            angle (float, optional): Initial value of the angle slider.
600
                Defaults to 0.
601
            apply (bool, optional): Option to directly apply the provided
602
                transformations. Defaults to False.
603
            use_correction (bool, option): Whether to use the spline warp correction
604
                or not. Defaults to True.
605
            reset (bool, optional):
606
                Option to reset the correction before transformation. Defaults to True.
607
        """
608
        # Generate homomorphy as default if no distortion correction has been applied
609
        if self.mc.slice_corrected is None:
1✔
610
            if self.mc.slice is None:
1✔
611
                raise ValueError(
1✔
612
                    "No slice for corrections and transformations loaded!",
613
                )
614
            self.mc.slice_corrected = self.mc.slice
×
615

616
        if not use_correction:
1✔
617
            self.mc.reset_deformation()
1✔
618

619
        if self.mc.cdeform_field is None or self.mc.rdeform_field is None:
1✔
620
            # Generate distortion correction from config values
621
            self.mc.add_features()
×
622
            self.mc.spline_warp_estimate()
×
623

624
        self.mc.pose_adjustment(
1✔
625
            scale=scale,
626
            xtrans=xtrans,
627
            ytrans=ytrans,
628
            angle=angle,
629
            apply=apply,
630
            reset=reset,
631
        )
632

633
    # 5. Apply the momentum correction to the dataframe
634
    def apply_momentum_correction(
1✔
635
        self,
636
        preview: bool = False,
637
    ):
638
        """Applies the distortion correction and pose adjustment (optional)
639
        to the dataframe.
640

641
        Args:
642
            rdeform_field (np.ndarray, optional): Row deformation field.
643
                Defaults to None.
644
            cdeform_field (np.ndarray, optional): Column deformation field.
645
                Defaults to None.
646
            inv_dfield (np.ndarray, optional): Inverse deformation field.
647
                Defaults to None.
648
            preview (bool): Option to preview the first elements of the data frame.
649
        """
650
        if self._dataframe is not None:
1✔
651
            print("Adding corrected X/Y columns to dataframe:")
1✔
652
            self._dataframe, metadata = self.mc.apply_corrections(
1✔
653
                df=self._dataframe,
654
            )
655
            if self._timed_dataframe is not None:
1✔
656
                if (
1✔
657
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
658
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
659
                ):
660
                    self._timed_dataframe, _ = self.mc.apply_corrections(
1✔
661
                        self._timed_dataframe,
662
                    )
663
            # Add Metadata
664
            self._attributes.add(
1✔
665
                metadata,
666
                "momentum_correction",
667
                duplicate_policy="merge",
668
            )
669
            if preview:
1✔
670
                print(self._dataframe.head(10))
×
671
            else:
672
                if self.verbose:
1✔
673
                    print(self._dataframe)
×
674

675
    # Momentum calibration work flow
676
    # 1. Calculate momentum calibration
677
    def calibrate_momentum_axes(
1✔
678
        self,
679
        point_a: Union[np.ndarray, List[int]] = None,
680
        point_b: Union[np.ndarray, List[int]] = None,
681
        k_distance: float = None,
682
        k_coord_a: Union[np.ndarray, List[float]] = None,
683
        k_coord_b: Union[np.ndarray, List[float]] = np.array([0.0, 0.0]),
684
        equiscale: bool = True,
685
        apply=False,
686
    ):
687
        """1. step of the momentum calibration workflow. Calibrate momentum
688
        axes using either provided pixel coordinates of a high-symmetry point and its
689
        distance to the BZ center, or the k-coordinates of two points in the BZ
690
        (depending on the equiscale option). Opens an interactive panel for selecting
691
        the points.
692

693
        Args:
694
            point_a (Union[np.ndarray, List[int]]): Pixel coordinates of the first
695
                point used for momentum calibration.
696
            point_b (Union[np.ndarray, List[int]], optional): Pixel coordinates of the
697
                second point used for momentum calibration.
698
                Defaults to config["momentum"]["center_pixel"].
699
            k_distance (float, optional): Momentum distance between point a and b.
700
                Needs to be provided if no specific k-koordinates for the two points
701
                are given. Defaults to None.
702
            k_coord_a (Union[np.ndarray, List[float]], optional): Momentum coordinate
703
                of the first point used for calibration. Used if equiscale is False.
704
                Defaults to None.
705
            k_coord_b (Union[np.ndarray, List[float]], optional): Momentum coordinate
706
                of the second point used for calibration. Defaults to [0.0, 0.0].
707
            equiscale (bool, optional): Option to apply different scales to kx and ky.
708
                If True, the distance between points a and b, and the absolute
709
                position of point a are used for defining the scale. If False, the
710
                scale is calculated from the k-positions of both points a and b.
711
                Defaults to True.
712
            apply (bool, optional): Option to directly store the momentum calibration
713
                in the class. Defaults to False.
714
        """
715
        if point_b is None:
1✔
716
            point_b = self._config["momentum"]["center_pixel"]
1✔
717

718
        self.mc.select_k_range(
1✔
719
            point_a=point_a,
720
            point_b=point_b,
721
            k_distance=k_distance,
722
            k_coord_a=k_coord_a,
723
            k_coord_b=k_coord_b,
724
            equiscale=equiscale,
725
            apply=apply,
726
        )
727

728
    # 1a. Save momentum calibration parameters to config file.
729
    def save_momentum_calibration(
1✔
730
        self,
731
        filename: str = None,
732
        overwrite: bool = False,
733
    ):
734
        """Save the generated momentum calibration parameters to the folder config file.
735

736
        Args:
737
            filename (str, optional): Filename of the config dictionary to save to.
738
                Defaults to "sed_config.yaml" in the current folder.
739
            overwrite (bool, optional): Option to overwrite the present dictionary.
740
                Defaults to False.
741
        """
742
        if filename is None:
1✔
743
            filename = "sed_config.yaml"
×
744
        calibration = {}
1✔
745
        try:
1✔
746
            for key in [
1✔
747
                "kx_scale",
748
                "ky_scale",
749
                "x_center",
750
                "y_center",
751
                "rstart",
752
                "cstart",
753
                "rstep",
754
                "cstep",
755
            ]:
756
                calibration[key] = float(self.mc.calibration[key])
1✔
757
        except KeyError as exc:
×
758
            raise KeyError(
×
759
                "Momentum calibration parameters not found, need to generate parameters first!",
760
            ) from exc
761

762
        config = {"momentum": {"calibration": calibration}}
1✔
763
        save_config(config, filename, overwrite)
1✔
764
        print(f"Saved momentum calibration parameters to {filename}")
1✔
765

766
    # 2. Apply correction and calibration to the dataframe
767
    def apply_momentum_calibration(
1✔
768
        self,
769
        calibration: dict = None,
770
        preview: bool = False,
771
    ):
772
        """2. step of the momentum calibration work flow: Apply the momentum
773
        calibration stored in the class to the dataframe. If corrected X/Y axis exist,
774
        these are used.
775

776
        Args:
777
            calibration (dict, optional): Optional dictionary with calibration data to
778
                use. Defaults to None.
779
            preview (bool): Option to preview the first elements of the data frame.
780
        """
781
        if self._dataframe is not None:
1✔
782

783
            print("Adding kx/ky columns to dataframe:")
1✔
784
            self._dataframe, metadata = self.mc.append_k_axis(
1✔
785
                df=self._dataframe,
786
                calibration=calibration,
787
            )
788
            if self._timed_dataframe is not None:
1✔
789
                if (
1✔
790
                    self._config["dataframe"]["x_column"] in self._timed_dataframe.columns
791
                    and self._config["dataframe"]["y_column"] in self._timed_dataframe.columns
792
                ):
793
                    self._timed_dataframe, _ = self.mc.append_k_axis(
1✔
794
                        df=self._timed_dataframe,
795
                        calibration=calibration,
796
                    )
797

798
            # Add Metadata
799
            self._attributes.add(
1✔
800
                metadata,
801
                "momentum_calibration",
802
                duplicate_policy="merge",
803
            )
804
            if preview:
1✔
805
                print(self._dataframe.head(10))
×
806
            else:
807
                if self.verbose:
1✔
808
                    print(self._dataframe)
×
809

810
    # Energy correction workflow
811
    # 1. Adjust the energy correction parameters
812
    def adjust_energy_correction(
1✔
813
        self,
814
        correction_type: str = None,
815
        amplitude: float = None,
816
        center: Tuple[float, float] = None,
817
        apply=False,
818
        **kwds,
819
    ):
820
        """1. step of the energy crrection workflow: Opens an interactive plot to
821
        adjust the parameters for the TOF/energy correction. Also pre-bins the data if
822
        they are not present yet.
823

824
        Args:
825
            correction_type (str, optional): Type of correction to apply to the TOF
826
                axis. Valid values are:
827

828
                - 'spherical'
829
                - 'Lorentzian'
830
                - 'Gaussian'
831
                - 'Lorentzian_asymmetric'
832

833
                Defaults to config["energy"]["correction_type"].
834
            amplitude (float, optional): Amplitude of the correction.
835
                Defaults to config["energy"]["correction"]["amplitude"].
836
            center (Tuple[float, float], optional): Center X/Y coordinates for the
837
                correction. Defaults to config["energy"]["correction"]["center"].
838
            apply (bool, optional): Option to directly apply the provided or default
839
                correction parameters. Defaults to False.
840
        """
841
        if self._pre_binned is None:
1✔
842
            print(
1✔
843
                "Pre-binned data not present, binning using defaults from config...",
844
            )
845
            self._pre_binned = self.pre_binning()
1✔
846

847
        self.ec.adjust_energy_correction(
1✔
848
            self._pre_binned,
849
            correction_type=correction_type,
850
            amplitude=amplitude,
851
            center=center,
852
            apply=apply,
853
            **kwds,
854
        )
855

856
    # 1a. Save energy correction parameters to config file.
857
    def save_energy_correction(
1✔
858
        self,
859
        filename: str = None,
860
        overwrite: bool = False,
861
    ):
862
        """Save the generated energy correction parameters to the folder config file.
863

864
        Args:
865
            filename (str, optional): Filename of the config dictionary to save to.
866
                Defaults to "sed_config.yaml" in the current folder.
867
            overwrite (bool, optional): Option to overwrite the present dictionary.
868
                Defaults to False.
869
        """
870
        if filename is None:
1✔
871
            filename = "sed_config.yaml"
1✔
872
        correction = {}
1✔
873
        try:
1✔
874
            for key, val in self.ec.correction.items():
1✔
875
                if key == "correction_type":
1✔
876
                    correction[key] = val
1✔
877
                elif key == "center":
1✔
878
                    correction[key] = [float(i) for i in val]
1✔
879
                else:
880
                    correction[key] = float(val)
1✔
881
        except AttributeError as exc:
×
882
            raise AttributeError(
×
883
                "Energy correction parameters not found, need to generate parameters first!",
884
            ) from exc
885

886
        config = {"energy": {"correction": correction}}
1✔
887
        save_config(config, filename, overwrite)
1✔
888
        print(f"Saved energy correction parameters to {filename}")
1✔
889

890
    # 2. Apply energy correction to dataframe
891
    def apply_energy_correction(
1✔
892
        self,
893
        correction: dict = None,
894
        preview: bool = False,
895
        **kwds,
896
    ):
897
        """2. step of the energy correction workflow: Apply the enery correction
898
        parameters stored in the class to the dataframe.
899

900
        Args:
901
            correction (dict, optional): Dictionary containing the correction
902
                parameters. Defaults to config["energy"]["calibration"].
903
            preview (bool): Option to preview the first elements of the data frame.
904
            **kwds:
905
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
906
            preview (bool): Option to preview the first elements of the data frame.
907
            **kwds:
908
                Keyword args passed to ``EnergyCalibrator.apply_energy_correction``.
909
        """
910
        if self._dataframe is not None:
1✔
911
            print("Applying energy correction to dataframe...")
1✔
912
            self._dataframe, metadata = self.ec.apply_energy_correction(
1✔
913
                df=self._dataframe,
914
                correction=correction,
915
                **kwds,
916
            )
917
            if self._timed_dataframe is not None:
1✔
918
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
919
                    self._timed_dataframe, _ = self.ec.apply_energy_correction(
1✔
920
                        df=self._timed_dataframe,
921
                        correction=correction,
922
                        **kwds,
923
                    )
924

925
            # Add Metadata
926
            self._attributes.add(
1✔
927
                metadata,
928
                "energy_correction",
929
            )
930
            if preview:
1✔
931
                print(self._dataframe.head(10))
×
932
            else:
933
                if self.verbose:
1✔
934
                    print(self._dataframe)
×
935

936
    # Energy calibrator workflow
937
    # 1. Load and normalize data
938
    def load_bias_series(
1✔
939
        self,
940
        binned_data: Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]] = None,
941
        data_files: List[str] = None,
942
        axes: List[str] = None,
943
        bins: List = None,
944
        ranges: Sequence[Tuple[float, float]] = None,
945
        biases: np.ndarray = None,
946
        bias_key: str = None,
947
        normalize: bool = None,
948
        span: int = None,
949
        order: int = None,
950
    ):
951
        """1. step of the energy calibration workflow: Load and bin data from
952
        single-event files, or load binned bias/TOF traces.
953

954
        Args:
955
            binned_data (Union[xr.DataArray, Tuple[np.ndarray, np.ndarray, np.ndarray]], optional):
956
                Binned data If provided as DataArray, Needs to contain dimensions
957
                config["dataframe"]["tof_column"] and config["dataframe"]["bias_column"]. If
958
                provided as tuple, needs to contain elements tof, biases, traces.
959
            data_files (List[str], optional): list of file paths to bin
960
            axes (List[str], optional): bin axes.
961
                Defaults to config["dataframe"]["tof_column"].
962
            bins (List, optional): number of bins.
963
                Defaults to config["energy"]["bins"].
964
            ranges (Sequence[Tuple[float, float]], optional): bin ranges.
965
                Defaults to config["energy"]["ranges"].
966
            biases (np.ndarray, optional): Bias voltages used. If missing, bias
967
                voltages are extracted from the data files.
968
            bias_key (str, optional): hdf5 path where bias values are stored.
969
                Defaults to config["energy"]["bias_key"].
970
            normalize (bool, optional): Option to normalize traces.
971
                Defaults to config["energy"]["normalize"].
972
            span (int, optional): span smoothing parameters of the LOESS method
973
                (see ``scipy.signal.savgol_filter()``).
974
                Defaults to config["energy"]["normalize_span"].
975
            order (int, optional): order smoothing parameters of the LOESS method
976
                (see ``scipy.signal.savgol_filter()``).
977
                Defaults to config["energy"]["normalize_order"].
978
        """
979
        if binned_data is not None:
1✔
980
            if isinstance(binned_data, xr.DataArray):
1✔
981
                if (
1✔
982
                    self._config["dataframe"]["tof_column"] not in binned_data.dims
983
                    or self._config["dataframe"]["bias_column"] not in binned_data.dims
984
                ):
985
                    raise ValueError(
1✔
986
                        "If binned_data is provided as an xarray, it needs to contain dimensions "
987
                        f"'{self._config['dataframe']['tof_column']}' and "
988
                        f"'{self._config['dataframe']['bias_column']}'!.",
989
                    )
990
                tof = binned_data.coords[self._config["dataframe"]["tof_column"]].values
1✔
991
                biases = binned_data.coords[self._config["dataframe"]["bias_column"]].values
1✔
992
                traces = binned_data.values[:, :]
1✔
993
            else:
994
                try:
1✔
995
                    (tof, biases, traces) = binned_data
1✔
996
                except ValueError as exc:
1✔
997
                    raise ValueError(
1✔
998
                        "If binned_data is provided as tuple, it needs to contain "
999
                        "(tof, biases, traces)!",
1000
                    ) from exc
1001
            self.ec.load_data(biases=biases, traces=traces, tof=tof)
1✔
1002

1003
        elif data_files is not None:
1✔
1004

1005
            self.ec.bin_data(
1✔
1006
                data_files=cast(List[str], self.cpy(data_files)),
1007
                axes=axes,
1008
                bins=bins,
1009
                ranges=ranges,
1010
                biases=biases,
1011
                bias_key=bias_key,
1012
            )
1013

1014
        else:
1015
            raise ValueError("Either binned_data or data_files needs to be provided!")
1✔
1016

1017
        if (normalize is not None and normalize is True) or (
1✔
1018
            normalize is None and self._config["energy"]["normalize"]
1019
        ):
1020
            if span is None:
1✔
1021
                span = self._config["energy"]["normalize_span"]
1✔
1022
            if order is None:
1✔
1023
                order = self._config["energy"]["normalize_order"]
1✔
1024
            self.ec.normalize(smooth=True, span=span, order=order)
1✔
1025
        self.ec.view(
1✔
1026
            traces=self.ec.traces_normed,
1027
            xaxis=self.ec.tof,
1028
            backend="bokeh",
1029
        )
1030

1031
    # 2. extract ranges and get peak positions
1032
    def find_bias_peaks(
1✔
1033
        self,
1034
        ranges: Union[List[Tuple], Tuple],
1035
        ref_id: int = 0,
1036
        infer_others: bool = True,
1037
        mode: str = "replace",
1038
        radius: int = None,
1039
        peak_window: int = None,
1040
        apply: bool = False,
1041
    ):
1042
        """2. step of the energy calibration workflow: Find a peak within a given range
1043
        for the indicated reference trace, and tries to find the same peak for all
1044
        other traces. Uses fast_dtw to align curves, which might not be too good if the
1045
        shape of curves changes qualitatively. Ideally, choose a reference trace in the
1046
        middle of the set, and don't choose the range too narrow around the peak.
1047
        Alternatively, a list of ranges for all traces can be provided.
1048

1049
        Args:
1050
            ranges (Union[List[Tuple], Tuple]): Tuple of TOF values indicating a range.
1051
                Alternatively, a list of ranges for all traces can be given.
1052
            refid (int, optional): The id of the trace the range refers to.
1053
                Defaults to 0.
1054
            infer_others (bool, optional): Whether to determine the range for the other
1055
                traces. Defaults to True.
1056
            mode (str, optional): Whether to "add" or "replace" existing ranges.
1057
                Defaults to "replace".
1058
            radius (int, optional): Radius parameter for fast_dtw.
1059
                Defaults to config["energy"]["fastdtw_radius"].
1060
            peak_window (int, optional): Peak_window parameter for the peak detection
1061
                algorthm. amount of points that have to have to behave monotoneously
1062
                around a peak. Defaults to config["energy"]["peak_window"].
1063
            apply (bool, optional): Option to directly apply the provided parameters.
1064
                Defaults to False.
1065
        """
1066
        if radius is None:
1✔
1067
            radius = self._config["energy"]["fastdtw_radius"]
1✔
1068
        if peak_window is None:
1✔
1069
            peak_window = self._config["energy"]["peak_window"]
1✔
1070
        if not infer_others:
1✔
1071
            self.ec.add_ranges(
1✔
1072
                ranges=ranges,
1073
                ref_id=ref_id,
1074
                infer_others=infer_others,
1075
                mode=mode,
1076
                radius=radius,
1077
            )
1078
            print(self.ec.featranges)
1✔
1079
            try:
1✔
1080
                self.ec.feature_extract(peak_window=peak_window)
1✔
1081
                self.ec.view(
1✔
1082
                    traces=self.ec.traces_normed,
1083
                    segs=self.ec.featranges,
1084
                    xaxis=self.ec.tof,
1085
                    peaks=self.ec.peaks,
1086
                    backend="bokeh",
1087
                )
1088
            except IndexError:
×
1089
                print("Could not determine all peaks!")
×
1090
                raise
×
1091
        else:
1092
            # New adjustment tool
1093
            assert isinstance(ranges, tuple)
1✔
1094
            self.ec.adjust_ranges(
1✔
1095
                ranges=ranges,
1096
                ref_id=ref_id,
1097
                traces=self.ec.traces_normed,
1098
                infer_others=infer_others,
1099
                radius=radius,
1100
                peak_window=peak_window,
1101
                apply=apply,
1102
            )
1103

1104
    # 3. Fit the energy calibration relation
1105
    def calibrate_energy_axis(
1✔
1106
        self,
1107
        ref_id: int,
1108
        ref_energy: float,
1109
        method: str = None,
1110
        energy_scale: str = None,
1111
        **kwds,
1112
    ):
1113
        """3. Step of the energy calibration workflow: Calculate the calibration
1114
        function for the energy axis, and apply it to the dataframe. Two
1115
        approximations are implemented, a (normally 3rd order) polynomial
1116
        approximation, and a d^2/(t-t0)^2 relation.
1117

1118
        Args:
1119
            ref_id (int): id of the trace at the bias where the reference energy is
1120
                given.
1121
            ref_energy (float): Absolute energy of the detected feature at the bias
1122
                of ref_id
1123
            method (str, optional): Method for determining the energy calibration.
1124

1125
                - **'lmfit'**: Energy calibration using lmfit and 1/t^2 form.
1126
                - **'lstsq'**, **'lsqr'**: Energy calibration using polynomial form.
1127

1128
                Defaults to config["energy"]["calibration_method"]
1129
            energy_scale (str, optional): Direction of increasing energy scale.
1130

1131
                - **'kinetic'**: increasing energy with decreasing TOF.
1132
                - **'binding'**: increasing energy with increasing TOF.
1133

1134
                Defaults to config["energy"]["energy_scale"]
1135
        """
1136
        if method is None:
1✔
1137
            method = self._config["energy"]["calibration_method"]
1✔
1138

1139
        if energy_scale is None:
1✔
1140
            energy_scale = self._config["energy"]["energy_scale"]
1✔
1141

1142
        self.ec.calibrate(
1✔
1143
            ref_id=ref_id,
1144
            ref_energy=ref_energy,
1145
            method=method,
1146
            energy_scale=energy_scale,
1147
            **kwds,
1148
        )
1149
        print("Quality of Calibration:")
1✔
1150
        self.ec.view(
1✔
1151
            traces=self.ec.traces_normed,
1152
            xaxis=self.ec.calibration["axis"],
1153
            align=True,
1154
            energy_scale=energy_scale,
1155
            backend="bokeh",
1156
        )
1157
        print("E/TOF relationship:")
1✔
1158
        self.ec.view(
1✔
1159
            traces=self.ec.calibration["axis"][None, :],
1160
            xaxis=self.ec.tof,
1161
            backend="matplotlib",
1162
            show_legend=False,
1163
        )
1164
        if energy_scale == "kinetic":
1✔
1165
            plt.scatter(
1✔
1166
                self.ec.peaks[:, 0],
1167
                -(self.ec.biases - self.ec.biases[ref_id]) + ref_energy,
1168
                s=50,
1169
                c="k",
1170
            )
1171
        elif energy_scale == "binding":
1✔
1172
            plt.scatter(
1✔
1173
                self.ec.peaks[:, 0],
1174
                self.ec.biases - self.ec.biases[ref_id] + ref_energy,
1175
                s=50,
1176
                c="k",
1177
            )
1178
        else:
1179
            raise ValueError(
×
1180
                'energy_scale needs to be either "binding" or "kinetic"',
1181
                f", got {energy_scale}.",
1182
            )
1183
        plt.xlabel("Time-of-flight", fontsize=15)
1✔
1184
        plt.ylabel("Energy (eV)", fontsize=15)
1✔
1185
        plt.show()
1✔
1186

1187
    # 3a. Save energy calibration parameters to config file.
1188
    def save_energy_calibration(
1✔
1189
        self,
1190
        filename: str = None,
1191
        overwrite: bool = False,
1192
    ):
1193
        """Save the generated energy calibration parameters to the folder config file.
1194

1195
        Args:
1196
            filename (str, optional): Filename of the config dictionary to save to.
1197
                Defaults to "sed_config.yaml" in the current folder.
1198
            overwrite (bool, optional): Option to overwrite the present dictionary.
1199
                Defaults to False.
1200
        """
1201
        if filename is None:
1✔
1202
            filename = "sed_config.yaml"
×
1203
        calibration = {}
1✔
1204
        try:
1✔
1205
            for (key, value) in self.ec.calibration.items():
1✔
1206
                if key in ["axis", "refid", "Tmat", "bvec"]:
1✔
1207
                    continue
1✔
1208
                if key == "energy_scale":
1✔
1209
                    calibration[key] = value
1✔
1210
                elif key == "coeffs":
1✔
1211
                    calibration[key] = [float(i) for i in value]
1✔
1212
                else:
1213
                    calibration[key] = float(value)
1✔
1214
        except AttributeError as exc:
×
1215
            raise AttributeError(
×
1216
                "Energy calibration parameters not found, need to generate parameters first!",
1217
            ) from exc
1218

1219
        config = {"energy": {"calibration": calibration}}
1✔
1220
        save_config(config, filename, overwrite)
1✔
1221
        print(f'Saved energy calibration parameters to "{filename}".')
1✔
1222

1223
    # 4. Apply energy calibration to the dataframe
1224
    def append_energy_axis(
1✔
1225
        self,
1226
        calibration: dict = None,
1227
        preview: bool = False,
1228
        **kwds,
1229
    ):
1230
        """4. step of the energy calibration workflow: Apply the calibration function
1231
        to to the dataframe. Two approximations are implemented, a (normally 3rd order)
1232
        polynomial approximation, and a d^2/(t-t0)^2 relation. a calibration dictionary
1233
        can be provided.
1234

1235
        Args:
1236
            calibration (dict, optional): Calibration dict containing calibration
1237
                parameters. Overrides calibration from class or config.
1238
                Defaults to None.
1239
            preview (bool): Option to preview the first elements of the data frame.
1240
            **kwds:
1241
                Keyword args passed to ``EnergyCalibrator.append_energy_axis``.
1242
        """
1243
        if self._dataframe is not None:
1✔
1244
            print("Adding energy column to dataframe:")
1✔
1245
            self._dataframe, metadata = self.ec.append_energy_axis(
1✔
1246
                df=self._dataframe,
1247
                calibration=calibration,
1248
                **kwds,
1249
            )
1250
            if self._timed_dataframe is not None:
1✔
1251
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1252
                    self._timed_dataframe, _ = self.ec.append_energy_axis(
1✔
1253
                        df=self._timed_dataframe,
1254
                        calibration=calibration,
1255
                        **kwds,
1256
                    )
1257

1258
            # Add Metadata
1259
            self._attributes.add(
1✔
1260
                metadata,
1261
                "energy_calibration",
1262
                duplicate_policy="merge",
1263
            )
1264
            if preview:
1✔
1265
                print(self._dataframe.head(10))
1✔
1266
            else:
1267
                if self.verbose:
1✔
1268
                    print(self._dataframe)
×
1269

1270
    def add_energy_offset(
1✔
1271
        self,
1272
        constant: float = None,
1273
        columns: Union[str, Sequence[str]] = None,
1274
        signs: Union[int, Sequence[int]] = None,
1275
        reductions: Union[str, Sequence[str]] = None,
1276
        preserve_mean: Union[bool, Sequence[bool]] = None,
1277
    ) -> None:
1278
        """Shift the energy axis of the dataframe by a given amount.
1279

1280
        Args:
1281
            constant (float, optional): The constant to shift the energy axis by.
1282
            columns (Union[str, Sequence[str]]): Name of the column(s) to apply the shift from.
1283
            signs (Union[int, Sequence[int]]): Sign of the shift to apply. (+1 or -1) A positive
1284
                sign shifts the energy axis to higher kinetic energies. Defaults to +1.
1285
            preserve_mean (bool): Whether to subtract the mean of the column before applying the
1286
                shift. Defaults to False.
1287
            reductions (str): The reduction to apply to the column. Should be an available method
1288
                of dask.dataframe.Series. For example "mean". In this case the function is applied
1289
                to the column to generate a single value for the whole dataset. If None, the shift
1290
                is applied per-dataframe-row. Defaults to None. Currently only "mean" is supported.
1291

1292
        Raises:
1293
            ValueError: If the energy column is not in the dataframe.
1294
        """
1295
        energy_column = self._config["dataframe"]["energy_column"]
1✔
1296
        if self.dataframe is not None:
1✔
1297
            if energy_column not in self._dataframe.columns:
1✔
1298
                raise ValueError(
1✔
1299
                    f"Energy column {energy_column} not found in dataframe! "
1300
                    "Run `append energy axis` first.",
1301
                )
1302
            df, metadata = self.ec.add_offsets(
1✔
1303
                df=self._dataframe,
1304
                constant=constant,
1305
                columns=columns,
1306
                energy_column=energy_column,
1307
                signs=signs,
1308
                reductions=reductions,
1309
                preserve_mean=preserve_mean,
1310
            )
1311
            if self._timed_dataframe is not None:
1✔
1312
                if energy_column in self._timed_dataframe.columns:
1✔
1313
                    self._timed_dataframe, _ = self.ec.add_offsets(
1✔
1314
                        df=self._timed_dataframe,
1315
                        constant=constant,
1316
                        columns=columns,
1317
                        energy_column=energy_column,
1318
                        signs=signs,
1319
                        reductions=reductions,
1320
                        preserve_mean=preserve_mean,
1321
                    )
1322
            self._attributes.add(
1✔
1323
                metadata,
1324
                "add_energy_offset",
1325
                # TODO: allow only appending when no offset along this column(s) was applied
1326
                # TODO: clear memory of modifications if the energy axis is recalculated
1327
                duplicate_policy="append",
1328
            )
1329
            self._dataframe = df
1✔
1330
        else:
1331
            raise ValueError("No dataframe loaded!")
×
1332

1333
    def save_energy_offset(
1✔
1334
        self,
1335
        filename: str = None,
1336
        overwrite: bool = False,
1337
    ):
1338
        """Save the generated energy calibration parameters to the folder config file.
1339

1340
        Args:
1341
            filename (str, optional): Filename of the config dictionary to save to.
1342
                Defaults to "sed_config.yaml" in the current folder.
1343
            overwrite (bool, optional): Option to overwrite the present dictionary.
1344
                Defaults to False.
1345
        """
NEW
1346
        if filename is None:
×
NEW
1347
            filename = "sed_config.yaml"
×
NEW
1348
        if len(self.ec.offset) == 0:
×
NEW
1349
            raise ValueError("No energy offset parameters to save!")
×
NEW
1350
        config = {"energy": {"offset": self.ec.offset}}
×
NEW
1351
        save_config(config, filename, overwrite)
×
NEW
1352
        print(f'Saved energy offset parameters to "{filename}".')
×
1353

1354
    def append_tof_ns_axis(
1✔
1355
        self,
1356
        **kwargs,
1357
    ):
1358
        """Convert time-of-flight channel steps to nanoseconds.
1359

1360
        Args:
1361
            tof_ns_column (str, optional): Name of the generated column containing the
1362
                time-of-flight in nanosecond.
1363
                Defaults to config["dataframe"]["tof_ns_column"].
1364
            kwargs: additional arguments are passed to ``energy.tof_step_to_ns``.
1365

1366
        """
1367
        if self._dataframe is not None:
1✔
1368
            print("Adding time-of-flight column in nanoseconds to dataframe:")
1✔
1369
            # TODO assert order of execution through metadata
1370

1371
            self._dataframe, metadata = self.ec.append_tof_ns_axis(
1✔
1372
                df=self._dataframe,
1373
                **kwargs,
1374
            )
1375
            if self._timed_dataframe is not None:
1✔
1376
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1377
                    self._timed_dataframe, _ = self.ec.append_tof_ns_axis(
1✔
1378
                        df=self._timed_dataframe,
1379
                        **kwargs,
1380
                    )
1381
            self._attributes.add(
1✔
1382
                metadata,
1383
                "tof_ns_conversion",
1384
                duplicate_policy="append",
1385
            )
1386

1387
    def align_dld_sectors(self, sector_delays: np.ndarray = None, **kwargs):
1✔
1388
        """Align the 8s sectors of the HEXTOF endstation.
1389

1390
        Args:
1391
            sector_delays (np.ndarray, optional): Array containing the sector delays. Defaults to
1392
                config["dataframe"]["sector_delays"].
1393
        """
1394
        if self._dataframe is not None:
1✔
1395
            print("Aligning 8s sectors of dataframe")
1✔
1396
            # TODO assert order of execution through metadata
1397
            self._dataframe, metadata = self.ec.align_dld_sectors(
1✔
1398
                df=self._dataframe,
1399
                sector_delays=sector_delays,
1400
                **kwargs,
1401
            )
1402
            if self._timed_dataframe is not None:
1✔
1403
                if self._config["dataframe"]["tof_column"] in self._timed_dataframe.columns:
1✔
1404
                    self._timed_dataframe, _ = self.ec.align_dld_sectors(
×
1405
                        df=self._timed_dataframe,
1406
                        sector_delays=sector_delays,
1407
                        **kwargs,
1408
                    )
1409
            self._attributes.add(
1✔
1410
                metadata,
1411
                "dld_sector_alignment",
1412
                duplicate_policy="raise",
1413
            )
1414

1415
    # Delay calibration function
1416
    def calibrate_delay_axis(
1✔
1417
        self,
1418
        delay_range: Tuple[float, float] = None,
1419
        datafile: str = None,
1420
        preview: bool = False,
1421
        **kwds,
1422
    ):
1423
        """Append delay column to dataframe. Either provide delay ranges, or read
1424
        them from a file.
1425

1426
        Args:
1427
            delay_range (Tuple[float, float], optional): The scanned delay range in
1428
                picoseconds. Defaults to None.
1429
            datafile (str, optional): The file from which to read the delay ranges.
1430
                Defaults to None.
1431
            preview (bool): Option to preview the first elements of the data frame.
1432
            **kwds: Keyword args passed to ``DelayCalibrator.append_delay_axis``.
1433
        """
1434
        if self._dataframe is not None:
1✔
1435
            print("Adding delay column to dataframe:")
1✔
1436

1437
            if delay_range is not None:
1✔
1438
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1439
                    self._dataframe,
1440
                    delay_range=delay_range,
1441
                    **kwds,
1442
                )
1443
                if self._timed_dataframe is not None:
1✔
1444
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1445
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1446
                            self._timed_dataframe,
1447
                            delay_range=delay_range,
1448
                            **kwds,
1449
                        )
1450
            else:
1451
                if datafile is None:
1✔
1452
                    try:
1✔
1453
                        datafile = self._files[0]
1✔
1454
                    except IndexError:
×
1455
                        print(
×
1456
                            "No datafile available, specify either",
1457
                            " 'datafile' or 'delay_range'",
1458
                        )
1459
                        raise
×
1460

1461
                self._dataframe, metadata = self.dc.append_delay_axis(
1✔
1462
                    self._dataframe,
1463
                    datafile=datafile,
1464
                    **kwds,
1465
                )
1466
                if self._timed_dataframe is not None:
1✔
1467
                    if self._config["dataframe"]["adc_column"] in self._timed_dataframe.columns:
1✔
1468
                        self._timed_dataframe, _ = self.dc.append_delay_axis(
1✔
1469
                            self._timed_dataframe,
1470
                            datafile=datafile,
1471
                            **kwds,
1472
                        )
1473

1474
            # Add Metadata
1475
            self._attributes.add(
1✔
1476
                metadata,
1477
                "delay_calibration",
1478
                duplicate_policy="merge",
1479
            )
1480
            if preview:
1✔
1481
                print(self._dataframe.head(10))
1✔
1482
            else:
1483
                if self.verbose:
1✔
1484
                    print(self._dataframe)
×
1485

1486
    def save_workflow_params(
1✔
1487
        self,
1488
        filename: str = None,
1489
        overwrite: bool = False,
1490
    ) -> None:
1491
        """run all save calibration parameter methods
1492

1493
        Args:
1494
            filename (str, optional): Filename of the config dictionary to save to.
1495
                Defaults to "sed_config.yaml" in the current folder.
1496
            overwrite (bool, optional): Option to overwrite the present dictionary.
1497
                Defaults to False.
1498
        """
NEW
1499
        for method in [
×
1500
            self.save_momentum_calibration,
1501
            self.save_energy_correction,
1502
            self.save_energy_calibration,
1503
            self.save_energy_offset,
1504
            # self.save_delay_calibration,  # TODO: uncomment once implemented
1505
        ]:
NEW
1506
            try:
×
NEW
1507
                method(filename, overwrite)
×
NEW
1508
            except (ValueError, AttributeError, KeyError):
×
NEW
1509
                pass
×
1510

1511
    def add_jitter(
1✔
1512
        self,
1513
        cols: List[str] = None,
1514
        amps: Union[float, Sequence[float]] = None,
1515
        **kwds,
1516
    ):
1517
        """Add jitter to the selected dataframe columns.
1518

1519
        Args:
1520
            cols (List[str], optional): The colums onto which to apply jitter.
1521
                Defaults to config["dataframe"]["jitter_cols"].
1522
            amps (Union[float, Sequence[float]], optional): Amplitude scalings for the
1523
                jittering noise. If one number is given, the same is used for all axes.
1524
                For uniform noise (default) it will cover the interval [-amp, +amp].
1525
                Defaults to config["dataframe"]["jitter_amps"].
1526
            **kwds: additional keyword arguments passed to apply_jitter
1527
        """
1528
        if cols is None:
1✔
1529
            cols = self._config["dataframe"]["jitter_cols"]
1✔
1530
        for loc, col in enumerate(cols):
1✔
1531
            if col.startswith("@"):
1✔
1532
                cols[loc] = self._config["dataframe"].get(col.strip("@"))
1✔
1533

1534
        if amps is None:
1✔
1535
            amps = self._config["dataframe"]["jitter_amps"]
1✔
1536

1537
        self._dataframe = self._dataframe.map_partitions(
1✔
1538
            apply_jitter,
1539
            cols=cols,
1540
            cols_jittered=cols,
1541
            amps=amps,
1542
            **kwds,
1543
        )
1544
        if self._timed_dataframe is not None:
1✔
1545
            cols_timed = cols.copy()
1✔
1546
            for col in cols:
1✔
1547
                if col not in self._timed_dataframe.columns:
1✔
1548
                    cols_timed.remove(col)
×
1549

1550
            if cols_timed:
1✔
1551
                self._timed_dataframe = self._timed_dataframe.map_partitions(
1✔
1552
                    apply_jitter,
1553
                    cols=cols_timed,
1554
                    cols_jittered=cols_timed,
1555
                )
1556
        metadata = []
1✔
1557
        for col in cols:
1✔
1558
            metadata.append(col)
1✔
1559
        self._attributes.add(metadata, "jittering", duplicate_policy="append")
1✔
1560

1561
    def pre_binning(
1✔
1562
        self,
1563
        df_partitions: int = 100,
1564
        axes: List[str] = None,
1565
        bins: List[int] = None,
1566
        ranges: Sequence[Tuple[float, float]] = None,
1567
        **kwds,
1568
    ) -> xr.DataArray:
1569
        """Function to do an initial binning of the dataframe loaded to the class.
1570

1571
        Args:
1572
            df_partitions (int, optional): Number of dataframe partitions to use for
1573
                the initial binning. Defaults to 100.
1574
            axes (List[str], optional): Axes to bin.
1575
                Defaults to config["momentum"]["axes"].
1576
            bins (List[int], optional): Bin numbers to use for binning.
1577
                Defaults to config["momentum"]["bins"].
1578
            ranges (List[Tuple], optional): Ranges to use for binning.
1579
                Defaults to config["momentum"]["ranges"].
1580
            **kwds: Keyword argument passed to ``compute``.
1581

1582
        Returns:
1583
            xr.DataArray: pre-binned data-array.
1584
        """
1585
        if axes is None:
1✔
1586
            axes = self._config["momentum"]["axes"]
1✔
1587
        for loc, axis in enumerate(axes):
1✔
1588
            if axis.startswith("@"):
1✔
1589
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1590

1591
        if bins is None:
1✔
1592
            bins = self._config["momentum"]["bins"]
1✔
1593
        if ranges is None:
1✔
1594
            ranges_ = list(self._config["momentum"]["ranges"])
1✔
1595
            ranges_[2] = np.asarray(ranges_[2]) / 2 ** (
1✔
1596
                self._config["dataframe"]["tof_binning"] - 1
1597
            )
1598
            ranges = [cast(Tuple[float, float], tuple(v)) for v in ranges_]
1✔
1599

1600
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1601

1602
        return self.compute(
1✔
1603
            bins=bins,
1604
            axes=axes,
1605
            ranges=ranges,
1606
            df_partitions=df_partitions,
1607
            **kwds,
1608
        )
1609

1610
    def compute(
1✔
1611
        self,
1612
        bins: Union[
1613
            int,
1614
            dict,
1615
            tuple,
1616
            List[int],
1617
            List[np.ndarray],
1618
            List[tuple],
1619
        ] = 100,
1620
        axes: Union[str, Sequence[str]] = None,
1621
        ranges: Sequence[Tuple[float, float]] = None,
1622
        normalize_to_acquisition_time: Union[bool, str] = False,
1623
        **kwds,
1624
    ) -> xr.DataArray:
1625
        """Compute the histogram along the given dimensions.
1626

1627
        Args:
1628
            bins (int, dict, tuple, List[int], List[np.ndarray], List[tuple], optional):
1629
                Definition of the bins. Can be any of the following cases:
1630

1631
                - an integer describing the number of bins in on all dimensions
1632
                - a tuple of 3 numbers describing start, end and step of the binning
1633
                  range
1634
                - a np.arrays defining the binning edges
1635
                - a list (NOT a tuple) of any of the above (int, tuple or np.ndarray)
1636
                - a dictionary made of the axes as keys and any of the above as values.
1637

1638
                This takes priority over the axes and range arguments. Defaults to 100.
1639
            axes (Union[str, Sequence[str]], optional): The names of the axes (columns)
1640
                on which to calculate the histogram. The order will be the order of the
1641
                dimensions in the resulting array. Defaults to None.
1642
            ranges (Sequence[Tuple[float, float]], optional): list of tuples containing
1643
                the start and end point of the binning range. Defaults to None.
1644
            normalize_to_acquisition_time (Union[bool, str]): Option to normalize the
1645
                result to the acquistion time. If a "slow" axis was scanned, providing
1646
                the name of the scanned axis will compute and apply the corresponding
1647
                normalization histogram. Defaults to False.
1648
            **kwds: Keyword arguments:
1649

1650
                - **hist_mode**: Histogram calculation method. "numpy" or "numba". See
1651
                  ``bin_dataframe`` for details. Defaults to
1652
                  config["binning"]["hist_mode"].
1653
                - **mode**: Defines how the results from each partition are combined.
1654
                  "fast", "lean" or "legacy". See ``bin_dataframe`` for details.
1655
                  Defaults to config["binning"]["mode"].
1656
                - **pbar**: Option to show the tqdm progress bar. Defaults to
1657
                  config["binning"]["pbar"].
1658
                - **n_cores**: Number of CPU cores to use for parallelization.
1659
                  Defaults to config["binning"]["num_cores"] or N_CPU-1.
1660
                - **threads_per_worker**: Limit the number of threads that
1661
                  multiprocessing can spawn per binning thread. Defaults to
1662
                  config["binning"]["threads_per_worker"].
1663
                - **threadpool_api**: The API to use for multiprocessing. "blas",
1664
                  "openmp" or None. See ``threadpool_limit`` for details. Defaults to
1665
                  config["binning"]["threadpool_API"].
1666
                - **df_partitions**: A range or list of dataframe partitions, or the
1667
                  number of the dataframe partitions to use. Defaults to all partitions.
1668

1669
                Additional kwds are passed to ``bin_dataframe``.
1670

1671
        Raises:
1672
            AssertError: Rises when no dataframe has been loaded.
1673

1674
        Returns:
1675
            xr.DataArray: The result of the n-dimensional binning represented in an
1676
            xarray object, combining the data with the axes.
1677
        """
1678
        assert self._dataframe is not None, "dataframe needs to be loaded first!"
1✔
1679

1680
        hist_mode = kwds.pop("hist_mode", self._config["binning"]["hist_mode"])
1✔
1681
        mode = kwds.pop("mode", self._config["binning"]["mode"])
1✔
1682
        pbar = kwds.pop("pbar", self._config["binning"]["pbar"])
1✔
1683
        num_cores = kwds.pop("num_cores", self._config["binning"]["num_cores"])
1✔
1684
        threads_per_worker = kwds.pop(
1✔
1685
            "threads_per_worker",
1686
            self._config["binning"]["threads_per_worker"],
1687
        )
1688
        threadpool_api = kwds.pop(
1✔
1689
            "threadpool_API",
1690
            self._config["binning"]["threadpool_API"],
1691
        )
1692
        df_partitions = kwds.pop("df_partitions", None)
1✔
1693
        if isinstance(df_partitions, int):
1✔
1694
            df_partitions = slice(
1✔
1695
                0,
1696
                min(df_partitions, self._dataframe.npartitions),
1697
            )
1698
        if df_partitions is not None:
1✔
1699
            dataframe = self._dataframe.partitions[df_partitions]
1✔
1700
        else:
1701
            dataframe = self._dataframe
1✔
1702

1703
        self._binned = bin_dataframe(
1✔
1704
            df=dataframe,
1705
            bins=bins,
1706
            axes=axes,
1707
            ranges=ranges,
1708
            hist_mode=hist_mode,
1709
            mode=mode,
1710
            pbar=pbar,
1711
            n_cores=num_cores,
1712
            threads_per_worker=threads_per_worker,
1713
            threadpool_api=threadpool_api,
1714
            **kwds,
1715
        )
1716

1717
        for dim in self._binned.dims:
1✔
1718
            try:
1✔
1719
                self._binned[dim].attrs["unit"] = self._config["dataframe"]["units"][dim]
1✔
1720
            except KeyError:
1✔
1721
                pass
1✔
1722

1723
        self._binned.attrs["units"] = "counts"
1✔
1724
        self._binned.attrs["long_name"] = "photoelectron counts"
1✔
1725
        self._binned.attrs["metadata"] = self._attributes.metadata
1✔
1726

1727
        if normalize_to_acquisition_time:
1✔
1728
            if isinstance(normalize_to_acquisition_time, str):
1✔
1729
                axis = normalize_to_acquisition_time
1✔
1730
                print(
1✔
1731
                    f"Calculate normalization histogram for axis '{axis}'...",
1732
                )
1733
                self._normalization_histogram = self.get_normalization_histogram(
1✔
1734
                    axis=axis,
1735
                    df_partitions=df_partitions,
1736
                )
1737
                # if the axes are named correctly, xarray figures out the normalization correctly
1738
                self._normalized = self._binned / self._normalization_histogram
1✔
1739
                self._attributes.add(
1✔
1740
                    self._normalization_histogram.values,
1741
                    name="normalization_histogram",
1742
                    duplicate_policy="overwrite",
1743
                )
1744
            else:
1745
                acquisition_time = self.loader.get_elapsed_time(
×
1746
                    fids=df_partitions,
1747
                )
1748
                if acquisition_time > 0:
×
1749
                    self._normalized = self._binned / acquisition_time
×
1750
                self._attributes.add(
×
1751
                    acquisition_time,
1752
                    name="normalization_histogram",
1753
                    duplicate_policy="overwrite",
1754
                )
1755

1756
            self._normalized.attrs["units"] = "counts/second"
1✔
1757
            self._normalized.attrs["long_name"] = "photoelectron counts per second"
1✔
1758
            self._normalized.attrs["metadata"] = self._attributes.metadata
1✔
1759

1760
            return self._normalized
1✔
1761

1762
        return self._binned
1✔
1763

1764
    def get_normalization_histogram(
1✔
1765
        self,
1766
        axis: str = "delay",
1767
        use_time_stamps: bool = False,
1768
        **kwds,
1769
    ) -> xr.DataArray:
1770
        """Generates a normalization histogram from the timed dataframe. Optionally,
1771
        use the TimeStamps column instead.
1772

1773
        Args:
1774
            axis (str, optional): The axis for which to compute histogram.
1775
                Defaults to "delay".
1776
            use_time_stamps (bool, optional): Use the TimeStamps column of the
1777
                dataframe, rather than the timed dataframe. Defaults to False.
1778
            **kwds: Keyword arguments:
1779

1780
                -df_partitions (int, optional): Number of dataframe partitions to use.
1781
                  Defaults to all.
1782

1783
        Raises:
1784
            ValueError: Raised if no data are binned.
1785
            ValueError: Raised if 'axis' not in binned coordinates.
1786
            ValueError: Raised if config["dataframe"]["time_stamp_alias"] not found
1787
                in Dataframe.
1788

1789
        Returns:
1790
            xr.DataArray: The computed normalization histogram (in TimeStamp units
1791
            per bin).
1792
        """
1793

1794
        if self._binned is None:
1✔
1795
            raise ValueError("Need to bin data first!")
1✔
1796
        if axis not in self._binned.coords:
1✔
1797
            raise ValueError(f"Axis '{axis}' not found in binned data!")
1✔
1798

1799
        df_partitions: Union[int, slice] = kwds.pop("df_partitions", None)
1✔
1800
        if isinstance(df_partitions, int):
1✔
1801
            df_partitions = slice(
1✔
1802
                0,
1803
                min(df_partitions, self._dataframe.npartitions),
1804
            )
1805

1806
        if use_time_stamps or self._timed_dataframe is None:
1✔
1807
            if df_partitions is not None:
1✔
1808
                self._normalization_histogram = normalization_histogram_from_timestamps(
1✔
1809
                    self._dataframe.partitions[df_partitions],
1810
                    axis,
1811
                    self._binned.coords[axis].values,
1812
                    self._config["dataframe"]["time_stamp_alias"],
1813
                )
1814
            else:
1815
                self._normalization_histogram = normalization_histogram_from_timestamps(
×
1816
                    self._dataframe,
1817
                    axis,
1818
                    self._binned.coords[axis].values,
1819
                    self._config["dataframe"]["time_stamp_alias"],
1820
                )
1821
        else:
1822
            if df_partitions is not None:
1✔
1823
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
1✔
1824
                    self._timed_dataframe.partitions[df_partitions],
1825
                    axis,
1826
                    self._binned.coords[axis].values,
1827
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1828
                )
1829
            else:
1830
                self._normalization_histogram = normalization_histogram_from_timed_dataframe(
×
1831
                    self._timed_dataframe,
1832
                    axis,
1833
                    self._binned.coords[axis].values,
1834
                    self._config["dataframe"]["timed_dataframe_unit_time"],
1835
                )
1836

1837
        return self._normalization_histogram
1✔
1838

1839
    def view_event_histogram(
1✔
1840
        self,
1841
        dfpid: int,
1842
        ncol: int = 2,
1843
        bins: Sequence[int] = None,
1844
        axes: Sequence[str] = None,
1845
        ranges: Sequence[Tuple[float, float]] = None,
1846
        backend: str = "bokeh",
1847
        legend: bool = True,
1848
        histkwds: dict = None,
1849
        legkwds: dict = None,
1850
        **kwds,
1851
    ):
1852
        """Plot individual histograms of specified dimensions (axes) from a substituent
1853
        dataframe partition.
1854

1855
        Args:
1856
            dfpid (int): Number of the data frame partition to look at.
1857
            ncol (int, optional): Number of columns in the plot grid. Defaults to 2.
1858
            bins (Sequence[int], optional): Number of bins to use for the speicified
1859
                axes. Defaults to config["histogram"]["bins"].
1860
            axes (Sequence[str], optional): Names of the axes to display.
1861
                Defaults to config["histogram"]["axes"].
1862
            ranges (Sequence[Tuple[float, float]], optional): Value ranges of all
1863
                specified axes. Defaults toconfig["histogram"]["ranges"].
1864
            backend (str, optional): Backend of the plotting library
1865
                ('matplotlib' or 'bokeh'). Defaults to "bokeh".
1866
            legend (bool, optional): Option to include a legend in the histogram plots.
1867
                Defaults to True.
1868
            histkwds (dict, optional): Keyword arguments for histograms
1869
                (see ``matplotlib.pyplot.hist()``). Defaults to {}.
1870
            legkwds (dict, optional): Keyword arguments for legend
1871
                (see ``matplotlib.pyplot.legend()``). Defaults to {}.
1872
            **kwds: Extra keyword arguments passed to
1873
                ``sed.diagnostics.grid_histogram()``.
1874

1875
        Raises:
1876
            TypeError: Raises when the input values are not of the correct type.
1877
        """
1878
        if bins is None:
1✔
1879
            bins = self._config["histogram"]["bins"]
1✔
1880
        if axes is None:
1✔
1881
            axes = self._config["histogram"]["axes"]
1✔
1882
        axes = list(axes)
1✔
1883
        for loc, axis in enumerate(axes):
1✔
1884
            if axis.startswith("@"):
1✔
1885
                axes[loc] = self._config["dataframe"].get(axis.strip("@"))
1✔
1886
        if ranges is None:
1✔
1887
            ranges = list(self._config["histogram"]["ranges"])
1✔
1888
            for loc, axis in enumerate(axes):
1✔
1889
                if axis == self._config["dataframe"]["tof_column"]:
1✔
1890
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
1✔
1891
                        self._config["dataframe"]["tof_binning"] - 1
1892
                    )
1893
                elif axis == self._config["dataframe"]["adc_column"]:
1✔
1894
                    ranges[loc] = np.asarray(ranges[loc]) / 2 ** (
×
1895
                        self._config["dataframe"]["adc_binning"] - 1
1896
                    )
1897

1898
        input_types = map(type, [axes, bins, ranges])
1✔
1899
        allowed_types = [list, tuple]
1✔
1900

1901
        df = self._dataframe
1✔
1902

1903
        if not set(input_types).issubset(allowed_types):
1✔
1904
            raise TypeError(
×
1905
                "Inputs of axes, bins, ranges need to be list or tuple!",
1906
            )
1907

1908
        # Read out the values for the specified groups
1909
        group_dict_dd = {}
1✔
1910
        dfpart = df.get_partition(dfpid)
1✔
1911
        cols = dfpart.columns
1✔
1912
        for ax in axes:
1✔
1913
            group_dict_dd[ax] = dfpart.values[:, cols.get_loc(ax)]
1✔
1914
        group_dict = ddf.compute(group_dict_dd)[0]
1✔
1915

1916
        # Plot multiple histograms in a grid
1917
        grid_histogram(
1✔
1918
            group_dict,
1919
            ncol=ncol,
1920
            rvs=axes,
1921
            rvbins=bins,
1922
            rvranges=ranges,
1923
            backend=backend,
1924
            legend=legend,
1925
            histkwds=histkwds,
1926
            legkwds=legkwds,
1927
            **kwds,
1928
        )
1929

1930
    def save(
1✔
1931
        self,
1932
        faddr: str,
1933
        **kwds,
1934
    ):
1935
        """Saves the binned data to the provided path and filename.
1936

1937
        Args:
1938
            faddr (str): Path and name of the file to write. Its extension determines
1939
                the file type to write. Valid file types are:
1940

1941
                - "*.tiff", "*.tif": Saves a TIFF stack.
1942
                - "*.h5", "*.hdf5": Saves an HDF5 file.
1943
                - "*.nxs", "*.nexus": Saves a NeXus file.
1944

1945
            **kwds: Keyword argumens, which are passed to the writer functions:
1946
                For TIFF writing:
1947

1948
                - **alias_dict**: Dictionary of dimension aliases to use.
1949

1950
                For HDF5 writing:
1951

1952
                - **mode**: hdf5 read/write mode. Defaults to "w".
1953

1954
                For NeXus:
1955

1956
                - **reader**: Name of the nexustools reader to use.
1957
                  Defaults to config["nexus"]["reader"]
1958
                - **definiton**: NeXus application definition to use for saving.
1959
                  Must be supported by the used ``reader``. Defaults to
1960
                  config["nexus"]["definition"]
1961
                - **input_files**: A list of input files to pass to the reader.
1962
                  Defaults to config["nexus"]["input_files"]
1963
                - **eln_data**: An electronic-lab-notebook file in '.yaml' format
1964
                  to add to the list of files to pass to the reader.
1965
        """
1966
        if self._binned is None:
1✔
1967
            raise NameError("Need to bin data first!")
1✔
1968

1969
        if self._normalized is not None:
1✔
1970
            data = self._normalized
×
1971
        else:
1972
            data = self._binned
1✔
1973

1974
        extension = pathlib.Path(faddr).suffix
1✔
1975

1976
        if extension in (".tif", ".tiff"):
1✔
1977
            to_tiff(
1✔
1978
                data=data,
1979
                faddr=faddr,
1980
                **kwds,
1981
            )
1982
        elif extension in (".h5", ".hdf5"):
1✔
1983
            to_h5(
1✔
1984
                data=data,
1985
                faddr=faddr,
1986
                **kwds,
1987
            )
1988
        elif extension in (".nxs", ".nexus"):
1✔
1989
            try:
1✔
1990
                reader = kwds.pop("reader", self._config["nexus"]["reader"])
1✔
1991
                definition = kwds.pop(
1✔
1992
                    "definition",
1993
                    self._config["nexus"]["definition"],
1994
                )
1995
                input_files = kwds.pop(
1✔
1996
                    "input_files",
1997
                    self._config["nexus"]["input_files"],
1998
                )
1999
            except KeyError as exc:
×
2000
                raise ValueError(
×
2001
                    "The nexus reader, definition and input files need to be provide!",
2002
                ) from exc
2003

2004
            if isinstance(input_files, str):
1✔
2005
                input_files = [input_files]
1✔
2006

2007
            if "eln_data" in kwds:
1✔
2008
                input_files.append(kwds.pop("eln_data"))
×
2009

2010
            to_nexus(
1✔
2011
                data=data,
2012
                faddr=faddr,
2013
                reader=reader,
2014
                definition=definition,
2015
                input_files=input_files,
2016
                **kwds,
2017
            )
2018

2019
        else:
2020
            raise NotImplementedError(
1✔
2021
                f"Unrecognized file format: {extension}.",
2022
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc