• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 9799510297

04 Jul 2024 08:17PM UTC coverage: 92.349% (-0.2%) from 92.511%
9799510297

Pull #466

github

rettigl
add tests for illegal keyword errors
Pull Request #466: Catch illegal kwds

106 of 123 new or added lines in 17 files covered. (86.18%)

3 existing lines in 1 file now uncovered.

6940 of 7515 relevant lines covered (92.35%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.75
/sed/calibrator/momentum.py
1
"""sed.calibrator.momentum module. Code for momentum calibration and distortion
2
correction. Mostly ported from https://github.com/mpes-kit/mpes.
3
"""
4
from __future__ import annotations
1✔
5

6
import itertools as it
1✔
7
from copy import deepcopy
1✔
8
from datetime import datetime
1✔
9
from typing import Any
1✔
10

11
import bokeh.palettes as bp
1✔
12
import bokeh.plotting as pbk
1✔
13
import dask.dataframe
1✔
14
import ipywidgets as ipw
1✔
15
import matplotlib
1✔
16
import matplotlib.pyplot as plt
1✔
17
import numpy as np
1✔
18
import pandas as pd
1✔
19
import scipy.ndimage as ndi
1✔
20
import xarray as xr
1✔
21
from bokeh.colors import RGB
1✔
22
from bokeh.io import output_notebook
1✔
23
from bokeh.palettes import Category10 as ColorCycle
1✔
24
from IPython.display import display
1✔
25
from joblib import delayed
1✔
26
from joblib import Parallel
1✔
27
from matplotlib import cm
1✔
28
from numpy.linalg import norm
1✔
29
from scipy.interpolate import griddata
1✔
30
from scipy.ndimage import map_coordinates
1✔
31
from symmetrize import pointops as po
1✔
32
from symmetrize import sym
1✔
33
from symmetrize import tps
1✔
34

35

36
class MomentumCorrector:
1✔
37
    """
38
    Momentum distortion correction and momentum calibration workflow functions.
39

40
    Args:
41
        data (xr.DataArray | np.ndarray, optional): Multidimensional hypervolume
42
            containing the data. Defaults to None.
43
        bin_ranges (list[tuple], optional): Binning ranges of the data volume, if
44
            provided as np.ndarray. Defaults to None.
45
        rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
46
        config (dict, optional): Config dictionary. Defaults to None.
47
    """
48

49
    def __init__(
1✔
50
        self,
51
        data: xr.DataArray | np.ndarray = None,
52
        bin_ranges: list[tuple] = None,
53
        rotsym: int = 6,
54
        config: dict = None,
55
    ):
56
        """Constructor of the MomentumCorrector class.
57

58
        Args:
59
            data (xr.DataArray | np.ndarray, optional): Multidimensional
60
                hypervolume containing the data. Defaults to None.
61
            bin_ranges (list[tuple], optional): Binning ranges of the data volume,
62
                if provided as np.ndarray. Defaults to None.
63
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
64
            config (dict, optional): Config dictionary. Defaults to None.
65
        """
66
        if config is None:
1✔
67
            config = {}
×
68

69
        self._config = config
1✔
70

71
        self.image: np.ndarray = None
1✔
72
        self.img_ndim: int = None
1✔
73
        self.slice: np.ndarray = None
1✔
74
        self.slice_corrected: np.ndarray = None
1✔
75
        self.slice_transformed: np.ndarray = None
1✔
76
        self.bin_ranges: list[tuple] = self._config["momentum"].get("bin_ranges", [])
1✔
77

78
        if data is not None:
1✔
79
            self.load_data(data=data, bin_ranges=bin_ranges)
×
80

81
        self.detector_ranges = self._config["momentum"]["detector_ranges"]
1✔
82

83
        self.rotsym = int(rotsym)
1✔
84
        self.rotsym_angle = int(360 / self.rotsym)
1✔
85
        self.arot = np.array([0] + [self.rotsym_angle] * (self.rotsym - 1))
1✔
86
        self.ascale = np.array([1.0] * self.rotsym)
1✔
87
        self.peaks: np.ndarray = None
1✔
88
        self.include_center: bool = False
1✔
89
        self.use_center: bool = False
1✔
90
        self.pouter: np.ndarray = None
1✔
91
        self.pcent: tuple[float, ...] = None
1✔
92
        self.pouter_ord: np.ndarray = None
1✔
93
        self.prefs: np.ndarray = None
1✔
94
        self.ptargs: np.ndarray = None
1✔
95
        self.csm_original: float = np.nan
1✔
96
        self.mdist: float = np.nan
1✔
97
        self.mcvdist: float = np.nan
1✔
98
        self.mvvdist: float = np.nan
1✔
99
        self.cvdist: np.ndarray = np.array(np.nan)
1✔
100
        self.vvdist: np.ndarray = np.array(np.nan)
1✔
101
        self.rdeform_field: np.ndarray = None
1✔
102
        self.cdeform_field: np.ndarray = None
1✔
103
        self.rdeform_field_bkp: np.ndarray = None
1✔
104
        self.cdeform_field_bkp: np.ndarray = None
1✔
105
        self.inverse_dfield: np.ndarray = None
1✔
106
        self.dfield_updated: bool = False
1✔
107
        self.transformations: dict[str, Any] = self._config["momentum"].get("transformations", {})
1✔
108
        self.correction: dict[str, Any] = self._config["momentum"].get("correction", {})
1✔
109
        self.adjust_params: dict[str, Any] = {}
1✔
110
        self.calibration: dict[str, Any] = self._config["momentum"].get("calibration", {})
1✔
111

112
        self.x_column = self._config["dataframe"]["x_column"]
1✔
113
        self.y_column = self._config["dataframe"]["y_column"]
1✔
114
        self.corrected_x_column = self._config["dataframe"]["corrected_x_column"]
1✔
115
        self.corrected_y_column = self._config["dataframe"]["corrected_y_column"]
1✔
116
        self.kx_column = self._config["dataframe"]["kx_column"]
1✔
117
        self.ky_column = self._config["dataframe"]["ky_column"]
1✔
118

119
        self._state: int = 0
1✔
120

121
    @property
1✔
122
    def features(self) -> dict:
1✔
123
        """Dictionary of detected features for the symmetrization process.
124
        ``self.features`` is a derived attribute from existing ones.
125

126
        Returns:
127
            dict: Dict containing features "verts" and "center".
128
        """
129
        feature_dict = {
1✔
130
            "verts": np.asarray(self.__dict__.get("pouter_ord", [])),
131
            "center": np.asarray(self.__dict__.get("pcent", [])),
132
        }
133

134
        return feature_dict
1✔
135

136
    @property
1✔
137
    def symscores(self) -> dict:
1✔
138
        """Dictionary of symmetry-related scores.
139

140
        Returns:
141
            dict: Dictionary containing symmetry scores.
142
        """
143
        sym_dict = {
×
144
            "csm_original": self.__dict__.get("csm_original", ""),
145
            "csm_current": self.__dict__.get("csm_current", ""),
146
            "arm_original": self.__dict__.get("arm_original", ""),
147
            "arm_current": self.__dict__.get("arm_current", ""),
148
        }
149

150
        return sym_dict
×
151

152
    def load_data(
1✔
153
        self,
154
        data: xr.DataArray | np.ndarray,
155
        bin_ranges: list[tuple] = None,
156
    ):
157
        """Load binned data into the momentum calibrator class
158

159
        Args:
160
            data (xr.DataArray | np.ndarray):
161
                2D or 3D data array, either as np.ndarray or xr.DataArray.
162
            bin_ranges (list[tuple], optional):
163
                Binning ranges. Needs to be provided in case the data are given
164
                as np.ndarray. Otherwise, they are determined from the coords of
165
                the xr.DataArray. Defaults to None.
166

167
        Raises:
168
            ValueError: Raised if the dimensions of the input data do not fit.
169
        """
170
        if isinstance(data, xr.DataArray):
1✔
171
            self.image = np.squeeze(data.data)
1✔
172
            self.bin_ranges = []
1✔
173
            for axis in data.coords:
1✔
174
                self.bin_ranges.append(
1✔
175
                    (
176
                        data.coords[axis][0].values,
177
                        2 * data.coords[axis][-1].values - data.coords[axis][-2].values,  # endpoint
178
                    ),
179
                )
180
        else:
181
            assert bin_ranges is not None
1✔
182
            self.image = np.squeeze(data)
1✔
183
            self.bin_ranges = bin_ranges
1✔
184

185
        self.img_ndim = self.image.ndim
1✔
186
        if (self.img_ndim > 3) or (self.img_ndim < 2):
1✔
187
            raise ValueError("The input image dimension need to be 2 or 3!")
×
188
        if self.img_ndim == 2:
1✔
189
            self.slice = self.image
1✔
190

191
        if self.slice is not None:
1✔
192
            self.slice_corrected = self.slice_transformed = self.slice
1✔
193

194
    def select_slicer(
1✔
195
        self,
196
        plane: int = 0,
197
        width: int = 5,
198
        axis: int = 2,
199
        apply: bool = False,
200
    ):
201
        """Interactive panel to select (hyper)slice from a (hyper)volume.
202

203
        Args:
204
            plane (int, optional): initial value of the plane slider. Defaults to 0.
205
            width (int, optional): initial value of the width slider. Defaults to 5.
206
            axis (int, optional): Axis along which to slice the image. Defaults to 2.
207
            apply (bool, optional):  Option to directly apply the values and select the
208
                slice. Defaults to False.
209
        """
210
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
211

212
        assert self.img_ndim == 3
1✔
213
        selector = slice(plane, plane + width)
1✔
214
        image = np.moveaxis(self.image, axis, 0)
1✔
215
        try:
1✔
216
            img_slice = image[selector, ...].sum(axis=0)
1✔
217
        except AttributeError:
×
218
            img_slice = image[selector, ...]
×
219

220
        fig, ax = plt.subplots(1, 1)
1✔
221
        img = ax.imshow(img_slice.T, origin="lower", cmap="terrain_r")
1✔
222

223
        def update(plane: int, width: int):
1✔
224
            selector = slice(plane, plane + width)
1✔
225
            try:
1✔
226
                img_slice = image[selector, ...].sum(axis=0)
1✔
227
            except AttributeError:
×
228
                img_slice = image[selector, ...]
×
229
            img.set_data(img_slice.T)
1✔
230
            axmin = np.min(img_slice, axis=(0, 1))
1✔
231
            axmax = np.max(img_slice, axis=(0, 1))
1✔
232
            if axmin < axmax:
1✔
233
                img.set_clim(axmin, axmax)
1✔
234
            ax.set_title(f"Plane[{plane}:{plane+width}]")
1✔
235
            fig.canvas.draw_idle()
1✔
236

237
        update(plane, width)
1✔
238

239
        plane_slider = ipw.IntSlider(
1✔
240
            value=plane,
241
            min=0,
242
            max=self.image.shape[2] - width,
243
            step=1,
244
        )
245
        width_slider = ipw.IntSlider(value=width, min=1, max=20, step=1)
1✔
246

247
        ipw.interact(
1✔
248
            update,
249
            plane=plane_slider,
250
            width=width_slider,
251
        )
252

253
        def apply_fun(apply: bool):  # noqa: ARG001
1✔
254
            start = plane_slider.value
1✔
255
            stop = plane_slider.value + width_slider.value
1✔
256

257
            selector = slice(
1✔
258
                start,
259
                stop,
260
            )
261
            self.select_slice(selector=selector, axis=axis)
1✔
262

263
            img.set_data(self.slice.T)
1✔
264
            axmin = np.min(self.slice, axis=(0, 1))
1✔
265
            axmax = np.max(self.slice, axis=(0, 1))
1✔
266
            if axmin < axmax:
1✔
267
                img.set_clim(axmin, axmax)
1✔
268
            ax.set_title(f"Plane[{start}:{stop}]")
1✔
269
            fig.canvas.draw_idle()
1✔
270

271
            plane_slider.close()
1✔
272
            width_slider.close()
1✔
273
            apply_button.close()
1✔
274

275
        apply_button = ipw.Button(description="apply")
1✔
276
        display(apply_button)
1✔
277
        apply_button.on_click(apply_fun)
1✔
278

279
        plt.show()
1✔
280

281
        if apply:
1✔
282
            apply_fun(True)
1✔
283

284
    def select_slice(
1✔
285
        self,
286
        selector: slice | list[int] | int,
287
        axis: int = 2,
288
    ):
289
        """Select (hyper)slice from a (hyper)volume.
290

291
        Args:
292
            selector (slice | list[int] | int):
293
                Selector along the specified axis to extract the slice (image). Use
294
                the construct slice(start, stop, step) to select a range of images
295
                and sum them. Use an integer to specify only a particular slice.
296
            axis (int, optional): Axis along which to select the image. Defaults to 2.
297

298
        Raises:
299
            ValueError: Raised if self.image is already 2D.
300
        """
301
        if self.img_ndim > 2:
1✔
302
            image = np.moveaxis(self.image, axis, 0)
1✔
303
            try:
1✔
304
                self.slice = image[selector, ...].sum(axis=0)
1✔
305
            except AttributeError:
×
306
                self.slice = image[selector, ...]
×
307

308
            if self.slice is not None:
1✔
309
                self.slice_corrected = self.slice_transformed = self.slice
1✔
310

311
        elif self.img_ndim == 2:
×
312
            raise ValueError("Input image dimension is already 2!")
×
313

314
    def add_features(
1✔
315
        self,
316
        features: np.ndarray,
317
        direction: str = "ccw",
318
        rotsym: int = 6,
319
        symscores: bool = True,
320
        symtype: str = "rotation",
321
    ):
322
        """Add features as reference points provided as np.ndarray. If provided,
323
        detects the center of the points and orders the points.
324

325
        Args:
326
            features (np.ndarray):
327
                Array of landmarks, possibly including a center peak. Its shape should
328
                be (n,2), where n is equal to the rotation symmetry, or the rotation
329
                symmetry+1, if the center is included.
330
            direction (str, optional):
331
                Direction for ordering the points. Defaults to "ccw".
332
            symscores (bool, optional):
333
                Option to calculate symmetry scores. Defaults to False.
334
            symtype (str, optional): Type of symmetry scores to calculate
335
                  if symscores is True. Defaults to "rotation".
336

337
        Raises:
338
            ValueError: Raised if the number of points does not match the rotsym.
339
        """
340
        self.rotsym = int(rotsym)
1✔
341
        self.rotsym_angle = int(360 / self.rotsym)
1✔
342
        self.arot = np.array([0] + [self.rotsym_angle] * (self.rotsym - 1))
1✔
343
        self.ascale = np.array([1.0] * self.rotsym)
1✔
344

345
        if features.shape[0] == self.rotsym:  # assume no center present
1✔
346
            self.pcent, self.pouter = po.pointset_center(
1✔
347
                features,
348
                method="centroid",
349
            )
350
            self.include_center = False
1✔
351
        elif features.shape[0] == self.rotsym + 1:  # assume center included
1✔
352
            self.pcent, self.pouter = po.pointset_center(
1✔
353
                features,
354
                method="centroidnn",
355
            )
356
            self.include_center = True
1✔
357
        else:
358
            raise ValueError(
1✔
359
                f"Found {features.shape[0]} points, ",
360
                f"but {self.rotsym} or {self.rotsym+1} (incl.center) required.",
361
            )
362
        if isinstance(self.pcent, np.ndarray):
1✔
363
            self.pcent = tuple(val.item() for val in self.pcent)
1✔
364
        # Order the point landmarks
365
        self.pouter_ord = po.pointset_order(
1✔
366
            self.pouter,
367
            direction=direction,
368
        )
369

370
        # Calculate geometric distances
371
        if self.pcent is not None:
1✔
372
            self.calc_geometric_distances()
1✔
373

374
        if symscores is True:
1✔
375
            self.csm_original = self.calc_symmetry_scores(symtype=symtype)
1✔
376

377
        if self.rotsym == 6 and self.pcent is not None:
1✔
378
            self.mdist = (self.mcvdist + self.mvvdist) / 2
1✔
379
            self.mcvdist = self.mdist
1✔
380
            self.mvvdist = self.mdist
1✔
381

382
    def feature_extract(
1✔
383
        self,
384
        image: np.ndarray = None,
385
        direction: str = "ccw",
386
        feature_type: str = "points",
387
        rotsym: int = 6,
388
        symscores: bool = True,
389
        **kwds,
390
    ):
391
        """Extract features from the selected 2D slice.
392
        Currently only point feature detection is implemented.
393

394
        Args:
395
            image (np.ndarray, optional):
396
                The (2D) image slice to extract features from.
397
                Defaults to self.slice
398
            direction (str, optional):
399
                The circular direction to reorder the features in ('cw' or 'ccw').
400
                Defaults to "ccw".
401
            feature_type (str, optional):
402
                The type of features to extract. Defaults to "points".
403
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
404
            symscores (bool, optional):
405
                Option for calculating symmetry scores. Defaults to True.
406
            **kwds:
407
                Extra keyword arguments for ``symmetrize.pointops.peakdetect2d()`` and
408
                ``add_features()``.
409

410
        Raises:
411
            NotImplementedError:
412
                Raised for undefined feature_types.
413
        """
414
        if image is None:
1✔
415
            if self.slice is not None:
1✔
416
                image = self.slice
1✔
417
            else:
418
                raise ValueError("No image loaded for feature extraction!")
×
419

420
        # split off config keywords
421
        feature_kwds = {
1✔
422
            key: value
423
            for key, value in kwds.items()
424
            if key in self.add_features.__code__.co_varnames
425
        }
426
        for key in feature_kwds.keys():
1✔
NEW
427
            del kwds[key]
×
428

429
        if feature_type == "points":
1✔
430
            # Detect the point landmarks
431
            self.peaks = po.peakdetect2d(image, **kwds)
1✔
432

433
            self.add_features(
1✔
434
                features=self.peaks,
435
                direction=direction,
436
                rotsym=rotsym,
437
                symscores=symscores,
438
                **feature_kwds,
439
            )
440
        else:
441
            raise NotImplementedError
×
442

443
    def feature_select(
1✔
444
        self,
445
        image: np.ndarray = None,
446
        features: np.ndarray = None,
447
        include_center: bool = True,
448
        rotsym: int = 6,
449
        apply: bool = False,
450
        **kwds,
451
    ):
452
        """Extract features from the selected 2D slice.
453
        Currently only point feature detection is implemented.
454

455
        Args:
456
            image (np.ndarray, optional):
457
                The (2D) image slice to extract features from.
458
                Defaults to self.slice
459
            include_center (bool, optional):
460
                Option to include the image center/centroid in the registration
461
                process. Defaults to True.
462
            features (np.ndarray, optional):
463
                Array of landmarks, possibly including a center peak. Its shape should
464
                be (n,2), where n is equal to the rotation symmetry, or the rotation
465
                symmetry+1, if the center is included.
466
                If omitted, an array filled with zeros is generated.
467
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
468
            apply (bool, optional): Option to directly store the features in the class.
469
                Defaults to False.
470
            **kwds:
471
                Keyword arguments for ``add_features``.
472

473
        Raises:
474
            ValueError: If no valid image is found from which to ge the coordinates.
475
        """
476
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
477
        if image is None:
1✔
478
            if self.slice is not None:
1✔
479
                image = self.slice
1✔
480
            else:
481
                raise ValueError("No valid image loaded!")
×
482

483
        fig, ax = plt.subplots(1, 1)
1✔
484
        ax.imshow(image.T, origin="lower", cmap="terrain_r")
1✔
485

486
        if features is None:
1✔
487
            features = np.zeros((rotsym + (include_center), 2))
×
488

489
        markers = []
1✔
490
        for peak in features:
1✔
491
            markers.append(ax.plot(peak[0], peak[1], "o")[0])
1✔
492

493
        def update_point_no(
1✔
494
            point_no: int,
495
        ):
496
            fig.canvas.draw_idle()
1✔
497

498
            point_x = features[point_no][0]
1✔
499
            point_y = features[point_no][1]
1✔
500

501
            point_input_x.value = point_x
1✔
502
            point_input_y.value = point_y
1✔
503

504
        def update_point_pos(
1✔
505
            point_x: float,
506
            point_y: float,
507
        ):
508
            fig.canvas.draw_idle()
1✔
509
            point_no = point_no_input.value
1✔
510
            features[point_no][0] = point_x
1✔
511
            features[point_no][1] = point_y
1✔
512

513
            markers[point_no].set_xdata(point_x)
1✔
514
            markers[point_no].set_ydata(point_y)
×
515

516
        point_no_input = ipw.Dropdown(
1✔
517
            options=range(features.shape[0]),
518
            description="Point:",
519
        )
520

521
        point_input_x = ipw.FloatText(features[0][0])
1✔
522
        point_input_y = ipw.FloatText(features[0][1])
1✔
523
        ipw.interact(
1✔
524
            update_point_no,
525
            point_no=point_no_input,
526
        )
527
        ipw.interact(
1✔
528
            update_point_pos,
529
            point_y=point_input_y,
530
            point_x=point_input_x,
531
        )
532

533
        def onclick(event):
1✔
534
            point_input_x.value = event.xdata
×
535
            point_input_y.value = event.ydata
×
536
            point_no_input.value = (point_no_input.value + 1) % features.shape[0]
×
537

538
        cid = fig.canvas.mpl_connect("button_press_event", onclick)
1✔
539

540
        def apply_func(apply: bool):  # noqa: ARG001
1✔
541
            fig.canvas.mpl_disconnect(cid)
1✔
542

543
            point_no_input.close()
1✔
544
            point_input_x.close()
1✔
545
            point_input_y.close()
1✔
546
            apply_button.close()
1✔
547

548
            fig.canvas.draw_idle()
1✔
549

550
            self.add_features(features=features, rotsym=rotsym, **kwds)
1✔
551

552
        apply_button = ipw.Button(description="apply")
1✔
553
        display(apply_button)
1✔
554
        apply_button.on_click(apply_func)
1✔
555

556
        if apply:
1✔
557
            apply_func(True)
1✔
558

559
        plt.show()
1✔
560

561
    def calc_geometric_distances(self) -> None:
1✔
562
        """Calculate geometric distances involving the center and the vertices.
563
        Distances calculated include center-vertex and nearest-neighbor vertex-vertex
564
        distances.
565
        """
566
        self.cvdist = po.cvdist(self.pouter_ord, self.pcent)
1✔
567
        self.mcvdist = self.cvdist.mean()
1✔
568
        self.vvdist = po.vvdist(self.pouter_ord)
1✔
569
        self.mvvdist = self.vvdist.mean()
1✔
570

571
    def calc_symmetry_scores(self, symtype: str = "rotation") -> float:
1✔
572
        """Calculate the symmetry scores from geometric quantities.
573

574
        Args:
575
            symtype (str, optional): Type of symmetry score to calculate.
576
                Defaults to "rotation".
577

578
        Returns:
579
            float: Calculated symmetry score.
580
        """
581
        csm = po.csm(
1✔
582
            self.pcent,
583
            self.pouter_ord,
584
            rotsym=self.rotsym,
585
            type=symtype,
586
        )
587

588
        return csm
1✔
589

590
    def spline_warp_estimate(
1✔
591
        self,
592
        image: np.ndarray = None,
593
        use_center: bool = None,
594
        fixed_center: bool = True,
595
        interp_order: int = 1,
596
        ascale: float | list | tuple | np.ndarray = None,
597
        verbose: bool = True,
598
        **kwds,
599
    ) -> np.ndarray:
600
        """Estimate the spline deformation field using thin plate spline registration.
601

602
        Args:
603
            image (np.ndarray, optional):
604
                2D array. Image slice to be corrected. Defaults to self.slice.
605
            use_center (bool, optional):
606
                Option to use the image center/centroid in the registration
607
                process. Defaults to config value, or True.
608
            fixed_center (bool, optional):
609
                Option to have a fixed center during registration-based
610
                symmetrization. Defaults to True.
611
            interp_order (int, optional):
612
                Order of interpolation (see ``scipy.ndimage.map_coordinates()``).
613
                Defaults to 1.
614
            ascale: (float | list | tuple | np.ndarray, optional): Scale parameter determining a
615
                relative scale for each symmetry feature. If provided as single float, rotsym has
616
                to be 4. This parameter describes the relative scaling between the two orthogonal
617
                symmetry directions (for an orthorhombic system). This requires the correction
618
                points to be located along the principal axes (X/Y points of the Brillouin zone).
619
                Otherwise, an array with ``rotsym`` elements is expected, containing relative
620
                scales for each feature. Defaults to an array of equal scales.
621
            verbose (bool, optional): Option to report the used landmarks for correction.
622
                Defaults to True.
623
            **kwds: keyword arguments:
624

625
                - **landmarks**: (list/array): Landmark positions (row, column) used
626
                  for registration. Defaults to  self.pouter_ord
627
                - **targets**: (list/array): Target positions (row, column) used for
628
                  registration. If empty, it will be generated by
629
                  ``symmetrize.rotVertexGenerator()``.
630
                - **new_centers**: (dict): User-specified center positions for the
631
                  reference and target sets. {'lmkcenter': (row, col),
632
                  'targcenter': (row, col)}
633

634
                Additional keywords are passed to ``tpsWarping()``.
635

636
        Returns:
637
            np.ndarray: The corrected image.
638
        """
639
        if image is None:
1✔
640
            if self.slice is not None:
1✔
641
                image = self.slice
1✔
642
            else:
643
                image = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
644
                self.bin_ranges = self._config["momentum"]["ranges"]
1✔
645

646
        if self.pouter_ord is None:
1✔
647
            if self.pouter is not None:
1✔
648
                self.pouter_ord = po.pointset_order(self.pouter)
×
649
                self.correction["creation_date"] = datetime.now().timestamp()
×
650
                self.correction["creation_date"] = datetime.now().timestamp()
×
651
            else:
652
                try:
1✔
653
                    features = np.asarray(
1✔
654
                        self.correction["feature_points"],
655
                    )
656
                    rotsym = self.correction["rotation_symmetry"]
1✔
657
                    include_center = self.correction["include_center"]
1✔
658
                    if not include_center and len(features) > rotsym:
1✔
659
                        features = features[:rotsym, :]
×
660
                    ascale = self.correction.get("ascale", None)
1✔
661
                    if ascale is not None:
1✔
662
                        ascale = np.asarray(ascale)
1✔
663

664
                    if verbose:
1✔
665
                        if "creation_date" in self.correction:
1✔
666
                            datestring = datetime.fromtimestamp(
1✔
667
                                self.correction["creation_date"],
668
                            ).strftime(
669
                                "%m/%d/%Y, %H:%M:%S",
670
                            )
671
                            print(
1✔
672
                                "No landmarks defined, using momentum correction parameters "
673
                                f"generated on {datestring}",
674
                            )
675
                        else:
676
                            print(
×
677
                                "No landmarks defined, using momentum correction parameters "
678
                                "from config.",
679
                            )
680
                except KeyError as exc:
×
681
                    raise ValueError(
×
682
                        "No valid landmarks defined, and no landmarks found in configuration!",
683
                    ) from exc
684

685
                self.add_features(features=features, rotsym=rotsym)
1✔
686

687
        else:
688
            self.correction["creation_date"] = datetime.now().timestamp()
1✔
689

690
        if ascale is not None:
1✔
691
            if isinstance(ascale, (int, float, np.floating, np.integer)):
1✔
692
                if self.rotsym != 4:
1✔
693
                    raise ValueError(
1✔
694
                        "Providing ascale as scalar number is only valid for 'rotsym'==4.",
695
                    )
696
                self.ascale = np.array([1.0, ascale, 1.0, ascale])
1✔
697
            elif isinstance(ascale, (tuple, list, np.ndarray)):
1✔
698
                if len(ascale) != len(self.ascale):
1✔
699
                    raise ValueError(
1✔
700
                        f"ascale needs to be of length 'rotsym', but has length {len(ascale)}.",
701
                    )
702
                self.ascale = np.asarray(ascale)
1✔
703
            else:
704
                raise TypeError(
1✔
705
                    (
706
                        "ascale needs to be a single number or a list/tuple/np.ndarray of length ",
707
                        f"'rotsym' ({self.rotsym})!",
708
                    ),
709
                )
710

711
        if use_center is None:
1✔
712
            try:
1✔
713
                use_center = self.correction["use_center"]
1✔
714
            except KeyError:
1✔
715
                use_center = True
1✔
716
        self.use_center = use_center
1✔
717

718
        self.prefs = kwds.pop("landmarks", self.pouter_ord)
1✔
719
        self.ptargs = kwds.pop("targets", [])
1✔
720
        newcenters = kwds.pop("new_centers", {})
1✔
721

722
        # Generate the target point set
723
        if not self.ptargs:
1✔
724
            self.ptargs = sym.rotVertexGenerator(
1✔
725
                self.pcent,
726
                fixedvertex=self.pouter_ord[0, :],
727
                arot=self.arot,
728
                direction=-1,
729
                scale=self.ascale,
730
                ret="all",
731
            )[1:, :]
732

733
        if use_center is True:
1✔
734
            # Use center of image pattern in the registration-based symmetrization
735
            if fixed_center is True:
1✔
736
                # Add the same center to both the reference and target sets
737

738
                self.prefs = np.column_stack((self.prefs.T, self.pcent)).T
1✔
739
                self.ptargs = np.column_stack((self.ptargs.T, self.pcent)).T
1✔
740

741
            else:  # Add different centers to the reference and target sets
742
                self.prefs = np.column_stack(
×
743
                    (self.prefs.T, newcenters["lmkcenter"]),
744
                ).T
745
                self.ptargs = np.column_stack(
×
746
                    (self.ptargs.T, newcenters["targcenter"]),
747
                ).T
748

749
        # Non-iterative estimation of deformation field
750
        corrected_image, splinewarp = tps.tpsWarping(
1✔
751
            self.prefs,
752
            self.ptargs,
753
            image,
754
            None,
755
            interp_order,
756
            ret="all",
757
            **kwds,
758
        )
759

760
        self.reset_deformation(image=image, coordtype="cartesian")
1✔
761

762
        self.update_deformation(
1✔
763
            splinewarp[0],
764
            splinewarp[1],
765
        )
766

767
        # save backup copies to reset transformations
768
        self.rdeform_field_bkp = self.rdeform_field
1✔
769
        self.cdeform_field_bkp = self.cdeform_field
1✔
770

771
        self.correction["outer_points"] = self.pouter_ord
1✔
772
        self.correction["center_point"] = np.asarray(self.pcent)
1✔
773
        self.correction["reference_points"] = self.prefs
1✔
774
        self.correction["target_points"] = self.ptargs
1✔
775
        self.correction["rotation_symmetry"] = self.rotsym
1✔
776
        self.correction["use_center"] = self.use_center
1✔
777
        self.correction["include_center"] = self.include_center
1✔
778
        if self.include_center:
1✔
779
            self.correction["feature_points"] = np.concatenate(
1✔
780
                (self.pouter_ord, np.asarray([self.pcent])),
781
            )
782
        else:
783
            self.correction["feature_points"] = self.pouter_ord
1✔
784
        self.correction["ascale"] = self.ascale
1✔
785

786
        if self.slice is not None:
1✔
787
            self.slice_corrected = corrected_image
1✔
788

789
        if verbose:
1✔
790
            print("Calculated thin spline correction based on the following landmarks:")
1✔
791
            print(f"pouter: {self.pouter}")
1✔
792
            if use_center:
1✔
793
                print(f"pcent: {self.pcent}")
1✔
794

795
        return corrected_image
1✔
796

797
    def apply_correction(
1✔
798
        self,
799
        image: np.ndarray,
800
        axis: int,
801
        dfield: np.ndarray = None,
802
    ) -> np.ndarray:
803
        """Apply a 2D transform to a stack of 2D images (3D) along a specific axis.
804

805
        Args:
806
            image (np.ndarray): Image which to apply the transformation to
807
            axis (int): Axis for slice selection.
808
            dfield (np.ndarray, optional): row and column deformation field.
809
                Defaults to [self.rdeform_field, self.cdeformfield].
810

811
        Returns:
812
            np.ndarray: The corrected image.
813
        """
814
        if dfield is None:
×
815
            dfield = np.asarray([self.rdeform_field, self.cdeform_field])
×
816

817
        image_corrected = sym.applyWarping(
×
818
            image,
819
            axis,
820
            warptype="deform_field",
821
            dfield=dfield,
822
        )
823

824
        return image_corrected
×
825

826
    def reset_deformation(self, **kwds):
1✔
827
        """Reset the deformation field.
828

829
        Args:
830
            **kwds: keyword arguments:
831

832
                - **image**: the image to base the deformation fields on. Its sizes are
833
                  used. Defaults to self.slice
834
                - **coordtype**: The coordinate system to use. Defaults to 'cartesian'.
835
        """
836
        image = kwds.pop("image", self.slice)
1✔
837
        coordtype = kwds.pop("coordtype", "cartesian")
1✔
838

839
        if len(kwds) > 0:
1✔
NEW
840
            raise TypeError(f"reset_deformation() got unexpected keyword arguments {kwds.keys()}.")
×
841

842
        coordmat = sym.coordinate_matrix_2D(
1✔
843
            image,
844
            coordtype=coordtype,
845
            stackaxis=0,
846
        ).astype("float64")
847

848
        self.rdeform_field = coordmat[1, ...]
1✔
849
        self.cdeform_field = coordmat[0, ...]
1✔
850

851
        self.dfield_updated = True
1✔
852

853
    def update_deformation(self, rdeform: np.ndarray, cdeform: np.ndarray):
1✔
854
        """Update the class deformation field by applying the provided column/row
855
        deformation fields.
856

857
        Parameters:
858
            rdeform (np.ndarray): 2D array of row-ordered deformation field.
859
            cdeform (np.ndarray): 2D array of column-ordered deformation field.
860
        """
861
        self.rdeform_field = ndi.map_coordinates(
1✔
862
            self.rdeform_field,
863
            [rdeform, cdeform],
864
            order=1,
865
            cval=np.nan,
866
        )
867
        self.cdeform_field = ndi.map_coordinates(
1✔
868
            self.cdeform_field,
869
            [rdeform, cdeform],
870
            order=1,
871
            cval=np.nan,
872
        )
873

874
        self.dfield_updated = True
1✔
875

876
    def coordinate_transform(
1✔
877
        self,
878
        transform_type: str,
879
        keep: bool = False,
880
        interp_order: int = 1,
881
        mapkwds: dict = None,
882
        **kwds,
883
    ) -> np.ndarray:
884
        """Apply a pixel-wise coordinate transform to the image
885
        by means of the deformation field.
886

887
        Args:
888
            transform_type (str): Type of deformation to apply to image slice. Possible
889
                values are:
890

891
                - translation.
892
                - rotation.
893
                - rotation_auto.
894
                - scaling.
895
                - scaling_auto.
896
                - homography.
897

898
            keep (bool, optional): Option to keep the specified coordinate transform in
899
                the class. Defaults to False.
900
            interp_order (int, optional): Interpolation order for filling in missed
901
                pixels. Defaults to 1.
902
            mapkwds (dict, optional): Additional arguments passed to
903
                ``scipy.ndimage.map_coordinates()``. Defaults to None.
904
            **kwds: keyword arguments.
905

906
                - **image**: Image to use. Defaults to self.slice.
907
                - **stackaxis**: Stacking axis for coordinate transformation matrices.
908
                  Defaults to 0.
909

910
                Additional arguments are passed to the specific deformation field generators.
911
                See ``symmetrize.sym`` module.
912
        Returns:
913
            np.ndarray: The corrected image.
914
        """
915
        if mapkwds is None:
1✔
916
            mapkwds = {}
1✔
917

918
        image = kwds.pop("image", self.slice)
1✔
919
        stackax = kwds.pop("stackaxis", 0)
1✔
920
        coordmat = sym.coordinate_matrix_2D(
1✔
921
            image,
922
            coordtype="homogeneous",
923
            stackaxis=stackax,
924
        )
925

926
        if transform_type == "translation":
1✔
927
            if "xtrans" in kwds and "ytrans" in kwds:
1✔
928
                tmp = kwds["ytrans"]
1✔
929
                kwds["ytrans"] = kwds["xtrans"]
1✔
930
                kwds["xtrans"] = tmp
1✔
931

932
            rdisp, cdisp = sym.translationDF(
1✔
933
                coordmat,
934
                stackaxis=stackax,
935
                ret="displacement",
936
                **kwds,
937
            )
938
        elif transform_type == "rotation":
1✔
939
            rdisp, cdisp = sym.rotationDF(
1✔
940
                coordmat,
941
                stackaxis=stackax,
942
                ret="displacement",
943
                **kwds,
944
            )
945
        elif transform_type == "rotation_auto":
1✔
946
            center = kwds.pop("center", self.pcent)
×
947
            # Estimate the optimal rotation angle using intensity symmetry
948
            angle_auto, _ = sym.sym_pose_estimate(
×
949
                image / image.max(),
950
                center=center,
951
                **kwds,
952
            )
953
            self.adjust_params = dictmerge(
×
954
                self.adjust_params,
955
                {"center": center, "angle": angle_auto},
956
            )
957
            rdisp, cdisp = sym.rotationDF(
×
958
                coordmat,
959
                stackaxis=stackax,
960
                ret="displacement",
961
                angle=angle_auto,
962
            )
963
        elif transform_type == "scaling":
1✔
964
            rdisp, cdisp = sym.scalingDF(
1✔
965
                coordmat,
966
                stackaxis=stackax,
967
                ret="displacement",
968
                **kwds,
969
            )
970
        elif transform_type == "scaling_auto":  # Compare scaling to a reference image
×
971
            pass
×
972
        elif transform_type == "shearing":
×
973
            rdisp, cdisp = sym.shearingDF(
×
974
                coordmat,
975
                stackaxis=stackax,
976
                ret="displacement",
977
                **kwds,
978
            )
979
        elif transform_type == "homography":
×
980
            transform = kwds.pop("transform", np.eye(3))
×
981
            rdisp, cdisp = sym.compose_deform_field(
×
982
                coordmat,
983
                mat_transform=transform,
984
                stackaxis=stackax,
985
                ret="displacement",
986
                **kwds,
987
            )
988

989
        # Compute deformation field
990
        if stackax == 0:
1✔
991
            rdeform, cdeform = (
1✔
992
                coordmat[1, ...] + rdisp,
993
                coordmat[0, ...] + cdisp,
994
            )
995
        elif stackax == -1:
×
996
            rdeform, cdeform = (
×
997
                coordmat[..., 1] + rdisp,
998
                coordmat[..., 0] + cdisp,
999
            )
1000

1001
        # Resample image in the deformation field
1002
        if image is self.slice:  # resample using all previous displacement fields
1✔
1003
            total_rdeform = ndi.map_coordinates(
1✔
1004
                self.rdeform_field,
1005
                [rdeform, cdeform],
1006
                order=1,
1007
            )
1008
            total_cdeform = ndi.map_coordinates(
1✔
1009
                self.cdeform_field,
1010
                [rdeform, cdeform],
1011
                order=1,
1012
            )
1013
            slice_transformed = ndi.map_coordinates(
1✔
1014
                image,
1015
                [total_rdeform, total_cdeform],
1016
                order=interp_order,
1017
                **mapkwds,
1018
            )
1019
            self.slice_transformed = slice_transformed
1✔
1020
        else:
1021
            # if external image is provided, apply only the new additional transformation
1022
            slice_transformed = ndi.map_coordinates(
1✔
1023
                image,
1024
                [rdeform, cdeform],
1025
                order=interp_order,
1026
                **mapkwds,
1027
            )
1028

1029
        # Combine deformation fields
1030
        if keep is True:
1✔
1031
            self.update_deformation(
1✔
1032
                rdeform,
1033
                cdeform,
1034
            )
1035
            self.adjust_params["applied"] = True
1✔
1036
            self.adjust_params = dictmerge(self.adjust_params, kwds)
1✔
1037

1038
        return slice_transformed
1✔
1039

1040
    def pose_adjustment(
1✔
1041
        self,
1042
        transformations: dict[str, Any] = None,
1043
        apply: bool = False,
1044
        reset: bool = True,
1045
        verbose: bool = True,
1046
        **kwds,
1047
    ):
1048
        """Interactive panel to adjust transformations that are applied to the image.
1049
        Applies first a scaling, next a x/y translation, and last a rotation around
1050
        the center of the image (pixel 256/256).
1051

1052
        Args:
1053
            transformations (dict, optional): Dictionary with transformations.
1054
                Defaults to self.transformations or config["momentum"]["transformations"].
1055
            apply (bool, optional):
1056
                Option to directly apply the provided transformations.
1057
                Defaults to False.
1058
            reset (bool, optional):
1059
                Option to reset the correction before transformation. Defaults to True.
1060
            verbose (bool, optional):
1061
                Option to report the performed transformations. Defaults to True.
1062
            **kwds: Keyword parameters defining defaults for the transformations:
1063

1064
                - **scale** (float): Initial value of the scaling slider.
1065
                - **xtrans** (float): Initial value of the xtrans slider.
1066
                - **ytrans** (float): Initial value of the ytrans slider.
1067
                - **angle** (float): Initial value of the angle slider.
1068
        """
1069
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
1070
        if self.slice_corrected is None or not self.slice_corrected.any():
1✔
1071
            if self.slice is None or not self.slice.any():
1✔
1072
                self.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
1073
            source_image = self.slice
1✔
1074
            plot = False
1✔
1075
        else:
1076
            source_image = self.slice_corrected
1✔
1077
            plot = True
1✔
1078

1079
        transformed_image = source_image
1✔
1080

1081
        if reset:
1✔
1082
            if self.rdeform_field_bkp is not None and self.cdeform_field_bkp is not None:
1✔
1083
                self.rdeform_field = self.rdeform_field_bkp
1✔
1084
                self.cdeform_field = self.cdeform_field_bkp
1✔
1085
            else:
1086
                self.reset_deformation()
1✔
1087

1088
        center = self._config["momentum"]["center_pixel"]
1✔
1089
        if plot:
1✔
1090
            fig, ax = plt.subplots(1, 1)
1✔
1091
            img = ax.imshow(transformed_image.T, origin="lower", cmap="terrain_r")
1✔
1092
            ax.axvline(x=center[0])
1✔
1093
            ax.axhline(y=center[1])
1✔
1094

1095
        if transformations is None:
1✔
1096
            transformations = deepcopy(self.transformations)
1✔
1097

1098
        if len(kwds) > 0:
1✔
1099
            for key in ["scale", "xtrans", "ytrans", "angle"]:
1✔
1100
                if key in kwds:
1✔
1101
                    transformations[key] = kwds.pop(key)
1✔
1102

1103
            if len(kwds) > 0:
1✔
NEW
1104
                raise TypeError(
×
1105
                    f"pose_adjustment() got unexpected keyword arguments {kwds.keys()}.",
1106
                )
1107

1108
        elif "creation_date" in transformations and verbose:
1✔
1109
            datestring = datetime.fromtimestamp(transformations["creation_date"]).strftime(
1✔
1110
                "%m/%d/%Y, %H:%M:%S",
1111
            )
1112
            print(f"Using transformation parameters generated on {datestring}")
1✔
1113

1114
        def update(scale: float, xtrans: float, ytrans: float, angle: float):
1✔
1115
            transformed_image = source_image
1✔
1116
            if scale != 1:
1✔
1117
                transformations["scale"] = scale
1✔
1118
                transformed_image = self.coordinate_transform(
1✔
1119
                    image=transformed_image,
1120
                    transform_type="scaling",
1121
                    xscale=scale,
1122
                    yscale=scale,
1123
                )
1124
            if xtrans != 0:
1✔
1125
                transformations["xtrans"] = xtrans
1✔
1126
            if ytrans != 0:
1✔
1127
                transformations["ytrans"] = ytrans
1✔
1128
            if xtrans != 0 or ytrans != 0:
1✔
1129
                transformed_image = self.coordinate_transform(
1✔
1130
                    image=transformed_image,
1131
                    transform_type="translation",
1132
                    xtrans=xtrans,
1133
                    ytrans=ytrans,
1134
                )
1135
            if angle != 0:
1✔
1136
                transformations["angle"] = angle
1✔
1137
                transformed_image = self.coordinate_transform(
1✔
1138
                    image=transformed_image,
1139
                    transform_type="rotation",
1140
                    angle=angle,
1141
                    center=center,
1142
                )
1143
            if plot:
1✔
1144
                img.set_data(transformed_image.T)
1✔
1145
                axmin = np.min(transformed_image, axis=(0, 1))
1✔
1146
                axmax = np.max(transformed_image, axis=(0, 1))
1✔
1147
                if axmin < axmax:
1✔
1148
                    img.set_clim(axmin, axmax)
1✔
1149
                fig.canvas.draw_idle()
1✔
1150

1151
        update(
1✔
1152
            scale=transformations.get("scale", 1),
1153
            xtrans=transformations.get("xtrans", 0),
1154
            ytrans=transformations.get("ytrans", 0),
1155
            angle=transformations.get("angle", 0),
1156
        )
1157

1158
        scale_slider = ipw.FloatSlider(
1✔
1159
            value=transformations.get("scale", 1),
1160
            min=0.8,
1161
            max=1.2,
1162
            step=0.01,
1163
        )
1164
        xtrans_slider = ipw.FloatSlider(
1✔
1165
            value=transformations.get("xtrans", 0),
1166
            min=-200,
1167
            max=200,
1168
            step=1,
1169
        )
1170
        ytrans_slider = ipw.FloatSlider(
1✔
1171
            value=transformations.get("ytrans", 0),
1172
            min=-200,
1173
            max=200,
1174
            step=1,
1175
        )
1176
        angle_slider = ipw.FloatSlider(
1✔
1177
            value=transformations.get("angle", 0),
1178
            min=-180,
1179
            max=180,
1180
            step=1,
1181
        )
1182
        results_box = ipw.Output()
1✔
1183
        ipw.interact(
1✔
1184
            update,
1185
            scale=scale_slider,
1186
            xtrans=xtrans_slider,
1187
            ytrans=ytrans_slider,
1188
            angle=angle_slider,
1189
        )
1190

1191
        def apply_func(apply: bool):  # noqa: ARG001
1✔
1192
            if transformations.get("scale", 1) != 1:
1✔
1193
                self.coordinate_transform(
1✔
1194
                    transform_type="scaling",
1195
                    xscale=transformations["scale"],
1196
                    yscale=transformations["scale"],
1197
                    keep=True,
1198
                )
1199
                if verbose:
1✔
1200
                    with results_box:
1✔
1201
                        print(f"Applied scaling with scale={transformations['scale']}.")
1✔
1202
            if transformations.get("xtrans", 0) != 0 or transformations.get("ytrans", 0) != 0:
1✔
1203
                self.coordinate_transform(
1✔
1204
                    transform_type="translation",
1205
                    xtrans=transformations.get("xtrans", 0),
1206
                    ytrans=transformations.get("ytrans", 0),
1207
                    keep=True,
1208
                )
1209
                if verbose:
1✔
1210
                    with results_box:
1✔
1211
                        print(
1✔
1212
                            f"Applied translation with (xtrans={transformations.get('xtrans', 0)},",
1213
                            f"ytrans={transformations.get('ytrans', 0)}).",
1214
                        )
1215
            if transformations.get("angle", 0) != 0:
1✔
1216
                self.coordinate_transform(
1✔
1217
                    transform_type="rotation",
1218
                    angle=transformations["angle"],
1219
                    center=center,
1220
                    keep=True,
1221
                )
1222
                if verbose:
1✔
1223
                    with results_box:
1✔
1224
                        print(f"Applied rotation with angle={transformations['angle']}.")
1✔
1225

1226
                display(results_box)
1✔
1227

1228
            if plot:
1✔
1229
                img.set_data(self.slice_transformed.T)
1✔
1230
                axmin = np.min(self.slice_transformed, axis=(0, 1))
1✔
1231
                axmax = np.max(self.slice_transformed, axis=(0, 1))
1✔
1232
                if axmin < axmax:
1✔
1233
                    img.set_clim(axmin, axmax)
1✔
1234
                fig.canvas.draw_idle()
1✔
1235

1236
            if transformations != self.transformations:
1✔
1237
                transformations["creation_date"] = datetime.now().timestamp()
1✔
1238
                self.transformations = transformations
1✔
1239

1240
            if verbose:
1✔
1241
                plt.figure()
1✔
1242
                subs = 20
1✔
1243
                plt.title("Deformation field")
1✔
1244
                plt.scatter(
1✔
1245
                    self.rdeform_field[::subs, ::subs].ravel(),
1246
                    self.cdeform_field[::subs, ::subs].ravel(),
1247
                    c="b",
1248
                )
1249
                plt.show()
1✔
1250
            scale_slider.close()
1✔
1251
            xtrans_slider.close()
1✔
1252
            ytrans_slider.close()
1✔
1253
            angle_slider.close()
1✔
1254
            apply_button.close()
1✔
1255

1256
        apply_button = ipw.Button(description="apply")
1✔
1257
        display(apply_button)
1✔
1258
        apply_button.on_click(apply_func)
1✔
1259

1260
        if plot:
1✔
1261
            plt.show()
1✔
1262

1263
        if apply:
1✔
1264
            apply_func(True)
1✔
1265

1266
    def calc_inverse_dfield(self):
1✔
1267
        """Calculate the inverse dfield from the cdeform and rdeform fields"""
1268
        self.inverse_dfield = generate_inverse_dfield(
×
1269
            self.rdeform_field,
1270
            self.cdeform_field,
1271
            self.bin_ranges,
1272
            self.detector_ranges,
1273
        )
1274

1275
        return self.inverse_dfield
×
1276

1277
    def view(
1✔
1278
        self,
1279
        image: np.ndarray = None,
1280
        origin: str = "lower",
1281
        cmap: str = "terrain_r",
1282
        figsize: tuple[int, int] = (4, 4),
1283
        points: dict = None,
1284
        annotated: bool = False,
1285
        backend: str = "matplotlib",
1286
        imkwds: dict = {},
1287
        scatterkwds: dict = {},
1288
        cross: bool = False,
1289
        crosshair: bool = False,
1290
        crosshair_radii: list[int] = [50, 100, 150],
1291
        crosshair_thickness: int = 1,
1292
        **kwds,
1293
    ):
1294
        """Display image slice with specified annotations.
1295

1296
        Args:
1297
            image (np.ndarray, optional): The image to plot. Defaults to self.slice.
1298
            origin (str, optional): Figure origin specification ('lower' or 'upper').
1299
                Defaults to "lower".
1300
            cmap (str, optional):  Colormap specification. Defaults to "terrain_r".
1301
            figsize (tuple[int, int], optional): Figure size. Defaults to (4, 4).
1302
            points (dict, optional): Points for annotation. Defaults to None.
1303
            annotated (bool, optional): Option to add annotation. Defaults to False.
1304
            backend (str, optional): Visualization backend specification. Defaults to
1305
                "matplotlib".
1306

1307
                - 'matplotlib': use static display rendered by matplotlib.
1308
                - 'bokeh': use interactive display rendered by bokeh.
1309

1310
            imkwds (dict, optional): Keyword arguments for
1311
                ``matplotlib.pyplot.imshow()``. Defaults to {}.
1312
            scatterkwds (dict, optional): Keyword arguments for
1313
                ``matplotlib.pyplot.scatter()``. Defaults to {}.
1314
            cross (bool, optional): Option to display a horizontal/vertical lines at
1315
                self.pcent. Defaults to False.
1316
            crosshair (bool, optional): Display option to plot circles around center
1317
                self.pcent. Works only in bokeh backend. Defaults to False.
1318
            crosshair_radii (list[int], optional): Pixel radii of circles to plot when
1319
                crosshair option is activated. Defaults to [50, 100, 150].
1320
            crosshair_thickness (int, optional): Thickness of crosshair circles.
1321
                Defaults to 1.
1322
            **kwds: keyword arguments.
1323
                General extra arguments for the plotting procedure.
1324
        """
1325
        if image is None:
1✔
1326
            image = self.slice
1✔
1327
        num_rows, num_cols = image.shape
1✔
1328

1329
        if points is None:
1✔
1330
            points = self.features
1✔
1331

1332
        if annotated:
1✔
1333
            tsr, tsc = kwds.pop("textshift", (3, 3))
1✔
1334
            txtsize = kwds.pop("textsize", 12)
1✔
1335

1336
        # Handle unexpected kwds:
1337
        handled_kwds = {"figsize"}
1✔
1338
        if not set(kwds.keys()).issubset(handled_kwds):
1✔
NEW
1339
            raise TypeError(
×
1340
                f"view() got unexpected keyword arguments {set(kwds.keys()) - handled_kwds}.",
1341
            )
1342

1343
        if backend == "matplotlib":
1✔
1344
            fig_plt, ax = plt.subplots(figsize=figsize)
×
1345
            ax.imshow(image.T, origin=origin, cmap=cmap, **imkwds)
×
1346

1347
            if cross:
×
1348
                center = self._config["momentum"]["center_pixel"]
×
1349
                ax.axvline(x=center[0])
×
1350
                ax.axhline(y=center[1])
×
1351

1352
            # Add annotation to the figure
1353
            if annotated:
×
1354
                for (
×
1355
                    p_keys,  # pylint: disable=unused-variable
1356
                    p_vals,
1357
                ) in points.items():
1358
                    try:
×
1359
                        ax.scatter(p_vals[:, 0], p_vals[:, 1], **scatterkwds)
×
1360
                    except IndexError:
×
1361
                        try:
×
1362
                            ax.scatter(p_vals[0], p_vals[1], **scatterkwds)
×
1363
                        except IndexError:
×
1364
                            pass
×
1365

1366
                    if p_vals.size > 2:
×
1367
                        for i_pval, pval in enumerate(p_vals):
×
1368
                            ax.text(
×
1369
                                pval[0] + tsc,
1370
                                pval[1] + tsr,
1371
                                str(i_pval),
1372
                                fontsize=txtsize,
1373
                            )
1374

1375
        elif backend == "bokeh":
1✔
1376
            output_notebook(hide_banner=True)
1✔
1377
            colors = it.cycle(ColorCycle[10])
1✔
1378
            ttp = [("(x, y)", "($x, $y)")]
1✔
1379
            figsize = kwds.pop("figsize", (320, 300))
1✔
1380
            palette = cm2palette(cmap)  # Retrieve palette colors
1✔
1381
            fig = pbk.figure(
1✔
1382
                width=figsize[0],
1383
                height=figsize[1],
1384
                tooltips=ttp,
1385
                x_range=(0, num_rows),
1386
                y_range=(0, num_cols),
1387
            )
1388
            fig.image(
1✔
1389
                image=[image.T],
1390
                x=0,
1391
                y=0,
1392
                dw=num_rows,
1393
                dh=num_cols,
1394
                palette=palette,
1395
                **imkwds,
1396
            )
1397

1398
            if annotated is True:
1✔
1399
                for p_keys, p_vals in points.items():
1✔
1400
                    try:
1✔
1401
                        xcirc, ycirc = p_vals[:, 0], p_vals[:, 1]
1✔
1402
                        fig.scatter(
1✔
1403
                            xcirc,
1404
                            ycirc,
1405
                            size=8,
1406
                            color=next(colors),
1407
                            **scatterkwds,
1408
                        )
1409
                    except IndexError:
1✔
1410
                        try:
1✔
1411
                            xcirc, ycirc = p_vals[0], p_vals[1]
1✔
1412
                            fig.scatter(
1✔
1413
                                xcirc,
1414
                                ycirc,
1415
                                size=8,
1416
                                color=next(colors),
1417
                                **scatterkwds,
1418
                            )
1419
                        except IndexError:
×
1420
                            pass
×
1421
            if crosshair and self.pcent is not None:
1✔
1422
                for radius in crosshair_radii:
1✔
1423
                    fig.annulus(
1✔
1424
                        x=[self.pcent[0]],
1425
                        y=[self.pcent[1]],
1426
                        inner_radius=radius - crosshair_thickness,
1427
                        outer_radius=radius,
1428
                        color="red",
1429
                        alpha=0.6,
1430
                    )
1431

1432
            pbk.show(fig)
1✔
1433

1434
    def select_k_range(
1✔
1435
        self,
1436
        point_a: np.ndarray | list[int] = None,
1437
        point_b: np.ndarray | list[int] = None,
1438
        k_distance: float = None,
1439
        k_coord_a: np.ndarray | list[float] = None,
1440
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
1441
        equiscale: bool = True,
1442
        apply: bool = False,
1443
    ):
1444
        """Interactive selection function for features for the Momentum axes calibration. It allows
1445
        the user to select the pixel positions of two symmetry points (a and b) and the k-space
1446
        distance of the two. Alternatively, the coordinates of both points can be provided. See the
1447
        equiscale option for details on the specifications of point coordinates.
1448

1449
        Args:
1450
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the
1451
                symmetry point a.
1452
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
1453
                symmetry point b. Defaults to the center pixel of the image, defined by
1454
                config["momentum"]["center_pixel"].
1455
            k_distance (float, optional): The known momentum space distance between the
1456
                two symmetry points.
1457
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
1458
                of the symmetry points a. Only valid if equiscale=False.
1459
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
1460
                of the symmetry points b. Only valid if equiscale=False. Defaults to
1461
                the k-space center np.array([0.0, 0.0]).
1462
            equiscale (bool, optional): Option to adopt equal scale along both the x
1463
                and y directions.
1464

1465
                - **True**: Use a uniform scale for both x and y directions in the
1466
                  image coordinate system. This applies to the situation where
1467
                  k_distance is given and the points a and b are (close to) parallel
1468
                  with one of the two image axes.
1469
                - **False**: Calculate the momentum scale for both x and y directions
1470
                  separately. This applies to the situation where the points a and b
1471
                  are sufficiently different in both x and y directions in the image
1472
                  coordinate system.
1473

1474
                Defaults to 'True'.
1475

1476
            apply (bool, optional): Option to directly store the calibration parameters
1477
                to the class. Defaults to False.
1478

1479
        Raises:
1480
            ValueError: If no valid image is found from which to ge the coordinates.
1481
        """
1482
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
1483
        if self.slice_transformed is not None:
1✔
1484
            image = self.slice_transformed
1✔
1485
        elif self.slice_corrected is not None:
1✔
1486
            image = self.slice_corrected
×
1487
        elif self.slice is not None:
1✔
1488
            image = self.slice
×
1489
        else:
1490
            raise ValueError("No valid image loaded!")
1✔
1491

1492
        if point_b is None:
1✔
1493
            point_b = self._config["momentum"]["center_pixel"]
×
1494

1495
        if point_a is None:
1✔
1496
            point_a = [0, 0]
×
1497

1498
        fig, ax = plt.subplots(1, 1)
1✔
1499
        img = ax.imshow(image.T, origin="lower", cmap="terrain_r")
1✔
1500

1501
        (marker_a,) = ax.plot(point_a[0], point_a[1], "o")
1✔
1502
        (marker_b,) = ax.plot(point_b[0], point_b[1], "ro")
1✔
1503

1504
        def update(
1✔
1505
            point_a_x: int,
1506
            point_a_y: int,
1507
            point_b_x: int,
1508
            point_b_y: int,
1509
            k_distance: float,  # noqa: ARG001
1510
        ):
1511
            fig.canvas.draw_idle()
1✔
1512
            marker_a.set_xdata(point_a_x)
1✔
1513
            marker_a.set_ydata(point_a_y)
×
1514
            marker_b.set_xdata(point_b_x)
×
1515
            marker_b.set_ydata(point_b_y)
×
1516

1517
        point_a_input_x = ipw.IntText(point_a[0])
1✔
1518
        point_a_input_y = ipw.IntText(point_a[1])
1✔
1519
        point_b_input_x = ipw.IntText(point_b[0])
1✔
1520
        point_b_input_y = ipw.IntText(point_b[1])
1✔
1521
        k_distance_input = ipw.FloatText(k_distance)
1✔
1522
        ipw.interact(
1✔
1523
            update,
1524
            point_a_x=point_a_input_x,
1525
            point_a_y=point_a_input_y,
1526
            point_b_x=point_b_input_x,
1527
            point_b_y=point_b_input_y,
1528
            k_distance=k_distance_input,
1529
        )
1530

1531
        self._state = 0
1✔
1532

1533
        def onclick(event):
1✔
1534
            if self._state == 0:
×
1535
                point_a_input_x.value = event.xdata
×
1536
                point_a_input_y.value = event.ydata
×
1537
                self._state = 1
×
1538
            else:
1539
                point_b_input_x.value = event.xdata
×
1540
                point_b_input_y.value = event.ydata
×
1541
                self._state = 0
×
1542

1543
        cid = fig.canvas.mpl_connect("button_press_event", onclick)
1✔
1544

1545
        def apply_func(apply: bool):  # noqa: ARG001
1✔
1546
            point_a = [point_a_input_x.value, point_a_input_y.value]
1✔
1547
            point_b = [point_b_input_x.value, point_b_input_y.value]
1✔
1548
            calibration = self.calibrate(
1✔
1549
                point_a=point_a,
1550
                point_b=point_b,
1551
                k_distance=k_distance,
1552
                equiscale=equiscale,
1553
                k_coord_a=k_coord_a,
1554
                k_coord_b=k_coord_b,
1555
            )
1556

1557
            img.set_extent(calibration["extent"])
1✔
1558
            plt.title("Momentum calibrated data")
1✔
1559
            plt.xlabel("$k_x$", fontsize=15)
1✔
1560
            plt.ylabel("$k_y$", fontsize=15)
1✔
1561
            ax.axhline(0)
1✔
1562
            ax.axvline(0)
1✔
1563

1564
            fig.canvas.mpl_disconnect(cid)
1✔
1565

1566
            point_a_input_x.close()
1✔
1567
            point_a_input_y.close()
1✔
1568
            point_b_input_x.close()
1✔
1569
            point_b_input_y.close()
1✔
1570
            k_distance_input.close()
1✔
1571
            apply_button.close()
1✔
1572

1573
            fig.canvas.draw_idle()
1✔
1574

1575
        apply_button = ipw.Button(description="apply")
1✔
1576
        display(apply_button)
1✔
1577
        apply_button.on_click(apply_func)
1✔
1578

1579
        if apply:
1✔
1580
            apply_func(True)
1✔
1581

1582
        plt.show()
1✔
1583

1584
    def calibrate(
1✔
1585
        self,
1586
        point_a: np.ndarray | list[int],
1587
        point_b: np.ndarray | list[int],
1588
        k_distance: float = None,
1589
        k_coord_a: np.ndarray | list[float] = None,
1590
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
1591
        equiscale: bool = True,
1592
        image: np.ndarray = None,
1593
    ) -> dict:
1594
        """Momentum axes calibration using the pixel positions of two symmetry points
1595
        (a and b) and the absolute coordinate of a single point (b), defaulted to
1596
        [0., 0.]. All coordinates should be specified in the (x/y), i.e. (column_index,
1597
        row_index) format. See the equiscale option for details on the specifications
1598
        of point coordinates.
1599

1600
        Args:
1601
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the
1602
                symmetry point a.
1603
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
1604
                symmetry point b. Defaults to the center pixel of the image, defined by
1605
                config["momentum"]["center_pixel"].
1606
            k_distance (float, optional): The known momentum space distance between the
1607
                two symmetry points.
1608
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
1609
                of the symmetry points a. Only valid if equiscale=False.
1610
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
1611
                of the symmetry points b. Only valid if equiscale=False. Defaults to
1612
                the k-space center np.array([0.0, 0.0]).
1613
            equiscale (bool, optional): Option to adopt equal scale along both the x
1614
                and y directions.
1615

1616
                - **True**: Use a uniform scale for both x and y directions in the
1617
                  image coordinate system. This applies to the situation where
1618
                  k_distance is given and the points a and b are (close to) parallel
1619
                  with one of the two image axes.
1620
                - **False**: Calculate the momentum scale for both x and y directions
1621
                  separately. This applies to the situation where the points a and b
1622
                  are sufficiently different in both x and y directions in the image
1623
                  coordinate system.
1624

1625
                Defaults to 'True'.
1626
            image (np.ndarray, optional): The energy slice for which to return the
1627
                calibration. Defaults to self.slice_corrected.
1628

1629
        Returns:
1630
            dict: dictionary with following entries:
1631

1632
                - "axes": Tuple of 1D arrays
1633
                  Momentum coordinates of the row and column.
1634
                - "extent": list
1635
                  Extent of the two momentum axis (can be used directly in imshow).
1636
                - "grid": Tuple of 2D arrays
1637
                  Row and column mesh grid generated from the coordinates
1638
                  (can be used directly in pcolormesh).
1639
                - "coeffs": Tuple of (x, y) calibration coefficients
1640
                - "x_center", "y_center": Pixel positions of the k-space center
1641
                - "cstart", "rstart": Detector positions of the image used for
1642
                  calibration
1643
                - "cstep", "rstep": Step size of detector coordinates in the image
1644
                  used for calibration
1645
        """
1646
        if image is None:
1✔
1647
            image = self.slice_corrected
1✔
1648

1649
        nrows, ncols = image.shape
1✔
1650
        point_a, point_b = map(np.array, [point_a, point_b])
1✔
1651

1652
        rowdist = range(nrows) - point_b[0]
1✔
1653
        coldist = range(ncols) - point_b[1]
1✔
1654

1655
        if equiscale is True:
1✔
1656
            assert k_distance is not None
1✔
1657
            # Use the same conversion factor along both x and y directions
1658
            # (need k_distance)
1659
            pixel_distance = norm(point_a - point_b)
1✔
1660
            # Calculate the pixel to momentum conversion factor
1661
            xratio = yratio = k_distance / pixel_distance
1✔
1662

1663
        else:
1664
            assert k_coord_a is not None
1✔
1665
            # Calculate the conversion factor along x and y directions separately
1666
            # (need k_coord_a)
1667
            kxb, kyb = k_coord_b
1✔
1668
            kxa, kya = k_coord_a
1✔
1669
            # Calculate the column- and row-wise conversion factor
1670
            xratio = (kxa - kxb) / (point_a[0] - point_b[0])
1✔
1671
            yratio = (kya - kyb) / (point_a[1] - point_b[1])
1✔
1672

1673
        k_row = rowdist * xratio + k_coord_b[0]
1✔
1674
        k_col = coldist * yratio + k_coord_b[1]
1✔
1675

1676
        # Calculate other return parameters
1677
        k_rowgrid, k_colgrid = np.meshgrid(k_row, k_col)
1✔
1678

1679
        # Assemble into return dictionary
1680
        self.calibration = {}
1✔
1681
        self.calibration["creation_date"] = datetime.now().timestamp()
1✔
1682
        self.calibration["kx_axis"] = k_row
1✔
1683
        self.calibration["ky_axis"] = k_col
1✔
1684
        self.calibration["grid"] = (k_rowgrid, k_colgrid)
1✔
1685
        self.calibration["extent"] = (k_row[0], k_row[-1], k_col[0], k_col[-1])
1✔
1686
        self.calibration["kx_scale"] = xratio
1✔
1687
        self.calibration["ky_scale"] = yratio
1✔
1688
        self.calibration["x_center"] = point_b[0] - k_coord_b[0] / xratio
1✔
1689
        self.calibration["y_center"] = point_b[1] - k_coord_b[1] / yratio
1✔
1690
        # copy parameters for applying calibration
1691
        try:
1✔
1692
            self.calibration["rstart"] = self.bin_ranges[0][0]
1✔
1693
            self.calibration["cstart"] = self.bin_ranges[1][0]
1✔
1694
            self.calibration["rstep"] = (self.bin_ranges[0][1] - self.bin_ranges[0][0]) / nrows
1✔
1695
            self.calibration["cstep"] = (self.bin_ranges[1][1] - self.bin_ranges[1][0]) / ncols
1✔
1696
        except (AttributeError, IndexError):
×
1697
            pass
×
1698

1699
        return self.calibration
1✔
1700

1701
    def apply_corrections(
1✔
1702
        self,
1703
        df: pd.DataFrame | dask.dataframe.DataFrame,
1704
        x_column: str = None,
1705
        y_column: str = None,
1706
        new_x_column: str = None,
1707
        new_y_column: str = None,
1708
        verbose: bool = True,
1709
    ) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]:
1710
        """Calculate and replace the X and Y values with their distortion-corrected
1711
        version.
1712

1713
        Args:
1714
            df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply
1715
                the distortion correction to.
1716
            x_column (str, optional): Label of the 'X' column before momentum
1717
                distortion correction. Defaults to config["momentum"]["x_column"].
1718
            y_column (str, optional): Label of the 'Y' column before momentum
1719
                distortion correction. Defaults to config["momentum"]["y_column"].
1720
            new_x_column (str, optional): Label of the 'X' column after momentum
1721
                distortion correction.
1722
                Defaults to config["momentum"]["corrected_x_column"].
1723
            new_y_column (str, optional): Label of the 'Y' column after momentum
1724
                distortion correction.
1725
                Defaults to config["momentum"]["corrected_y_column"].
1726
            verbose (bool, optional): Option to report the used landmarks for correction.
1727
                Defaults to True.
1728

1729
        Returns:
1730
            tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]: Dataframe with
1731
            added columns and momentum correction metadata dictionary.
1732
        """
1733
        if x_column is None:
1✔
1734
            x_column = self.x_column
1✔
1735
        if y_column is None:
1✔
1736
            y_column = self.y_column
1✔
1737

1738
        if new_x_column is None:
1✔
1739
            new_x_column = self.corrected_x_column
1✔
1740
        if new_y_column is None:
1✔
1741
            new_y_column = self.corrected_y_column
1✔
1742

1743
        if self.inverse_dfield is None or self.dfield_updated:
1✔
1744
            if self.rdeform_field is None and self.cdeform_field is None:
1✔
1745
                if self.correction or self.transformations:
1✔
1746
                    if self.correction:
1✔
1747
                        # Generate spline warp from class features or config
1748
                        self.spline_warp_estimate(verbose=verbose)
1✔
1749
                    if self.transformations:
1✔
1750
                        # Apply config pose adjustments
1751
                        self.pose_adjustment()
1✔
1752
                else:
1753
                    raise ValueError("No corrections or transformations defined!")
1✔
1754

1755
            self.inverse_dfield = generate_inverse_dfield(
1✔
1756
                self.rdeform_field,
1757
                self.cdeform_field,
1758
                self.bin_ranges,
1759
                self.detector_ranges,
1760
            )
1761
            self.dfield_updated = False
1✔
1762

1763
        out_df = df.map_partitions(
1✔
1764
            apply_dfield,
1765
            dfield=self.inverse_dfield,
1766
            x_column=x_column,
1767
            y_column=y_column,
1768
            new_x_column=new_x_column,
1769
            new_y_column=new_y_column,
1770
            detector_ranges=self.detector_ranges,
1771
        )
1772

1773
        metadata = self.gather_correction_metadata()
1✔
1774

1775
        return out_df, metadata
1✔
1776

1777
    def gather_correction_metadata(self) -> dict:
1✔
1778
        """Collect meta data for momentum correction.
1779

1780
        Returns:
1781
            dict: generated correction metadata dictionary.
1782
        """
1783
        metadata: dict[Any, Any] = {}
1✔
1784
        if len(self.correction) > 0:
1✔
1785
            metadata["correction"] = self.correction
1✔
1786
            metadata["correction"]["applied"] = True
1✔
1787
            metadata["correction"]["cdeform_field"] = self.cdeform_field
1✔
1788
            metadata["correction"]["rdeform_field"] = self.rdeform_field
1✔
1789
            try:
1✔
1790
                metadata["correction"]["creation_date"] = self.correction["creation_date"]
1✔
1791
            except KeyError:
×
1792
                pass
×
1793
        if len(self.adjust_params) > 0:
1✔
1794
            metadata["registration"] = self.adjust_params
1✔
1795
            metadata["registration"]["creation_date"] = datetime.now().timestamp()
1✔
1796
            metadata["registration"]["applied"] = True
1✔
1797
            metadata["registration"]["depends_on"] = (
1✔
1798
                "/entry/process/registration/transformations/rot_z"
1799
                if "angle" in metadata["registration"] and metadata["registration"]["angle"]
1800
                else "/entry/process/registration/transformations/trans_y"
1801
                if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]
1802
                else "/entry/process/registration/transformations/trans_x"
1803
                if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1804
                else "."
1805
            )
1806
            if (
1✔
1807
                "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1808
            ):  # swapped definitions
1809
                metadata["registration"]["trans_x"] = {}
1✔
1810
                metadata["registration"]["trans_x"]["value"] = metadata["registration"]["ytrans"]
1✔
1811
                metadata["registration"]["trans_x"]["type"] = "translation"
1✔
1812
                metadata["registration"]["trans_x"]["units"] = "pixel"
1✔
1813
                metadata["registration"]["trans_x"]["vector"] = np.asarray(
1✔
1814
                    [1.0, 0.0, 0.0],
1815
                )
1816
                metadata["registration"]["trans_x"]["depends_on"] = "."
1✔
1817
            if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]:
1✔
1818
                metadata["registration"]["trans_y"] = {}
1✔
1819
                metadata["registration"]["trans_y"]["value"] = metadata["registration"]["xtrans"]
1✔
1820
                metadata["registration"]["trans_y"]["type"] = "translation"
1✔
1821
                metadata["registration"]["trans_y"]["units"] = "pixel"
1✔
1822
                metadata["registration"]["trans_y"]["vector"] = np.asarray(
1✔
1823
                    [0.0, 1.0, 0.0],
1824
                )
1825
                metadata["registration"]["trans_y"]["depends_on"] = (
1✔
1826
                    "/entry/process/registration/transformations/trans_x"
1827
                    if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1828
                    else "."
1829
                )
1830
            if "angle" in metadata["registration"] and metadata["registration"]["angle"]:
1✔
1831
                metadata["registration"]["rot_z"] = {}
1✔
1832
                metadata["registration"]["rot_z"]["value"] = metadata["registration"]["angle"]
1✔
1833
                metadata["registration"]["rot_z"]["type"] = "rotation"
1✔
1834
                metadata["registration"]["rot_z"]["units"] = "degrees"
1✔
1835
                metadata["registration"]["rot_z"]["vector"] = np.asarray(
1✔
1836
                    [0.0, 0.0, 1.0],
1837
                )
1838
                metadata["registration"]["rot_z"]["offset"] = np.concatenate(
1✔
1839
                    (metadata["registration"]["center"], [0.0]),
1840
                )
1841
                metadata["registration"]["rot_z"]["depends_on"] = (
1✔
1842
                    "/entry/process/registration/transformations/trans_y"
1843
                    if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]
1844
                    else "/entry/process/registration/transformations/trans_x"
1845
                    if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1846
                    else "."
1847
                )
1848

1849
        return metadata
1✔
1850

1851
    def append_k_axis(
1✔
1852
        self,
1853
        df: pd.DataFrame | dask.dataframe.DataFrame,
1854
        x_column: str = None,
1855
        y_column: str = None,
1856
        new_x_column: str = None,
1857
        new_y_column: str = None,
1858
        calibration: dict = None,
1859
        **kwds,
1860
    ) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]:
1861
        """Calculate and append the k axis coordinates (kx, ky) to the events dataframe.
1862

1863
        Args:
1864
            df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply the
1865
                distortion correction to.
1866
            x_column (str, optional): Label of the source 'X' column.
1867
                Defaults to config["momentum"]["corrected_x_column"] or
1868
                config["momentum"]["x_column"] (whichever is present).
1869
            y_column (str, optional): Label of the source 'Y' column.
1870
                Defaults to config["momentum"]["corrected_y_column"] or
1871
                config["momentum"]["y_column"] (whichever is present).
1872
            new_x_column (str, optional): Label of the destination 'X' column after
1873
                momentum calibration. Defaults to config["momentum"]["kx_column"].
1874
            new_y_column (str, optional): Label of the destination 'Y' column after
1875
                momentum calibration. Defaults to config["momentum"]["ky_column"].
1876
            calibration (dict, optional): Dictionary containing calibration parameters.
1877
                Defaults to 'self.calibration' or config["momentum"]["calibration"].
1878
            **kwds: Keyword parameters for momentum calibration. Parameters are added
1879
                to the calibration dictionary.
1880

1881
        Returns:
1882
            tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]: Dataframe with
1883
            added columns and momentum calibration metadata dictionary.
1884
        """
1885
        if x_column is None:
1✔
1886
            if self.corrected_x_column in df.columns:
1✔
1887
                x_column = self.corrected_x_column
×
1888
            else:
1889
                x_column = self.x_column
1✔
1890
        if y_column is None:
1✔
1891
            if self.corrected_y_column in df.columns:
1✔
1892
                y_column = self.corrected_y_column
×
1893
            else:
1894
                y_column = self.y_column
1✔
1895

1896
        if new_x_column is None:
1✔
1897
            new_x_column = self.kx_column
1✔
1898

1899
        if new_y_column is None:
1✔
1900
            new_y_column = self.ky_column
1✔
1901

1902
        # pylint: disable=duplicate-code
1903
        if calibration is None:
1✔
1904
            calibration = deepcopy(self.calibration)
1✔
1905

1906
        if len(kwds) > 0:
1✔
1907
            for key in [
1✔
1908
                "rstart",
1909
                "cstart",
1910
                "x_center",
1911
                "y_center",
1912
                "kx_scale",
1913
                "ky_scale",
1914
                "rstep",
1915
                "cstep",
1916
            ]:
1917
                if key in kwds:
1✔
1918
                    calibration[key] = kwds.pop(key)
1✔
1919
            calibration["creation_date"] = datetime.now().timestamp()
1✔
1920

1921
            if len(kwds) > 0:
1✔
NEW
1922
                raise TypeError(f"append_k_axis() got unexpected keyword arguments {kwds.keys()}.")
×
1923

1924
        try:
1✔
1925
            (df[new_x_column], df[new_y_column]) = detector_coordinates_2_k_coordinates(
1✔
1926
                r_det=df[x_column],
1927
                c_det=df[y_column],
1928
                r_start=calibration["rstart"],
1929
                c_start=calibration["cstart"],
1930
                r_center=calibration["x_center"],
1931
                c_center=calibration["y_center"],
1932
                r_conversion=calibration["kx_scale"],
1933
                c_conversion=calibration["ky_scale"],
1934
                r_step=calibration["rstep"],
1935
                c_step=calibration["cstep"],
1936
            )
1937
        except KeyError as exc:
1✔
1938
            raise ValueError(
1✔
1939
                "Required calibration parameters missing!",
1940
            ) from exc
1941

1942
        metadata = self.gather_calibration_metadata(calibration=calibration)
1✔
1943

1944
        return df, metadata
1✔
1945

1946
    def gather_calibration_metadata(self, calibration: dict = None) -> dict:
1✔
1947
        """Collect meta data for momentum calibration
1948

1949
        Args:
1950
            calibration (dict, optional): Dictionary with momentum calibration
1951
                parameters. If omitted will be taken from the class.
1952

1953
        Returns:
1954
            dict: Generated metadata dictionary.
1955
        """
1956
        if calibration is None:
1✔
1957
            calibration = self.calibration
×
1958
        metadata: dict[Any, Any] = {}
1✔
1959
        try:
1✔
1960
            metadata["creation_date"] = calibration["creation_date"]
1✔
1961
        except KeyError:
1✔
1962
            pass
1✔
1963
        metadata["applied"] = True
1✔
1964
        metadata["calibration"] = calibration
1✔
1965
        # create empty calibrated axis entries, if they are not present.
1966
        if "kx_axis" not in metadata["calibration"]:
1✔
1967
            metadata["calibration"]["kx_axis"] = 0
1✔
1968
        if "ky_axis" not in metadata["calibration"]:
1✔
1969
            metadata["calibration"]["ky_axis"] = 0
1✔
1970

1971
        return metadata
1✔
1972

1973

1974
def cm2palette(cmap_name: str) -> list:
1✔
1975
    """Convert certain matplotlib colormap (cm) to bokeh palette.
1976

1977
    Args:
1978
        cmap_name (str): Name of the colormap/palette.
1979

1980
    Returns:
1981
        list: List of colors in hex representation (a bokeh palette).
1982
    """
1983
    if cmap_name in bp.all_palettes.keys():
1✔
1984
        palette_func = getattr(bp, cmap_name)
×
1985
        palette = palette_func
×
1986

1987
    else:
1988
        palette_func = getattr(cm, cmap_name)
1✔
1989
        mpl_cm_rgb = (255 * palette_func(range(256))).astype("int")
1✔
1990
        palette = [RGB(*tuple(rgb)).to_hex() for rgb in mpl_cm_rgb]
1✔
1991

1992
    return palette
1✔
1993

1994

1995
def dictmerge(
1✔
1996
    main_dict: dict,
1997
    other_entries: list[dict] | tuple[dict] | dict,
1998
) -> dict:
1999
    """Merge a dictionary with other dictionaries.
2000

2001
    Args:
2002
        main_dict (dict): Main dictionary.
2003
        other_entries (list[dict] | tuple[dict] | dict):
2004
            Other dictionary or composite dictionarized elements.
2005

2006
    Returns:
2007
        dict: Merged dictionary.
2008
    """
2009
    # Merge main_dict with a list or tuple of dictionaries
2010
    if isinstance(other_entries, (list, tuple)):
1✔
2011
        for oth in other_entries:
×
2012
            main_dict = {**main_dict, **oth}
×
2013
    # Merge D with a single dictionary
2014
    elif isinstance(other_entries, dict):
1✔
2015
        main_dict = {**main_dict, **other_entries}
1✔
2016

2017
    return main_dict
1✔
2018

2019

2020
def detector_coordinates_2_k_coordinates(
1✔
2021
    r_det: float,
2022
    c_det: float,
2023
    r_start: float,
2024
    c_start: float,
2025
    r_center: float,
2026
    c_center: float,
2027
    r_conversion: float,
2028
    c_conversion: float,
2029
    r_step: float,
2030
    c_step: float,
2031
) -> tuple[float, float]:
2032
    """Conversion from detector coordinates (r_det, c_det) to momentum coordinates
2033
    (kr, kc).
2034

2035
    Args:
2036
        r_det (float): Row detector coordinates.
2037
        c_det (float): Column detector coordinates.
2038
        r_start (float): Start value for row detector coordinates.
2039
        c_start (float): Start value for column detector coordinates.
2040
        r_center (float): Center value for row detector coordinates.
2041
        c_center (float): Center value for column detector coordinates.
2042
        r_conversion (float): Row conversion factor.
2043
        c_conversion (float): Column conversion factor.
2044
        r_step (float): Row stepping factor.
2045
        c_step (float): Column stepping factor.
2046

2047
    Returns:
2048
        tuple[float, float]: Converted momentum space row/column coordinates.
2049
    """
2050
    r_det0 = r_start + r_step * r_center
1✔
2051
    c_det0 = c_start + c_step * c_center
1✔
2052
    k_r = r_conversion * ((r_det - r_det0) / r_step)
1✔
2053
    k_c = c_conversion * ((c_det - c_det0) / c_step)
1✔
2054

2055
    return (k_r, k_c)
1✔
2056

2057

2058
def apply_dfield(
1✔
2059
    df: pd.DataFrame | dask.dataframe.DataFrame,
2060
    dfield: np.ndarray,
2061
    x_column: str,
2062
    y_column: str,
2063
    new_x_column: str,
2064
    new_y_column: str,
2065
    detector_ranges: list[tuple],
2066
) -> pd.DataFrame | dask.dataframe.DataFrame:
2067
    """Application of the inverse displacement-field to the dataframe coordinates.
2068

2069
    Args:
2070
        df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply the
2071
            distortion correction to.
2072
        dfield (np.ndarray): The distortion correction field. 3D matrix,
2073
            with column and row distortion fields stacked along the first dimension.
2074
        x_column (str): Label of the 'X' source column.
2075
        y_column (str): Label of the 'Y' source column.
2076
        new_x_column (str): Label of the 'X' destination column.
2077
        new_y_column (str): Label of the 'Y' destination column.
2078
        detector_ranges (list[tuple]): tuple of pixel ranges of the detector x/y
2079
            coordinates
2080

2081
    Returns:
2082
        pd.DataFrame | dask.dataframe.DataFrame: dataframe with added columns
2083
    """
2084
    x = df[x_column]
1✔
2085
    y = df[y_column]
1✔
2086

2087
    r_axis_steps = (detector_ranges[0][1] - detector_ranges[0][0]) / dfield[0].shape[0]
1✔
2088
    c_axis_steps = (detector_ranges[1][1] - detector_ranges[1][0]) / dfield[0].shape[1]
1✔
2089

2090
    df[new_x_column], df[new_y_column] = (
1✔
2091
        map_coordinates(dfield[0], (x, y), order=1) * r_axis_steps,
2092
        map_coordinates(dfield[1], (x, y), order=1) * c_axis_steps,
2093
    )
2094
    return df
1✔
2095

2096

2097
def generate_inverse_dfield(
1✔
2098
    rdeform_field: np.ndarray,
2099
    cdeform_field: np.ndarray,
2100
    bin_ranges: list[tuple],
2101
    detector_ranges: list[tuple],
2102
) -> np.ndarray:
2103
    """Generate inverse deformation field using interpolation with griddata.
2104
    Assuming the binning range of the input ``rdeform_field`` and ``cdeform_field``
2105
    covers the whole detector.
2106

2107
    Args:
2108
        rdeform_field (np.ndarray): Row-wise deformation field.
2109
        cdeform_field (np.ndarray): Column-wise deformation field.
2110
        bin_ranges (list[tuple]): Detector ranges of the binned coordinates.
2111
        detector_ranges (list[tuple]): Ranges of detector coordinates to interpolate to.
2112

2113
    Returns:
2114
        np.ndarray: The calculated inverse deformation field (row/column)
2115
    """
2116
    print(
1✔
2117
        "Calculating inverse deformation field, this might take a moment...",
2118
    )
2119

2120
    # Interpolate to 2048x2048 grid of the detector coordinates
2121
    r_mesh, c_mesh = np.meshgrid(
1✔
2122
        np.linspace(
2123
            detector_ranges[0][0],
2124
            cdeform_field.shape[0],
2125
            detector_ranges[0][1],
2126
            endpoint=False,
2127
        ),
2128
        np.linspace(
2129
            detector_ranges[1][0],
2130
            cdeform_field.shape[1],
2131
            detector_ranges[1][1],
2132
            endpoint=False,
2133
        ),
2134
        sparse=False,
2135
        indexing="ij",
2136
    )
2137

2138
    bin_step = (
1✔
2139
        np.asarray(bin_ranges)[0:2][:, 1] - np.asarray(bin_ranges)[0:2][:, 0]
2140
    ) / cdeform_field.shape
2141
    rc_position = []  # row/column position in c/rdeform_field
1✔
2142
    r_dest = []  # destination pixel row position
1✔
2143
    c_dest = []  # destination pixel column position
1✔
2144
    for i in np.arange(cdeform_field.shape[0]):
1✔
2145
        for j in np.arange(cdeform_field.shape[1]):
1✔
2146
            if not np.isnan(rdeform_field[i, j]) and not np.isnan(
1✔
2147
                cdeform_field[i, j],
2148
            ):
2149
                rc_position.append(
1✔
2150
                    [
2151
                        rdeform_field[i, j] + bin_ranges[0][0] / bin_step[0],
2152
                        cdeform_field[i, j] + bin_ranges[0][0] / bin_step[1],
2153
                    ],
2154
                )
2155
                r_dest.append(
1✔
2156
                    bin_step[0] * i + bin_ranges[0][0],
2157
                )
2158
                c_dest.append(
1✔
2159
                    bin_step[1] * j + bin_ranges[1][0],
2160
                )
2161

2162
    ret = Parallel(n_jobs=2)(
1✔
2163
        delayed(griddata)(np.asarray(rc_position), np.asarray(arg), (r_mesh, c_mesh))
2164
        for arg in [r_dest, c_dest]
2165
    )
2166

2167
    inverse_dfield = np.asarray([ret[0], ret[1]])
1✔
2168

2169
    return inverse_dfield
1✔
2170

2171

2172
def load_dfield(file: str) -> tuple[np.ndarray, np.ndarray]:
1✔
2173
    """Load inverse dfield from file
2174

2175
    Args:
2176
        file (str): Path to file containing the inverse dfield
2177

2178
    Returns:
2179
        tuple[np.ndarray, np.ndarray]: the loaded inverse row and column deformation fields
2180
    """
2181
    rdeform_field: np.ndarray = None
×
2182
    cdeform_field: np.ndarray = None
×
2183

2184
    try:
×
2185
        dfield = np.load(file)
×
2186
        rdeform_field = dfield[0]
×
2187
        cdeform_field = dfield[1]
×
2188

2189
    except FileNotFoundError:
×
2190
        pass
×
2191

2192
    return rdeform_field, cdeform_field
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc