• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

OpenCOMPES / sed / 9636700237

23 Jun 2024 10:01PM UTC coverage: 91.857% (-0.1%) from 91.962%
9636700237

Pull #437

github

web-flow
Merge pull request #446 from OpenCOMPES/update_benchmark_targets

Update benchmark targets
Pull Request #437: Upgrade to V1

182 of 183 new or added lines in 41 files covered. (99.45%)

7 existing lines in 2 files now uncovered.

6430 of 7000 relevant lines covered (91.86%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.78
/sed/calibrator/momentum.py
1
"""sed.calibrator.momentum module. Code for momentum calibration and distortion
2
correction. Mostly ported from https://github.com/mpes-kit/mpes.
3
"""
4
from __future__ import annotations
1✔
5

6
import itertools as it
1✔
7
from copy import deepcopy
1✔
8
from datetime import datetime
1✔
9
from typing import Any
1✔
10

11
import bokeh.palettes as bp
1✔
12
import bokeh.plotting as pbk
1✔
13
import dask.dataframe
1✔
14
import ipywidgets as ipw
1✔
15
import matplotlib
1✔
16
import matplotlib.pyplot as plt
1✔
17
import numpy as np
1✔
18
import pandas as pd
1✔
19
import scipy.ndimage as ndi
1✔
20
import xarray as xr
1✔
21
from bokeh.colors import RGB
1✔
22
from bokeh.io import output_notebook
1✔
23
from bokeh.palettes import Category10 as ColorCycle
1✔
24
from IPython.display import display
1✔
25
from joblib import delayed
1✔
26
from joblib import Parallel
1✔
27
from matplotlib import cm
1✔
28
from numpy.linalg import norm
1✔
29
from scipy.interpolate import griddata
1✔
30
from scipy.ndimage import map_coordinates
1✔
31
from symmetrize import pointops as po
1✔
32
from symmetrize import sym
1✔
33
from symmetrize import tps
1✔
34

35

36
class MomentumCorrector:
1✔
37
    """
38
    Momentum distortion correction and momentum calibration workflow functions.
39

40
    Args:
41
        data (xr.DataArray | np.ndarray, optional): Multidimensional hypervolume
42
            containing the data. Defaults to None.
43
        bin_ranges (list[tuple], optional): Binning ranges of the data volume, if
44
            provided as np.ndarray. Defaults to None.
45
        rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
46
        config (dict, optional): Config dictionary. Defaults to None.
47
    """
48

49
    def __init__(
1✔
50
        self,
51
        data: xr.DataArray | np.ndarray = None,
52
        bin_ranges: list[tuple] = None,
53
        rotsym: int = 6,
54
        config: dict = None,
55
    ):
56
        """Constructor of the MomentumCorrector class.
57

58
        Args:
59
            data (xr.DataArray | np.ndarray, optional): Multidimensional
60
                hypervolume containing the data. Defaults to None.
61
            bin_ranges (list[tuple], optional): Binning ranges of the data volume,
62
                if provided as np.ndarray. Defaults to None.
63
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
64
            config (dict, optional): Config dictionary. Defaults to None.
65
        """
66
        if config is None:
1✔
67
            config = {}
×
68

69
        self._config = config
1✔
70

71
        self.image: np.ndarray = None
1✔
72
        self.img_ndim: int = None
1✔
73
        self.slice: np.ndarray = None
1✔
74
        self.slice_corrected: np.ndarray = None
1✔
75
        self.slice_transformed: np.ndarray = None
1✔
76
        self.bin_ranges: list[tuple] = self._config["momentum"].get("bin_ranges", [])
1✔
77

78
        if data is not None:
1✔
79
            self.load_data(data=data, bin_ranges=bin_ranges)
×
80

81
        self.detector_ranges = self._config["momentum"]["detector_ranges"]
1✔
82

83
        self.rotsym = int(rotsym)
1✔
84
        self.rotsym_angle = int(360 / self.rotsym)
1✔
85
        self.arot = np.array([0] + [self.rotsym_angle] * (self.rotsym - 1))
1✔
86
        self.ascale = np.array([1.0] * self.rotsym)
1✔
87
        self.peaks: np.ndarray = None
1✔
88
        self.include_center: bool = False
1✔
89
        self.use_center: bool = False
1✔
90
        self.pouter: np.ndarray = None
1✔
91
        self.pcent: tuple[float, ...] = None
1✔
92
        self.pouter_ord: np.ndarray = None
1✔
93
        self.prefs: np.ndarray = None
1✔
94
        self.ptargs: np.ndarray = None
1✔
95
        self.csm_original: float = np.nan
1✔
96
        self.mdist: float = np.nan
1✔
97
        self.mcvdist: float = np.nan
1✔
98
        self.mvvdist: float = np.nan
1✔
99
        self.cvdist: np.ndarray = np.array(np.nan)
1✔
100
        self.vvdist: np.ndarray = np.array(np.nan)
1✔
101
        self.rdeform_field: np.ndarray = None
1✔
102
        self.cdeform_field: np.ndarray = None
1✔
103
        self.rdeform_field_bkp: np.ndarray = None
1✔
104
        self.cdeform_field_bkp: np.ndarray = None
1✔
105
        self.inverse_dfield: np.ndarray = None
1✔
106
        self.dfield_updated: bool = False
1✔
107
        self.transformations: dict[str, Any] = self._config["momentum"].get("transformations", {})
1✔
108
        self.correction: dict[str, Any] = self._config["momentum"].get("correction", {})
1✔
109
        self.adjust_params: dict[str, Any] = {}
1✔
110
        self.calibration: dict[str, Any] = self._config["momentum"].get("calibration", {})
1✔
111

112
        self.x_column = self._config["dataframe"]["x_column"]
1✔
113
        self.y_column = self._config["dataframe"]["y_column"]
1✔
114
        self.corrected_x_column = self._config["dataframe"]["corrected_x_column"]
1✔
115
        self.corrected_y_column = self._config["dataframe"]["corrected_y_column"]
1✔
116
        self.kx_column = self._config["dataframe"]["kx_column"]
1✔
117
        self.ky_column = self._config["dataframe"]["ky_column"]
1✔
118

119
        self._state: int = 0
1✔
120

121
    @property
1✔
122
    def features(self) -> dict:
1✔
123
        """Dictionary of detected features for the symmetrization process.
124
        ``self.features`` is a derived attribute from existing ones.
125

126
        Returns:
127
            dict: Dict containing features "verts" and "center".
128
        """
129
        feature_dict = {
1✔
130
            "verts": np.asarray(self.__dict__.get("pouter_ord", [])),
131
            "center": np.asarray(self.__dict__.get("pcent", [])),
132
        }
133

134
        return feature_dict
1✔
135

136
    @property
1✔
137
    def symscores(self) -> dict:
1✔
138
        """Dictionary of symmetry-related scores.
139

140
        Returns:
141
            dict: Dictionary containing symmetry scores.
142
        """
143
        sym_dict = {
×
144
            "csm_original": self.__dict__.get("csm_original", ""),
145
            "csm_current": self.__dict__.get("csm_current", ""),
146
            "arm_original": self.__dict__.get("arm_original", ""),
147
            "arm_current": self.__dict__.get("arm_current", ""),
148
        }
149

150
        return sym_dict
×
151

152
    def load_data(
1✔
153
        self,
154
        data: xr.DataArray | np.ndarray,
155
        bin_ranges: list[tuple] = None,
156
    ):
157
        """Load binned data into the momentum calibrator class
158

159
        Args:
160
            data (xr.DataArray | np.ndarray):
161
                2D or 3D data array, either as np.ndarray or xr.DataArray.
162
            bin_ranges (list[tuple], optional):
163
                Binning ranges. Needs to be provided in case the data are given
164
                as np.ndarray. Otherwise, they are determined from the coords of
165
                the xr.DataArray. Defaults to None.
166

167
        Raises:
168
            ValueError: Raised if the dimensions of the input data do not fit.
169
        """
170
        if isinstance(data, xr.DataArray):
1✔
171
            self.image = np.squeeze(data.data)
1✔
172
            self.bin_ranges = []
1✔
173
            for axis in data.coords:
1✔
174
                self.bin_ranges.append(
1✔
175
                    (
176
                        data.coords[axis][0].values,
177
                        2 * data.coords[axis][-1].values - data.coords[axis][-2].values,  # endpoint
178
                    ),
179
                )
180
        else:
181
            assert bin_ranges is not None
1✔
182
            self.image = np.squeeze(data)
1✔
183
            self.bin_ranges = bin_ranges
1✔
184

185
        self.img_ndim = self.image.ndim
1✔
186
        if (self.img_ndim > 3) or (self.img_ndim < 2):
1✔
187
            raise ValueError("The input image dimension need to be 2 or 3!")
×
188
        if self.img_ndim == 2:
1✔
189
            self.slice = self.image
1✔
190

191
        if self.slice is not None:
1✔
192
            self.slice_corrected = self.slice_transformed = self.slice
1✔
193

194
    def select_slicer(
1✔
195
        self,
196
        plane: int = 0,
197
        width: int = 5,
198
        axis: int = 2,
199
        apply: bool = False,
200
    ):
201
        """Interactive panel to select (hyper)slice from a (hyper)volume.
202

203
        Args:
204
            plane (int, optional): initial value of the plane slider. Defaults to 0.
205
            width (int, optional): initial value of the width slider. Defaults to 5.
206
            axis (int, optional): Axis along which to slice the image. Defaults to 2.
207
            apply (bool, optional):  Option to directly apply the values and select the
208
                slice. Defaults to False.
209
        """
210
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
211

212
        assert self.img_ndim == 3
1✔
213
        selector = slice(plane, plane + width)
1✔
214
        image = np.moveaxis(self.image, axis, 0)
1✔
215
        try:
1✔
216
            img_slice = image[selector, ...].sum(axis=0)
1✔
217
        except AttributeError:
×
218
            img_slice = image[selector, ...]
×
219

220
        fig, ax = plt.subplots(1, 1)
1✔
221
        img = ax.imshow(img_slice.T, origin="lower", cmap="terrain_r")
1✔
222

223
        def update(plane: int, width: int):
1✔
224
            selector = slice(plane, plane + width)
1✔
225
            try:
1✔
226
                img_slice = image[selector, ...].sum(axis=0)
1✔
227
            except AttributeError:
×
228
                img_slice = image[selector, ...]
×
229
            img.set_data(img_slice.T)
1✔
230
            axmin = np.min(img_slice, axis=(0, 1))
1✔
231
            axmax = np.max(img_slice, axis=(0, 1))
1✔
232
            if axmin < axmax:
1✔
233
                img.set_clim(axmin, axmax)
1✔
234
            ax.set_title(f"Plane[{plane}:{plane+width}]")
1✔
235
            fig.canvas.draw_idle()
1✔
236

237
        update(plane, width)
1✔
238

239
        plane_slider = ipw.IntSlider(
1✔
240
            value=plane,
241
            min=0,
242
            max=self.image.shape[2] - width,
243
            step=1,
244
        )
245
        width_slider = ipw.IntSlider(value=width, min=1, max=20, step=1)
1✔
246

247
        ipw.interact(
1✔
248
            update,
249
            plane=plane_slider,
250
            width=width_slider,
251
        )
252

253
        def apply_fun(apply: bool):  # noqa: ARG001
1✔
254
            start = plane_slider.value
1✔
255
            stop = plane_slider.value + width_slider.value
1✔
256

257
            selector = slice(
1✔
258
                start,
259
                stop,
260
            )
261
            self.select_slice(selector=selector, axis=axis)
1✔
262

263
            img.set_data(self.slice.T)
1✔
264
            axmin = np.min(self.slice, axis=(0, 1))
1✔
265
            axmax = np.max(self.slice, axis=(0, 1))
1✔
266
            if axmin < axmax:
1✔
267
                img.set_clim(axmin, axmax)
1✔
268
            ax.set_title(f"Plane[{start}:{stop}]")
1✔
269
            fig.canvas.draw_idle()
1✔
270

271
            plane_slider.close()
1✔
272
            width_slider.close()
1✔
273
            apply_button.close()
1✔
274

275
        apply_button = ipw.Button(description="apply")
1✔
276
        display(apply_button)
1✔
277
        apply_button.on_click(apply_fun)
1✔
278

279
        plt.show()
1✔
280

281
        if apply:
1✔
282
            apply_fun(True)
1✔
283

284
    def select_slice(
1✔
285
        self,
286
        selector: slice | list[int] | int,
287
        axis: int = 2,
288
    ):
289
        """Select (hyper)slice from a (hyper)volume.
290

291
        Args:
292
            selector (slice | list[int] | int):
293
                Selector along the specified axis to extract the slice (image). Use
294
                the construct slice(start, stop, step) to select a range of images
295
                and sum them. Use an integer to specify only a particular slice.
296
            axis (int, optional): Axis along which to select the image. Defaults to 2.
297

298
        Raises:
299
            ValueError: Raised if self.image is already 2D.
300
        """
301
        if self.img_ndim > 2:
1✔
302
            image = np.moveaxis(self.image, axis, 0)
1✔
303
            try:
1✔
304
                self.slice = image[selector, ...].sum(axis=0)
1✔
305
            except AttributeError:
×
306
                self.slice = image[selector, ...]
×
307

308
            if self.slice is not None:
1✔
309
                self.slice_corrected = self.slice_transformed = self.slice
1✔
310

311
        elif self.img_ndim == 2:
×
312
            raise ValueError("Input image dimension is already 2!")
×
313

314
    def add_features(
1✔
315
        self,
316
        features: np.ndarray,
317
        direction: str = "ccw",
318
        rotsym: int = 6,
319
        symscores: bool = True,
320
        **kwds,
321
    ):
322
        """Add features as reference points provided as np.ndarray. If provided,
323
        detects the center of the points and orders the points.
324

325
        Args:
326
            features (np.ndarray):
327
                Array of landmarks, possibly including a center peak. Its shape should
328
                be (n,2), where n is equal to the rotation symmetry, or the rotation
329
                symmetry+1, if the center is included.
330
            direction (str, optional):
331
                Direction for ordering the points. Defaults to "ccw".
332
            symscores (bool, optional):
333
                Option to calculate symmetry scores. Defaults to False.
334
            **kwds: Keyword arguments.
335

336
                - **symtype** (str): Type of symmetry scores to calculte
337
                  if symscores is True. Defaults to "rotation".
338

339
        Raises:
340
            ValueError: Raised if the number of points does not match the rotsym.
341
        """
342
        self.rotsym = int(rotsym)
1✔
343
        self.rotsym_angle = int(360 / self.rotsym)
1✔
344
        self.arot = np.array([0] + [self.rotsym_angle] * (self.rotsym - 1))
1✔
345
        self.ascale = np.array([1.0] * self.rotsym)
1✔
346

347
        if features.shape[0] == self.rotsym:  # assume no center present
1✔
348
            self.pcent, self.pouter = po.pointset_center(
1✔
349
                features,
350
                method="centroid",
351
            )
352
            self.include_center = False
1✔
353
        elif features.shape[0] == self.rotsym + 1:  # assume center included
1✔
354
            self.pcent, self.pouter = po.pointset_center(
1✔
355
                features,
356
                method="centroidnn",
357
            )
358
            self.include_center = True
1✔
359
        else:
360
            raise ValueError(
1✔
361
                f"Found {features.shape[0]} points, ",
362
                f"but {self.rotsym} or {self.rotsym+1} (incl.center) required.",
363
            )
364
        if isinstance(self.pcent, np.ndarray):
1✔
365
            self.pcent = tuple(val.item() for val in self.pcent)
1✔
366
        # Order the point landmarks
367
        self.pouter_ord = po.pointset_order(
1✔
368
            self.pouter,
369
            direction=direction,
370
        )
371

372
        # Calculate geometric distances
373
        if self.pcent is not None:
1✔
374
            self.calc_geometric_distances()
1✔
375

376
        if symscores is True:
1✔
377
            symtype = kwds.pop("symtype", "rotation")
1✔
378
            self.csm_original = self.calc_symmetry_scores(symtype=symtype)
1✔
379

380
        if self.rotsym == 6 and self.pcent is not None:
1✔
381
            self.mdist = (self.mcvdist + self.mvvdist) / 2
1✔
382
            self.mcvdist = self.mdist
1✔
383
            self.mvvdist = self.mdist
1✔
384

385
    def feature_extract(
1✔
386
        self,
387
        image: np.ndarray = None,
388
        direction: str = "ccw",
389
        feature_type: str = "points",
390
        rotsym: int = 6,
391
        symscores: bool = True,
392
        **kwds,
393
    ):
394
        """Extract features from the selected 2D slice.
395
        Currently only point feature detection is implemented.
396

397
        Args:
398
            image (np.ndarray, optional):
399
                The (2D) image slice to extract features from.
400
                Defaults to self.slice
401
            direction (str, optional):
402
                The circular direction to reorder the features in ('cw' or 'ccw').
403
                Defaults to "ccw".
404
            feature_type (str, optional):
405
                The type of features to extract. Defaults to "points".
406
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
407
            symscores (bool, optional):
408
                Option for calculating symmetry scores. Defaults to True.
409
            **kwds:
410
                Extra keyword arguments for ``symmetrize.pointops.peakdetect2d()``.
411

412
        Raises:
413
            NotImplementedError:
414
                Raised for undefined feature_types.
415
        """
416
        if image is None:
1✔
417
            if self.slice is not None:
1✔
418
                image = self.slice
1✔
419
            else:
420
                raise ValueError("No image loaded for feature extraction!")
×
421

422
        if feature_type == "points":
1✔
423
            # Detect the point landmarks
424
            self.peaks = po.peakdetect2d(image, **kwds)
1✔
425

426
            self.add_features(
1✔
427
                features=self.peaks,
428
                direction=direction,
429
                rotsym=rotsym,
430
                symscores=symscores,
431
                **kwds,
432
            )
433
        else:
434
            raise NotImplementedError
×
435

436
    def feature_select(
1✔
437
        self,
438
        image: np.ndarray = None,
439
        features: np.ndarray = None,
440
        include_center: bool = True,
441
        rotsym: int = 6,
442
        apply: bool = False,
443
        **kwds,
444
    ):
445
        """Extract features from the selected 2D slice.
446
        Currently only point feature detection is implemented.
447

448
        Args:
449
            image (np.ndarray, optional):
450
                The (2D) image slice to extract features from.
451
                Defaults to self.slice
452
            include_center (bool, optional):
453
                Option to include the image center/centroid in the registration
454
                process. Defaults to True.
455
            features (np.ndarray, optional):
456
                Array of landmarks, possibly including a center peak. Its shape should
457
                be (n,2), where n is equal to the rotation symmetry, or the rotation
458
                symmetry+1, if the center is included.
459
                If omitted, an array filled with zeros is generated.
460
            rotsym (int, optional): Rotational symmetry of the data. Defaults to 6.
461
            apply (bool, optional): Option to directly store the features in the class.
462
                Defaults to False.
463
            **kwds:
464
                Extra keyword arguments for ``symmetrize.pointops.peakdetect2d()``.
465

466
        Raises:
467
            ValueError: If no valid image is found from which to ge the coordinates.
468
        """
469
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
470
        if image is None:
1✔
471
            if self.slice is not None:
1✔
472
                image = self.slice
1✔
473
            else:
474
                raise ValueError("No valid image loaded!")
×
475

476
        fig, ax = plt.subplots(1, 1)
1✔
477
        ax.imshow(image.T, origin="lower", cmap="terrain_r")
1✔
478

479
        if features is None:
1✔
480
            features = np.zeros((rotsym + (include_center), 2))
×
481

482
        markers = []
1✔
483
        for peak in features:
1✔
484
            markers.append(ax.plot(peak[0], peak[1], "o")[0])
1✔
485

486
        def update_point_no(
1✔
487
            point_no: int,
488
        ):
489
            fig.canvas.draw_idle()
1✔
490

491
            point_x = features[point_no][0]
1✔
492
            point_y = features[point_no][1]
1✔
493

494
            point_input_x.value = point_x
1✔
495
            point_input_y.value = point_y
1✔
496

497
        def update_point_pos(
1✔
498
            point_x: float,
499
            point_y: float,
500
        ):
501
            fig.canvas.draw_idle()
1✔
502
            point_no = point_no_input.value
1✔
503
            features[point_no][0] = point_x
1✔
504
            features[point_no][1] = point_y
1✔
505

506
            markers[point_no].set_xdata(point_x)
1✔
UNCOV
507
            markers[point_no].set_ydata(point_y)
×
508

509
        point_no_input = ipw.Dropdown(
1✔
510
            options=range(features.shape[0]),
511
            description="Point:",
512
        )
513

514
        point_input_x = ipw.FloatText(features[0][0])
1✔
515
        point_input_y = ipw.FloatText(features[0][1])
1✔
516
        ipw.interact(
1✔
517
            update_point_no,
518
            point_no=point_no_input,
519
        )
520
        ipw.interact(
1✔
521
            update_point_pos,
522
            point_y=point_input_y,
523
            point_x=point_input_x,
524
        )
525

526
        def onclick(event):
1✔
527
            point_input_x.value = event.xdata
×
528
            point_input_y.value = event.ydata
×
529
            point_no_input.value = (point_no_input.value + 1) % features.shape[0]
×
530

531
        cid = fig.canvas.mpl_connect("button_press_event", onclick)
1✔
532

533
        def apply_func(apply: bool):  # noqa: ARG001
1✔
534
            fig.canvas.mpl_disconnect(cid)
1✔
535

536
            point_no_input.close()
1✔
537
            point_input_x.close()
1✔
538
            point_input_y.close()
1✔
539
            apply_button.close()
1✔
540

541
            fig.canvas.draw_idle()
1✔
542

543
            self.add_features(
1✔
544
                features=features,
545
                rotsym=rotsym,
546
                **kwds,
547
            )
548

549
        apply_button = ipw.Button(description="apply")
1✔
550
        display(apply_button)
1✔
551
        apply_button.on_click(apply_func)
1✔
552

553
        if apply:
1✔
554
            apply_func(True)
1✔
555

556
        plt.show()
1✔
557

558
    def calc_geometric_distances(self) -> None:
1✔
559
        """Calculate geometric distances involving the center and the vertices.
560
        Distances calculated include center-vertex and nearest-neighbor vertex-vertex
561
        distances.
562
        """
563
        self.cvdist = po.cvdist(self.pouter_ord, self.pcent)
1✔
564
        self.mcvdist = self.cvdist.mean()
1✔
565
        self.vvdist = po.vvdist(self.pouter_ord)
1✔
566
        self.mvvdist = self.vvdist.mean()
1✔
567

568
    def calc_symmetry_scores(self, symtype: str = "rotation") -> float:
1✔
569
        """Calculate the symmetry scores from geometric quantities.
570

571
        Args:
572
            symtype (str, optional): Type of symmetry score to calculate.
573
                Defaults to "rotation".
574

575
        Returns:
576
            float: Calculated symmetry score.
577
        """
578
        csm = po.csm(
1✔
579
            self.pcent,
580
            self.pouter_ord,
581
            rotsym=self.rotsym,
582
            type=symtype,
583
        )
584

585
        return csm
1✔
586

587
    def spline_warp_estimate(
1✔
588
        self,
589
        image: np.ndarray = None,
590
        use_center: bool = None,
591
        fixed_center: bool = True,
592
        interp_order: int = 1,
593
        ascale: float | list | tuple | np.ndarray = None,
594
        verbose: bool = True,
595
        **kwds,
596
    ) -> np.ndarray:
597
        """Estimate the spline deformation field using thin plate spline registration.
598

599
        Args:
600
            image (np.ndarray, optional):
601
                2D array. Image slice to be corrected. Defaults to self.slice.
602
            use_center (bool, optional):
603
                Option to use the image center/centroid in the registration
604
                process. Defaults to config value, or True.
605
            fixed_center (bool, optional):
606
                Option to have a fixed center during registration-based
607
                symmetrization. Defaults to True.
608
            interp_order (int, optional):
609
                Order of interpolation (see ``scipy.ndimage.map_coordinates()``).
610
                Defaults to 1.
611
            ascale: (float | list | tuple | np.ndarray, optional): Scale parameter determining a
612
                relative scale for each symmetry feature. If provided as single float, rotsym has
613
                to be 4. This parameter describes the relative scaling between the two orthogonal
614
                symmetry directions (for an orthorhombic system). This requires the correction
615
                points to be located along the principal axes (X/Y points of the Brillouin zone).
616
                Otherwise, an array with ``rotsym`` elements is expected, containing relative
617
                scales for each feature. Defaults to an array of equal scales.
618
            verbose (bool, optional): Option to report the used landmarks for correction.
619
                Defaults to True.
620
            **kwds: keyword arguments:
621

622
                - **landmarks**: (list/array): Landmark positions (row, column) used
623
                  for registration. Defaults to  self.pouter_ord
624
                - **targets**: (list/array): Target positions (row, column) used for
625
                  registration. If empty, it will be generated by
626
                  ``symmetrize.rotVertexGenerator()``.
627
                - **new_centers**: (dict): User-specified center positions for the
628
                  reference and target sets. {'lmkcenter': (row, col),
629
                  'targcenter': (row, col)}
630
        Returns:
631
            np.ndarray: The corrected image.
632
        """
633
        if image is None:
1✔
634
            if self.slice is not None:
1✔
635
                image = self.slice
1✔
636
            else:
637
                image = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
638
                self.bin_ranges = self._config["momentum"]["ranges"]
1✔
639

640
        if self.pouter_ord is None:
1✔
641
            if self.pouter is not None:
1✔
642
                self.pouter_ord = po.pointset_order(self.pouter)
×
643
                self.correction["creation_date"] = datetime.now().timestamp()
×
644
                self.correction["creation_date"] = datetime.now().timestamp()
×
645
            else:
646
                try:
1✔
647
                    features = np.asarray(
1✔
648
                        self.correction["feature_points"],
649
                    )
650
                    rotsym = self.correction["rotation_symmetry"]
1✔
651
                    include_center = self.correction["include_center"]
1✔
652
                    if not include_center and len(features) > rotsym:
1✔
653
                        features = features[:rotsym, :]
×
654
                    ascale = self.correction.get("ascale", None)
1✔
655
                    if ascale is not None:
1✔
656
                        ascale = np.asarray(ascale)
1✔
657

658
                    if verbose:
1✔
659
                        if "creation_date" in self.correction:
1✔
660
                            datestring = datetime.fromtimestamp(
1✔
661
                                self.correction["creation_date"],
662
                            ).strftime(
663
                                "%m/%d/%Y, %H:%M:%S",
664
                            )
665
                            print(
1✔
666
                                "No landmarks defined, using momentum correction parameters "
667
                                f"generated on {datestring}",
668
                            )
669
                        else:
670
                            print(
×
671
                                "No landmarks defined, using momentum correction parameters "
672
                                "from config.",
673
                            )
674
                except KeyError as exc:
×
675
                    raise ValueError(
×
676
                        "No valid landmarks defined, and no landmarks found in configuration!",
677
                    ) from exc
678

679
                self.add_features(features=features, rotsym=rotsym, include_center=include_center)
1✔
680

681
        else:
682
            self.correction["creation_date"] = datetime.now().timestamp()
1✔
683

684
        if ascale is not None:
1✔
685
            if isinstance(ascale, (int, float, np.floating, np.integer)):
1✔
686
                if self.rotsym != 4:
1✔
687
                    raise ValueError(
1✔
688
                        "Providing ascale as scalar number is only valid for 'rotsym'==4.",
689
                    )
690
                self.ascale = np.array([1.0, ascale, 1.0, ascale])
1✔
691
            elif isinstance(ascale, (tuple, list, np.ndarray)):
1✔
692
                if len(ascale) != len(self.ascale):
1✔
693
                    raise ValueError(
1✔
694
                        f"ascale needs to be of length 'rotsym', but has length {len(ascale)}.",
695
                    )
696
                self.ascale = np.asarray(ascale)
1✔
697
            else:
698
                raise TypeError(
1✔
699
                    (
700
                        "ascale needs to be a single number or a list/tuple/np.ndarray of length ",
701
                        f"'rotsym' ({self.rotsym})!",
702
                    ),
703
                )
704

705
        if use_center is None:
1✔
706
            try:
1✔
707
                use_center = self.correction["use_center"]
1✔
708
            except KeyError:
1✔
709
                use_center = True
1✔
710
        self.use_center = use_center
1✔
711

712
        self.prefs = kwds.pop("landmarks", self.pouter_ord)
1✔
713
        self.ptargs = kwds.pop("targets", [])
1✔
714

715
        # Generate the target point set
716
        if not self.ptargs:
1✔
717
            self.ptargs = sym.rotVertexGenerator(
1✔
718
                self.pcent,
719
                fixedvertex=self.pouter_ord[0, :],
720
                arot=self.arot,
721
                direction=-1,
722
                scale=self.ascale,
723
                ret="all",
724
            )[1:, :]
725

726
        if use_center is True:
1✔
727
            # Use center of image pattern in the registration-based symmetrization
728
            if fixed_center is True:
1✔
729
                # Add the same center to both the reference and target sets
730

731
                self.prefs = np.column_stack((self.prefs.T, self.pcent)).T
1✔
732
                self.ptargs = np.column_stack((self.ptargs.T, self.pcent)).T
1✔
733

734
            else:  # Add different centers to the reference and target sets
735
                newcenters = kwds.pop("new_centers", {})
×
736
                self.prefs = np.column_stack(
×
737
                    (self.prefs.T, newcenters["lmkcenter"]),
738
                ).T
739
                self.ptargs = np.column_stack(
×
740
                    (self.ptargs.T, newcenters["targcenter"]),
741
                ).T
742

743
        # Non-iterative estimation of deformation field
744
        corrected_image, splinewarp = tps.tpsWarping(
1✔
745
            self.prefs,
746
            self.ptargs,
747
            image,
748
            None,
749
            interp_order,
750
            ret="all",
751
            **kwds,
752
        )
753

754
        self.reset_deformation(image=image, coordtype="cartesian")
1✔
755

756
        self.update_deformation(
1✔
757
            splinewarp[0],
758
            splinewarp[1],
759
        )
760

761
        # save backup copies to reset transformations
762
        self.rdeform_field_bkp = self.rdeform_field
1✔
763
        self.cdeform_field_bkp = self.cdeform_field
1✔
764

765
        self.correction["outer_points"] = self.pouter_ord
1✔
766
        self.correction["center_point"] = np.asarray(self.pcent)
1✔
767
        self.correction["reference_points"] = self.prefs
1✔
768
        self.correction["target_points"] = self.ptargs
1✔
769
        self.correction["rotation_symmetry"] = self.rotsym
1✔
770
        self.correction["use_center"] = self.use_center
1✔
771
        self.correction["include_center"] = self.include_center
1✔
772
        if self.include_center:
1✔
773
            self.correction["feature_points"] = np.concatenate(
1✔
774
                (self.pouter_ord, np.asarray([self.pcent])),
775
            )
776
        else:
777
            self.correction["feature_points"] = self.pouter_ord
1✔
778
        self.correction["ascale"] = self.ascale
1✔
779

780
        if self.slice is not None:
1✔
781
            self.slice_corrected = corrected_image
1✔
782

783
        if verbose:
1✔
784
            print("Calulated thin spline correction based on the following landmarks:")
1✔
785
            print(f"pouter: {self.pouter}")
1✔
786
            if use_center:
1✔
787
                print(f"pcent: {self.pcent}")
1✔
788

789
        return corrected_image
1✔
790

791
    def apply_correction(
1✔
792
        self,
793
        image: np.ndarray,
794
        axis: int,
795
        dfield: np.ndarray = None,
796
    ) -> np.ndarray:
797
        """Apply a 2D transform to a stack of 2D images (3D) along a specific axis.
798

799
        Args:
800
            image (np.ndarray): Image which to apply the transformation to
801
            axis (int): Axis for slice selection.
802
            dfield (np.ndarray, optional): row and column deformation field.
803
                Defaults to [self.rdeform_field, self.cdeformfield].
804

805
        Returns:
806
            np.ndarray: The corrected image.
807
        """
808
        if dfield is None:
×
809
            dfield = np.asarray([self.rdeform_field, self.cdeform_field])
×
810

811
        image_corrected = sym.applyWarping(
×
812
            image,
813
            axis,
814
            warptype="deform_field",
815
            dfield=dfield,
816
        )
817

818
        return image_corrected
×
819

820
    def reset_deformation(self, **kwds):
1✔
821
        """Reset the deformation field.
822

823
        Args:
824
            **kwds: keyword arguments:
825

826
                - **image**: the image to base the deformation fields on. Its sizes are
827
                  used. Defaults to self.slice
828
                - **coordtype**: The coordinate system to use. Defaults to 'cartesian'.
829
        """
830
        image = kwds.pop("image", self.slice)
1✔
831
        coordtype = kwds.pop("coordtype", "cartesian")
1✔
832
        coordmat = sym.coordinate_matrix_2D(
1✔
833
            image,
834
            coordtype=coordtype,
835
            stackaxis=0,
836
        ).astype("float64")
837

838
        self.rdeform_field = coordmat[1, ...]
1✔
839
        self.cdeform_field = coordmat[0, ...]
1✔
840

841
        self.dfield_updated = True
1✔
842

843
    def update_deformation(self, rdeform: np.ndarray, cdeform: np.ndarray):
1✔
844
        """Update the class deformation field by applying the provided column/row
845
        deformation fields.
846

847
        Parameters:
848
            rdeform (np.ndarray): 2D array of row-ordered deformation field.
849
            cdeform (np.ndarray): 2D array of column-ordered deformation field.
850
        """
851
        self.rdeform_field = ndi.map_coordinates(
1✔
852
            self.rdeform_field,
853
            [rdeform, cdeform],
854
            order=1,
855
            cval=np.nan,
856
        )
857
        self.cdeform_field = ndi.map_coordinates(
1✔
858
            self.cdeform_field,
859
            [rdeform, cdeform],
860
            order=1,
861
            cval=np.nan,
862
        )
863

864
        self.dfield_updated = True
1✔
865

866
    def coordinate_transform(
1✔
867
        self,
868
        transform_type: str,
869
        keep: bool = False,
870
        interp_order: int = 1,
871
        mapkwds: dict = None,
872
        **kwds,
873
    ) -> np.ndarray:
874
        """Apply a pixel-wise coordinate transform to the image
875
        by means of the deformation field.
876

877
        Args:
878
            transform_type (str): Type of deformation to apply to image slice. Possible
879
                values are:
880

881
                - translation.
882
                - rotation.
883
                - rotation_auto.
884
                - scaling.
885
                - scaling_auto.
886
                - homomorphy.
887

888
            keep (bool, optional): Option to keep the specified coordinate transform in
889
                the class. Defaults to False.
890
            interp_order (int, optional): Interpolation order for filling in missed
891
                pixels. Defaults to 1.
892
            mapkwds (dict, optional): Additional arguments passed to
893
                ``scipy.ndimage.map_coordinates()``. Defaults to None.
894
            **kwds: keyword arguments.
895
                Additional arguments in specific deformation field.
896
                See ``symmetrize.sym`` module.
897
        Returns:
898
            np.ndarray: The corrected image.
899
        """
900
        if mapkwds is None:
1✔
901
            mapkwds = {}
1✔
902

903
        image = kwds.pop("image", self.slice)
1✔
904
        stackax = kwds.pop("stackaxis", 0)
1✔
905
        coordmat = sym.coordinate_matrix_2D(
1✔
906
            image,
907
            coordtype="homogeneous",
908
            stackaxis=stackax,
909
        )
910

911
        if transform_type == "translation":
1✔
912
            if "xtrans" in kwds and "ytrans" in kwds:
1✔
913
                tmp = kwds["ytrans"]
1✔
914
                kwds["ytrans"] = kwds["xtrans"]
1✔
915
                kwds["xtrans"] = tmp
1✔
916

917
            rdisp, cdisp = sym.translationDF(
1✔
918
                coordmat,
919
                stackaxis=stackax,
920
                ret="displacement",
921
                **kwds,
922
            )
923
        elif transform_type == "rotation":
1✔
924
            rdisp, cdisp = sym.rotationDF(
1✔
925
                coordmat,
926
                stackaxis=stackax,
927
                ret="displacement",
928
                **kwds,
929
            )
930
        elif transform_type == "rotation_auto":
1✔
931
            center = kwds.pop("center", self.pcent)
×
932
            # Estimate the optimal rotation angle using intensity symmetry
933
            angle_auto, _ = sym.sym_pose_estimate(
×
934
                image / image.max(),
935
                center=center,
936
                **kwds,
937
            )
938
            self.adjust_params = dictmerge(
×
939
                self.adjust_params,
940
                {"center": center, "angle": angle_auto},
941
            )
942
            rdisp, cdisp = sym.rotationDF(
×
943
                coordmat,
944
                stackaxis=stackax,
945
                ret="displacement",
946
                angle=angle_auto,
947
            )
948
        elif transform_type == "scaling":
1✔
949
            rdisp, cdisp = sym.scalingDF(
1✔
950
                coordmat,
951
                stackaxis=stackax,
952
                ret="displacement",
953
                **kwds,
954
            )
955
        elif transform_type == "scaling_auto":  # Compare scaling to a reference image
×
956
            pass
×
957
        elif transform_type == "shearing":
×
958
            rdisp, cdisp = sym.shearingDF(
×
959
                coordmat,
960
                stackaxis=stackax,
961
                ret="displacement",
962
                **kwds,
963
            )
964
        elif transform_type == "homography":
×
965
            transform = kwds.pop("transform", np.eye(3))
×
966
            rdisp, cdisp = sym.compose_deform_field(
×
967
                coordmat,
968
                mat_transform=transform,
969
                stackaxis=stackax,
970
                ret="displacement",
971
                **kwds,
972
            )
973

974
        # Compute deformation field
975
        if stackax == 0:
1✔
976
            rdeform, cdeform = (
1✔
977
                coordmat[1, ...] + rdisp,
978
                coordmat[0, ...] + cdisp,
979
            )
980
        elif stackax == -1:
×
981
            rdeform, cdeform = (
×
982
                coordmat[..., 1] + rdisp,
983
                coordmat[..., 0] + cdisp,
984
            )
985

986
        # Resample image in the deformation field
987
        if image is self.slice:  # resample using all previous displacement fields
1✔
988
            total_rdeform = ndi.map_coordinates(
1✔
989
                self.rdeform_field,
990
                [rdeform, cdeform],
991
                order=1,
992
            )
993
            total_cdeform = ndi.map_coordinates(
1✔
994
                self.cdeform_field,
995
                [rdeform, cdeform],
996
                order=1,
997
            )
998
            slice_transformed = ndi.map_coordinates(
1✔
999
                image,
1000
                [total_rdeform, total_cdeform],
1001
                order=interp_order,
1002
                **mapkwds,
1003
            )
1004
            self.slice_transformed = slice_transformed
1✔
1005
        else:
1006
            # if external image is provided, apply only the new addional tranformation
1007
            slice_transformed = ndi.map_coordinates(
1✔
1008
                image,
1009
                [rdeform, cdeform],
1010
                order=interp_order,
1011
                **mapkwds,
1012
            )
1013

1014
        # Combine deformation fields
1015
        if keep is True:
1✔
1016
            self.update_deformation(
1✔
1017
                rdeform,
1018
                cdeform,
1019
            )
1020
            self.adjust_params["applied"] = True
1✔
1021
            self.adjust_params = dictmerge(self.adjust_params, kwds)
1✔
1022

1023
        return slice_transformed
1✔
1024

1025
    def pose_adjustment(
1✔
1026
        self,
1027
        transformations: dict[str, Any] = None,
1028
        apply: bool = False,
1029
        reset: bool = True,
1030
        verbose: bool = True,
1031
        **kwds,
1032
    ):
1033
        """Interactive panel to adjust transformations that are applied to the image.
1034
        Applies first a scaling, next a x/y translation, and last a rotation around
1035
        the center of the image (pixel 256/256).
1036

1037
        Args:
1038
            transformations (dict, optional): Dictionary with transformations.
1039
                Defaults to self.transformations or config["momentum"]["transformtions"].
1040
            apply (bool, optional):
1041
                Option to directly apply the provided transformations.
1042
                Defaults to False.
1043
            reset (bool, optional):
1044
                Option to reset the correction before transformation. Defaults to True.
1045
            verbose (bool, optional):
1046
                Option to report the performed transformations. Defaults to True.
1047
            **kwds: Keyword parameters defining defaults for the transformations:
1048

1049
                - **scale** (float): Initial value of the scaling slider.
1050
                - **xtrans** (float): Initial value of the xtrans slider.
1051
                - **ytrans** (float): Initial value of the ytrans slider.
1052
                - **angle** (float): Initial value of the angle slider.
1053
        """
1054
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
1055
        if self.slice_corrected is None or not self.slice_corrected.any():
1✔
1056
            if self.slice is None or not self.slice.any():
1✔
1057
                self.slice = np.zeros(self._config["momentum"]["bins"][0:2])
1✔
1058
            source_image = self.slice
1✔
1059
            plot = False
1✔
1060
        else:
1061
            source_image = self.slice_corrected
1✔
1062
            plot = True
1✔
1063

1064
        transformed_image = source_image
1✔
1065

1066
        if reset:
1✔
1067
            if self.rdeform_field_bkp is not None and self.cdeform_field_bkp is not None:
1✔
1068
                self.rdeform_field = self.rdeform_field_bkp
1✔
1069
                self.cdeform_field = self.cdeform_field_bkp
1✔
1070
            else:
1071
                self.reset_deformation()
1✔
1072

1073
        center = self._config["momentum"]["center_pixel"]
1✔
1074
        if plot:
1✔
1075
            fig, ax = plt.subplots(1, 1)
1✔
1076
            img = ax.imshow(transformed_image.T, origin="lower", cmap="terrain_r")
1✔
1077
            ax.axvline(x=center[0])
1✔
1078
            ax.axhline(y=center[1])
1✔
1079

1080
        if transformations is None:
1✔
1081
            transformations = deepcopy(self.transformations)
1✔
1082

1083
        if len(kwds) > 0:
1✔
1084
            for key, value in kwds.items():
1✔
1085
                transformations[key] = value
1✔
1086

1087
        elif "creation_date" in transformations and verbose:
1✔
1088
            datestring = datetime.fromtimestamp(transformations["creation_date"]).strftime(
1✔
1089
                "%m/%d/%Y, %H:%M:%S",
1090
            )
1091
            print(f"Using transformation parameters generated on {datestring}")
1✔
1092

1093
        def update(scale: float, xtrans: float, ytrans: float, angle: float):
1✔
1094
            transformed_image = source_image
1✔
1095
            if scale != 1:
1✔
1096
                transformations["scale"] = scale
1✔
1097
                transformed_image = self.coordinate_transform(
1✔
1098
                    image=transformed_image,
1099
                    transform_type="scaling",
1100
                    xscale=scale,
1101
                    yscale=scale,
1102
                )
1103
            if xtrans != 0:
1✔
1104
                transformations["xtrans"] = xtrans
1✔
1105
            if ytrans != 0:
1✔
1106
                transformations["ytrans"] = ytrans
1✔
1107
            if xtrans != 0 or ytrans != 0:
1✔
1108
                transformed_image = self.coordinate_transform(
1✔
1109
                    image=transformed_image,
1110
                    transform_type="translation",
1111
                    xtrans=xtrans,
1112
                    ytrans=ytrans,
1113
                )
1114
            if angle != 0:
1✔
1115
                transformations["angle"] = angle
1✔
1116
                transformed_image = self.coordinate_transform(
1✔
1117
                    image=transformed_image,
1118
                    transform_type="rotation",
1119
                    angle=angle,
1120
                    center=center,
1121
                )
1122
            if plot:
1✔
1123
                img.set_data(transformed_image.T)
1✔
1124
                axmin = np.min(transformed_image, axis=(0, 1))
1✔
1125
                axmax = np.max(transformed_image, axis=(0, 1))
1✔
1126
                if axmin < axmax:
1✔
1127
                    img.set_clim(axmin, axmax)
1✔
1128
                fig.canvas.draw_idle()
1✔
1129

1130
        update(
1✔
1131
            scale=transformations.get("scale", 1),
1132
            xtrans=transformations.get("xtrans", 0),
1133
            ytrans=transformations.get("ytrans", 0),
1134
            angle=transformations.get("angle", 0),
1135
        )
1136

1137
        scale_slider = ipw.FloatSlider(
1✔
1138
            value=transformations.get("scale", 1),
1139
            min=0.8,
1140
            max=1.2,
1141
            step=0.01,
1142
        )
1143
        xtrans_slider = ipw.FloatSlider(
1✔
1144
            value=transformations.get("xtrans", 0),
1145
            min=-200,
1146
            max=200,
1147
            step=1,
1148
        )
1149
        ytrans_slider = ipw.FloatSlider(
1✔
1150
            value=transformations.get("ytrans", 0),
1151
            min=-200,
1152
            max=200,
1153
            step=1,
1154
        )
1155
        angle_slider = ipw.FloatSlider(
1✔
1156
            value=transformations.get("angle", 0),
1157
            min=-180,
1158
            max=180,
1159
            step=1,
1160
        )
1161
        results_box = ipw.Output()
1✔
1162
        ipw.interact(
1✔
1163
            update,
1164
            scale=scale_slider,
1165
            xtrans=xtrans_slider,
1166
            ytrans=ytrans_slider,
1167
            angle=angle_slider,
1168
        )
1169

1170
        def apply_func(apply: bool):  # noqa: ARG001
1✔
1171
            if transformations.get("scale", 1) != 1:
1✔
1172
                self.coordinate_transform(
1✔
1173
                    transform_type="scaling",
1174
                    xscale=transformations["scale"],
1175
                    yscale=transformations["scale"],
1176
                    keep=True,
1177
                )
1178
                if verbose:
1✔
1179
                    with results_box:
1✔
1180
                        print(f"Applied scaling with scale={transformations['scale']}.")
1✔
1181
            if transformations.get("xtrans", 0) != 0 or transformations.get("ytrans", 0) != 0:
1✔
1182
                self.coordinate_transform(
1✔
1183
                    transform_type="translation",
1184
                    xtrans=transformations.get("xtrans", 0),
1185
                    ytrans=transformations.get("ytrans", 0),
1186
                    keep=True,
1187
                )
1188
                if verbose:
1✔
1189
                    with results_box:
1✔
1190
                        print(
1✔
1191
                            f"Applied translation with (xtrans={transformations.get('xtrans', 0)},",
1192
                            f"ytrans={transformations.get('ytrans', 0)}).",
1193
                        )
1194
            if transformations.get("angle", 0) != 0:
1✔
1195
                self.coordinate_transform(
1✔
1196
                    transform_type="rotation",
1197
                    angle=transformations["angle"],
1198
                    center=center,
1199
                    keep=True,
1200
                )
1201
                if verbose:
1✔
1202
                    with results_box:
1✔
1203
                        print(f"Applied rotation with angle={transformations['angle']}.")
1✔
1204

1205
                display(results_box)
1✔
1206

1207
            if plot:
1✔
1208
                img.set_data(self.slice_transformed.T)
1✔
1209
                axmin = np.min(self.slice_transformed, axis=(0, 1))
1✔
1210
                axmax = np.max(self.slice_transformed, axis=(0, 1))
1✔
1211
                if axmin < axmax:
1✔
1212
                    img.set_clim(axmin, axmax)
1✔
1213
                fig.canvas.draw_idle()
1✔
1214

1215
            if transformations != self.transformations:
1✔
1216
                transformations["creation_date"] = datetime.now().timestamp()
1✔
1217
                self.transformations = transformations
1✔
1218

1219
            if verbose:
1✔
1220
                plt.figure()
1✔
1221
                subs = 20
1✔
1222
                plt.title("Deformation field")
1✔
1223
                plt.scatter(
1✔
1224
                    self.rdeform_field[::subs, ::subs].ravel(),
1225
                    self.cdeform_field[::subs, ::subs].ravel(),
1226
                    c="b",
1227
                )
1228
                plt.show()
1✔
1229
            scale_slider.close()
1✔
1230
            xtrans_slider.close()
1✔
1231
            ytrans_slider.close()
1✔
1232
            angle_slider.close()
1✔
1233
            apply_button.close()
1✔
1234

1235
        apply_button = ipw.Button(description="apply")
1✔
1236
        display(apply_button)
1✔
1237
        apply_button.on_click(apply_func)
1✔
1238

1239
        if plot:
1✔
1240
            plt.show()
1✔
1241

1242
        if apply:
1✔
1243
            apply_func(True)
1✔
1244

1245
    def calc_inverse_dfield(self):
1✔
1246
        """Calculate the inverse dfield from the cdeform and rdeform fields"""
1247
        self.inverse_dfield = generate_inverse_dfield(
×
1248
            self.rdeform_field,
1249
            self.cdeform_field,
1250
            self.bin_ranges,
1251
            self.detector_ranges,
1252
        )
1253

1254
        return self.inverse_dfield
×
1255

1256
    def view(  # pylint: disable=dangerous-default-value
1✔
1257
        self,
1258
        image: np.ndarray = None,
1259
        origin: str = "lower",
1260
        cmap: str = "terrain_r",
1261
        figsize: tuple[int, int] = (4, 4),
1262
        points: dict = None,
1263
        annotated: bool = False,
1264
        backend: str = "matplotlib",
1265
        imkwds: dict = {},
1266
        scatterkwds: dict = {},
1267
        cross: bool = False,
1268
        crosshair: bool = False,
1269
        crosshair_radii: list[int] = [50, 100, 150],
1270
        crosshair_thickness: int = 1,
1271
        **kwds,
1272
    ):
1273
        """Display image slice with specified annotations.
1274

1275
        Args:
1276
            image (np.ndarray, optional): The image to plot. Defaults to self.slice.
1277
            origin (str, optional): Figure origin specification ('lower' or 'upper').
1278
                Defaults to "lower".
1279
            cmap (str, optional):  Colormap specification. Defaults to "terrain_r".
1280
            figsize (tuple[int, int], optional): Figure size. Defaults to (4, 4).
1281
            points (dict, optional): Points for annotation. Defaults to None.
1282
            annotated (bool, optional): Option to add annotation. Defaults to False.
1283
            backend (str, optional): Visualization backend specification. Defaults to
1284
                "matplotlib".
1285

1286
                - 'matplotlib': use static display rendered by matplotlib.
1287
                - 'bokeh': use interactive display rendered by bokeh.
1288

1289
            imkwds (dict, optional): Keyword arguments for
1290
                ``matplotlib.pyplot.imshow()``. Defaults to {}.
1291
            scatterkwds (dict, optional): Keyword arguments for
1292
                ``matplotlib.pyplot.scatter()``. Defaults to {}.
1293
            cross (bool, optional): Option to display a horizontal/vertical lines at
1294
                self.pcent. Defaults to False.
1295
            crosshair (bool, optional): Display option to plot circles around center
1296
                self.pcent. Works only in bokeh backend. Defaults to False.
1297
            crosshair_radii (list[int], optional): Pixel radii of circles to plot when
1298
                crosshair option is activated. Defaults to [50, 100, 150].
1299
            crosshair_thickness (int, optional): Thickness of crosshair circles.
1300
                Defaults to 1.
1301
            **kwds: keyword arguments.
1302
                General extra arguments for the plotting procedure.
1303
        """
1304
        if image is None:
1✔
1305
            image = self.slice
1✔
1306
        num_rows, num_cols = image.shape
1✔
1307

1308
        if points is None:
1✔
1309
            points = self.features
1✔
1310

1311
        if annotated:
1✔
1312
            tsr, tsc = kwds.pop("textshift", (3, 3))
1✔
1313
            txtsize = kwds.pop("textsize", 12)
1✔
1314

1315
        if backend == "matplotlib":
1✔
NEW
1316
            fig_plt, ax = plt.subplots(figsize=figsize)
×
1317
            ax.imshow(image.T, origin=origin, cmap=cmap, **imkwds)
×
1318

1319
            if cross:
×
1320
                center = self._config["momentum"]["center_pixel"]
×
1321
                ax.axvline(x=center[0])
×
1322
                ax.axhline(y=center[1])
×
1323

1324
            # Add annotation to the figure
1325
            if annotated:
×
1326
                for (
×
1327
                    p_keys,  # pylint: disable=unused-variable
1328
                    p_vals,
1329
                ) in points.items():
1330
                    try:
×
1331
                        ax.scatter(p_vals[:, 0], p_vals[:, 1], **scatterkwds)
×
1332
                    except IndexError:
×
1333
                        try:
×
1334
                            ax.scatter(p_vals[0], p_vals[1], **scatterkwds)
×
1335
                        except IndexError:
×
1336
                            pass
×
1337

1338
                    if p_vals.size > 2:
×
1339
                        for i_pval, pval in enumerate(p_vals):
×
1340
                            ax.text(
×
1341
                                pval[0] + tsc,
1342
                                pval[1] + tsr,
1343
                                str(i_pval),
1344
                                fontsize=txtsize,
1345
                            )
1346

1347
        elif backend == "bokeh":
1✔
1348
            output_notebook(hide_banner=True)
1✔
1349
            colors = it.cycle(ColorCycle[10])
1✔
1350
            ttp = [("(x, y)", "($x, $y)")]
1✔
1351
            figsize = kwds.pop("figsize", (320, 300))
1✔
1352
            palette = cm2palette(cmap)  # Retrieve palette colors
1✔
1353
            fig = pbk.figure(
1✔
1354
                width=figsize[0],
1355
                height=figsize[1],
1356
                tooltips=ttp,
1357
                x_range=(0, num_rows),
1358
                y_range=(0, num_cols),
1359
            )
1360
            fig.image(
1✔
1361
                image=[image.T],
1362
                x=0,
1363
                y=0,
1364
                dw=num_rows,
1365
                dh=num_cols,
1366
                palette=palette,
1367
                **imkwds,
1368
            )
1369

1370
            if annotated is True:
1✔
1371
                for p_keys, p_vals in points.items():
1✔
1372
                    try:
1✔
1373
                        xcirc, ycirc = p_vals[:, 0], p_vals[:, 1]
1✔
1374
                        fig.scatter(
1✔
1375
                            xcirc,
1376
                            ycirc,
1377
                            size=8,
1378
                            color=next(colors),
1379
                            **scatterkwds,
1380
                        )
1381
                    except IndexError:
1✔
1382
                        try:
1✔
1383
                            xcirc, ycirc = p_vals[0], p_vals[1]
1✔
1384
                            fig.scatter(
1✔
1385
                                xcirc,
1386
                                ycirc,
1387
                                size=8,
1388
                                color=next(colors),
1389
                                **scatterkwds,
1390
                            )
1391
                        except IndexError:
×
1392
                            pass
×
1393
            if crosshair and self.pcent is not None:
1✔
1394
                for radius in crosshair_radii:
1✔
1395
                    fig.annulus(
1✔
1396
                        x=[self.pcent[0]],
1397
                        y=[self.pcent[1]],
1398
                        inner_radius=radius - crosshair_thickness,
1399
                        outer_radius=radius,
1400
                        color="red",
1401
                        alpha=0.6,
1402
                    )
1403

1404
            pbk.show(fig)
1✔
1405

1406
    def select_k_range(
1✔
1407
        self,
1408
        point_a: np.ndarray | list[int] = None,
1409
        point_b: np.ndarray | list[int] = None,
1410
        k_distance: float = None,
1411
        k_coord_a: np.ndarray | list[float] = None,
1412
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
1413
        equiscale: bool = True,
1414
        apply: bool = False,
1415
    ):
1416
        """Interactive selection function for features for the Momentum axes calibra-
1417
        tion. It allows the user to select the pixel positions of two symmetry points
1418
        (a and b) and the k-space distance of the two. Alternatively, the corrdinates
1419
        of both points can be provided. See the equiscale option for details on the
1420
        specifications of point coordinates.
1421

1422
        Args:
1423
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the
1424
                symmetry point a.
1425
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
1426
                symmetry point b. Defaults to the center pixel of the image, defined by
1427
                config["momentum"]["center_pixel"].
1428
            k_distance (float, optional): The known momentum space distance between the
1429
                two symmetry points.
1430
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
1431
                of the symmetry points a. Only valid if equiscale=False.
1432
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
1433
                of the symmetry points b. Only valid if equiscale=False. Defaults to
1434
                the k-space center np.array([0.0, 0.0]).
1435
            equiscale (bool, optional): Option to adopt equal scale along both the x
1436
                and y directions.
1437

1438
                - **True**: Use a uniform scale for both x and y directions in the
1439
                  image coordinate system. This applies to the situation where
1440
                  k_distance is given and the points a and b are (close to) parallel
1441
                  with one of the two image axes.
1442
                - **False**: Calculate the momentum scale for both x and y directions
1443
                  separately. This applies to the situation where the points a and b
1444
                  are sufficiently different in both x and y directions in the image
1445
                  coordinate system.
1446

1447
                Defaults to 'True'.
1448

1449
            apply (bool, optional): Option to directly store the calibration parameters
1450
                to the class. Defaults to False.
1451

1452
        Raises:
1453
            ValueError: If no valid image is found from which to ge the coordinates.
1454
        """
1455
        matplotlib.use("module://ipympl.backend_nbagg")
1✔
1456
        if self.slice_transformed is not None:
1✔
1457
            image = self.slice_transformed
1✔
1458
        elif self.slice_corrected is not None:
1✔
1459
            image = self.slice_corrected
×
1460
        elif self.slice is not None:
1✔
1461
            image = self.slice
×
1462
        else:
1463
            raise ValueError("No valid image loaded!")
1✔
1464

1465
        if point_b is None:
1✔
1466
            point_b = self._config["momentum"]["center_pixel"]
×
1467

1468
        if point_a is None:
1✔
1469
            point_a = [0, 0]
×
1470

1471
        fig, ax = plt.subplots(1, 1)
1✔
1472
        img = ax.imshow(image.T, origin="lower", cmap="terrain_r")
1✔
1473

1474
        (marker_a,) = ax.plot(point_a[0], point_a[1], "o")
1✔
1475
        (marker_b,) = ax.plot(point_b[0], point_b[1], "ro")
1✔
1476

1477
        def update(
1✔
1478
            point_a_x: int,
1479
            point_a_y: int,
1480
            point_b_x: int,
1481
            point_b_y: int,
1482
            k_distance: float,  # noqa: ARG001
1483
        ):
UNCOV
1484
            fig.canvas.draw_idle()
×
UNCOV
1485
            marker_a.set_xdata(point_a_x)
×
UNCOV
1486
            marker_a.set_ydata(point_a_y)
×
UNCOV
1487
            marker_b.set_xdata(point_b_x)
×
UNCOV
1488
            marker_b.set_ydata(point_b_y)
×
1489

1490
        point_a_input_x = ipw.IntText(point_a[0])
1✔
1491
        point_a_input_y = ipw.IntText(point_a[1])
1✔
1492
        point_b_input_x = ipw.IntText(point_b[0])
1✔
1493
        point_b_input_y = ipw.IntText(point_b[1])
1✔
1494
        k_distance_input = ipw.FloatText(k_distance)
1✔
1495
        ipw.interact(
1✔
1496
            update,
1497
            point_a_x=point_a_input_x,
1498
            point_a_y=point_a_input_y,
1499
            point_b_x=point_b_input_x,
1500
            point_b_y=point_b_input_y,
1501
            k_distance=k_distance_input,
1502
        )
1503

1504
        self._state = 0
1✔
1505

1506
        def onclick(event):
1✔
1507
            if self._state == 0:
×
1508
                point_a_input_x.value = event.xdata
×
1509
                point_a_input_y.value = event.ydata
×
1510
                self._state = 1
×
1511
            else:
1512
                point_b_input_x.value = event.xdata
×
1513
                point_b_input_y.value = event.ydata
×
1514
                self._state = 0
×
1515

1516
        cid = fig.canvas.mpl_connect("button_press_event", onclick)
1✔
1517

1518
        def apply_func(apply: bool):  # noqa: ARG001
1✔
1519
            point_a = [point_a_input_x.value, point_a_input_y.value]
1✔
1520
            point_b = [point_b_input_x.value, point_b_input_y.value]
1✔
1521
            calibration = self.calibrate(
1✔
1522
                point_a=point_a,
1523
                point_b=point_b,
1524
                k_distance=k_distance,
1525
                equiscale=equiscale,
1526
                k_coord_a=k_coord_a,
1527
                k_coord_b=k_coord_b,
1528
            )
1529

1530
            img.set_extent(calibration["extent"])
1✔
1531
            plt.title("Momentum calibrated data")
1✔
1532
            plt.xlabel("$k_x$", fontsize=15)
1✔
1533
            plt.ylabel("$k_y$", fontsize=15)
1✔
1534
            ax.axhline(0)
1✔
1535
            ax.axvline(0)
1✔
1536

1537
            fig.canvas.mpl_disconnect(cid)
1✔
1538

1539
            point_a_input_x.close()
1✔
1540
            point_a_input_y.close()
1✔
1541
            point_b_input_x.close()
1✔
1542
            point_b_input_y.close()
1✔
1543
            k_distance_input.close()
1✔
1544
            apply_button.close()
1✔
1545

1546
            fig.canvas.draw_idle()
1✔
1547

1548
        apply_button = ipw.Button(description="apply")
1✔
1549
        display(apply_button)
1✔
1550
        apply_button.on_click(apply_func)
1✔
1551

1552
        if apply:
1✔
1553
            apply_func(True)
1✔
1554

1555
        plt.show()
1✔
1556

1557
    def calibrate(
1✔
1558
        self,
1559
        point_a: np.ndarray | list[int],
1560
        point_b: np.ndarray | list[int],
1561
        k_distance: float = None,
1562
        k_coord_a: np.ndarray | list[float] = None,
1563
        k_coord_b: np.ndarray | list[float] = np.array([0.0, 0.0]),
1564
        equiscale: bool = True,
1565
        image: np.ndarray = None,
1566
    ) -> dict:
1567
        """Momentum axes calibration using the pixel positions of two symmetry points
1568
        (a and b) and the absolute coordinate of a single point (b), defaulted to
1569
        [0., 0.]. All coordinates should be specified in the (x/y), i.e. (column_index,
1570
        row_index) format. See the equiscale option for details on the specifications
1571
        of point coordinates.
1572

1573
        Args:
1574
            point_a (np.ndarray | list[int], optional): Pixel coordinates of the
1575
                symmetry point a.
1576
            point_b (np.ndarray | list[int], optional): Pixel coordinates of the
1577
                symmetry point b. Defaults to the center pixel of the image, defined by
1578
                config["momentum"]["center_pixel"].
1579
            k_distance (float, optional): The known momentum space distance between the
1580
                two symmetry points.
1581
            k_coord_a (np.ndarray | list[float], optional): Momentum coordinate
1582
                of the symmetry points a. Only valid if equiscale=False.
1583
            k_coord_b (np.ndarray | list[float], optional): Momentum coordinate
1584
                of the symmetry points b. Only valid if equiscale=False. Defaults to
1585
                the k-space center np.array([0.0, 0.0]).
1586
            equiscale (bool, optional): Option to adopt equal scale along both the x
1587
                and y directions.
1588

1589
                - **True**: Use a uniform scale for both x and y directions in the
1590
                  image coordinate system. This applies to the situation where
1591
                  k_distance is given and the points a and b are (close to) parallel
1592
                  with one of the two image axes.
1593
                - **False**: Calculate the momentum scale for both x and y directions
1594
                  separately. This applies to the situation where the points a and b
1595
                  are sufficiently different in both x and y directions in the image
1596
                  coordinate system.
1597

1598
                Defaults to 'True'.
1599
            image (np.ndarray, optional): The energy slice for which to return the
1600
                calibration. Defaults to self.slice_corrected.
1601

1602
        Returns:
1603
            dict: dictionary with following entries:
1604

1605
                - "axes": Tuple of 1D arrays
1606
                  Momentum coordinates of the row and column.
1607
                - "extent": list
1608
                  Extent of the two momentum axis (can be used directly in imshow).
1609
                - "grid": Tuple of 2D arrays
1610
                  Row and column mesh grid generated from the coordinates
1611
                  (can be used directly in pcolormesh).
1612
                - "coeffs": Tuple of (x, y) calibration coefficients
1613
                - "x_center", "y_center": Pixel positions of the k-space center
1614
                - "cstart", "rstart": Detector positions of the image used for
1615
                  calibration
1616
                - "cstep", "rstep": Step size of detector coordinates in the image
1617
                  used for calibration
1618
        """
1619
        if image is None:
1✔
1620
            image = self.slice_corrected
1✔
1621

1622
        nrows, ncols = image.shape
1✔
1623
        point_a, point_b = map(np.array, [point_a, point_b])
1✔
1624

1625
        rowdist = range(nrows) - point_b[0]
1✔
1626
        coldist = range(ncols) - point_b[1]
1✔
1627

1628
        if equiscale is True:
1✔
1629
            assert k_distance is not None
1✔
1630
            # Use the same conversion factor along both x and y directions
1631
            # (need k_distance)
1632
            pixel_distance = norm(point_a - point_b)
1✔
1633
            # Calculate the pixel to momentum conversion factor
1634
            xratio = yratio = k_distance / pixel_distance
1✔
1635

1636
        else:
1637
            assert k_coord_a is not None
1✔
1638
            # Calculate the conversion factor along x and y directions separately
1639
            # (need k_coord_a)
1640
            kxb, kyb = k_coord_b
1✔
1641
            kxa, kya = k_coord_a
1✔
1642
            # Calculate the column- and row-wise conversion factor
1643
            xratio = (kxa - kxb) / (point_a[0] - point_b[0])
1✔
1644
            yratio = (kya - kyb) / (point_a[1] - point_b[1])
1✔
1645

1646
        k_row = rowdist * xratio + k_coord_b[0]
1✔
1647
        k_col = coldist * yratio + k_coord_b[1]
1✔
1648

1649
        # Calculate other return parameters
1650
        k_rowgrid, k_colgrid = np.meshgrid(k_row, k_col)
1✔
1651

1652
        # Assemble into return dictionary
1653
        self.calibration = {}
1✔
1654
        self.calibration["creation_date"] = datetime.now().timestamp()
1✔
1655
        self.calibration["kx_axis"] = k_row
1✔
1656
        self.calibration["ky_axis"] = k_col
1✔
1657
        self.calibration["grid"] = (k_rowgrid, k_colgrid)
1✔
1658
        self.calibration["extent"] = (k_row[0], k_row[-1], k_col[0], k_col[-1])
1✔
1659
        self.calibration["kx_scale"] = xratio
1✔
1660
        self.calibration["ky_scale"] = yratio
1✔
1661
        self.calibration["x_center"] = point_b[0] - k_coord_b[0] / xratio
1✔
1662
        self.calibration["y_center"] = point_b[1] - k_coord_b[1] / yratio
1✔
1663
        # copy parameters for applying calibration
1664
        try:
1✔
1665
            self.calibration["rstart"] = self.bin_ranges[0][0]
1✔
1666
            self.calibration["cstart"] = self.bin_ranges[1][0]
1✔
1667
            self.calibration["rstep"] = (self.bin_ranges[0][1] - self.bin_ranges[0][0]) / nrows
1✔
1668
            self.calibration["cstep"] = (self.bin_ranges[1][1] - self.bin_ranges[1][0]) / ncols
1✔
1669
        except (AttributeError, IndexError):
×
1670
            pass
×
1671

1672
        return self.calibration
1✔
1673

1674
    def apply_corrections(
1✔
1675
        self,
1676
        df: pd.DataFrame | dask.dataframe.DataFrame,
1677
        x_column: str = None,
1678
        y_column: str = None,
1679
        new_x_column: str = None,
1680
        new_y_column: str = None,
1681
        verbose: bool = True,
1682
        **kwds,
1683
    ) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]:
1684
        """Calculate and replace the X and Y values with their distortion-corrected
1685
        version.
1686

1687
        Args:
1688
            df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply
1689
                the distotion correction to.
1690
            x_column (str, optional): Label of the 'X' column before momentum
1691
                distortion correction. Defaults to config["momentum"]["x_column"].
1692
            y_column (str, optional): Label of the 'Y' column before momentum
1693
                distortion correction. Defaults to config["momentum"]["y_column"].
1694
            new_x_column (str, optional): Label of the 'X' column after momentum
1695
                distortion correction.
1696
                Defaults to config["momentum"]["corrected_x_column"].
1697
            new_y_column (str, optional): Label of the 'Y' column after momentum
1698
                distortion correction.
1699
                Defaults to config["momentum"]["corrected_y_column"].
1700
            verbose (bool, optional): Option to report the used landmarks for correction.
1701
                Defaults to True.
1702
            **kwds: Keyword arguments:
1703

1704
                - **dfield**: Inverse dfield
1705
                - **cdeform_field**, **rdeform_field**: Column- and row-wise forward
1706
                  deformation fields.
1707

1708
                Additional keyword arguments are passed to ``apply_dfield``.
1709

1710
        Returns:
1711
            tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]: Dataframe with
1712
            added columns and momentum correction metadata dictionary.
1713
        """
1714
        if x_column is None:
1✔
1715
            x_column = self.x_column
1✔
1716
        if y_column is None:
1✔
1717
            y_column = self.y_column
1✔
1718

1719
        if new_x_column is None:
1✔
1720
            new_x_column = self.corrected_x_column
1✔
1721
        if new_y_column is None:
1✔
1722
            new_y_column = self.corrected_y_column
1✔
1723

1724
        if self.inverse_dfield is None or self.dfield_updated:
1✔
1725
            if self.rdeform_field is None and self.cdeform_field is None:
1✔
1726
                if self.correction or self.transformations:
1✔
1727
                    if self.correction:
1✔
1728
                        # Generate spline warp from class features or config
1729
                        self.spline_warp_estimate(verbose=verbose)
1✔
1730
                    if self.transformations:
1✔
1731
                        # Apply config pose adjustments
1732
                        self.pose_adjustment()
1✔
1733
                else:
1734
                    raise ValueError("No corrections or transformations defined!")
1✔
1735

1736
            self.inverse_dfield = generate_inverse_dfield(
1✔
1737
                self.rdeform_field,
1738
                self.cdeform_field,
1739
                self.bin_ranges,
1740
                self.detector_ranges,
1741
            )
1742
            self.dfield_updated = False
1✔
1743

1744
        out_df = df.map_partitions(
1✔
1745
            apply_dfield,
1746
            dfield=self.inverse_dfield,
1747
            x_column=x_column,
1748
            y_column=y_column,
1749
            new_x_column=new_x_column,
1750
            new_y_column=new_y_column,
1751
            detector_ranges=self.detector_ranges,
1752
            **kwds,
1753
        )
1754

1755
        metadata = self.gather_correction_metadata()
1✔
1756

1757
        return out_df, metadata
1✔
1758

1759
    def gather_correction_metadata(self) -> dict:
1✔
1760
        """Collect meta data for momentum correction.
1761

1762
        Returns:
1763
            dict: generated correction metadata dictionary.
1764
        """
1765
        metadata: dict[Any, Any] = {}
1✔
1766
        if len(self.correction) > 0:
1✔
1767
            metadata["correction"] = self.correction
1✔
1768
            metadata["correction"]["applied"] = True
1✔
1769
            metadata["correction"]["cdeform_field"] = self.cdeform_field
1✔
1770
            metadata["correction"]["rdeform_field"] = self.rdeform_field
1✔
1771
            try:
1✔
1772
                metadata["correction"]["creation_date"] = self.correction["creation_date"]
1✔
1773
            except KeyError:
×
1774
                pass
×
1775
        if len(self.adjust_params) > 0:
1✔
1776
            metadata["registration"] = self.adjust_params
1✔
1777
            metadata["registration"]["creation_date"] = datetime.now().timestamp()
1✔
1778
            metadata["registration"]["applied"] = True
1✔
1779
            metadata["registration"]["depends_on"] = (
1✔
1780
                "/entry/process/registration/tranformations/rot_z"
1781
                if "angle" in metadata["registration"] and metadata["registration"]["angle"]
1782
                else "/entry/process/registration/tranformations/trans_y"
1783
                if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]
1784
                else "/entry/process/registration/tranformations/trans_x"
1785
                if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1786
                else "."
1787
            )
1788
            if (
1✔
1789
                "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1790
            ):  # swapped definitions
1791
                metadata["registration"]["trans_x"] = {}
1✔
1792
                metadata["registration"]["trans_x"]["value"] = metadata["registration"]["ytrans"]
1✔
1793
                metadata["registration"]["trans_x"]["type"] = "translation"
1✔
1794
                metadata["registration"]["trans_x"]["units"] = "pixel"
1✔
1795
                metadata["registration"]["trans_x"]["vector"] = np.asarray(
1✔
1796
                    [1.0, 0.0, 0.0],
1797
                )
1798
                metadata["registration"]["trans_x"]["depends_on"] = "."
1✔
1799
            if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]:
1✔
1800
                metadata["registration"]["trans_y"] = {}
1✔
1801
                metadata["registration"]["trans_y"]["value"] = metadata["registration"]["xtrans"]
1✔
1802
                metadata["registration"]["trans_y"]["type"] = "translation"
1✔
1803
                metadata["registration"]["trans_y"]["units"] = "pixel"
1✔
1804
                metadata["registration"]["trans_y"]["vector"] = np.asarray(
1✔
1805
                    [0.0, 1.0, 0.0],
1806
                )
1807
                metadata["registration"]["trans_y"]["depends_on"] = (
1✔
1808
                    "/entry/process/registration/tranformations/trans_x"
1809
                    if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1810
                    else "."
1811
                )
1812
            if "angle" in metadata["registration"] and metadata["registration"]["angle"]:
1✔
1813
                metadata["registration"]["rot_z"] = {}
1✔
1814
                metadata["registration"]["rot_z"]["value"] = metadata["registration"]["angle"]
1✔
1815
                metadata["registration"]["rot_z"]["type"] = "rotation"
1✔
1816
                metadata["registration"]["rot_z"]["units"] = "degrees"
1✔
1817
                metadata["registration"]["rot_z"]["vector"] = np.asarray(
1✔
1818
                    [0.0, 0.0, 1.0],
1819
                )
1820
                metadata["registration"]["rot_z"]["offset"] = np.concatenate(
1✔
1821
                    (metadata["registration"]["center"], [0.0]),
1822
                )
1823
                metadata["registration"]["rot_z"]["depends_on"] = (
1✔
1824
                    "/entry/process/registration/tranformations/trans_y"
1825
                    if "xtrans" in metadata["registration"] and metadata["registration"]["xtrans"]
1826
                    else "/entry/process/registration/tranformations/trans_x"
1827
                    if "ytrans" in metadata["registration"] and metadata["registration"]["ytrans"]
1828
                    else "."
1829
                )
1830

1831
        return metadata
1✔
1832

1833
    def append_k_axis(
1✔
1834
        self,
1835
        df: pd.DataFrame | dask.dataframe.DataFrame,
1836
        x_column: str = None,
1837
        y_column: str = None,
1838
        new_x_column: str = None,
1839
        new_y_column: str = None,
1840
        calibration: dict = None,
1841
        **kwds,
1842
    ) -> tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]:
1843
        """Calculate and append the k axis coordinates (kx, ky) to the events dataframe.
1844

1845
        Args:
1846
            df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply the
1847
                distotion correction to.
1848
            x_column (str, optional): Label of the source 'X' column.
1849
                Defaults to config["momentum"]["corrected_x_column"] or
1850
                config["momentum"]["x_column"] (whichever is present).
1851
            y_column (str, optional): Label of the source 'Y' column.
1852
                Defaults to config["momentum"]["corrected_y_column"] or
1853
                config["momentum"]["y_column"] (whichever is present).
1854
            new_x_column (str, optional): Label of the destination 'X' column after
1855
                momentum calibration. Defaults to config["momentum"]["kx_column"].
1856
            new_y_column (str, optional): Label of the destination 'Y' column after
1857
                momentum calibration. Defaults to config["momentum"]["ky_column"].
1858
            calibration (dict, optional): Dictionary containing calibration parameters.
1859
                Defaults to 'self.calibration' or config["momentum"]["calibration"].
1860
            **kwds: Keyword parameters for momentum calibration. Parameters are added
1861
                to the calibration dictionary.
1862

1863
        Returns:
1864
            tuple[pd.DataFrame | dask.dataframe.DataFrame, dict]: Dataframe with
1865
            added columns and momentum calibration metadata dictionary.
1866
        """
1867
        if x_column is None:
1✔
1868
            if self.corrected_x_column in df.columns:
1✔
1869
                x_column = self.corrected_x_column
×
1870
            else:
1871
                x_column = self.x_column
1✔
1872
        if y_column is None:
1✔
1873
            if self.corrected_y_column in df.columns:
1✔
1874
                y_column = self.corrected_y_column
×
1875
            else:
1876
                y_column = self.y_column
1✔
1877

1878
        if new_x_column is None:
1✔
1879
            new_x_column = self.kx_column
1✔
1880

1881
        if new_y_column is None:
1✔
1882
            new_y_column = self.ky_column
1✔
1883

1884
        # pylint: disable=duplicate-code
1885
        if calibration is None:
1✔
1886
            calibration = deepcopy(self.calibration)
1✔
1887

1888
        if len(kwds) > 0:
1✔
1889
            for key, value in kwds.items():
1✔
1890
                calibration[key] = value
1✔
1891
            calibration["creation_date"] = datetime.now().timestamp()
1✔
1892

1893
        try:
1✔
1894
            (df[new_x_column], df[new_y_column]) = detector_coordiantes_2_k_koordinates(
1✔
1895
                r_det=df[x_column],
1896
                c_det=df[y_column],
1897
                r_start=calibration["rstart"],
1898
                c_start=calibration["cstart"],
1899
                r_center=calibration["x_center"],
1900
                c_center=calibration["y_center"],
1901
                r_conversion=calibration["kx_scale"],
1902
                c_conversion=calibration["ky_scale"],
1903
                r_step=calibration["rstep"],
1904
                c_step=calibration["cstep"],
1905
            )
1906
        except KeyError as exc:
1✔
1907
            raise ValueError(
1✔
1908
                "Required calibration parameters missing!",
1909
            ) from exc
1910

1911
        metadata = self.gather_calibration_metadata(calibration=calibration)
1✔
1912

1913
        return df, metadata
1✔
1914

1915
    def gather_calibration_metadata(self, calibration: dict = None) -> dict:
1✔
1916
        """Collect meta data for momentum calibration
1917

1918
        Args:
1919
            calibration (dict, optional): Dictionary with momentum calibration
1920
                parameters. If omitted will be taken from the class.
1921

1922
        Returns:
1923
            dict: Generated metadata dictionary.
1924
        """
1925
        if calibration is None:
1✔
1926
            calibration = self.calibration
×
1927
        metadata: dict[Any, Any] = {}
1✔
1928
        try:
1✔
1929
            metadata["creation_date"] = calibration["creation_date"]
1✔
1930
        except KeyError:
1✔
1931
            pass
1✔
1932
        metadata["applied"] = True
1✔
1933
        metadata["calibration"] = calibration
1✔
1934
        # create empty calibrated axis entries, if they are not present.
1935
        if "kx_axis" not in metadata["calibration"]:
1✔
1936
            metadata["calibration"]["kx_axis"] = 0
1✔
1937
        if "ky_axis" not in metadata["calibration"]:
1✔
1938
            metadata["calibration"]["ky_axis"] = 0
1✔
1939

1940
        return metadata
1✔
1941

1942

1943
def cm2palette(cmap_name: str) -> list:
1✔
1944
    """Convert certain matplotlib colormap (cm) to bokeh palette.
1945

1946
    Args:
1947
        cmap_name (str): Name of the colormap/palette.
1948

1949
    Returns:
1950
        list: List of colors in hex representation (a bokoeh palette).
1951
    """
1952
    if cmap_name in bp.all_palettes.keys():
1✔
1953
        palette_func = getattr(bp, cmap_name)
×
1954
        palette = palette_func
×
1955

1956
    else:
1957
        palette_func = getattr(cm, cmap_name)
1✔
1958
        mpl_cm_rgb = (255 * palette_func(range(256))).astype("int")
1✔
1959
        palette = [RGB(*tuple(rgb)).to_hex() for rgb in mpl_cm_rgb]
1✔
1960

1961
    return palette
1✔
1962

1963

1964
def dictmerge(
1✔
1965
    main_dict: dict,
1966
    other_entries: list[dict] | tuple[dict] | dict,
1967
) -> dict:
1968
    """Merge a dictionary with other dictionaries.
1969

1970
    Args:
1971
        main_dict (dict): Main dictionary.
1972
        other_entries (list[dict] | tuple[dict] | dict):
1973
            Other dictionary or composite dictionarized elements.
1974

1975
    Returns:
1976
        dict: Merged dictionary.
1977
    """
1978
    # Merge main_dict with a list or tuple of dictionaries
1979
    if isinstance(other_entries, (list, tuple)):
1✔
1980
        for oth in other_entries:
×
1981
            main_dict = {**main_dict, **oth}
×
1982
    # Merge D with a single dictionary
1983
    elif isinstance(other_entries, dict):
1✔
1984
        main_dict = {**main_dict, **other_entries}
1✔
1985

1986
    return main_dict
1✔
1987

1988

1989
def detector_coordiantes_2_k_koordinates(
1✔
1990
    r_det: float,
1991
    c_det: float,
1992
    r_start: float,
1993
    c_start: float,
1994
    r_center: float,
1995
    c_center: float,
1996
    r_conversion: float,
1997
    c_conversion: float,
1998
    r_step: float,
1999
    c_step: float,
2000
) -> tuple[float, float]:
2001
    """Conversion from detector coordinates (rdet, cdet) to momentum coordinates
2002
    (kr, kc).
2003

2004
    Args:
2005
        r_det (float): Row detector coordinates.
2006
        c_det (float): Column detector coordinates.
2007
        r_start (float): Start value for row detector coordinates.
2008
        c_start (float): Start value for column detector coordinates.
2009
        r_center (float): Center value for row detector coordinates.
2010
        c_center (float): Center value for column detector coordinates.
2011
        r_conversion (float): Row conversion factor.
2012
        c_conversion (float): Column conversion factor.
2013
        r_step (float): Row stepping factor.
2014
        c_step (float): Column stepping factor.
2015

2016
    Returns:
2017
        tuple[float, float]: Converted momentum space row/column coordinates.
2018
    """
2019
    r_det0 = r_start + r_step * r_center
1✔
2020
    c_det0 = c_start + c_step * c_center
1✔
2021
    k_r = r_conversion * ((r_det - r_det0) / r_step)
1✔
2022
    k_c = c_conversion * ((c_det - c_det0) / c_step)
1✔
2023

2024
    return (k_r, k_c)
1✔
2025

2026

2027
def apply_dfield(
1✔
2028
    df: pd.DataFrame | dask.dataframe.DataFrame,
2029
    dfield: np.ndarray,
2030
    x_column: str,
2031
    y_column: str,
2032
    new_x_column: str,
2033
    new_y_column: str,
2034
    detector_ranges: list[tuple],
2035
) -> pd.DataFrame | dask.dataframe.DataFrame:
2036
    """Application of the inverse displacement-field to the dataframe coordinates.
2037

2038
    Args:
2039
        df (pd.DataFrame | dask.dataframe.DataFrame): Dataframe to apply the
2040
            distotion correction to.
2041
        dfield (np.ndarray): The distortion correction field. 3D matrix,
2042
            with column and row distortion fields stacked along the first dimension.
2043
        x_column (str): Label of the 'X' source column.
2044
        y_column (str): Label of the 'Y' source column.
2045
        new_x_column (str): Label of the 'X' destination column.
2046
        new_y_column (str): Label of the 'Y' destination column.
2047
        detector_ranges (list[tuple]): tuple of pixel ranges of the detector x/y
2048
            coordinates
2049

2050
    Returns:
2051
        pd.DataFrame | dask.dataframe.DataFrame: dataframe with added columns
2052
    """
2053
    x = df[x_column]
1✔
2054
    y = df[y_column]
1✔
2055

2056
    r_axis_steps = (detector_ranges[0][1] - detector_ranges[0][0]) / dfield[0].shape[0]
1✔
2057
    c_axis_steps = (detector_ranges[1][1] - detector_ranges[1][0]) / dfield[0].shape[1]
1✔
2058

2059
    df[new_x_column], df[new_y_column] = (
1✔
2060
        map_coordinates(dfield[0], (x, y), order=1) * r_axis_steps,
2061
        map_coordinates(dfield[1], (x, y), order=1) * c_axis_steps,
2062
    )
2063
    return df
1✔
2064

2065

2066
def generate_inverse_dfield(
1✔
2067
    rdeform_field: np.ndarray,
2068
    cdeform_field: np.ndarray,
2069
    bin_ranges: list[tuple],
2070
    detector_ranges: list[tuple],
2071
) -> np.ndarray:
2072
    """Generate inverse deformation field using inperpolation with griddata.
2073
    Assuming the binning range of the input ``rdeform_field`` and ``cdeform_field``
2074
    covers the whole detector.
2075

2076
    Args:
2077
        rdeform_field (np.ndarray): Row-wise deformation field.
2078
        cdeform_field (np.ndarray): Column-wise deformation field.
2079
        bin_ranges (list[tuple]): Detector ranges of the binned coordinates.
2080
        detector_ranges (list[tuple]): Ranges of detector coordinates to interpolate to.
2081

2082
    Returns:
2083
        np.ndarray: The calculated inverse deformation field (row/column)
2084
    """
2085
    print(
1✔
2086
        "Calculating inverse deformation field, this might take a moment...",
2087
    )
2088

2089
    # Interpolate to 2048x2048 grid of the detector coordinates
2090
    r_mesh, c_mesh = np.meshgrid(
1✔
2091
        np.linspace(
2092
            detector_ranges[0][0],
2093
            cdeform_field.shape[0],
2094
            detector_ranges[0][1],
2095
            endpoint=False,
2096
        ),
2097
        np.linspace(
2098
            detector_ranges[1][0],
2099
            cdeform_field.shape[1],
2100
            detector_ranges[1][1],
2101
            endpoint=False,
2102
        ),
2103
        sparse=False,
2104
        indexing="ij",
2105
    )
2106

2107
    bin_step = (
1✔
2108
        np.asarray(bin_ranges)[0:2][:, 1] - np.asarray(bin_ranges)[0:2][:, 0]
2109
    ) / cdeform_field.shape
2110
    rc_position = []  # row/column position in c/rdeform_field
1✔
2111
    r_dest = []  # destination pixel row position
1✔
2112
    c_dest = []  # destination pixel column position
1✔
2113
    for i in np.arange(cdeform_field.shape[0]):
1✔
2114
        for j in np.arange(cdeform_field.shape[1]):
1✔
2115
            if not np.isnan(rdeform_field[i, j]) and not np.isnan(
1✔
2116
                cdeform_field[i, j],
2117
            ):
2118
                rc_position.append(
1✔
2119
                    [
2120
                        rdeform_field[i, j] + bin_ranges[0][0] / bin_step[0],
2121
                        cdeform_field[i, j] + bin_ranges[0][0] / bin_step[1],
2122
                    ],
2123
                )
2124
                r_dest.append(
1✔
2125
                    bin_step[0] * i + bin_ranges[0][0],
2126
                )
2127
                c_dest.append(
1✔
2128
                    bin_step[1] * j + bin_ranges[1][0],
2129
                )
2130

2131
    ret = Parallel(n_jobs=2)(
1✔
2132
        delayed(griddata)(np.asarray(rc_position), np.asarray(arg), (r_mesh, c_mesh))
2133
        for arg in [r_dest, c_dest]
2134
    )
2135

2136
    inverse_dfield = np.asarray([ret[0], ret[1]])
1✔
2137

2138
    return inverse_dfield
1✔
2139

2140

2141
def load_dfield(file: str) -> tuple[np.ndarray, np.ndarray]:
1✔
2142
    """Load inverse dfield from file
2143

2144
    Args:
2145
        file (str): Path to file containing the inverse dfield
2146

2147
    Returns:
2148
        tuple[np.ndarray, np.ndarray]: the loaded inverse row and column deformation fields
2149
    """
2150
    rdeform_field: np.ndarray = None
×
2151
    cdeform_field: np.ndarray = None
×
2152

2153
    try:
×
2154
        dfield = np.load(file)
×
2155
        rdeform_field = dfield[0]
×
2156
        cdeform_field = dfield[1]
×
2157

2158
    except FileNotFoundError:
×
2159
        pass
×
2160

2161
    return rdeform_field, cdeform_field
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc