• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 4515862454

pending completion
4515862454

push

github

cylammarco
added remarklint to pre-commit. linted everything.

1884 of 2056 relevant lines covered (91.63%)

3.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.57
/src/rascal/calibrator.py
1
import copy
4✔
2
import itertools
4✔
3
import logging
4✔
4

5
import numpy as np
4✔
6
from scipy.spatial import Delaunay
4✔
7
from scipy.optimize import minimize
4✔
8
from scipy import interpolate
4✔
9
from tqdm.autonotebook import tqdm
4✔
10

11
from .util import _derivative
4✔
12
from .util import gauss
4✔
13

14
from . import plotting
4✔
15
from . import models
4✔
16
from .houghtransform import HoughTransform
4✔
17
from .atlas import Atlas
4✔
18

19

20
class Calibrator:
4✔
21
    def __init__(self, peaks, spectrum=None):
4✔
22
        """
23
        Initialise the calibrator object.
24

25
        Parameters
26
        ----------
27
        peaks: list
28
            List of identified arc line pixel values.
29
        spectrum: list
30
            The spectral intensity as a function of pixel.
31

32
        """
33

34
        self.logger = None
4✔
35
        self.log_level = None
4✔
36

37
        self.peaks = copy.deepcopy(peaks)
4✔
38
        self.spectrum = copy.deepcopy(spectrum)
4✔
39
        self.matplotlib_imported = False
4✔
40
        self.plotly_imported = False
4✔
41
        self.plot_with_matplotlib = False
4✔
42
        self.plot_with_plotly = False
4✔
43
        self.atlas = None
4✔
44
        self.pix_known = None
4✔
45
        self.wave_known = None
4✔
46
        self.hough_lines = None
4✔
47
        self.hough_points = None
4✔
48
        self.ht = HoughTransform()
4✔
49

50
        # calibrator_properties
51
        self.num_pix = None
4✔
52
        self.pixel_list = None
4✔
53
        self.plotting_library = None
4✔
54
        self.constrain_poly = None
4✔
55

56
        # hough_properties
57
        self.num_slopes = None
4✔
58
        self.xbins = None
4✔
59
        self.ybins = None
4✔
60
        self.min_wavelength = None
4✔
61
        self.max_wavelength = None
4✔
62
        self.range_tolerance = None
4✔
63
        self.linearity_tolerance = None
4✔
64

65
        # ransac_properties
66
        self.sample_size = None
4✔
67
        self.top_n_candidate = None
4✔
68
        self.linear = None
4✔
69
        self.filter_close = None
4✔
70
        self.ransac_tolerance = None
4✔
71
        self.candidate_weighted = None
4✔
72
        self.hough_weight = None
4✔
73
        self.minimum_matches = None
4✔
74
        self.minimum_peak_utilisation = None
4✔
75
        self.minimum_fit_error = None
4✔
76

77
        # results
78
        self.matched_peaks = []
4✔
79
        self.matched_atlas = []
4✔
80
        self.fit_coeff = None
4✔
81

82
        self.set_calibrator_properties()
4✔
83
        self.set_hough_properties()
4✔
84
        self.set_ransac_properties()
4✔
85

86
    def _generate_pairs(self):
4✔
87
        """
88
        Generate pixel-wavelength pairs without the allowed regions set by the
89
        linearity limit. This assumes a relatively linear spectrograph.
90

91
        Parameters
92
        ----------
93
        candidate_tolerance: float (default: 10)
94
            toleranceold  (Angstroms) for considering a point to be an inlier
95
            during candidate peak/line selection. This should be reasonable
96
            small as we want to search for candidate points which are
97
            *locally* linear.
98
        constrain_poly: boolean
99
            Apply a polygonal constraint on possible peak/atlas pairs
100

101
        """
102

103
        pairs = [
4✔
104
            pair
105
            for pair in itertools.product(self.peaks, self.atlas.get_lines())
106
        ]
107

108
        if self.constrain_poly:
4✔
109
            # Remove pairs outside polygon
110
            valid_area = Delaunay(
4✔
111
                [
112
                    (0, self.max_intercept + self.candidate_tolerance),
113
                    (0, self.min_intercept - self.candidate_tolerance),
114
                    (
115
                        self.pixel_list.max(),
116
                        self.max_wavelength
117
                        - self.range_tolerance
118
                        - self.candidate_tolerance,
119
                    ),
120
                    (
121
                        self.pixel_list.max(),
122
                        self.max_wavelength
123
                        + self.range_tolerance
124
                        + self.candidate_tolerance,
125
                    ),
126
                ]
127
            )
128

129
            mask = valid_area.find_simplex(pairs) >= 0
4✔
130
            self.pairs = np.array(pairs)[mask]
4✔
131

132
        else:
133
            self.pairs = np.array(pairs)
4✔
134

135
    def _merge_candidates(self, candidates):
4✔
136
        """
137
        Merge two candidate lists.
138

139
        Parameters
140
        ----------
141
        candidates: list
142
            list containing pixel-wavelength pairs.
143

144
        """
145

146
        merged = []
4✔
147

148
        for pairs in candidates:
4✔
149
            for pair in np.array(pairs).T:
4✔
150
                merged.append(pair)
4✔
151

152
        return np.sort(np.array(merged))
4✔
153

154
    def _get_most_common_candidates(
4✔
155
        self, candidates, top_n_candidate, weighted
156
    ):
157
        """
158
        Takes a number of candidate pair sets and returns the most common
159
        pair for each wavelength
160

161
        Parameters
162
        ----------
163
        candidates: list of list(float, float)
164
            A list of list of peak/line pairs
165
        top_n_candidate: int
166
            Top ranked lines to be fitted.
167
        weighted: boolean
168
            If True, the distance from the atlas wavelength will be used to
169
            compute the probilitiy based on how far it is from the Gaussian
170
            distribution from the known line.
171

172
        """
173

174
        peaks = []
4✔
175
        wavelengths = []
4✔
176
        probabilities = []
4✔
177

178
        for candidate in candidates:
4✔
179
            peaks.extend(candidate[0])
4✔
180
            wavelengths.extend(candidate[1])
4✔
181
            probabilities.extend(candidate[2])
4✔
182

183
        peaks = np.array(peaks)
4✔
184
        wavelengths = np.array(wavelengths)
4✔
185
        probabilities = np.array(probabilities)
4✔
186

187
        out_peaks = []
4✔
188
        out_wavelengths = []
4✔
189

190
        for peak in np.unique(peaks):
4✔
191
            idx = np.where(peaks == peak)
4✔
192

193
            if len(idx) > 0:
4✔
194
                wavelengths_matched = wavelengths[idx]
4✔
195

196
                if weighted:
4✔
197
                    counts = probabilities[idx]
4✔
198

199
                else:
200
                    counts = np.ones_like(probabilities[idx])
×
201

202
                n = int(
4✔
203
                    min(top_n_candidate, len(np.unique(wavelengths_matched)))
204
                )
205

206
                unique_wavelengths = np.unique(wavelengths_matched)
4✔
207
                aggregated_count = np.zeros_like(unique_wavelengths)
4✔
208
                for j, w in enumerate(unique_wavelengths):
4✔
209
                    idx_j = np.where(wavelengths_matched == w)
4✔
210
                    aggregated_count[j] = np.sum(counts[idx_j])
4✔
211

212
                out_peaks.extend([peak] * n)
4✔
213
                out_wavelengths.extend(
4✔
214
                    wavelengths_matched[np.argsort(-aggregated_count)[:n]]
215
                )
216

217
        return out_peaks, out_wavelengths
4✔
218

219
    def _get_candidate_points_linear(self, candidate_tolerance):
4✔
220
        """
221
        Returns a list of peak/wavelengths pairs which agree with the fit
222

223
        (wavelength - gradient * x + intercept) < tolerance
224

225
        Note: depending on the candidate_tolerance , one peak may match with
226
        multiple wavelengths.
227

228
        Parameters
229
        ----------
230
        candidate_tolerance: float (default: 10)
231
            tolerance  (Angstroms) for considering a point to be an inlier
232
            during candidate peak/line selection. This should be reasonable
233
            small as we want to search for candidate points which are
234
            *locally* linear.
235

236
        """
237

238
        # Locate candidate points for these lines fits
239
        self.candidates = []
4✔
240

241
        for line in self.hough_lines:
4✔
242
            gradient, intercept = line
4✔
243

244
            predicted = gradient * self.pairs[:, 0] + intercept
4✔
245
            actual = self.pairs[:, 1]
4✔
246
            diff = np.abs(predicted - actual)
4✔
247
            mask = diff <= candidate_tolerance
4✔
248

249
            # Match the range_tolerance to 1.1775 s.d. to match the FWHM
250
            # Note that the pairs outside of the range_tolerance were already
251
            # removed in an earlier stage
252
            weight = gauss(
4✔
253
                actual[mask],
254
                1.0,
255
                predicted[mask],
256
                (self.range_tolerance + self.linearity_tolerance) * 1.1775,
257
            )
258

259
            self.candidates.append(
4✔
260
                (self.pairs[:, 0][mask], actual[mask], weight)
261
            )
262

263
    def _get_candidate_points_poly(self, candidate_tolerance):
4✔
264
        """
265
        **EXPERIMENTAL**
266

267
        Returns a list of peak/wavelengths pairs which agree with the fit
268

269
        (wavelength - gradient * x + intercept) < tolerance
270

271
        Note: depending on the candidate_tolerance, one peak may
272
        match with multiple wavelengths.
273

274
        Parameters
275
        ----------
276
        candidate_tolerance: float (default: 10)
277
            toleranceold  (Angstroms) for considering a point to be an inlier
278
            during candidate peak/line selection. This should be reasonable
279
            small as we want to search for candidate points which are
280
            *locally* linear.
281

282
        """
283

284
        if self.fit_coeff is None:
4✔
285
            raise ValueError(
×
286
                "A guess solution for a polynomial fit has to "
287
                "be provided as fit_coeff in fit() in order to generate "
288
                "candidates for RANSAC sampling."
289
            )
290

291
        self.candidates = []
4✔
292

293
        # actual wavelengths
294
        actual = np.array(self.atlas.get_lines())
4✔
295

296
        n = len(self.hough_lines)
4✔
297

298
        delta = (
4✔
299
            np.random.random(n) * self.range_tolerance * 2.0
300
            - self.range_tolerance
301
        )
302

303
        for d in delta:
4✔
304
            # predicted wavelength
305
            predicted = self.polyval(self.peaks, self.fit_coeff) + d
4✔
306
            diff = np.abs(actual - predicted)
4✔
307
            mask = diff < candidate_tolerance
4✔
308

309
            if np.sum(mask) > 0:
4✔
310
                weight = gauss(
4✔
311
                    actual[mask], 1.0, predicted[mask], self.range_tolerance
312
                )
313
                self.candidates.append(
4✔
314
                    [self.peaks[mask], actual[mask], weight]
315
                )
316

317
    def _match_bijective(self, candidates, peaks, fit_coeff):
4✔
318
        """
319

320
        Internal function used to return a list of inliers with a
321
        one-to-one relationship between peaks and wavelengths. This
322
        is critical as often we have several potential candidate lines
323
        for each peak. This function first iterates through each peak
324
        and selects the wavelength with the smallest error. It then
325
        iterates through this list and does the same for duplicate
326
        wavelengths.
327

328
        parameters
329
        ----------
330
        candidates: dict
331
            match candidates, internal to ransac
332

333
        peaks: list
334
            list of peaks [px]
335

336
        fit_coeff: list
337
            polynomial fit coefficients
338

339
        """
340

341
        err = []
4✔
342
        matched_x = []
4✔
343
        matched_y = []
4✔
344

345
        for peak in peaks:
4✔
346
            fit = self.polyval(peak, fit_coeff)
4✔
347

348
            # Get closest match for this peak
349
            errs = np.abs(fit - candidates[peak])
4✔
350
            idx = np.argmin(errs)
4✔
351

352
            err.append(errs[idx])
4✔
353
            matched_x.append(peak)
4✔
354
            matched_y.append(candidates[peak][idx])
4✔
355

356
        err = np.array(err)
4✔
357
        matched_x = np.array(matched_x)
4✔
358
        matched_y = np.array(matched_y)
4✔
359

360
        # Now we also need to resolve duplicate y's
361
        filtered_x = []
4✔
362
        filtered_y = []
4✔
363
        filtered_err = []
4✔
364

365
        for wavelength in np.unique(matched_y):
4✔
366
            mask = matched_y == wavelength
4✔
367
            filtered_y.append(wavelength)
4✔
368

369
            err_idx = np.argmin(err[mask])
4✔
370
            filtered_x.append(matched_x[mask][err_idx])
4✔
371
            filtered_err.append(err[mask][err_idx])
4✔
372

373
        # overwrite
374
        err = np.array(filtered_err)
4✔
375
        matched_x = np.array(filtered_x)
4✔
376
        matched_y = np.array(filtered_y)
4✔
377

378
        assert len(np.unique(matched_x)) == len(np.unique(matched_y))
4✔
379

380
        return err, matched_x, matched_y
4✔
381

382
    def _solve_candidate_ransac(
4✔
383
        self,
384
        fit_deg,
385
        fit_coeff,
386
        max_tries,
387
        candidate_tolerance,
388
        brute_force,
389
        progress,
390
    ):
391
        """
392
        Use RANSAC to sample the parameter space and give best guess
393

394
        Parameters
395
        ----------
396
        fit_deg: int
397
            The order of polynomial.
398
        fit_coeff: None or 1D numpy array
399
            Initial polynomial fit fit_coefficients.
400
        max_tries: int
401
            Number of trials of polynomial fitting.
402
        candidate_tolerance: float
403
            toleranceold  (Angstroms) for considering a point to be an inlier
404
            during candidate peak/line selection. This should be reasonable
405
            small as we want to search for candidate points which are
406
            *locally* linear.
407
        brute_force: boolean
408
            Solve all pixel-wavelength combinations with set to True.
409
        progress: boolean
410
            Show the progress bar with tdqm if set to True.
411

412
        Returns
413
        -------
414
        best_p: list
415
            A list of size fit_deg of the best fit polynomial
416
            fit_coefficient.
417
        best_err: float
418
            Arithmetic mean of the residuals.
419
        sum(best_inliers): int
420
            Number of lines fitted within the ransac_tolerance.
421
        valid_solution: boolean
422
            False if overfitted.
423

424
        """
425

426
        if self.linear:
4✔
427
            self._get_candidate_points_linear(candidate_tolerance)
4✔
428

429
        else:
430
            self._get_candidate_points_poly(candidate_tolerance)
4✔
431

432
        (
4✔
433
            self.candidate_peak,
434
            self.candidate_arc,
435
        ) = self._get_most_common_candidates(
436
            self.candidates,
437
            top_n_candidate=self.top_n_candidate,
438
            weighted=self.candidate_weighted,
439
        )
440

441
        self.fit_deg = fit_deg
4✔
442

443
        valid_solution = False
4✔
444
        best_p = None
4✔
445
        best_cost = 1e50
4✔
446
        best_err = 1e50
4✔
447
        best_mask = [False]
4✔
448
        best_residual = None
4✔
449
        best_inliers = 0
4✔
450

451
        # Note that there may be multiple matches for
452
        # each peak, that is len(x) > len(np.unique(x))
453
        x = np.array(self.candidate_peak)
4✔
454
        y = np.array(self.candidate_arc)
4✔
455

456
        # Filter close wavelengths
457
        if self.filter_close:
4✔
458
            unique_y = np.unique(y)
4✔
459
            idx = np.argwhere(
4✔
460
                unique_y[1:] - unique_y[0:-1] < 3 * self.ransac_tolerance
461
            )
462
            separation_mask = np.argwhere((y == unique_y[idx]).sum(0) == 0)
4✔
463
            y = y[separation_mask].flatten()
4✔
464
            x = x[separation_mask].flatten()
4✔
465

466
        # If the number of lines is smaller than the number of degree of
467
        # polynomial fit, return failed fit.
468
        if len(np.unique(x)) <= self.fit_deg:
4✔
469
            return (best_p, best_err, sum(best_mask), 0, False)
×
470

471
        # Brute force check all combinations. If the request sample_size is
472
        # the same or larger than the available lines, it is essentially a
473
        # brute force.
474
        if brute_force or (self.sample_size >= len(np.unique(x))):
4✔
475
            idx = range(len(x))
×
476
            sampler = itertools.combinations(idx, self.sample_size)
×
477
            self.sample_size = len(np.unique(x))
×
478

479
        else:
480
            sampler = range(int(max_tries))
4✔
481

482
        if progress:
4✔
483
            sampler_list = tqdm(sampler)
4✔
484

485
        else:
486
            sampler_list = sampler
×
487

488
        peaks = np.sort(np.unique(x))
4✔
489
        idx = range(len(peaks))
4✔
490

491
        # Build a key(pixel)-value(wavelength) dictionary from the candidates
492
        candidates = {}
4✔
493

494
        for p in np.unique(x):
4✔
495
            candidates[p] = y[x == p]
4✔
496

497
        if self.ht.xedges is not None:
4✔
498
            xbin_size = (self.ht.xedges[1] - self.ht.xedges[0]) / 2.0
4✔
499
            ybin_size = (self.ht.yedges[1] - self.ht.yedges[0]) / 2.0
4✔
500

501
            if np.isfinite(self.hough_weight):
4✔
502
                twoditp = interpolate.RectBivariateSpline(
4✔
503
                    self.ht.xedges[1:] - xbin_size,
504
                    self.ht.yedges[1:] - ybin_size,
505
                    self.ht.hist,
506
                )
507

508
        else:
509
            twoditp = None
×
510

511
        # Calculate initial error given pre-existing fit
512
        if fit_coeff is not None:
4✔
513
            err, _, _ = self._match_bijective(candidates, peaks, fit_coeff)
4✔
514
            best_cost = sum(err)
4✔
515
            best_err = np.sqrt(np.mean(err**2.0))
4✔
516

517
        # The histogram is fixed, so pre-computed outside the loop
518
        if not brute_force:
4✔
519
            # weight the probability of choosing the sample by the inverse
520
            # line density
521
            h = np.histogram(peaks, bins=10)
4✔
522
            prob = 1.0 / h[0][np.digitize(peaks, h[1], right=True) - 1]
4✔
523
            prob = prob / np.sum(prob)
4✔
524

525
        for sample in sampler_list:
4✔
526
            keep_trying = True
4✔
527
            self.logger.debug(sample)
4✔
528

529
            while keep_trying:
4✔
530
                stop_n_candidateow = False
4✔
531

532
                if brute_force:
4✔
533
                    x_hat = x[[sample]]
×
534
                    y_hat = y[[sample]]
×
535

536
                else:
537
                    # Pick some random peaks
538
                    x_hat = np.random.choice(
4✔
539
                        peaks, self.sample_size, replace=False, p=prob
540
                    )
541
                    y_hat = []
4✔
542

543
                    # Pick a random wavelength for this x
544
                    for _x in x_hat:
4✔
545
                        y_choice = candidates[_x]
4✔
546

547
                        # Avoid picking a y that's already associated with
548
                        # another x
549
                        if not set(y_choice).issubset(set(y_hat)):
4✔
550
                            y_temp = np.random.choice(y_choice)
4✔
551

552
                            while y_temp in y_hat:
4✔
553
                                y_temp = np.random.choice(y_choice)
4✔
554

555
                            y_hat.append(y_temp)
4✔
556

557
                        else:
558
                            self.logger.debug(
4✔
559
                                "Not possible to draw a unique "
560
                                "set of atlas wavelengths."
561
                            )
562
                            stop_n_candidateow = True
4✔
563
                            break
4✔
564

565
                if stop_n_candidateow:
4✔
566
                    break
4✔
567

568
                # insert user given known pairs
569
                if self.pix_known is not None:
4✔
570
                    x_hat = np.concatenate((x_hat, self.pix_known))
×
571
                    y_hat = np.concatenate((y_hat, self.wave_known))
×
572

573
                # Try to fit the data.
574
                # This doesn't need to be robust, it's an exact fit.
575
                fit_coeffs = self.polyfit(x_hat, y_hat, self.fit_deg)
4✔
576

577
                # Check the intercept.
578
                if (fit_coeffs[0] < self.min_intercept) | (
4✔
579
                    fit_coeffs[0] > self.max_intercept
580
                ):
581
                    self.logger.debug("Intercept exceeds bounds.")
4✔
582
                    continue
4✔
583

584
                # Check monotonicity.
585
                pix_min = peaks[0] - np.ptp(peaks) * 0.2
4✔
586
                pix_max = peaks[-1] + np.ptp(peaks) * 0.2
4✔
587
                self.logger.debug((pix_min, pix_max))
4✔
588

589
                if not np.all(
4✔
590
                    np.diff(
591
                        self.polyval(
592
                            np.arange(pix_min, pix_max, 1), fit_coeffs
593
                        )
594
                    )
595
                    > 0
596
                ):
597
                    self.logger.debug(
4✔
598
                        "Solution is not monotonically increasing."
599
                    )
600
                    continue
4✔
601

602
                # Compute error and filter out many-to-one matches
603
                err, matched_x, matched_y = self._match_bijective(
4✔
604
                    candidates, peaks, fit_coeffs
605
                )
606

607
                if len(matched_x) == 0:
4✔
608
                    continue
×
609

610
                # M-SAC Estimator (Torr and Zisserman, 1996)
611
                err[err > self.ransac_tolerance] = self.ransac_tolerance
4✔
612

613
                # use the Hough space density as weights for the cost function
614
                wave = self.polyval(self.pixel_list, fit_coeffs)
4✔
615
                gradient = self.polyval(
4✔
616
                    self.pixel_list, _derivative(fit_coeffs)
617
                )
618
                intercept = wave - gradient * self.pixel_list
4✔
619

620
                # modified cost function weighted by the Hough space density
621
                if (self.hough_weight is not None) & (twoditp is not None):
4✔
622
                    weight = self.hough_weight * np.sum(
4✔
623
                        twoditp(intercept, gradient, grid=False)
624
                    )
625

626
                else:
627
                    weight = 1.0
×
628

629
                cost = (
4✔
630
                    sum(err)
631
                    / (len(err) - len(fit_coeffs) + 1)
632
                    / (weight + 1e-9)
633
                )
634

635
                # If this is potentially a new best fit, then handle that first
636
                if cost <= best_cost:
4✔
637
                    # reject lines outside the rms limit (ransac_tolerance)
638
                    # TODO: should n_inliers be recalculated from the robust
639
                    # fit?
640
                    mask = err < self.ransac_tolerance
4✔
641
                    n_inliers = sum(mask)
4✔
642
                    matched_peaks = matched_x[mask]
4✔
643
                    matched_atlas = matched_y[mask]
4✔
644

645
                    if len(matched_peaks) <= self.fit_deg:
4✔
646
                        self.logger.debug(
4✔
647
                            "Too few good candidates for fitting."
648
                        )
649
                        continue
4✔
650

651
                    # Now we do a robust fit
652
                    try:
4✔
653
                        coeffs = models.robust_polyfit(
4✔
654
                            matched_peaks, matched_atlas, self.fit_deg
655
                        )
656

657
                    except np.linalg.LinAlgError:
×
658
                        self.logger.warning(
×
659
                            "Linear algebra error in robust fit"
660
                        )
661
                        continue
×
662

663
                    # Get the residual of the fit
664
                    residual = (
4✔
665
                        self.polyval(matched_peaks, coeffs) - matched_atlas
666
                    )
667
                    residual[
4✔
668
                        np.abs(residual) > self.ransac_tolerance
669
                    ] = self.ransac_tolerance
670

671
                    rms_residual = np.sqrt(np.mean(residual**2))
4✔
672

673
                    # Make sure that we don't accept fits with zero error
674
                    if rms_residual < self.minimum_fit_error:
4✔
675
                        self.logger.debug(
4✔
676
                            "Fit error too small, " "{:1.2f}.".format(best_err)
677
                        )
678

679
                        continue
4✔
680

681
                    # Check that we have enough inliers based on user specified
682
                    # constraints
683

684
                    if n_inliers < self.minimum_matches:
4✔
685
                        self.logger.debug(
4✔
686
                            "Not enough matched peaks for valid solution, "
687
                            "user specified {}.".format(self.minimum_matches)
688
                        )
689
                        continue
4✔
690

691
                    if n_inliers < self.minimum_peak_utilisation * len(
4✔
692
                        self.peaks
693
                    ):
694
                        self.logger.debug(
×
695
                            "Not enough matched peaks for valid solution, "
696
                            "user specified {:1.2f} %.".format(
697
                                100 * self.minimum_matches
698
                            )
699
                        )
700
                        continue
×
701

702
                    # If the best fit is accepted, update the lists
703
                    best_cost = cost
4✔
704
                    best_inliers = n_inliers
4✔
705
                    best_p = coeffs
4✔
706
                    best_err = rms_residual
4✔
707
                    best_residual = residual
4✔
708
                    self.matched_peaks = list(copy.deepcopy(matched_peaks))
4✔
709
                    self.matched_atlas = list(copy.deepcopy(matched_atlas))
4✔
710

711
                    # Sanity check that matching peaks/atlas lines are 1:1
712
                    assert len(np.unique(self.matched_peaks)) == len(
4✔
713
                        self.matched_peaks
714
                    )
715
                    assert len(np.unique(self.matched_atlas)) == len(
4✔
716
                        self.matched_atlas
717
                    )
718
                    assert len(np.unique(self.matched_atlas)) == len(
4✔
719
                        np.unique(self.matched_peaks)
720
                    )
721

722
                    if progress:
4✔
723
                        sampler_list.set_description(
4✔
724
                            "Most inliers: {:d}, "
725
                            "best error: {:1.4f}".format(
726
                                best_inliers, best_err
727
                            )
728
                        )
729

730
                    # Break early if all peaks are matched
731
                    if best_inliers == len(peaks):
4✔
732
                        break
4✔
733

734
                # If we got this far, then we can continue to the next sample
735
                keep_trying = False
4✔
736

737
        # Overfit check
738
        if best_inliers <= self.fit_deg + 1:
4✔
739
            valid_solution = False
×
740

741
        else:
742
            valid_solution = True
4✔
743

744
        # If we totally failed then this can be empty
745
        assert best_inliers == len(self.matched_peaks)
4✔
746
        assert best_inliers == len(self.matched_atlas)
4✔
747

748
        assert len(self.matched_atlas) == len(set(self.matched_atlas))
4✔
749

750
        self.logger.info("Found: {}".format(best_inliers))
4✔
751

752
        return best_p, best_err, best_residual, best_inliers, valid_solution
4✔
753

754
    def _adjust_polyfit(self, delta, fit, tolerance, min_frac):
4✔
755
        """
756
        **EXPERIMENTAL**
757

758
        Parameters
759
        ----------
760
        delta: list or numpy.ndarray
761
            The first n polynomial coefficients to be shifted by delta.
762
        fit: list or numpy.ndarray
763
            The polynomial coefficients.
764
        tolerance: float
765
            The maximum difference between fit and atlas to be accounted for
766
            the best fit.
767
        min_frac: float
768
            The minimum fraction of lines to be used.
769

770
        Return
771
        ------
772
        lsq: float
773
            The least squared value of the fit.
774

775
        """
776

777
        # x is wavelength
778
        # x_matched is pixel
779
        x_matched = []
4✔
780
        # y_matched is wavelength
781
        y_matched = []
4✔
782
        fit_new = fit.copy()
4✔
783

784
        atlas_lines = self.atlas.get_lines()
4✔
785

786
        for i, d in enumerate(delta):
4✔
787
            fit_new[i] += d
4✔
788

789
        for p in self.peaks:
4✔
790
            x = self.polyval(p, fit_new)
4✔
791
            diff = atlas_lines - x
4✔
792
            diff_abs = np.abs(diff)
4✔
793
            idx = np.argmin(diff_abs)
4✔
794

795
            if diff_abs[idx] < tolerance:
4✔
796
                x_matched.append(p)
4✔
797
                y_matched.append(atlas_lines[idx])
4✔
798

799
        x_matched = np.array(x_matched)
4✔
800
        y_matched = np.array(y_matched)
4✔
801

802
        dof = len(x_matched) - len(fit_new) - 1
4✔
803

804
        if dof < 1:
4✔
805
            return np.inf
×
806

807
        if len(x_matched) < len(self.peaks) * min_frac:
4✔
808
            return np.inf
×
809

810
        if not np.all(
4✔
811
            np.diff(self.polyval(np.sort(self.pixel_list), fit_new)) > 0
812
        ):
813
            self.logger.info("not monotonic")
×
814
            return np.inf
×
815

816
        lsq = (
4✔
817
            np.sum((y_matched - self.polyval(x_matched, fit_new)) ** 2.0) / dof
818
        )
819

820
        return lsq
4✔
821

822
    def which_plotting_library(self):
4✔
823
        """
824
        Call to show if the Calibrator is using matplotlib or plotly library
825
        (or neither).
826

827
        """
828

829
        if self.plot_with_matplotlib:
4✔
830
            self.logger.info("Using matplotlib.")
4✔
831
            return "matplotlib"
4✔
832

833
        elif self.plot_with_plotly:
4✔
834
            self.logger.info("Using plotly.")
4✔
835
            return "plotly"
4✔
836

837
        else:
838
            self.logger.warning("Neither maplotlib nor plotly are imported.")
×
839
            return None
×
840

841
    def use_matplotlib(self):
4✔
842
        """
843
        Call to switch to matplotlib.
844

845
        """
846

847
        self.plot_with_matplotlib = True
4✔
848
        self.plot_with_plotly = False
4✔
849

850
    def use_plotly(self):
4✔
851
        """
852
        Call to switch to plotly.
853

854
        """
855

856
        self.plot_with_plotly = True
4✔
857
        self.plot_with_matplotlib = False
4✔
858

859
    def set_calibrator_properties(
4✔
860
        self,
861
        num_pix=None,
862
        pixel_list=None,
863
        plotting_library=None,
864
        seed=None,
865
        logger_name="Calibrator",
866
        log_level="warning",
867
    ):
868
        """
869
        Initialise the calibrator object.
870

871
        Parameters
872
        ----------
873
        num_pix: int
874
            Number of pixels in the spectral axis.
875
        pixel_list: list
876
            pixel value of the of the spectrum, this is only needed if the
877
            spectrum spans multiple detector arrays.
878
        plotting_library: string (default: 'matplotlib')
879
            Choose between matplotlib and plotly.
880
        seed: int
881
            Set an optional seed for random number generators. If used,
882
            this parameter must be set prior to calling RANSAC. Useful
883
            for deterministic debugging.
884
        logger_name: string (default: 'Calibrator')
885
            The name of the logger. It can use an existing logger if a
886
            matching name is provided.
887
        log_level: string (default: 'info')
888
            Choose {critical, error, warning, info, debug, notset}.
889

890
        """
891

892
        # initialise the logger
893
        self.logger = logging.getLogger(logger_name)
4✔
894
        self.logger.propagate = False
4✔
895
        level = logging.getLevelName(log_level.upper())
4✔
896
        self.logger.setLevel(level)
4✔
897
        self.log_level = level
4✔
898

899
        formatter = logging.Formatter(
4✔
900
            "[%(asctime)s] %(levelname)s [%(filename)s:%(lineno)d] "
901
            "%(message)s",
902
            datefmt="%a, %d %b %Y %H:%M:%S",
903
        )
904

905
        if len(self.logger.handlers) == 0:
4✔
906
            handler = logging.StreamHandler()
4✔
907
            handler.setFormatter(formatter)
4✔
908
            self.logger.addHandler(handler)
4✔
909

910
        # set the num_pix
911
        if num_pix is not None:
4✔
912
            self.num_pix = num_pix
4✔
913

914
        elif self.num_pix is None:
4✔
915
            try:
4✔
916
                self.num_pix = len(self.spectrum)
4✔
917

918
            except Exception as e:
4✔
919
                self.logger.warning(e)
4✔
920
                self.logger.warning(
4✔
921
                    "Neither num_pix nor spectrum is given, "
922
                    "it uses 1.1 times max(peaks) as the "
923
                    "maximum pixel value."
924
                )
925
                self.num_pix = 1.1 * max(self.peaks)
4✔
926

927
        else:
928
            pass
2✔
929

930
        self.logger.info("num_pix is set to {}.".format(num_pix))
4✔
931

932
        # set the pixel_list
933
        if pixel_list is not None:
4✔
934
            self.pixel_list = np.asarray(pixel_list)
4✔
935

936
        elif self.pixel_list is None:
4✔
937
            self.pixel_list = np.arange(self.num_pix)
4✔
938

939
        else:
940
            pass
2✔
941

942
        self.logger.info("pixel_list is set to {}.".format(pixel_list))
4✔
943

944
        # map the list position to the pixel value
945
        self.pix_to_rawpix = interpolate.interp1d(
4✔
946
            self.pixel_list,
947
            np.arange(len(self.pixel_list)),
948
            fill_value="extrapolate",
949
        )
950

951
        if seed is not None:
4✔
952
            np.random.seed(seed)
×
953

954
        # if the plotting library is supplied
955
        if plotting_library is not None:
4✔
956
            # set the plotting library
957
            self.plotting_library = plotting_library
×
958

959
        # if the plotting library is not supplied but the calibrator does not
960
        # know which library to use yet.
961
        elif self.plotting_library is None:
4✔
962
            self.plotting_library = "matplotlib"
4✔
963

964
        # everything is good
965
        else:
966
            pass
2✔
967

968
        # check the choice of plotting library is available and used.
969
        if self.plotting_library == "matplotlib":
4✔
970
            self.use_matplotlib()
4✔
971
            self.logger.info("Plotting with matplotlib.")
4✔
972

973
        elif self.plotting_library == "plotly":
×
974
            self.use_plotly()
×
975
            self.logger.info("Plotting with plotly.")
×
976

977
        else:
978
            self.logger.warning(
×
979
                "Unknown plotting_library, please choose from "
980
                "matplotlib or plotly. Execute use_matplotlib() or "
981
                "use_plotly() to manually select the library."
982
            )
983

984
    def set_hough_properties(
4✔
985
        self,
986
        num_slopes=None,
987
        xbins=None,
988
        ybins=None,
989
        min_wavelength=None,
990
        max_wavelength=None,
991
        range_tolerance=None,
992
        linearity_tolerance=None,
993
    ):
994
        """
995
        parameters
996
        ----------
997
        num_slopes: int (default: 1000)
998
            Number of slopes to consider during Hough transform
999
        xbins: int (default: 50)
1000
            Number of bins for Hough accumulation
1001
        ybins: int (default: 50)
1002
            Number of bins for Hough accumulation
1003
        min_wavelength: float (default: 3000)
1004
            Minimum wavelength of the spectrum.
1005
        max_wavelength: float (default: 9000)
1006
            Maximum wavelength of the spectrum.
1007
        range_tolerance: float (default: 500)
1008
            Estimation of the error on the provided spectral range
1009
            e.g. 3000-5000 with tolerance 500 will search for
1010
            solutions that may satisfy 2500-5500
1011
        linearity_tolerance: float (default: 100)
1012
            A toleranceold (Ansgtroms) which defines some padding around the
1013
            range tolerance to allow for non-linearity. This should be the
1014
            maximum expected excursion from linearity.
1015

1016
        """
1017

1018
        # set the num_slopes
1019
        if num_slopes is not None:
4✔
1020
            self.num_slopes = int(num_slopes)
4✔
1021

1022
        elif self.num_slopes is None:
4✔
1023
            self.num_slopes = 2000
4✔
1024

1025
        else:
1026
            pass
2✔
1027

1028
        # set the xbins
1029
        if xbins is not None:
4✔
1030
            self.xbins = xbins
4✔
1031

1032
        elif self.xbins is None:
4✔
1033
            self.xbins = 100
4✔
1034

1035
        else:
1036
            pass
2✔
1037

1038
        # set the ybins
1039
        if ybins is not None:
4✔
1040
            self.ybins = ybins
4✔
1041

1042
        elif self.ybins is None:
4✔
1043
            self.ybins = 100
4✔
1044

1045
        else:
1046
            pass
2✔
1047

1048
        # set the min_wavelength
1049
        if min_wavelength is not None:
4✔
1050
            self.min_wavelength = min_wavelength
4✔
1051

1052
        elif self.min_wavelength is None:
4✔
1053
            self.min_wavelength = 3000.0
4✔
1054

1055
        else:
1056
            pass
×
1057

1058
        # set the max_wavelength
1059
        if max_wavelength is not None:
4✔
1060
            self.max_wavelength = max_wavelength
4✔
1061

1062
        elif self.max_wavelength is None:
4✔
1063
            self.max_wavelength = 9000.0
4✔
1064

1065
        else:
1066
            pass
×
1067

1068
        # Set the range_tolerance
1069
        if range_tolerance is not None:
4✔
1070
            self.range_tolerance = range_tolerance
4✔
1071

1072
        elif self.range_tolerance is None:
4✔
1073
            self.range_tolerance = 500
4✔
1074

1075
        else:
1076
            pass
×
1077

1078
        # Set the linearity_tolerance
1079
        if linearity_tolerance is not None:
4✔
1080
            self.linearity_tolerance = linearity_tolerance
×
1081

1082
        elif self.linearity_tolerance is None:
4✔
1083
            self.linearity_tolerance = 100
4✔
1084

1085
        else:
1086
            pass
2✔
1087

1088
        # Start wavelength in the spectrum, +/- some tolerance
1089
        self.min_intercept = self.min_wavelength - self.range_tolerance
4✔
1090
        self.max_intercept = self.min_wavelength + self.range_tolerance
4✔
1091

1092
        self.min_slope = (
4✔
1093
            (
1094
                self.max_wavelength
1095
                - self.range_tolerance
1096
                - self.linearity_tolerance
1097
            )
1098
            - (
1099
                self.min_intercept
1100
                + self.range_tolerance
1101
                + self.linearity_tolerance
1102
            )
1103
        ) / np.ptp(self.pixel_list)
1104

1105
        self.max_slope = (
4✔
1106
            (
1107
                self.max_wavelength
1108
                + self.range_tolerance
1109
                + self.linearity_tolerance
1110
            )
1111
            - (
1112
                self.min_intercept
1113
                - self.range_tolerance
1114
                - self.linearity_tolerance
1115
            )
1116
        ) / np.ptp(self.pixel_list)
1117

1118
        if self.atlas is not None:
4✔
1119
            self._generate_pairs()
×
1120

1121
    def set_ransac_properties(
4✔
1122
        self,
1123
        sample_size=None,
1124
        top_n_candidate=None,
1125
        linear=None,
1126
        filter_close=None,
1127
        ransac_tolerance=None,
1128
        candidate_weighted=None,
1129
        hough_weight=None,
1130
        minimum_matches=None,
1131
        minimum_peak_utilisation=None,
1132
        minimum_fit_error=None,
1133
    ):
1134
        """
1135
        Configure the Calibrator. This may require some manual twiddling before
1136
        the calibrator can work efficiently. However, in theory, a large
1137
        max_tries in fit() should provide a good solution in the expense of
1138
        performance (minutes instead of seconds).
1139

1140
        Parameters
1141
        ----------
1142
        sample_size: int (default: 5)
1143
            Number of samples used for fitting, this is automatically
1144
            set to the polynomial degree + 1, but a larger value can
1145
            be specified here.
1146
        top_n_candidate: int (default: 5)
1147
            Top ranked lines to be fitted.
1148
        linear: boolean (default: True)
1149
            True to use the hough transformed gradient, otherwise, use the
1150
            known polynomial.
1151
        filter_close: boolean (default: False)
1152
            Remove the pairs that are out of bounds in the hough space.
1153
        ransac_tolerance: float (default: 1)
1154
            The distance criteria  (Angstroms) to be considered an inlier to a
1155
            fit. This should be close to the size of the expected residuals on
1156
            the final fit (e.g. 1A is typical)
1157
        candidate_weighted: boolean (default: True)
1158
            Set to True to down-weight pairs that are far from the fit.
1159
        hough_weight: float or None (default: 1.0)
1160
            Set to use the hough space to weigh the fit. The theoretical
1161
            optimal weighting is unclear. The larger the value, the heavily it
1162
            relies on the overdensity in the hough space for a good fit.
1163
        minimum_matches: int or None (default: 0)
1164
            Set to only accept fit solutions with a minimum number of
1165
            matches. Setting this will prevent the fitting function from
1166
            accepting spurious low-error fits.
1167
        minimum_peak_utilisation: int or None (default: 0)
1168
            Set to only accept fit solutions with a fraction of matches. This
1169
            option is convenient if you don't want to specify an absolute
1170
            number of atlas lines. Range is 0 - 1 inclusive.
1171
        minimum_fit_error: float or None (default: 1e-4)
1172
            Set to only accept fits with a minimum error. This avoids
1173
            accepting "perfect" fit solutions with zero errors. However
1174
            if you have an extremely good system, you may want to set this
1175
            tolerance lower.
1176

1177
        """
1178

1179
        # Setting the sample_size
1180
        if sample_size is not None:
4✔
1181
            self.sample_size = sample_size
4✔
1182

1183
        elif self.sample_size is None:
4✔
1184
            self.sample_size = 5
4✔
1185

1186
        else:
1187
            pass
2✔
1188

1189
        # Set top_n_candidate
1190
        if top_n_candidate is not None:
4✔
1191
            self.top_n_candidate = top_n_candidate
4✔
1192

1193
        elif self.top_n_candidate is None:
4✔
1194
            self.top_n_candidate = 5
4✔
1195

1196
        else:
1197
            pass
2✔
1198

1199
        # Set linear
1200
        if linear is not None:
4✔
1201
            self.linear = linear
4✔
1202

1203
        elif self.linear is None:
4✔
1204
            self.linear = True
4✔
1205

1206
        else:
1207
            pass
2✔
1208

1209
        # Set to filter closely spaced lines
1210
        if filter_close is not None:
4✔
1211
            self.filter_close = filter_close
4✔
1212

1213
        elif self.filter_close is None:
4✔
1214
            self.filter_close = False
4✔
1215

1216
        else:
1217
            pass
2✔
1218

1219
        # Set the ransac_tolerance
1220
        if ransac_tolerance is not None:
4✔
1221
            self.ransac_tolerance = ransac_tolerance
×
1222

1223
        elif self.ransac_tolerance is None:
4✔
1224
            self.ransac_tolerance = 5
4✔
1225

1226
        else:
1227
            pass
2✔
1228

1229
        # Set to weigh the candidate pairs by the density (pixel)
1230
        if candidate_weighted is not None:
4✔
1231
            self.candidate_weighted = candidate_weighted
×
1232

1233
        elif self.candidate_weighted is None:
4✔
1234
            self.candidate_weighted = True
4✔
1235

1236
        else:
1237
            pass
2✔
1238

1239
        # Set the multiplier of the weight of the hough density
1240
        if hough_weight is not None:
4✔
1241
            self.hough_weight = hough_weight
×
1242

1243
        elif self.hough_weight is None:
4✔
1244
            self.hough_weight = 1.0
4✔
1245

1246
        else:
1247
            pass
2✔
1248

1249
        # Set the minimum number of desired matches
1250
        if minimum_matches is not None:
4✔
1251
            assert minimum_matches > 0
4✔
1252
            self.minimum_matches = minimum_matches
4✔
1253

1254
        elif self.minimum_matches is None:
4✔
1255
            self.minimum_matches = 0
4✔
1256

1257
        else:
1258
            pass
2✔
1259

1260
        # Set the minimum utilisation required
1261
        if minimum_peak_utilisation is not None:
4✔
1262
            assert (
×
1263
                minimum_peak_utilisation >= 0
1264
                and minimum_peak_utilisation <= 1.0
1265
            )
1266
            self.minimum_peak_utilisation = minimum_peak_utilisation
×
1267

1268
        elif self.minimum_peak_utilisation is None:
4✔
1269
            self.minimum_peak_utilisation = 0
4✔
1270

1271
        else:
1272
            pass
2✔
1273

1274
        # Set the minimum fit error
1275
        if minimum_fit_error is not None:
4✔
1276
            assert minimum_fit_error >= 0
4✔
1277
            self.minimum_fit_error = minimum_fit_error
4✔
1278

1279
        elif self.minimum_fit_error is None:
4✔
1280
            self.minimum_fit_error = 1e-4
4✔
1281

1282
        else:
1283
            pass
4✔
1284

1285
    def add_atlas(
4✔
1286
        self,
1287
        elements,
1288
        min_atlas_wavelength=None,
1289
        max_atlas_wavelength=None,
1290
        min_intensity=10.0,
1291
        min_distance=10.0,
1292
        candidate_tolerance=10.0,
1293
        constrain_poly=False,
1294
        vacuum=False,
1295
        pressure=101325.0,
1296
        temperature=273.15,
1297
        relative_humidity=0.0,
1298
    ):
1299
        self.logger.warning(
×
1300
            "Using add_atlas is now deprecated. "
1301
            "Please use the new Atlas class."
1302
        )
1303

1304
        if min_atlas_wavelength is None:
×
1305
            min_atlas_wavelength = self.min_wavelength - self.range_tolerance
×
1306

1307
        if max_atlas_wavelength is None:
×
1308
            max_atlas_wavelength = self.max_wavelength + self.range_tolerance
×
1309

1310
        if self.atlas is None:
×
1311
            new_atlas = Atlas(
×
1312
                elements,
1313
                min_atlas_wavelength=min_atlas_wavelength,
1314
                max_atlas_wavelength=max_atlas_wavelength,
1315
                min_intensity=min_intensity,
1316
                min_distance=min_distance,
1317
                range_tolerance=self.range_tolerance,
1318
                vacuum=vacuum,
1319
                pressure=pressure,
1320
                temperature=temperature,
1321
                relative_humidity=relative_humidity,
1322
            )
1323
            self.atlas = new_atlas
×
1324

1325
        else:
1326
            self.atlas.add(
×
1327
                elements,
1328
                min_atlas_wavelength=min_atlas_wavelength,
1329
                max_atlas_wavelength=max_atlas_wavelength,
1330
                min_intensity=min_intensity,
1331
                min_distance=min_distance,
1332
                vacuum=vacuum,
1333
                pressure=pressure,
1334
                temperature=temperature,
1335
                relative_humidity=relative_humidity,
1336
            )
1337

1338
        self.candidate_tolerance = candidate_tolerance
×
1339
        self.constrain_poly = constrain_poly
×
1340

1341
        self._generate_pairs()
×
1342

1343
    def remove_atlas_lines_range(self, wavelength, tolerance=10):
4✔
1344
        """
1345
        Remove arc lines within a certain wavelength range.
1346
        """
1347

1348
        self.atlas.remove_atlas_lines_range(wavelength, tolerance)
×
1349

1350
    def list_atlas(self):
4✔
1351
        """
1352
        List all the lines loaded to the Calibrator.
1353
        """
1354

1355
        self.atlas.list()
×
1356

1357
    def clear_atlas(self):
4✔
1358
        """
1359
        Remove all the lines loaded to the Calibrator.
1360
        """
1361

1362
        self.atlas.clear()
×
1363

1364
    def add_user_atlas(
4✔
1365
        self,
1366
        elements,
1367
        wavelengths,
1368
        intensities=None,
1369
        vacuum=False,
1370
        pressure=101325.0,
1371
        temperature=273.15,
1372
        relative_humidity=0.0,
1373
        candidate_tolerance=10,
1374
        constrain_poly=False,
1375
    ):
1376
        self.logger.warning(
×
1377
            "Using add_user_atlas is now deprecated. "
1378
            "Please use the new Atlas class."
1379
        )
1380

1381
        if self.atlas is None:
×
1382
            self.atlas = Atlas()
×
1383

1384
        self.atlas.add_user_atlas(
×
1385
            elements,
1386
            wavelengths,
1387
            intensities,
1388
            vacuum,
1389
            pressure,
1390
            temperature,
1391
            relative_humidity,
1392
        )
1393

1394
        self.candidate_tolerance = candidate_tolerance
×
1395
        self.constrain_poly = constrain_poly
×
1396

1397
        self._generate_pairs()
×
1398

1399
    def set_atlas(self, atlas, candidate_tolerance=10.0, constrain_poly=False):
4✔
1400
        """
1401
        Adds an atlas of arc lines to the calibrator
1402

1403
        Parameters
1404
        ----------
1405
        atlas: rascal.Atlas
1406
            Chemical symbol, case insensitive
1407
        candidate_tolerance: float (default: 10)
1408
            toleranceold  (Angstroms) for considering a point to be an inlier
1409
            during candidate peak/line selection. This should be reasonable
1410
            small as we want to search for candidate points which are
1411
            *locally* linear.
1412
        constrain_poly: boolean
1413
            Apply a polygonal constraint on possible peak/atlas pairs
1414
        """
1415

1416
        self.atlas = atlas
4✔
1417

1418
        self.candidate_tolerance = candidate_tolerance
4✔
1419
        self.constrain_poly = constrain_poly
4✔
1420

1421
        # Create a list of all possible pairs of detected peaks and lines
1422
        # from atlas
1423
        self._generate_pairs()
4✔
1424

1425
    def do_hough_transform(self, brute_force=False):
4✔
1426
        if self.pairs == []:
4✔
1427
            logging.warning("pairs list is empty. Try generating now.")
×
1428
            self._generate_pairs()
×
1429

1430
            if self.pairs == []:
×
1431
                logging.error("pairs list is still empty.")
×
1432

1433
        # Generate the hough_points from the pairs
1434
        self.ht.set_constraints(
4✔
1435
            self.min_slope,
1436
            self.max_slope,
1437
            self.min_intercept,
1438
            self.max_intercept,
1439
        )
1440

1441
        if brute_force:
4✔
1442
            self.ht.generate_hough_points_brute_force(
4✔
1443
                self.pairs[:, 0], self.pairs[:, 1]
1444
            )
1445
        else:
1446
            self.ht.generate_hough_points(
4✔
1447
                self.pairs[:, 0], self.pairs[:, 1], num_slopes=self.num_slopes
1448
            )
1449

1450
        self.ht.bin_hough_points(self.xbins, self.ybins)
4✔
1451
        self.hough_points = self.ht.hough_points
4✔
1452
        self.hough_lines = self.ht.hough_lines
4✔
1453

1454
    def save_hough_transform(
4✔
1455
        self,
1456
        filename="hough_transform",
1457
        fileformat="npy",
1458
        delimiter="+",
1459
        to_disk=True,
1460
    ):
1461
        """
1462
        Save the HoughTransform object to memory or to disk.
1463

1464
        Parameters
1465
        ----------
1466
        filename: str
1467
            The filename of the output, not used if to_disk is False. It
1468
            will be appended with the content type.
1469
        format: str (default: 'npy')
1470
            Choose from 'npy' and json'
1471
        delimiter: str (default: '+')
1472
            Delimiter for format and content types
1473
        to_disk: boolean
1474
            Set to True to save to disk, else return a numpy array object
1475

1476
        Returns
1477
        -------
1478
        hp_hough_points: numpy.ndarray
1479
            only return if to_disk is False.
1480

1481
        """
1482

1483
        self.ht.save(
4✔
1484
            filename=filename,
1485
            fileformat=fileformat,
1486
            delimiter=delimiter,
1487
            to_disk=to_disk,
1488
        )
1489

1490
    def load_hough_transform(self, filename="hough_transform", filetype="npy"):
4✔
1491
        """
1492
        Store the binned Hough space and/or the raw Hough pairs.
1493

1494
        Parameters
1495
        ----------
1496
        filename: str (default: 'hough_transform')
1497
            The filename of the output, not used if to_disk is False. It
1498
            will be appended with the content type.
1499
        filetype: str (default: 'npy')
1500
            The file type of the saved hough transform. Choose from 'npy'
1501
            and 'json'.
1502

1503
        """
1504

1505
        self.ht.load(filename=filename, filetype=filetype)
4✔
1506

1507
    def set_known_pairs(self, pix=(), wave=()):
4✔
1508
        """
1509
        Provide manual pixel-wavelength pair(s), they will be appended to the
1510
        list of pixel-wavelength pairs after the random sample being drawn from
1511
        the RANSAC step, i.e. they are ALWAYS PRESENT in the fitting step. Use
1512
        with caution because it can skew or bias the fit significantly, make
1513
        sure the pixel value is accurate to at least 1/10 of a pixel. We do not
1514
        recommend supplying more than a coupld of known pairs unless you are
1515
        very confident with the solution and intend to skew with the known
1516
        pairs.
1517

1518
        This can be used for example for low intensity lines at the edge of
1519
        the spectrum. Or saturated lines where peaks cannot be well positioned.
1520

1521
        Parameters
1522
        ----------
1523
        pix: numeric value, list or numpy 1D array (N) (default: ())
1524
            Any pixel value, can be outside the detector chip and
1525
            serve purely as anchor points.
1526
        wave: numeric value, list or numpy 1D array (N) (default: ())
1527
            The matching wavelength for each of the pix.
1528

1529
        """
1530

1531
        pix = np.asarray(pix, dtype="float").reshape(-1)
4✔
1532
        wave = np.asarray(wave, dtype="float").reshape(-1)
4✔
1533

1534
        assert pix.size == wave.size, ValueError(
4✔
1535
            "Please check the length of the input arrays. pix has size {} "
1536
            "and wave has size {}.".format(pix.size, wave.size)
1537
        )
1538

1539
        if not all(
4✔
1540
            isinstance(p, (float, int)) & (not np.isnan(p)) for p in pix
1541
        ):
1542
            raise ValueError("All pix elements have to be numeric.")
4✔
1543

1544
        if not all(
4✔
1545
            isinstance(w, (float, int)) & (not np.isnan(w)) for w in wave
1546
        ):
1547
            raise ValueError("All wave elements have to be numeric.")
4✔
1548

1549
        self.pix_known = pix
4✔
1550
        self.wave_known = wave
4✔
1551

1552
    def fit(
4✔
1553
        self,
1554
        max_tries=500,
1555
        fit_deg=4,
1556
        fit_coeff=None,
1557
        fit_tolerance=5.0,
1558
        fit_type="poly",
1559
        candidate_tolerance=2.0,
1560
        brute_force=False,
1561
        progress=True,
1562
    ):
1563
        """
1564
        Solve for the wavelength calibration polynomial by getting the most
1565
        likely solution with RANSAC.
1566

1567
        Parameters
1568
        ----------
1569
        max_tries: int (default: 5000)
1570
            Maximum number of iteration.
1571
        fit_deg: int (default: 4)
1572
            The degree of the polynomial to be fitted.
1573
        fit_coeff: list (default: None)
1574
            Set the baseline of the least square fit. If no fits outform this
1575
            set of polynomial coefficients, this will be used as the best fit.
1576
        fit_tolerance: float (default: 5.0)
1577
            Sets a tolerance on whether a fit found by RANSAC is considered
1578
            acceptable
1579
        fit_type: string (default: 'poly')
1580
            One of 'poly', 'legendre' or 'chebyshev'
1581
        candidate_tolerance: float (default: 2.0)
1582
            toleranceold  (Angstroms) for considering a point to be an inlier
1583
        brute_force: boolean (default: False)
1584
            Set to True to try all possible combination in the given parameter
1585
            space
1586
        progress: boolean (default: True)
1587
            True to show progress with tdqm. It is overrid if tdqm cannot be
1588
            imported.
1589

1590
        Returns
1591
        -------
1592
        fit_coeff: list
1593
            List of best fit polynomial fit_coefficient.
1594
        matched_peaks: list
1595
            Peaks used for final fit
1596
        matched_atlas: list
1597
            Atlas lines used for final fit
1598
        rms: float
1599
            The root-mean-squared of the residuals
1600
        residual: float
1601
            Residual from the best fit
1602
        peak_utilisation: float
1603
            Fraction of detected peaks (pixel) used for calibration [0-1].
1604
        atlas_utilisation: float
1605
            Fraction of supplied arc lines (wavelength) used for
1606
            calibration [0-1].
1607

1608
        """
1609

1610
        self.max_tries = max_tries
4✔
1611
        self.fit_deg = fit_deg
4✔
1612
        self.fit_coeff = fit_coeff
4✔
1613
        if fit_coeff is not None:
4✔
1614
            self.fit_deg = len(fit_coeff) - 1
4✔
1615

1616
        self.fit_tolerance = fit_tolerance
4✔
1617
        self.fit_type = fit_type
4✔
1618
        self.brute_force = brute_force
4✔
1619
        self.progress = progress
4✔
1620

1621
        if self.fit_type == "poly":
4✔
1622
            self.polyfit = np.polynomial.polynomial.polyfit
4✔
1623
            self.polyval = np.polynomial.polynomial.polyval
4✔
1624

1625
        elif self.fit_type == "legendre":
4✔
1626
            self.polyfit = np.polynomial.legendre.legfit
4✔
1627
            self.polyval = np.polynomial.legendre.legval
4✔
1628

1629
        elif self.fit_type == "chebyshev":
4✔
1630
            self.polyfit = np.polynomial.chebyshev.chebfit
4✔
1631
            self.polyval = np.polynomial.chebyshev.chebval
4✔
1632

1633
        else:
1634
            raise ValueError(
×
1635
                "fit_type must be: (1) poly, (2) legendre or (3) chebyshev"
1636
            )
1637

1638
        # Reduce sample_size if it is larger than the number of atlas available
1639
        if self.sample_size > len(self.atlas):
4✔
1640
            self.logger.warning(
×
1641
                "Size of sample_size is larger than the size of atlas, "
1642
                + "the sample_size is set to match the size of atlas = "
1643
                + str(len(self.atlas))
1644
                + "."
1645
            )
1646
            self.sample_size = len(self.atlas)
×
1647

1648
        if self.sample_size <= fit_deg:
4✔
1649
            self.sample_size = fit_deg + 1
4✔
1650

1651
        if (self.hough_lines is None) or (self.hough_points is None):
4✔
1652
            self.do_hough_transform()
4✔
1653

1654
        if self.minimum_matches > len(self.atlas):
4✔
1655
            self.logger.warning(
×
1656
                "Requested minimum matches is greater than the atlas size"
1657
                "setting the minimum number of matches to equal the atlas"
1658
                "size = " + str(len(self.atlas)) + "."
1659
            )
1660
            self.minimum_matches = len(self.atlas)
×
1661

1662
        if self.minimum_matches > len(self.peaks):
4✔
1663
            self.logger.warning(
×
1664
                "Requested minimum matches is greater than the number of "
1665
                "peaks detected, which has a size of "
1666
                "size = " + str(len(self.peaks)) + "."
1667
            )
1668
            self.minimum_matches = len(self.peaks)
×
1669

1670
        # TODO also check whether minimum peak utilisation is greater than
1671
        # minimum matches.
1672

1673
        (
4✔
1674
            fit_coeff,
1675
            rms,
1676
            residual,
1677
            n_inliers,
1678
            valid,
1679
        ) = self._solve_candidate_ransac(
1680
            fit_deg=self.fit_deg,
1681
            fit_coeff=self.fit_coeff,
1682
            max_tries=self.max_tries,
1683
            candidate_tolerance=candidate_tolerance,
1684
            brute_force=self.brute_force,
1685
            progress=self.progress,
1686
        )
1687

1688
        peak_utilisation = n_inliers / len(self.peaks)
4✔
1689
        atlas_utilisation = n_inliers / len(self.atlas)
4✔
1690

1691
        if not valid:
4✔
1692
            self.logger.warning("Invalid fit")
×
1693

1694
        if rms > self.fit_tolerance:
4✔
1695
            self.logger.warning(
×
1696
                "RMS too large {} > {}".format(rms, self.fit_tolerance)
1697
            )
1698

1699
        assert fit_coeff is not None, "Couldn't fit"
4✔
1700

1701
        self.fit_coeff = fit_coeff
4✔
1702
        self.rms = rms
4✔
1703
        self.residual = residual
4✔
1704
        self.peak_utilisation = peak_utilisation
4✔
1705
        self.atlas_utilisation = atlas_utilisation
4✔
1706

1707
        return (
4✔
1708
            self.fit_coeff,
1709
            self.matched_peaks,
1710
            self.matched_atlas,
1711
            self.rms,
1712
            self.residual,
1713
            self.peak_utilisation,
1714
            self.atlas_utilisation,
1715
        )
1716

1717
    def match_peaks(
4✔
1718
        self,
1719
        fit_coeff=None,
1720
        n_delta=None,
1721
        refine=False,
1722
        tolerance=10.0,
1723
        method="Nelder-Mead",
1724
        convergence=1e-6,
1725
        min_frac=0.5,
1726
        robust_refit=True,
1727
        fit_deg=None,
1728
    ):
1729
        """
1730
        ** refine option is EXPERIMENTAL, use with caution **
1731

1732
        Refine the polynomial fit fit_coefficients. Recommended to use in it
1733
        multiple calls to first refine the lowest order and gradually increase
1734
        the order of fit_coefficients to be included for refinement. This is be
1735
        achieved by providing delta in the length matching the number of the
1736
        lowest degrees to be refined.
1737

1738
        Set refine to True to improve on the polynomial solution.
1739

1740
        Set robust_refit to True to fit all the detected peaks with the
1741
        given polynomial solution for a fit using maximal information, with
1742
        the degree of polynomial = fit_deg.
1743

1744
        Set both refine and robust_refit to False will return the list of
1745
        arc lines are well fitted by the current solution within the
1746
        tolerance limit provided.
1747

1748
        Parameters
1749
        ----------
1750
        fit_coeff: list (default: None)
1751
            List of polynomial fit fit_coefficients.
1752
        n_delta: int (default: None)
1753
            The number of the lowest polynomial order to be adjusted
1754
        refine: boolean (default: True)
1755
            Set to True to refine solution.
1756
        tolerance: float (default: 10.)
1757
            Absolute difference between fit and model in the unit of nm.
1758
        method: string (default: 'Nelder-Mead')
1759
            scipy.optimize.minimize method.
1760
        convergence: float (default: 1e-6)
1761
            scipy.optimize.minimize tol.
1762
        min_frac: float (default: 0.5)
1763
            Minimum fractionof peaks to be refitted.
1764
        robust_refit: boolean (default: True)
1765
            Set to True to fit all the detected peaks with the given polynomial
1766
            solution.
1767
        fit_deg: int (default: length of the input fit_coefficients)
1768
            Order of polynomial fit with all the detected peaks.
1769

1770
        Returns
1771
        -------
1772
        fit_coeff: list
1773
            List of best fit polynomial fit_coefficient.
1774
        peak_match: numpy 1D array
1775
            Matched peaks
1776
        atlas_match: numpy 1D array
1777
            Corresponding atlas matches
1778
        rms: float
1779
            The root-mean-squared of the residuals
1780
        residual: numpy 1D array
1781
            The difference (NOT absolute) between the data and the best-fit
1782
            solution. * EXPERIMENTAL *
1783
        peak_utilisation: float
1784
            Fraction of detected peaks (pixel) used for calibration [0-1].
1785
        atlas_utilisation: float
1786
            Fraction of supplied arc lines (wavelength) used for
1787
            calibration [0-1].
1788

1789
        """
1790

1791
        if fit_coeff is None:
4✔
1792
            fit_coeff = self.fit_coeff.copy()
×
1793

1794
        if fit_deg is None:
4✔
1795
            fit_deg = len(fit_coeff) - 1
4✔
1796

1797
        if refine:
4✔
1798
            fit_coeff_new = fit_coeff.copy()
4✔
1799

1800
            if n_delta is None:
4✔
1801
                n_delta = len(fit_coeff_new) - 1
4✔
1802

1803
            # fit everything
1804
            fitted_delta = minimize(
4✔
1805
                self._adjust_polyfit,
1806
                fit_coeff_new[: int(n_delta)] * 1e-3,
1807
                args=(fit_coeff, tolerance, min_frac),
1808
                method=method,
1809
                tol=convergence,
1810
                options={"maxiter": 10000},
1811
            ).x
1812

1813
            for i, d in enumerate(fitted_delta):
4✔
1814
                fit_coeff_new[i] += d
4✔
1815

1816
            if np.any(np.isnan(fit_coeff_new)):
4✔
1817
                self.logger.warning(
×
1818
                    "_adjust_polyfit() returns None. "
1819
                    "Input solution is returned."
1820
                )
1821
                return fit_coeff, None, None, None, None, None, None
×
1822

1823
        matched_peaks = []
4✔
1824
        matched_atlas = []
4✔
1825
        residuals = []
4✔
1826

1827
        atlas_lines = self.atlas.get_lines()
4✔
1828

1829
        # Find all Atlas peaks within tolerance
1830
        for p in self.peaks:
4✔
1831
            x = self.polyval(p, fit_coeff)
4✔
1832
            diff = atlas_lines - x
4✔
1833
            diff_abs = np.abs(diff) < tolerance
4✔
1834

1835
            if diff_abs.any():
4✔
1836
                matched_peaks.append(p)
4✔
1837
                matched_atlas.append(list(np.asarray(atlas_lines)[diff_abs]))
4✔
1838
                residuals.append(diff_abs)
4✔
1839

1840
        # Create permutations:
1841
        candidates = [[]]
4✔
1842

1843
        # match is a list
1844
        for match in matched_atlas:
4✔
1845
            if len(match) == 0:
4✔
1846
                continue
×
1847

1848
            self.logger.info("matched: {}".format(match))
4✔
1849

1850
            new_candidates = []
4✔
1851
            # i is an int
1852
            # candidates is a list of list
1853

1854
            for i in range(len(candidates)):
4✔
1855
                # c is a list
1856
                c = candidates[i]
4✔
1857

1858
                if len(match) == 1:
4✔
1859
                    c.extend(match)
4✔
1860

1861
                else:
1862
                    # rep is a list of tuple
1863
                    rep = ~np.in1d(match, c)
×
1864

1865
                    if rep.any():
×
1866
                        for j in np.argwhere(rep):
×
1867
                            new_c = c + [match[j]]
×
1868
                            new_candidates.append(new_c)
×
1869

1870
                # Only add if new_candidates is not an empty list
1871
                if new_candidates != []:
4✔
1872
                    if candidates[0] == []:
×
1873
                        candidates[0] = new_candidates
×
1874

1875
                    else:
1876
                        candidates.append(new_candidates)
×
1877

1878
        if len(candidates) > 1:
4✔
1879
            self.logger.info(
×
1880
                "More than one match solution found, checking permutations."
1881
            )
1882

1883
        self.matched_peaks = np.array(copy.deepcopy(matched_peaks))
4✔
1884

1885
        # Check all candidates
1886
        best_err = 1e9
4✔
1887
        self.matched_atlas = None
4✔
1888
        self.residuals = None
4✔
1889

1890
        for candidate in candidates:
4✔
1891
            matched_atlas = np.array(candidate)
4✔
1892

1893
            fit_coeff = self.polyfit(matched_peaks, matched_atlas, fit_deg)
4✔
1894

1895
            x = self.polyval(matched_peaks, fit_coeff)
4✔
1896
            residuals = np.abs(matched_atlas - x)
4✔
1897
            err = np.sum(residuals)
4✔
1898

1899
            if err < best_err:
4✔
1900
                self.matched_atlas = matched_atlas
4✔
1901
                self.residuals = residuals
4✔
1902

1903
                best_err = err
4✔
1904

1905
        assert self.matched_atlas is not None
4✔
1906
        assert self.residuals is not None
4✔
1907

1908
        self.rms = np.sqrt(
4✔
1909
            np.nansum(self.residuals**2.0) / len(self.residuals)
1910
        )
1911

1912
        self.peak_utilisation = len(self.matched_peaks) / len(self.peaks)
4✔
1913
        self.atlas_utilisation = len(self.matched_atlas) / len(self.atlas)
4✔
1914

1915
        if robust_refit:
4✔
1916
            self.fit_coeff = models.robust_polyfit(
4✔
1917
                np.asarray(self.matched_peaks),
1918
                np.asarray(self.matched_atlas),
1919
                fit_deg,
1920
            )
1921

1922
            if np.any(np.isnan(self.fit_coeff)):
4✔
1923
                self.logger.warning(
×
1924
                    "robust_polyfit() returns None. "
1925
                    "Input solution is returned."
1926
                )
1927
                return (
×
1928
                    fit_coeff,
1929
                    self.matched_peaks,
1930
                    self.matched_atlas,
1931
                    self.rms,
1932
                    self.residuals,
1933
                    self.peak_utilisation,
1934
                    self.atlas_utilisation,
1935
                )
1936

1937
            else:
1938
                self.residuals = self.matched_atlas - self.polyval(
4✔
1939
                    self.matched_peaks, self.fit_coeff
1940
                )
1941
                self.rms = np.sqrt(
4✔
1942
                    np.nansum(self.residuals**2.0) / len(self.residuals)
1943
                )
1944

1945
        else:
1946
            self.fit_coeff = fit_coeff_new
×
1947

1948
        return (
4✔
1949
            self.fit_coeff,
1950
            self.matched_peaks,
1951
            self.matched_atlas,
1952
            self.rms,
1953
            self.residuals,
1954
            self.peak_utilisation,
1955
            self.atlas_utilisation,
1956
        )
1957

1958
    def get_pix_wave_pairs(self):
4✔
1959
        """
1960
        Return the list of matched_peaks and matched_atlas with their
1961
        position in the array.
1962

1963
        Return
1964
        ------
1965
        pw_pairs: list
1966
            List of tuples each containing the array position, peak (pixel)
1967
            and atlas (wavelength).
1968

1969
        """
1970

1971
        pw_pairs = []
4✔
1972

1973
        for i, (p, w) in enumerate(
4✔
1974
            zip(self.matched_peaks, self.matched_atlas)
1975
        ):
1976
            pw_pairs.append((i, p, w))
4✔
1977
            self.logger.info(
4✔
1978
                "Position {}: pixel {} is matched to wavelength {}".format(
1979
                    i, p, w
1980
                )
1981
            )
1982

1983
        return pw_pairs
4✔
1984

1985
    def add_pix_wave_pair(self, pix, wave):
4✔
1986
        """
1987
        Adding extra pixel-wavelength pair to the Calibrator for refitting.
1988
        This DOES NOT work before the Calibrator having fit for a solution
1989
        yet: use set_known_pairs() for that purpose.
1990

1991
        Parameters
1992
        ----------
1993
        pix: float
1994
            pixel position
1995
        wave: float
1996
            wavelength
1997

1998
        """
1999

2000
        arg = np.argwhere(pix > self.matched_peaks)[0]
4✔
2001

2002
        # Only update the lists if both can be inserted
2003
        matched_peaks = np.insert(self.matched_peaks, arg, pix)
4✔
2004
        matched_atlas = np.insert(self.matched_atlas, arg, wave)
4✔
2005

2006
        self.matched_peaks = matched_peaks
4✔
2007
        self.matched_atlas = matched_atlas
4✔
2008

2009
    def remove_pix_wave_pair(self, arg):
4✔
2010
        """
2011
        Remove fitted pixel-wavelength pair from the Calibrator for refitting.
2012
        The positions can be found from get_pix_wave_pairs(). One at a time.
2013

2014
        Parameters
2015
        ----------
2016
        arg: int
2017
            The position of the pairs in the arrays.
2018

2019
        """
2020

2021
        # Only update the lists if both can be deleted
2022
        matched_peaks = np.delete(self.matched_peaks, arg)
4✔
2023
        matched_atlas = np.delete(self.matched_atlas, arg)
4✔
2024

2025
        self.matched_peaks = matched_peaks
4✔
2026
        self.matched_atlas = matched_atlas
4✔
2027

2028
    def manual_refit(
4✔
2029
        self, matched_peaks=None, matched_atlas=None, degree=None, x0=None
2030
    ):
2031
        """
2032
        Perform a refinement of the matched peaks and atlas lines.
2033

2034
        This function takes lists of matched peaks and atlases, along with
2035
        user-specified lists of lines to add/remove from the lists.
2036

2037
        Any given peaks or atlas lines to remove are selected within a
2038
        user-specified tolerance, by default 1 pixel and 5 atlas Angstrom.
2039

2040
        The final set of matching peaks/lines is then matched using a
2041
        robust polyfit of the desired degree. Optionally, an initial
2042
        fit x0 can be provided to condition the optimiser.
2043

2044
        The parameters are identical in the format in the fit() and
2045
        match_peaks() functions, however, with manual changes to the lists of
2046
        peaks and atlas, peak_utilisation and atlas_utilisation are
2047
        meaningless so this function does not return in the same format.
2048

2049
        Parameters
2050
        ----------
2051
        matched_peaks: list
2052
            List of matched peaks
2053
        matched_atlas: list
2054
            List of matched atlas lines
2055
        degree: int
2056
            Polynomial fit degree (Only used if x0 is None)
2057
        x0: list
2058
            Initial fit coefficients
2059

2060
        Returns
2061
        -------
2062
        fit_coeff: list
2063
            List of best fit polynomial coefficients
2064
        matched_peaks: list
2065
            List of matched peaks
2066
        matched_atlas: list
2067
            List of matched atlas lines
2068
        rms: float
2069
            The root-mean-squared of the residuals
2070
        residuals: numpy 1D array
2071
            Residual match error per-peak
2072

2073
        """
2074

2075
        if matched_peaks is None:
4✔
2076
            matched_peaks = self.matched_peaks
×
2077

2078
        if matched_atlas is None:
4✔
2079
            matched_atlas = self.matched_atlas
×
2080

2081
        if (x0 is None) and (degree is None):
4✔
2082
            x0 = self.fit_coeff
4✔
2083
            degree = len(x0) - 1
4✔
2084

2085
        elif (x0 is not None) and (degree is None):
×
2086
            assert isinstance(x0, list)
×
2087
            degree = len(x0) - 1
×
2088

2089
        elif (x0 is None) and (degree is not None):
×
2090
            assert isinstance(degree, int)
×
2091

2092
        else:
2093
            assert isinstance(x0, list)
×
2094
            assert isinstance(degree, int)
×
2095
            assert len(x0) == degree + 1
×
2096

2097
        x = np.asarray(matched_peaks)
4✔
2098
        y = np.asarray(matched_atlas)
4✔
2099

2100
        assert len(x) == len(y)
4✔
2101
        assert len(x) > 0
4✔
2102
        assert degree > 0
4✔
2103
        assert degree <= len(x) - 1
4✔
2104

2105
        # Run robust fitting again
2106
        fit_coeff_new = models.robust_polyfit(x, y, degree, x0)
4✔
2107
        self.logger.info("Input fit_coeff is {}.".format(x0))
4✔
2108
        self.logger.info("Refit fit_coeff is {}.".format(fit_coeff_new))
4✔
2109

2110
        self.fit_coeff = fit_coeff_new
4✔
2111
        self.matched_peaks = copy.deepcopy(matched_peaks)
4✔
2112
        self.matched_atlas = copy.deepcopy(matched_atlas)
4✔
2113
        self.residuals = y - self.polyval(x, fit_coeff_new)
4✔
2114
        self.rms = np.sqrt(
4✔
2115
            np.nansum(self.residuals**2.0) / len(self.residuals)
2116
        )
2117

2118
        return (
4✔
2119
            self.fit_coeff,
2120
            self.matched_peaks,
2121
            self.matched_atlas,
2122
            self.rms,
2123
            self.residuals,
2124
        )
2125

2126
    def plot_arc(
4✔
2127
        self,
2128
        pixel_list=None,
2129
        log_spectrum=False,
2130
        save_fig=False,
2131
        fig_type="png",
2132
        filename=None,
2133
        return_jsonstring=False,
2134
        renderer="default",
2135
        display=True,
2136
    ):
2137
        """
2138
        Plots the 1D spectrum of the extracted arc.
2139

2140
        parameters
2141
        ----------
2142
        pixel_list: array (default: None)
2143
            pixel value of the of the spectrum, this is only needed if the
2144
            spectrum spans multiple detector arrays.
2145
        log_spectrum: boolean (default: False)
2146
            Set to true to display the wavelength calibrated arc spectrum in
2147
            logarithmic space.
2148
        save_fig: boolean (default: False)
2149
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2150
            while the plotly uses the pio.write_html() or pio.write_image().
2151
            The support format types should be provided in fig_type.
2152
        fig_type: string (default: 'png')
2153
            Image type to be saved, choose from:
2154
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2155
        filename: string (default: None)
2156
            Provide a filename or full path. If the extension is not provided
2157
            it is defaulted to png.
2158
        return_jsonstring: boolean (default: False)
2159
            Set to True to return json strings if using plotly as the plotting
2160
            library.
2161
        renderer: string (default: 'default')
2162
            Indicate the Plotly renderer. Nothing gets displayed if json is
2163
            set to True.
2164
        display: boolean (Default: False)
2165
            Set to True to display disgnostic plot.
2166

2167
        Returns
2168
        -------
2169
        Return json strings if using plotly as the plotting library and json
2170
        is True.
2171

2172
        """
2173

2174
        return plotting.plot_arc(
4✔
2175
            self,
2176
            pixel_list=pixel_list,
2177
            log_spectrum=log_spectrum,
2178
            save_fig=save_fig,
2179
            fig_type=fig_type,
2180
            filename=filename,
2181
            return_jsonstring=return_jsonstring,
2182
            renderer=renderer,
2183
            display=display,
2184
        )
2185

2186
    def plot_search_space(
4✔
2187
        self,
2188
        fit_coeff=None,
2189
        top_n_candidate=3,
2190
        weighted=True,
2191
        save_fig=False,
2192
        fig_type="png",
2193
        filename=None,
2194
        return_jsonstring=False,
2195
        renderer="default",
2196
        display=True,
2197
    ):
2198
        """
2199
        Plots the peak/arc line pairs that are considered as potential match
2200
        candidates.
2201

2202
        If fit fit_coefficients are provided, the model solution will be
2203
        overplotted.
2204

2205
        Parameters
2206
        ----------
2207
        fit_coeff: list (default: None)
2208
            List of best polynomial fit_coefficients
2209
        top_n_candidate: int (default: 3)
2210
            Top ranked lines to be fitted.
2211
        weighted: (default: True)
2212
            Draw sample based on the distance from the matched known wavelength
2213
            of the atlas.
2214
        save_fig: boolean (default: False)
2215
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2216
            while the plotly uses the pio.write_html() or pio.write_image().
2217
            The support format types should be provided in fig_type.
2218
        fig_type: string (default: 'png')
2219
            Image type to be saved, choose from:
2220
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2221
        filename: (default: None)
2222
            The destination to save the image.
2223
        return_jsonstring: (default: False)
2224
            Set to True to save the plotly figure as json string. Ignored if
2225
            matplotlib is used.
2226
        renderer: (default: 'default')
2227
            Set the rendered for the plotly display. Ignored if matplotlib is
2228
            used.
2229
        display: boolean (Default: False)
2230
            Set to True to display disgnostic plot.
2231

2232
        Return
2233
        ------
2234
        json object if json is True.
2235

2236
        """
2237

2238
        return plotting.plot_search_space(
4✔
2239
            self,
2240
            fit_coeff=fit_coeff,
2241
            top_n_candidate=top_n_candidate,
2242
            weighted=weighted,
2243
            save_fig=save_fig,
2244
            fig_type=fig_type,
2245
            filename=filename,
2246
            return_jsonstring=return_jsonstring,
2247
            renderer=renderer,
2248
            display=display,
2249
        )
2250

2251
    def plot_fit(
4✔
2252
        self,
2253
        fit_coeff=None,
2254
        spectrum=None,
2255
        tolerance=5.0,
2256
        plot_atlas=True,
2257
        log_spectrum=False,
2258
        save_fig=False,
2259
        fig_type="png",
2260
        filename=None,
2261
        return_jsonstring=False,
2262
        renderer="default",
2263
        display=True,
2264
    ):
2265
        """
2266
        Plots of the wavelength calibrated arc spectrum, the residual and the
2267
        pixel-to-wavelength solution.
2268

2269
        Parameters
2270
        ----------
2271
        fit_coeff: 1D numpy array or list
2272
            Best fit polynomial fit_coefficients
2273
        spectrum: 1D numpy array (N)
2274
            Array of length N pixels
2275
        tolerance: float (default: 5)
2276
            Absolute difference between model and fitted wavelengths in unit
2277
            of angstrom.
2278
        plot_atlas: boolean (default: True)
2279
            Display all the relavent lines available in the atlas library.
2280
        log_spectrum: boolean (default: False)
2281
            Display the arc in log-space if set to True.
2282
        save_fig: boolean (default: False)
2283
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2284
            while the plotly uses the pio.write_html() or pio.write_image().
2285
            The support format types should be provided in fig_type.
2286
        fig_type: string (default: 'png')
2287
            Image type to be saved, choose from:
2288
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2289
        filename: string (default: None)
2290
            Provide a filename or full path. If the extension is not provided
2291
            it is defaulted to png.
2292
        return_jsonstring: boolean (default: False)
2293
            Set to True to return json strings if using plotly as the plotting
2294
            library.
2295
        renderer: string (default: 'default')
2296
            Indicate the Plotly renderer. Nothing gets displayed if json is
2297
            set to True.
2298
        display: boolean (Default: False)
2299
            Set to True to display disgnostic plot.
2300

2301
        Returns
2302
        -------
2303
        Return json strings if using plotly as the plotting library and json
2304
        is True.
2305

2306
        """
2307

2308
        if fit_coeff is None:
4✔
2309
            fit_coeff = self.fit_coeff
4✔
2310

2311
        return plotting.plot_fit(
4✔
2312
            self,
2313
            fit_coeff=fit_coeff,
2314
            spectrum=spectrum,
2315
            tolerance=tolerance,
2316
            plot_atlas=plot_atlas,
2317
            log_spectrum=log_spectrum,
2318
            save_fig=save_fig,
2319
            fig_type=fig_type,
2320
            filename=filename,
2321
            return_jsonstring=return_jsonstring,
2322
            renderer=renderer,
2323
            display=display,
2324
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc