• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 4216281273

pending completion
4216281273

Pull #89

github

GitHub
<a href="https://github.com/jveitchmichaelis/rascal/commit/<a class=hub.com/jveitchmichaelis/rascal/commit/1dbc3970bf02065ba877f56362cf01f0a7927705">1dbc3970b<a href="https://github.com/jveitchmichaelis/rascal/commit/1dbc3970bf02065ba877f56362cf01f0a7927705">">Merge </a><a class="double-link" href="https://github.com/jveitchmichaelis/rascal/commit/<a class="double-link" href="https://github.com/jveitchmichaelis/rascal/commit/4758a61c955757d2fccc7787f14a64961d1bac25">4758a61c9</a>">4758a61c9</a><a href="https://github.com/jveitchmichaelis/rascal/commit/1dbc3970bf02065ba877f56362cf01f0a7927705"> into cec48f2a6">cec48f2a6</a>
Pull Request #89: main catching up to v0.3.9

3 of 8 new or added lines in 3 files covered. (37.5%)

1884 of 2056 relevant lines covered (91.63%)

3.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.57
/src/rascal/calibrator.py
1
import copy
4✔
2
import itertools
4✔
3
import logging
4✔
4

5
import numpy as np
4✔
6
from scipy.spatial import Delaunay
4✔
7
from scipy.optimize import minimize
4✔
8
from scipy import interpolate
4✔
9
from tqdm.autonotebook import tqdm
4✔
10

11
from .util import _derivative
4✔
12
from .util import gauss
4✔
13

14
from . import plotting
4✔
15
from . import models
4✔
16
from .houghtransform import HoughTransform
4✔
17
from .atlas import Atlas
4✔
18

19

20
class Calibrator:
4✔
21
    def __init__(self, peaks, spectrum=None):
4✔
22
        """
23
        Initialise the calibrator object.
24

25
        Parameters
26
        ----------
27
        peaks: list
28
            List of identified arc line pixel values.
29
        spectrum: list
30
            The spectral intensity as a function of pixel.
31

32
        """
33

34
        self.logger = None
4✔
35
        self.log_level = None
4✔
36

37
        self.peaks = copy.deepcopy(peaks)
4✔
38
        self.spectrum = copy.deepcopy(spectrum)
4✔
39
        self.matplotlib_imported = False
4✔
40
        self.plotly_imported = False
4✔
41
        self.plot_with_matplotlib = False
4✔
42
        self.plot_with_plotly = False
4✔
43
        self.atlas = None
4✔
44
        self.pix_known = None
4✔
45
        self.wave_known = None
4✔
46
        self.hough_lines = None
4✔
47
        self.hough_points = None
4✔
48
        self.ht = HoughTransform()
4✔
49

50
        # calibrator_properties
51
        self.num_pix = None
4✔
52
        self.pixel_list = None
4✔
53
        self.plotting_library = None
4✔
54
        self.constrain_poly = None
4✔
55

56
        # hough_properties
57
        self.num_slopes = None
4✔
58
        self.xbins = None
4✔
59
        self.ybins = None
4✔
60
        self.min_wavelength = None
4✔
61
        self.max_wavelength = None
4✔
62
        self.range_tolerance = None
4✔
63
        self.linearity_tolerance = None
4✔
64

65
        # ransac_properties
66
        self.sample_size = None
4✔
67
        self.top_n_candidate = None
4✔
68
        self.linear = None
4✔
69
        self.filter_close = None
4✔
70
        self.ransac_tolerance = None
4✔
71
        self.candidate_weighted = None
4✔
72
        self.hough_weight = None
4✔
73
        self.minimum_matches = None
4✔
74
        self.minimum_peak_utilisation = None
4✔
75
        self.minimum_fit_error = None
4✔
76

77
        # results
78
        self.matched_peaks = []
4✔
79
        self.matched_atlas = []
4✔
80
        self.fit_coeff = None
4✔
81

82
        self.set_calibrator_properties()
4✔
83
        self.set_hough_properties()
4✔
84
        self.set_ransac_properties()
4✔
85

86
    def _generate_pairs(self):
4✔
87
        """
88
        Generate pixel-wavelength pairs without the allowed regions set by the
89
        linearity limit. This assumes a relatively linear spectrograph.
90

91
        Parameters
92
        ----------
93
        candidate_tolerance: float (default: 10)
94
            toleranceold  (Angstroms) for considering a point to be an inlier
95
            during candidate peak/line selection. This should be reasonable
96
            small as we want to search for candidate points which are
97
            *locally* linear.
98
        constrain_poly: boolean
99
            Apply a polygonal constraint on possible peak/atlas pairs
100

101
        """
102

103
        pairs = [
4✔
104
            pair
105
            for pair in itertools.product(self.peaks, self.atlas.get_lines())
106
        ]
107

108
        if self.constrain_poly:
4✔
109

110
            # Remove pairs outside polygon
111
            valid_area = Delaunay(
4✔
112
                [
113
                    (0, self.max_intercept + self.candidate_tolerance),
114
                    (0, self.min_intercept - self.candidate_tolerance),
115
                    (
116
                        self.pixel_list.max(),
117
                        self.max_wavelength
118
                        - self.range_tolerance
119
                        - self.candidate_tolerance,
120
                    ),
121
                    (
122
                        self.pixel_list.max(),
123
                        self.max_wavelength
124
                        + self.range_tolerance
125
                        + self.candidate_tolerance,
126
                    ),
127
                ]
128
            )
129

130
            mask = valid_area.find_simplex(pairs) >= 0
4✔
131
            self.pairs = np.array(pairs)[mask]
4✔
132

133
        else:
134

135
            self.pairs = np.array(pairs)
4✔
136

137
    def _merge_candidates(self, candidates):
4✔
138
        """
139
        Merge two candidate lists.
140

141
        Parameters
142
        ----------
143
        candidates: list
144
            list containing pixel-wavelength pairs.
145

146
        """
147

148
        merged = []
4✔
149

150
        for pairs in candidates:
4✔
151

152
            for pair in np.array(pairs).T:
4✔
153

154
                merged.append(pair)
4✔
155

156
        return np.sort(np.array(merged))
4✔
157

158
    def _get_most_common_candidates(
4✔
159
        self, candidates, top_n_candidate, weighted
160
    ):
161
        """
162
        Takes a number of candidate pair sets and returns the most common
163
        pair for each wavelength
164

165
        Parameters
166
        ----------
167
        candidates: list of list(float, float)
168
            A list of list of peak/line pairs
169
        top_n_candidate: int
170
            Top ranked lines to be fitted.
171
        weighted: boolean
172
            If True, the distance from the atlas wavelength will be used to
173
            compute the probilitiy based on how far it is from the Gaussian
174
            distribution from the known line.
175

176
        """
177

178
        peaks = []
4✔
179
        wavelengths = []
4✔
180
        probabilities = []
4✔
181

182
        for candidate in candidates:
4✔
183

184
            peaks.extend(candidate[0])
4✔
185
            wavelengths.extend(candidate[1])
4✔
186
            probabilities.extend(candidate[2])
4✔
187

188
        peaks = np.array(peaks)
4✔
189
        wavelengths = np.array(wavelengths)
4✔
190
        probabilities = np.array(probabilities)
4✔
191

192
        out_peaks = []
4✔
193
        out_wavelengths = []
4✔
194

195
        for peak in np.unique(peaks):
4✔
196

197
            idx = np.where(peaks == peak)
4✔
198

199
            if len(idx) > 0:
4✔
200

201
                wavelengths_matched = wavelengths[idx]
4✔
202

203
                if weighted:
4✔
204

205
                    counts = probabilities[idx]
4✔
206

207
                else:
208

209
                    counts = np.ones_like(probabilities[idx])
×
210

211
                n = int(
4✔
212
                    min(top_n_candidate, len(np.unique(wavelengths_matched)))
213
                )
214

215
                unique_wavelengths = np.unique(wavelengths_matched)
4✔
216
                aggregated_count = np.zeros_like(unique_wavelengths)
4✔
217
                for j, w in enumerate(unique_wavelengths):
4✔
218

219
                    idx_j = np.where(wavelengths_matched == w)
4✔
220
                    aggregated_count[j] = np.sum(counts[idx_j])
4✔
221

222
                out_peaks.extend([peak] * n)
4✔
223
                out_wavelengths.extend(
4✔
224
                    wavelengths_matched[np.argsort(-aggregated_count)[:n]]
225
                )
226

227
        return out_peaks, out_wavelengths
4✔
228

229
    def _get_candidate_points_linear(self, candidate_tolerance):
4✔
230
        """
231
        Returns a list of peak/wavelengths pairs which agree with the fit
232

233
        (wavelength - gradient * x + intercept) < tolerance
234

235
        Note: depending on the candidate_tolerance , one peak may match with
236
        multiple wavelengths.
237

238
        Parameters
239
        ----------
240
        candidate_tolerance: float (default: 10)
241
            tolerance  (Angstroms) for considering a point to be an inlier
242
            during candidate peak/line selection. This should be reasonable
243
            small as we want to search for candidate points which are
244
            *locally* linear.
245

246
        """
247

248
        # Locate candidate points for these lines fits
249
        self.candidates = []
4✔
250

251
        for line in self.hough_lines:
4✔
252

253
            gradient, intercept = line
4✔
254

255
            predicted = gradient * self.pairs[:, 0] + intercept
4✔
256
            actual = self.pairs[:, 1]
4✔
257
            diff = np.abs(predicted - actual)
4✔
258
            mask = diff <= candidate_tolerance
4✔
259

260
            # Match the range_tolerance to 1.1775 s.d. to match the FWHM
261
            # Note that the pairs outside of the range_tolerance were already
262
            # removed in an earlier stage
263
            weight = gauss(
4✔
264
                actual[mask],
265
                1.0,
266
                predicted[mask],
267
                (self.range_tolerance + self.linearity_tolerance) * 1.1775,
268
            )
269

270
            self.candidates.append(
4✔
271
                (self.pairs[:, 0][mask], actual[mask], weight)
272
            )
273

274
    def _get_candidate_points_poly(self, candidate_tolerance):
4✔
275
        """
276
        **EXPERIMENTAL**
277

278
        Returns a list of peak/wavelengths pairs which agree with the fit
279

280
        (wavelength - gradient * x + intercept) < tolerance
281

282
        Note: depending on the candidate_tolerance, one peak may
283
        match with multiple wavelengths.
284

285
        Parameters
286
        ----------
287
        candidate_tolerance: float (default: 10)
288
            toleranceold  (Angstroms) for considering a point to be an inlier
289
            during candidate peak/line selection. This should be reasonable
290
            small as we want to search for candidate points which are
291
            *locally* linear.
292

293
        """
294

295
        if self.fit_coeff is None:
4✔
296

297
            raise ValueError(
×
298
                "A guess solution for a polynomial fit has to "
299
                "be provided as fit_coeff in fit() in order to generate "
300
                "candidates for RANSAC sampling."
301
            )
302

303
        self.candidates = []
4✔
304

305
        # actual wavelengths
306
        actual = np.array(self.atlas.get_lines())
4✔
307

308
        n = len(self.hough_lines)
4✔
309

310
        delta = (
4✔
311
            np.random.random(n) * self.range_tolerance * 2.0
312
            - self.range_tolerance
313
        )
314

315
        for d in delta:
4✔
316

317
            # predicted wavelength
318
            predicted = self.polyval(self.peaks, self.fit_coeff) + d
4✔
319
            diff = np.abs(actual - predicted)
4✔
320
            mask = diff < candidate_tolerance
4✔
321

322
            if np.sum(mask) > 0:
4✔
323

324
                weight = gauss(
4✔
325
                    actual[mask], 1.0, predicted[mask], self.range_tolerance
326
                )
327
                self.candidates.append(
4✔
328
                    [self.peaks[mask], actual[mask], weight]
329
                )
330

331
    def _match_bijective(self, candidates, peaks, fit_coeff):
4✔
332
        """
333

334
        Internal function used to return a list of inliers with a
335
        one-to-one relationship between peaks and wavelengths. This
336
        is critical as often we have several potential candidate lines
337
        for each peak. This function first iterates through each peak
338
        and selects the wavelength with the smallest error. It then
339
        iterates through this list and does the same for duplicate
340
        wavelengths.
341

342
        parameters
343
        ----------
344
        candidates: dict
345
            match candidates, internal to ransac
346

347
        peaks: list
348
            list of peaks [px]
349

350
        fit_coeff: list
351
            polynomial fit coefficients
352

353
        """
354

355
        err = []
4✔
356
        matched_x = []
4✔
357
        matched_y = []
4✔
358

359
        for peak in peaks:
4✔
360

361
            fit = self.polyval(peak, fit_coeff)
4✔
362

363
            # Get closest match for this peak
364
            errs = np.abs(fit - candidates[peak])
4✔
365
            idx = np.argmin(errs)
4✔
366

367
            err.append(errs[idx])
4✔
368
            matched_x.append(peak)
4✔
369
            matched_y.append(candidates[peak][idx])
4✔
370

371
        err = np.array(err)
4✔
372
        matched_x = np.array(matched_x)
4✔
373
        matched_y = np.array(matched_y)
4✔
374

375
        # Now we also need to resolve duplicate y's
376
        filtered_x = []
4✔
377
        filtered_y = []
4✔
378
        filtered_err = []
4✔
379

380
        for wavelength in np.unique(matched_y):
4✔
381

382
            mask = matched_y == wavelength
4✔
383
            filtered_y.append(wavelength)
4✔
384

385
            err_idx = np.argmin(err[mask])
4✔
386
            filtered_x.append(matched_x[mask][err_idx])
4✔
387
            filtered_err.append(err[mask][err_idx])
4✔
388

389
        # overwrite
390
        err = np.array(filtered_err)
4✔
391
        matched_x = np.array(filtered_x)
4✔
392
        matched_y = np.array(filtered_y)
4✔
393

394
        assert len(np.unique(matched_x)) == len(np.unique(matched_y))
4✔
395

396
        return err, matched_x, matched_y
4✔
397

398
    def _solve_candidate_ransac(
4✔
399
        self,
400
        fit_deg,
401
        fit_coeff,
402
        max_tries,
403
        candidate_tolerance,
404
        brute_force,
405
        progress,
406
    ):
407
        """
408
        Use RANSAC to sample the parameter space and give best guess
409

410
        Parameters
411
        ----------
412
        fit_deg: int
413
            The order of polynomial.
414
        fit_coeff: None or 1D numpy array
415
            Initial polynomial fit fit_coefficients.
416
        max_tries: int
417
            Number of trials of polynomial fitting.
418
        candidate_tolerance: float
419
            toleranceold  (Angstroms) for considering a point to be an inlier
420
            during candidate peak/line selection. This should be reasonable
421
            small as we want to search for candidate points which are
422
            *locally* linear.
423
        brute_force: boolean
424
            Solve all pixel-wavelength combinations with set to True.
425
        progress: boolean
426
            Show the progress bar with tdqm if set to True.
427

428
        Returns
429
        -------
430
        best_p: list
431
            A list of size fit_deg of the best fit polynomial
432
            fit_coefficient.
433
        best_err: float
434
            Arithmetic mean of the residuals.
435
        sum(best_inliers): int
436
            Number of lines fitted within the ransac_tolerance.
437
        valid_solution: boolean
438
            False if overfitted.
439

440
        """
441

442
        if self.linear:
4✔
443

444
            self._get_candidate_points_linear(candidate_tolerance)
4✔
445

446
        else:
447

448
            self._get_candidate_points_poly(candidate_tolerance)
4✔
449

450
        (
4✔
451
            self.candidate_peak,
452
            self.candidate_arc,
453
        ) = self._get_most_common_candidates(
454
            self.candidates,
455
            top_n_candidate=self.top_n_candidate,
456
            weighted=self.candidate_weighted,
457
        )
458

459
        self.fit_deg = fit_deg
4✔
460

461
        valid_solution = False
4✔
462
        best_p = None
4✔
463
        best_cost = 1e50
4✔
464
        best_err = 1e50
4✔
465
        best_mask = [False]
4✔
466
        best_residual = None
4✔
467
        best_inliers = 0
4✔
468

469
        # Note that there may be multiple matches for
470
        # each peak, that is len(x) > len(np.unique(x))
471
        x = np.array(self.candidate_peak)
4✔
472
        y = np.array(self.candidate_arc)
4✔
473

474
        # Filter close wavelengths
475
        if self.filter_close:
4✔
476

477
            unique_y = np.unique(y)
4✔
478
            idx = np.argwhere(
4✔
479
                unique_y[1:] - unique_y[0:-1] < 3 * self.ransac_tolerance
480
            )
481
            separation_mask = np.argwhere((y == unique_y[idx]).sum(0) == 0)
4✔
482
            y = y[separation_mask].flatten()
4✔
483
            x = x[separation_mask].flatten()
4✔
484

485
        # If the number of lines is smaller than the number of degree of
486
        # polynomial fit, return failed fit.
487
        if len(np.unique(x)) <= self.fit_deg:
4✔
488

489
            return (best_p, best_err, sum(best_mask), 0, False)
×
490

491
        # Brute force check all combinations. If the request sample_size is
492
        # the same or larger than the available lines, it is essentially a
493
        # brute force.
494
        if brute_force or (self.sample_size >= len(np.unique(x))):
4✔
495

496
            idx = range(len(x))
×
497
            sampler = itertools.combinations(idx, self.sample_size)
×
498
            self.sample_size = len(np.unique(x))
×
499

500
        else:
501

502
            sampler = range(int(max_tries))
4✔
503

504
        if progress:
4✔
505

506
            sampler_list = tqdm(sampler)
4✔
507

508
        else:
509

510
            sampler_list = sampler
×
511

512
        peaks = np.sort(np.unique(x))
4✔
513
        idx = range(len(peaks))
4✔
514

515
        # Build a key(pixel)-value(wavelength) dictionary from the candidates
516
        candidates = {}
4✔
517

518
        for p in np.unique(x):
4✔
519

520
            candidates[p] = y[x == p]
4✔
521

522
        if self.ht.xedges is not None:
4✔
523

524
            xbin_size = (self.ht.xedges[1] - self.ht.xedges[0]) / 2.0
4✔
525
            ybin_size = (self.ht.yedges[1] - self.ht.yedges[0]) / 2.0
4✔
526

527
            if np.isfinite(self.hough_weight):
4✔
528

529
                twoditp = interpolate.RectBivariateSpline(
4✔
530
                    self.ht.xedges[1:] - xbin_size,
531
                    self.ht.yedges[1:] - ybin_size,
532
                    self.ht.hist,
533
                )
534

535
        else:
536

537
            twoditp = None
×
538

539
        # Calculate initial error given pre-existing fit
540
        if fit_coeff is not None:
4✔
541
            err, _, _ = self._match_bijective(candidates, peaks, fit_coeff)
4✔
542
            best_cost = sum(err)
4✔
543
            best_err = np.sqrt(np.mean(err**2.0))
4✔
544

545
        # The histogram is fixed, so pre-computed outside the loop
546
        if not brute_force:
4✔
547

548
            # weight the probability of choosing the sample by the inverse
549
            # line density
550
            h = np.histogram(peaks, bins=10)
4✔
551
            prob = 1.0 / h[0][np.digitize(peaks, h[1], right=True) - 1]
4✔
552
            prob = prob / np.sum(prob)
4✔
553

554
        for sample in sampler_list:
4✔
555

556
            keep_trying = True
4✔
557
            self.logger.debug(sample)
4✔
558

559
            while keep_trying:
4✔
560

561
                stop_n_candidateow = False
4✔
562

563
                if brute_force:
4✔
564

565
                    x_hat = x[[sample]]
×
566
                    y_hat = y[[sample]]
×
567

568
                else:
569

570
                    # Pick some random peaks
571
                    x_hat = np.random.choice(
4✔
572
                        peaks, self.sample_size, replace=False, p=prob
573
                    )
574
                    y_hat = []
4✔
575

576
                    # Pick a random wavelength for this x
577
                    for _x in x_hat:
4✔
578

579
                        y_choice = candidates[_x]
4✔
580

581
                        # Avoid picking a y that's already associated with
582
                        # another x
583
                        if not set(y_choice).issubset(set(y_hat)):
4✔
584

585
                            y_temp = np.random.choice(y_choice)
4✔
586

587
                            while y_temp in y_hat:
4✔
588

589
                                y_temp = np.random.choice(y_choice)
4✔
590

591
                            y_hat.append(y_temp)
4✔
592

593
                        else:
594

595
                            self.logger.debug(
4✔
596
                                "Not possible to draw a unique "
597
                                "set of atlas wavelengths."
598
                            )
599
                            stop_n_candidateow = True
4✔
600
                            break
4✔
601

602
                if stop_n_candidateow:
4✔
603

604
                    break
4✔
605

606
                # insert user given known pairs
607
                if self.pix_known is not None:
4✔
608

609
                    x_hat = np.concatenate((x_hat, self.pix_known))
×
610
                    y_hat = np.concatenate((y_hat, self.wave_known))
×
611

612
                # Try to fit the data.
613
                # This doesn't need to be robust, it's an exact fit.
614
                fit_coeffs = self.polyfit(x_hat, y_hat, self.fit_deg)
4✔
615

616
                # Check the intercept.
617
                if (fit_coeffs[0] < self.min_intercept) | (
4✔
618
                    fit_coeffs[0] > self.max_intercept
619
                ):
620

621
                    self.logger.debug("Intercept exceeds bounds.")
4✔
622
                    continue
4✔
623

624
                # Check monotonicity.
625
                pix_min = peaks[0] - np.ptp(peaks) * 0.2
4✔
626
                pix_max = peaks[-1] + np.ptp(peaks) * 0.2
4✔
627
                self.logger.debug((pix_min, pix_max))
4✔
628

629
                if not np.all(
4✔
630
                    np.diff(
631
                        self.polyval(
632
                            np.arange(pix_min, pix_max, 1), fit_coeffs
633
                        )
634
                    )
635
                    > 0
636
                ):
637

638
                    self.logger.debug(
4✔
639
                        "Solution is not monotonically increasing."
640
                    )
641
                    continue
4✔
642

643
                # Compute error and filter out many-to-one matches
644
                err, matched_x, matched_y = self._match_bijective(
4✔
645
                    candidates, peaks, fit_coeffs
646
                )
647

648
                if len(matched_x) == 0:
4✔
649
                    continue
×
650

651
                # M-SAC Estimator (Torr and Zisserman, 1996)
652
                err[err > self.ransac_tolerance] = self.ransac_tolerance
4✔
653

654
                # use the Hough space density as weights for the cost function
655
                wave = self.polyval(self.pixel_list, fit_coeffs)
4✔
656
                gradient = self.polyval(
4✔
657
                    self.pixel_list, _derivative(fit_coeffs)
658
                )
659
                intercept = wave - gradient * self.pixel_list
4✔
660

661
                # modified cost function weighted by the Hough space density
662
                if (self.hough_weight is not None) & (twoditp is not None):
4✔
663

664
                    weight = self.hough_weight * np.sum(
4✔
665
                        twoditp(intercept, gradient, grid=False)
666
                    )
667

668
                else:
669

670
                    weight = 1.0
×
671

672
                cost = (
4✔
673
                    sum(err)
674
                    / (len(err) - len(fit_coeffs) + 1)
675
                    / (weight + 1e-9)
676
                )
677

678
                # If this is potentially a new best fit, then handle that first
679
                if cost <= best_cost:
4✔
680

681
                    # reject lines outside the rms limit (ransac_tolerance)
682
                    # TODO: should n_inliers be recalculated from the robust
683
                    # fit?
684
                    mask = err < self.ransac_tolerance
4✔
685
                    n_inliers = sum(mask)
4✔
686
                    matched_peaks = matched_x[mask]
4✔
687
                    matched_atlas = matched_y[mask]
4✔
688

689
                    if len(matched_peaks) <= self.fit_deg:
4✔
690

691
                        self.logger.debug(
4✔
692
                            "Too few good candidates for fitting."
693
                        )
694
                        continue
4✔
695

696
                    # Now we do a robust fit
697
                    try:
4✔
698

699
                        coeffs = models.robust_polyfit(
4✔
700
                            matched_peaks, matched_atlas, self.fit_deg
701
                        )
702

703
                    except np.linalg.LinAlgError:
×
704

705
                        self.logger.warning(
×
706
                            "Linear algebra error in robust fit"
707
                        )
708
                        continue
×
709

710
                    # Get the residual of the fit
711
                    residual = (
4✔
712
                        self.polyval(matched_peaks, coeffs) - matched_atlas
713
                    )
714
                    residual[
4✔
715
                        np.abs(residual) > self.ransac_tolerance
716
                    ] = self.ransac_tolerance
717

718
                    rms_residual = np.sqrt(np.mean(residual**2))
4✔
719

720
                    # Make sure that we don't accept fits with zero error
721
                    if rms_residual < self.minimum_fit_error:
4✔
722

723
                        self.logger.debug(
4✔
724
                            "Fit error too small, " "{:1.2f}.".format(best_err)
725
                        )
726

727
                        continue
4✔
728

729
                    # Check that we have enough inliers based on user specified
730
                    # constraints
731

732
                    if n_inliers < self.minimum_matches:
4✔
733

734
                        self.logger.debug(
4✔
735
                            "Not enough matched peaks for valid solution, "
736
                            "user specified {}.".format(self.minimum_matches)
737
                        )
738
                        continue
4✔
739

740
                    if n_inliers < self.minimum_peak_utilisation * len(
4✔
741
                        self.peaks
742
                    ):
743

744
                        self.logger.debug(
×
745
                            "Not enough matched peaks for valid solution, "
746
                            "user specified {:1.2f} %.".format(
747
                                100 * self.minimum_matches
748
                            )
749
                        )
750
                        continue
×
751

752
                    # If the best fit is accepted, update the lists
753
                    best_cost = cost
4✔
754
                    best_inliers = n_inliers
4✔
755
                    best_p = coeffs
4✔
756
                    best_err = rms_residual
4✔
757
                    best_residual = residual
4✔
758
                    self.matched_peaks = list(copy.deepcopy(matched_peaks))
4✔
759
                    self.matched_atlas = list(copy.deepcopy(matched_atlas))
4✔
760

761
                    # Sanity check that matching peaks/atlas lines are 1:1
762
                    assert len(np.unique(self.matched_peaks)) == len(
4✔
763
                        self.matched_peaks
764
                    )
765
                    assert len(np.unique(self.matched_atlas)) == len(
4✔
766
                        self.matched_atlas
767
                    )
768
                    assert len(np.unique(self.matched_atlas)) == len(
4✔
769
                        np.unique(self.matched_peaks)
770
                    )
771

772
                    if progress:
4✔
773

774
                        sampler_list.set_description(
4✔
775
                            "Most inliers: {:d}, "
776
                            "best error: {:1.4f}".format(
777
                                best_inliers, best_err
778
                            )
779
                        )
780

781
                    # Break early if all peaks are matched
782
                    if best_inliers == len(peaks):
4✔
783
                        break
4✔
784

785
                # If we got this far, then we can continue to the next sample
786
                keep_trying = False
4✔
787

788
        # Overfit check
789
        if best_inliers <= self.fit_deg + 1:
4✔
790

791
            valid_solution = False
×
792

793
        else:
794

795
            valid_solution = True
4✔
796

797
        # If we totally failed then this can be empty
798
        assert best_inliers == len(self.matched_peaks)
4✔
799
        assert best_inliers == len(self.matched_atlas)
4✔
800

801
        assert len(self.matched_atlas) == len(set(self.matched_atlas))
4✔
802

803
        self.logger.info("Found: {}".format(best_inliers))
4✔
804

805
        return best_p, best_err, best_residual, best_inliers, valid_solution
4✔
806

807
    def _adjust_polyfit(self, delta, fit, tolerance, min_frac):
4✔
808
        """
809
        **EXPERIMENTAL**
810

811
        Parameters
812
        ----------
813
        delta: list or numpy.ndarray
814
            The first n polynomial coefficients to be shifted by delta.
815
        fit: list or numpy.ndarray
816
            The polynomial coefficients.
817
        tolerance: float
818
            The maximum difference between fit and atlas to be accounted for
819
            the best fit.
820
        min_frac: float
821
            The minimum fraction of lines to be used.
822

823
        Return
824
        ------
825
        lsq: float
826
            The least squared value of the fit.
827

828
        """
829

830
        # x is wavelength
831
        # x_matched is pixel
832
        x_matched = []
4✔
833
        # y_matched is wavelength
834
        y_matched = []
4✔
835
        fit_new = fit.copy()
4✔
836

837
        atlas_lines = self.atlas.get_lines()
4✔
838

839
        for i, d in enumerate(delta):
4✔
840

841
            fit_new[i] += d
4✔
842

843
        for p in self.peaks:
4✔
844

845
            x = self.polyval(p, fit_new)
4✔
846
            diff = atlas_lines - x
4✔
847
            diff_abs = np.abs(diff)
4✔
848
            idx = np.argmin(diff_abs)
4✔
849

850
            if diff_abs[idx] < tolerance:
4✔
851

852
                x_matched.append(p)
4✔
853
                y_matched.append(atlas_lines[idx])
4✔
854

855
        x_matched = np.array(x_matched)
4✔
856
        y_matched = np.array(y_matched)
4✔
857

858
        dof = len(x_matched) - len(fit_new) - 1
4✔
859

860
        if dof < 1:
4✔
861

862
            return np.inf
×
863

864
        if len(x_matched) < len(self.peaks) * min_frac:
4✔
865

866
            return np.inf
×
867

868
        if not np.all(
4✔
869
            np.diff(self.polyval(np.sort(self.pixel_list), fit_new)) > 0
870
        ):
871

872
            self.logger.info("not monotonic")
×
873
            return np.inf
×
874

875
        lsq = (
4✔
876
            np.sum((y_matched - self.polyval(x_matched, fit_new)) ** 2.0) / dof
877
        )
878

879
        return lsq
4✔
880

881
    def which_plotting_library(self):
4✔
882
        """
883
        Call to show if the Calibrator is using matplotlib or plotly library
884
        (or neither).
885

886
        """
887

888
        if self.plot_with_matplotlib:
4✔
889

890
            self.logger.info("Using matplotlib.")
4✔
891
            return "matplotlib"
4✔
892

893
        elif self.plot_with_plotly:
4✔
894

895
            self.logger.info("Using plotly.")
4✔
896
            return "plotly"
4✔
897

898
        else:
899

900
            self.logger.warning("Neither maplotlib nor plotly are imported.")
×
901
            return None
×
902

903
    def use_matplotlib(self):
4✔
904
        """
905
        Call to switch to matplotlib.
906

907
        """
908

909
        self.plot_with_matplotlib = True
4✔
910
        self.plot_with_plotly = False
4✔
911

912
    def use_plotly(self):
4✔
913
        """
914
        Call to switch to plotly.
915

916
        """
917

918
        self.plot_with_plotly = True
4✔
919
        self.plot_with_matplotlib = False
4✔
920

921
    def set_calibrator_properties(
4✔
922
        self,
923
        num_pix=None,
924
        pixel_list=None,
925
        plotting_library=None,
926
        seed=None,
927
        logger_name="Calibrator",
928
        log_level="warning",
929
    ):
930
        """
931
        Initialise the calibrator object.
932

933
        Parameters
934
        ----------
935
        num_pix: int
936
            Number of pixels in the spectral axis.
937
        pixel_list: list
938
            pixel value of the of the spectrum, this is only needed if the
939
            spectrum spans multiple detector arrays.
940
        plotting_library: string (default: 'matplotlib')
941
            Choose between matplotlib and plotly.
942
        seed: int
943
            Set an optional seed for random number generators. If used,
944
            this parameter must be set prior to calling RANSAC. Useful
945
            for deterministic debugging.
946
        logger_name: string (default: 'Calibrator')
947
            The name of the logger. It can use an existing logger if a
948
            matching name is provided.
949
        log_level: string (default: 'info')
950
            Choose {critical, error, warning, info, debug, notset}.
951

952
        """
953

954
        # initialise the logger
955
        self.logger = logging.getLogger(logger_name)
4✔
956
        self.logger.propagate = False
4✔
957
        level = logging.getLevelName(log_level.upper())
4✔
958
        self.logger.setLevel(level)
4✔
959
        self.log_level = level
4✔
960

961
        formatter = logging.Formatter(
4✔
962
            "[%(asctime)s] %(levelname)s [%(filename)s:%(lineno)d] "
963
            "%(message)s",
964
            datefmt="%a, %d %b %Y %H:%M:%S",
965
        )
966

967
        if len(self.logger.handlers) == 0:
4✔
968
            handler = logging.StreamHandler()
4✔
969
            handler.setFormatter(formatter)
4✔
970
            self.logger.addHandler(handler)
4✔
971

972
        # set the num_pix
973
        if num_pix is not None:
4✔
974

975
            self.num_pix = num_pix
4✔
976

977
        elif self.num_pix is None:
4✔
978

979
            try:
4✔
980

981
                self.num_pix = len(self.spectrum)
4✔
982

983
            except Exception as e:
4✔
984

985
                self.logger.warning(e)
4✔
986
                self.logger.warning(
4✔
987
                    "Neither num_pix nor spectrum is given, "
988
                    "it uses 1.1 times max(peaks) as the "
989
                    "maximum pixel value."
990
                )
991
                self.num_pix = 1.1 * max(self.peaks)
4✔
992

993
        else:
994

995
            pass
2✔
996

997
        self.logger.info("num_pix is set to {}.".format(num_pix))
4✔
998

999
        # set the pixel_list
1000
        if pixel_list is not None:
4✔
1001

1002
            self.pixel_list = np.asarray(pixel_list)
4✔
1003

1004
        elif self.pixel_list is None:
4✔
1005

1006
            self.pixel_list = np.arange(self.num_pix)
4✔
1007

1008
        else:
1009

1010
            pass
2✔
1011

1012
        self.logger.info("pixel_list is set to {}.".format(pixel_list))
4✔
1013

1014
        # map the list position to the pixel value
1015
        self.pix_to_rawpix = interpolate.interp1d(
4✔
1016
            self.pixel_list,
1017
            np.arange(len(self.pixel_list)),
1018
            fill_value="extrapolate",
1019
        )
1020

1021
        if seed is not None:
4✔
1022
            np.random.seed(seed)
×
1023

1024
        # if the plotting library is supplied
1025
        if plotting_library is not None:
4✔
1026

1027
            # set the plotting library
1028
            self.plotting_library = plotting_library
×
1029

1030
        # if the plotting library is not supplied but the calibrator does not
1031
        # know which library to use yet.
1032
        elif self.plotting_library is None:
4✔
1033

1034
            self.plotting_library = "matplotlib"
4✔
1035

1036
        # everything is good
1037
        else:
1038

1039
            pass
2✔
1040

1041
        # check the choice of plotting library is available and used.
1042
        if self.plotting_library == "matplotlib":
4✔
1043

1044
            self.use_matplotlib()
4✔
1045
            self.logger.info("Plotting with matplotlib.")
4✔
1046

1047
        elif self.plotting_library == "plotly":
×
1048

1049
            self.use_plotly()
×
1050
            self.logger.info("Plotting with plotly.")
×
1051

1052
        else:
1053

1054
            self.logger.warning(
×
1055
                "Unknown plotting_library, please choose from "
1056
                "matplotlib or plotly. Execute use_matplotlib() or "
1057
                "use_plotly() to manually select the library."
1058
            )
1059

1060
    def set_hough_properties(
4✔
1061
        self,
1062
        num_slopes=None,
1063
        xbins=None,
1064
        ybins=None,
1065
        min_wavelength=None,
1066
        max_wavelength=None,
1067
        range_tolerance=None,
1068
        linearity_tolerance=None,
1069
    ):
1070
        """
1071
        parameters
1072
        ----------
1073
        num_slopes: int (default: 1000)
1074
            Number of slopes to consider during Hough transform
1075
        xbins: int (default: 50)
1076
            Number of bins for Hough accumulation
1077
        ybins: int (default: 50)
1078
            Number of bins for Hough accumulation
1079
        min_wavelength: float (default: 3000)
1080
            Minimum wavelength of the spectrum.
1081
        max_wavelength: float (default: 9000)
1082
            Maximum wavelength of the spectrum.
1083
        range_tolerance: float (default: 500)
1084
            Estimation of the error on the provided spectral range
1085
            e.g. 3000-5000 with tolerance 500 will search for
1086
            solutions that may satisfy 2500-5500
1087
        linearity_tolerance: float (default: 100)
1088
            A toleranceold (Ansgtroms) which defines some padding around the
1089
            range tolerance to allow for non-linearity. This should be the
1090
            maximum expected excursion from linearity.
1091

1092
        """
1093

1094
        # set the num_slopes
1095
        if num_slopes is not None:
4✔
1096

1097
            self.num_slopes = int(num_slopes)
4✔
1098

1099
        elif self.num_slopes is None:
4✔
1100

1101
            self.num_slopes = 2000
4✔
1102

1103
        else:
1104

1105
            pass
2✔
1106

1107
        # set the xbins
1108
        if xbins is not None:
4✔
1109

1110
            self.xbins = xbins
4✔
1111

1112
        elif self.xbins is None:
4✔
1113

1114
            self.xbins = 100
4✔
1115

1116
        else:
1117

1118
            pass
2✔
1119

1120
        # set the ybins
1121
        if ybins is not None:
4✔
1122

1123
            self.ybins = ybins
4✔
1124

1125
        elif self.ybins is None:
4✔
1126

1127
            self.ybins = 100
4✔
1128

1129
        else:
1130

1131
            pass
2✔
1132

1133
        # set the min_wavelength
1134
        if min_wavelength is not None:
4✔
1135

1136
            self.min_wavelength = min_wavelength
4✔
1137

1138
        elif self.min_wavelength is None:
4✔
1139

1140
            self.min_wavelength = 3000.0
4✔
1141

1142
        else:
1143

1144
            pass
×
1145

1146
        # set the max_wavelength
1147
        if max_wavelength is not None:
4✔
1148

1149
            self.max_wavelength = max_wavelength
4✔
1150

1151
        elif self.max_wavelength is None:
4✔
1152

1153
            self.max_wavelength = 9000.0
4✔
1154

1155
        else:
1156

1157
            pass
×
1158

1159
        # Set the range_tolerance
1160
        if range_tolerance is not None:
4✔
1161

1162
            self.range_tolerance = range_tolerance
4✔
1163

1164
        elif self.range_tolerance is None:
4✔
1165

1166
            self.range_tolerance = 500
4✔
1167

1168
        else:
1169

1170
            pass
×
1171

1172
        # Set the linearity_tolerance
1173
        if linearity_tolerance is not None:
4✔
1174

1175
            self.linearity_tolerance = linearity_tolerance
×
1176

1177
        elif self.linearity_tolerance is None:
4✔
1178

1179
            self.linearity_tolerance = 100
4✔
1180

1181
        else:
1182

1183
            pass
2✔
1184

1185
        # Start wavelength in the spectrum, +/- some tolerance
1186
        self.min_intercept = self.min_wavelength - self.range_tolerance
4✔
1187
        self.max_intercept = self.min_wavelength + self.range_tolerance
4✔
1188

1189
        self.min_slope = (
4✔
1190
            (
1191
                self.max_wavelength
1192
                - self.range_tolerance
1193
                - self.linearity_tolerance
1194
            )
1195
            - (
1196
                self.min_intercept
1197
                + self.range_tolerance
1198
                + self.linearity_tolerance
1199
            )
1200
        ) / np.ptp(self.pixel_list)
1201

1202
        self.max_slope = (
4✔
1203
            (
1204
                self.max_wavelength
1205
                + self.range_tolerance
1206
                + self.linearity_tolerance
1207
            )
1208
            - (
1209
                self.min_intercept
1210
                - self.range_tolerance
1211
                - self.linearity_tolerance
1212
            )
1213
        ) / np.ptp(self.pixel_list)
1214

1215
        if self.atlas is not None:
4✔
1216

1217
            self._generate_pairs()
×
1218

1219
    def set_ransac_properties(
4✔
1220
        self,
1221
        sample_size=None,
1222
        top_n_candidate=None,
1223
        linear=None,
1224
        filter_close=None,
1225
        ransac_tolerance=None,
1226
        candidate_weighted=None,
1227
        hough_weight=None,
1228
        minimum_matches=None,
1229
        minimum_peak_utilisation=None,
1230
        minimum_fit_error=None,
1231
    ):
1232
        """
1233
        Configure the Calibrator. This may require some manual twiddling before
1234
        the calibrator can work efficiently. However, in theory, a large
1235
        max_tries in fit() should provide a good solution in the expense of
1236
        performance (minutes instead of seconds).
1237

1238
        Parameters
1239
        ----------
1240
        sample_size: int (default: 5)
1241
            Number of samples used for fitting, this is automatically
1242
            set to the polynomial degree + 1, but a larger value can
1243
            be specified here.
1244
        top_n_candidate: int (default: 5)
1245
            Top ranked lines to be fitted.
1246
        linear: boolean (default: True)
1247
            True to use the hough transformed gradient, otherwise, use the
1248
            known polynomial.
1249
        filter_close: boolean (default: False)
1250
            Remove the pairs that are out of bounds in the hough space.
1251
        ransac_tolerance: float (default: 1)
1252
            The distance criteria  (Angstroms) to be considered an inlier to a
1253
            fit. This should be close to the size of the expected residuals on
1254
            the final fit (e.g. 1A is typical)
1255
        candidate_weighted: boolean (default: True)
1256
            Set to True to down-weight pairs that are far from the fit.
1257
        hough_weight: float or None (default: 1.0)
1258
            Set to use the hough space to weigh the fit. The theoretical
1259
            optimal weighting is unclear. The larger the value, the heavily it
1260
            relies on the overdensity in the hough space for a good fit.
1261
        minimum_matches: int or None (default: 0)
1262
            Set to only accept fit solutions with a minimum number of
1263
            matches. Setting this will prevent the fitting function from
1264
            accepting spurious low-error fits.
1265
        minimum_peak_utilisation: int or None (default: 0)
1266
            Set to only accept fit solutions with a fraction of matches. This
1267
            option is convenient if you don't want to specify an absolute
1268
            number of atlas lines. Range is 0 - 1 inclusive.
1269
        minimum_fit_error: float or None (default: 1e-4)
1270
            Set to only accept fits with a minimum error. This avoids
1271
            accepting "perfect" fit solutions with zero errors. However
1272
            if you have an extremely good system, you may want to set this
1273
            tolerance lower.
1274

1275
        """
1276

1277
        # Setting the sample_size
1278
        if sample_size is not None:
4✔
1279

1280
            self.sample_size = sample_size
4✔
1281

1282
        elif self.sample_size is None:
4✔
1283

1284
            self.sample_size = 5
4✔
1285

1286
        else:
1287

1288
            pass
2✔
1289

1290
        # Set top_n_candidate
1291
        if top_n_candidate is not None:
4✔
1292

1293
            self.top_n_candidate = top_n_candidate
4✔
1294

1295
        elif self.top_n_candidate is None:
4✔
1296

1297
            self.top_n_candidate = 5
4✔
1298

1299
        else:
1300

1301
            pass
2✔
1302

1303
        # Set linear
1304
        if linear is not None:
4✔
1305

1306
            self.linear = linear
4✔
1307

1308
        elif self.linear is None:
4✔
1309

1310
            self.linear = True
4✔
1311

1312
        else:
1313

1314
            pass
2✔
1315

1316
        # Set to filter closely spaced lines
1317
        if filter_close is not None:
4✔
1318

1319
            self.filter_close = filter_close
4✔
1320

1321
        elif self.filter_close is None:
4✔
1322

1323
            self.filter_close = False
4✔
1324

1325
        else:
1326

1327
            pass
2✔
1328

1329
        # Set the ransac_tolerance
1330
        if ransac_tolerance is not None:
4✔
1331

1332
            self.ransac_tolerance = ransac_tolerance
×
1333

1334
        elif self.ransac_tolerance is None:
4✔
1335

1336
            self.ransac_tolerance = 5
4✔
1337

1338
        else:
1339

1340
            pass
2✔
1341

1342
        # Set to weigh the candidate pairs by the density (pixel)
1343
        if candidate_weighted is not None:
4✔
1344

1345
            self.candidate_weighted = candidate_weighted
×
1346

1347
        elif self.candidate_weighted is None:
4✔
1348

1349
            self.candidate_weighted = True
4✔
1350

1351
        else:
1352

1353
            pass
2✔
1354

1355
        # Set the multiplier of the weight of the hough density
1356
        if hough_weight is not None:
4✔
1357

1358
            self.hough_weight = hough_weight
×
1359

1360
        elif self.hough_weight is None:
4✔
1361

1362
            self.hough_weight = 1.0
4✔
1363

1364
        else:
1365

1366
            pass
2✔
1367

1368
        # Set the minimum number of desired matches
1369
        if minimum_matches is not None:
4✔
1370

1371
            assert minimum_matches > 0
4✔
1372
            self.minimum_matches = minimum_matches
4✔
1373

1374
        elif self.minimum_matches is None:
4✔
1375

1376
            self.minimum_matches = 0
4✔
1377

1378
        else:
1379

1380
            pass
2✔
1381

1382
        # Set the minimum utilisation required
1383
        if minimum_peak_utilisation is not None:
4✔
1384

1385
            assert (
×
1386
                minimum_peak_utilisation >= 0
1387
                and minimum_peak_utilisation <= 1.0
1388
            )
1389
            self.minimum_peak_utilisation = minimum_peak_utilisation
×
1390

1391
        elif self.minimum_peak_utilisation is None:
4✔
1392

1393
            self.minimum_peak_utilisation = 0
4✔
1394

1395
        else:
1396

1397
            pass
2✔
1398

1399
        # Set the minimum fit error
1400
        if minimum_fit_error is not None:
4✔
1401

1402
            assert minimum_fit_error >= 0
4✔
1403
            self.minimum_fit_error = minimum_fit_error
4✔
1404

1405
        elif self.minimum_fit_error is None:
4✔
1406

1407
            self.minimum_fit_error = 1e-4
4✔
1408

1409
        else:
1410

1411
            pass
4✔
1412

1413
    def add_atlas(
4✔
1414
        self,
1415
        elements,
1416
        min_atlas_wavelength=None,
1417
        max_atlas_wavelength=None,
1418
        min_intensity=10.0,
1419
        min_distance=10.0,
1420
        candidate_tolerance=10.0,
1421
        constrain_poly=False,
1422
        vacuum=False,
1423
        pressure=101325.0,
1424
        temperature=273.15,
1425
        relative_humidity=0.0,
1426
    ):
1427

1428
        self.logger.warning(
×
1429
            "Using add_atlas is now deprecated. "
1430
            "Please use the new Atlas class."
1431
        )
1432

1433
        if min_atlas_wavelength is None:
×
1434

1435
            min_atlas_wavelength = self.min_wavelength - self.range_tolerance
×
1436

1437
        if max_atlas_wavelength is None:
×
1438

1439
            max_atlas_wavelength = self.max_wavelength + self.range_tolerance
×
1440

NEW
1441
        if self.atlas is None:
×
1442

NEW
1443
            new_atlas = Atlas(
×
1444
                elements,
1445
                min_atlas_wavelength=min_atlas_wavelength,
1446
                max_atlas_wavelength=max_atlas_wavelength,
1447
                min_intensity=min_intensity,
1448
                min_distance=min_distance,
1449
                range_tolerance=self.range_tolerance,
1450
                vacuum=vacuum,
1451
                pressure=pressure,
1452
                temperature=temperature,
1453
                relative_humidity=relative_humidity,
1454
            )
NEW
1455
            self.atlas = new_atlas
×
1456

1457
        else:
1458

NEW
1459
            self.atlas.add(
×
1460
                elements,
1461
                min_atlas_wavelength=min_atlas_wavelength,
1462
                max_atlas_wavelength=max_atlas_wavelength,
1463
                min_intensity=min_intensity,
1464
                min_distance=min_distance,
1465
                vacuum=vacuum,
1466
                pressure=pressure,
1467
                temperature=temperature,
1468
                relative_humidity=relative_humidity,
1469
            )
1470

1471
        self.candidate_tolerance = candidate_tolerance
×
1472
        self.constrain_poly = constrain_poly
×
1473

1474
        self._generate_pairs()
×
1475

1476
    def remove_atlas_lines_range(self, wavelength, tolerance=10):
4✔
1477
        """
1478
        Remove arc lines within a certain wavelength range.
1479
        """
1480

1481
        self.atlas.remove_atlas_lines_range(wavelength, tolerance)
×
1482

1483
    def list_atlas(self):
4✔
1484
        """
1485
        List all the lines loaded to the Calibrator.
1486
        """
1487

1488
        self.atlas.list()
×
1489

1490
    def clear_atlas(self):
4✔
1491
        """
1492
        Remove all the lines loaded to the Calibrator.
1493
        """
1494

1495
        self.atlas.clear()
×
1496

1497
    def add_user_atlas(
4✔
1498
        self,
1499
        elements,
1500
        wavelengths,
1501
        intensities=None,
1502
        vacuum=False,
1503
        pressure=101325.0,
1504
        temperature=273.15,
1505
        relative_humidity=0.0,
1506
        candidate_tolerance=10,
1507
        constrain_poly=False,
1508
    ):
1509

1510
        self.logger.warning(
×
1511
            "Using add_user_atlas is now deprecated. "
1512
            "Please use the new Atlas class."
1513
        )
1514

1515
        if self.atlas is None:
×
1516

1517
            self.atlas = Atlas()
×
1518

1519
        self.atlas.add_user_atlas(
×
1520
            elements,
1521
            wavelengths,
1522
            intensities,
1523
            vacuum,
1524
            pressure,
1525
            temperature,
1526
            relative_humidity,
1527
        )
1528

1529
        self.candidate_tolerance = candidate_tolerance
×
1530
        self.constrain_poly = constrain_poly
×
1531

1532
        self._generate_pairs()
×
1533

1534
    def set_atlas(self, atlas, candidate_tolerance=10.0, constrain_poly=False):
4✔
1535
        """
1536
        Adds an atlas of arc lines to the calibrator
1537

1538
        Parameters
1539
        ----------
1540
        atlas: rascal.Atlas
1541
            Chemical symbol, case insensitive
1542
        candidate_tolerance: float (default: 10)
1543
            toleranceold  (Angstroms) for considering a point to be an inlier
1544
            during candidate peak/line selection. This should be reasonable
1545
            small as we want to search for candidate points which are
1546
            *locally* linear.
1547
        constrain_poly: boolean
1548
            Apply a polygonal constraint on possible peak/atlas pairs
1549
        """
1550

1551
        self.atlas = atlas
4✔
1552

1553
        self.candidate_tolerance = candidate_tolerance
4✔
1554
        self.constrain_poly = constrain_poly
4✔
1555

1556
        # Create a list of all possible pairs of detected peaks and lines
1557
        # from atlas
1558
        self._generate_pairs()
4✔
1559

1560
    def do_hough_transform(self, brute_force=False):
4✔
1561

1562
        if self.pairs == []:
4✔
1563

1564
            logging.warning("pairs list is empty. Try generating now.")
×
1565
            self._generate_pairs()
×
1566

1567
            if self.pairs == []:
×
1568

1569
                logging.error("pairs list is still empty.")
×
1570

1571
        # Generate the hough_points from the pairs
1572
        self.ht.set_constraints(
4✔
1573
            self.min_slope,
1574
            self.max_slope,
1575
            self.min_intercept,
1576
            self.max_intercept,
1577
        )
1578

1579
        if brute_force:
4✔
1580
            self.ht.generate_hough_points_brute_force(
4✔
1581
                self.pairs[:, 0], self.pairs[:, 1]
1582
            )
1583
        else:
1584
            self.ht.generate_hough_points(
4✔
1585
                self.pairs[:, 0], self.pairs[:, 1], num_slopes=self.num_slopes
1586
            )
1587

1588
        self.ht.bin_hough_points(self.xbins, self.ybins)
4✔
1589
        self.hough_points = self.ht.hough_points
4✔
1590
        self.hough_lines = self.ht.hough_lines
4✔
1591

1592
    def save_hough_transform(
4✔
1593
        self,
1594
        filename="hough_transform",
1595
        fileformat="npy",
1596
        delimiter="+",
1597
        to_disk=True,
1598
    ):
1599
        """
1600
        Save the HoughTransform object to memory or to disk.
1601

1602
        Parameters
1603
        ----------
1604
        filename: str
1605
            The filename of the output, not used if to_disk is False. It
1606
            will be appended with the content type.
1607
        format: str (default: 'npy')
1608
            Choose from 'npy' and json'
1609
        delimiter: str (default: '+')
1610
            Delimiter for format and content types
1611
        to_disk: boolean
1612
            Set to True to save to disk, else return a numpy array object
1613

1614
        Returns
1615
        -------
1616
        hp_hough_points: numpy.ndarray
1617
            only return if to_disk is False.
1618

1619
        """
1620

1621
        self.ht.save(
4✔
1622
            filename=filename,
1623
            fileformat=fileformat,
1624
            delimiter=delimiter,
1625
            to_disk=to_disk,
1626
        )
1627

1628
    def load_hough_transform(self, filename="hough_transform", filetype="npy"):
4✔
1629
        """
1630
        Store the binned Hough space and/or the raw Hough pairs.
1631

1632
        Parameters
1633
        ----------
1634
        filename: str (default: 'hough_transform')
1635
            The filename of the output, not used if to_disk is False. It
1636
            will be appended with the content type.
1637
        filetype: str (default: 'npy')
1638
            The file type of the saved hough transform. Choose from 'npy'
1639
            and 'json'.
1640

1641
        """
1642

1643
        self.ht.load(filename=filename, filetype=filetype)
4✔
1644

1645
    def set_known_pairs(self, pix=(), wave=()):
4✔
1646
        """
1647
        Provide manual pixel-wavelength pair(s), they will be appended to the
1648
        list of pixel-wavelength pairs after the random sample being drawn from
1649
        the RANSAC step, i.e. they are ALWAYS PRESENT in the fitting step. Use
1650
        with caution because it can skew or bias the fit significantly, make
1651
        sure the pixel value is accurate to at least 1/10 of a pixel. We do not
1652
        recommend supplying more than a coupld of known pairs unless you are
1653
        very confident with the solution and intend to skew with the known
1654
        pairs.
1655

1656
        This can be used for example for low intensity lines at the edge of
1657
        the spectrum. Or saturated lines where peaks cannot be well positioned.
1658

1659
        Parameters
1660
        ----------
1661
        pix: numeric value, list or numpy 1D array (N) (default: ())
1662
            Any pixel value, can be outside the detector chip and
1663
            serve purely as anchor points.
1664
        wave: numeric value, list or numpy 1D array (N) (default: ())
1665
            The matching wavelength for each of the pix.
1666

1667
        """
1668

1669
        pix = np.asarray(pix, dtype="float").reshape(-1)
4✔
1670
        wave = np.asarray(wave, dtype="float").reshape(-1)
4✔
1671

1672
        assert pix.size == wave.size, ValueError(
4✔
1673
            "Please check the length of the input arrays. pix has size {} "
1674
            "and wave has size {}.".format(pix.size, wave.size)
1675
        )
1676

1677
        if not all(
4✔
1678
            isinstance(p, (float, int)) & (not np.isnan(p)) for p in pix
1679
        ):
1680

1681
            raise ValueError("All pix elements have to be numeric.")
4✔
1682

1683
        if not all(
4✔
1684
            isinstance(w, (float, int)) & (not np.isnan(w)) for w in wave
1685
        ):
1686

1687
            raise ValueError("All wave elements have to be numeric.")
4✔
1688

1689
        self.pix_known = pix
4✔
1690
        self.wave_known = wave
4✔
1691

1692
    def fit(
4✔
1693
        self,
1694
        max_tries=500,
1695
        fit_deg=4,
1696
        fit_coeff=None,
1697
        fit_tolerance=5.0,
1698
        fit_type="poly",
1699
        candidate_tolerance=2.0,
1700
        brute_force=False,
1701
        progress=True,
1702
    ):
1703
        """
1704
        Solve for the wavelength calibration polynomial by getting the most
1705
        likely solution with RANSAC.
1706

1707
        Parameters
1708
        ----------
1709
        max_tries: int (default: 5000)
1710
            Maximum number of iteration.
1711
        fit_deg: int (default: 4)
1712
            The degree of the polynomial to be fitted.
1713
        fit_coeff: list (default: None)
1714
            Set the baseline of the least square fit. If no fits outform this
1715
            set of polynomial coefficients, this will be used as the best fit.
1716
        fit_tolerance: float (default: 5.0)
1717
            Sets a tolerance on whether a fit found by RANSAC is considered
1718
            acceptable
1719
        fit_type: string (default: 'poly')
1720
            One of 'poly', 'legendre' or 'chebyshev'
1721
        candidate_tolerance: float (default: 2.0)
1722
            toleranceold  (Angstroms) for considering a point to be an inlier
1723
        brute_force: boolean (default: False)
1724
            Set to True to try all possible combination in the given parameter
1725
            space
1726
        progress: boolean (default: True)
1727
            True to show progress with tdqm. It is overrid if tdqm cannot be
1728
            imported.
1729

1730
        Returns
1731
        -------
1732
        fit_coeff: list
1733
            List of best fit polynomial fit_coefficient.
1734
        matched_peaks: list
1735
            Peaks used for final fit
1736
        matched_atlas: list
1737
            Atlas lines used for final fit
1738
        rms: float
1739
            The root-mean-squared of the residuals
1740
        residual: float
1741
            Residual from the best fit
1742
        peak_utilisation: float
1743
            Fraction of detected peaks (pixel) used for calibration [0-1].
1744
        atlas_utilisation: float
1745
            Fraction of supplied arc lines (wavelength) used for
1746
            calibration [0-1].
1747

1748
        """
1749

1750
        self.max_tries = max_tries
4✔
1751
        self.fit_deg = fit_deg
4✔
1752
        self.fit_coeff = fit_coeff
4✔
1753
        if fit_coeff is not None:
4✔
1754

1755
            self.fit_deg = len(fit_coeff) - 1
4✔
1756

1757
        self.fit_tolerance = fit_tolerance
4✔
1758
        self.fit_type = fit_type
4✔
1759
        self.brute_force = brute_force
4✔
1760
        self.progress = progress
4✔
1761

1762
        if self.fit_type == "poly":
4✔
1763

1764
            self.polyfit = np.polynomial.polynomial.polyfit
4✔
1765
            self.polyval = np.polynomial.polynomial.polyval
4✔
1766

1767
        elif self.fit_type == "legendre":
4✔
1768

1769
            self.polyfit = np.polynomial.legendre.legfit
4✔
1770
            self.polyval = np.polynomial.legendre.legval
4✔
1771

1772
        elif self.fit_type == "chebyshev":
4✔
1773

1774
            self.polyfit = np.polynomial.chebyshev.chebfit
4✔
1775
            self.polyval = np.polynomial.chebyshev.chebval
4✔
1776

1777
        else:
1778

1779
            raise ValueError(
×
1780
                "fit_type must be: (1) poly, (2) legendre or (3) chebyshev"
1781
            )
1782

1783
        # Reduce sample_size if it is larger than the number of atlas available
1784
        if self.sample_size > len(self.atlas):
4✔
1785

1786
            self.logger.warning(
×
1787
                "Size of sample_size is larger than the size of atlas, "
1788
                + "the sample_size is set to match the size of atlas = "
1789
                + str(len(self.atlas))
1790
                + "."
1791
            )
1792
            self.sample_size = len(self.atlas)
×
1793

1794
        if self.sample_size <= fit_deg:
4✔
1795

1796
            self.sample_size = fit_deg + 1
4✔
1797

1798
        if (self.hough_lines is None) or (self.hough_points is None):
4✔
1799

1800
            self.do_hough_transform()
4✔
1801

1802
        if self.minimum_matches > len(self.atlas):
4✔
1803
            self.logger.warning(
×
1804
                "Requested minimum matches is greater than the atlas size"
1805
                "setting the minimum number of matches to equal the atlas"
1806
                "size = " + str(len(self.atlas)) + "."
1807
            )
1808
            self.minimum_matches = len(self.atlas)
×
1809

1810
        if self.minimum_matches > len(self.peaks):
4✔
1811
            self.logger.warning(
×
1812
                "Requested minimum matches is greater than the number of "
1813
                "peaks detected, which has a size of "
1814
                "size = " + str(len(self.peaks)) + "."
1815
            )
1816
            self.minimum_matches = len(self.peaks)
×
1817

1818
        # TODO also check whether minimum peak utilisation is greater than
1819
        # minimum matches.
1820

1821
        (
4✔
1822
            fit_coeff,
1823
            rms,
1824
            residual,
1825
            n_inliers,
1826
            valid,
1827
        ) = self._solve_candidate_ransac(
1828
            fit_deg=self.fit_deg,
1829
            fit_coeff=self.fit_coeff,
1830
            max_tries=self.max_tries,
1831
            candidate_tolerance=candidate_tolerance,
1832
            brute_force=self.brute_force,
1833
            progress=self.progress,
1834
        )
1835

1836
        peak_utilisation = n_inliers / len(self.peaks)
4✔
1837
        atlas_utilisation = n_inliers / len(self.atlas)
4✔
1838

1839
        if not valid:
4✔
1840

1841
            self.logger.warning("Invalid fit")
×
1842

1843
        if rms > self.fit_tolerance:
4✔
1844

1845
            self.logger.warning(
×
1846
                "RMS too large {} > {}".format(rms, self.fit_tolerance)
1847
            )
1848

1849
        assert fit_coeff is not None, "Couldn't fit"
4✔
1850

1851
        self.fit_coeff = fit_coeff
4✔
1852
        self.rms = rms
4✔
1853
        self.residual = residual
4✔
1854
        self.peak_utilisation = peak_utilisation
4✔
1855
        self.atlas_utilisation = atlas_utilisation
4✔
1856

1857
        return (
4✔
1858
            self.fit_coeff,
1859
            self.matched_peaks,
1860
            self.matched_atlas,
1861
            self.rms,
1862
            self.residual,
1863
            self.peak_utilisation,
1864
            self.atlas_utilisation,
1865
        )
1866

1867
    def match_peaks(
4✔
1868
        self,
1869
        fit_coeff=None,
1870
        n_delta=None,
1871
        refine=False,
1872
        tolerance=10.0,
1873
        method="Nelder-Mead",
1874
        convergence=1e-6,
1875
        min_frac=0.5,
1876
        robust_refit=True,
1877
        fit_deg=None,
1878
    ):
1879
        """
1880
        ** refine option is EXPERIMENTAL, use with caution **
1881

1882
        Refine the polynomial fit fit_coefficients. Recommended to use in it
1883
        multiple calls to first refine the lowest order and gradually increase
1884
        the order of fit_coefficients to be included for refinement. This is be
1885
        achieved by providing delta in the length matching the number of the
1886
        lowest degrees to be refined.
1887

1888
        Set refine to True to improve on the polynomial solution.
1889

1890
        Set robust_refit to True to fit all the detected peaks with the
1891
        given polynomial solution for a fit using maximal information, with
1892
        the degree of polynomial = fit_deg.
1893

1894
        Set both refine and robust_refit to False will return the list of
1895
        arc lines are well fitted by the current solution within the
1896
        tolerance limit provided.
1897

1898
        Parameters
1899
        ----------
1900
        fit_coeff: list (default: None)
1901
            List of polynomial fit fit_coefficients.
1902
        n_delta: int (default: None)
1903
            The number of the lowest polynomial order to be adjusted
1904
        refine: boolean (default: True)
1905
            Set to True to refine solution.
1906
        tolerance: float (default: 10.)
1907
            Absolute difference between fit and model in the unit of nm.
1908
        method: string (default: 'Nelder-Mead')
1909
            scipy.optimize.minimize method.
1910
        convergence: float (default: 1e-6)
1911
            scipy.optimize.minimize tol.
1912
        min_frac: float (default: 0.5)
1913
            Minimum fractionof peaks to be refitted.
1914
        robust_refit: boolean (default: True)
1915
            Set to True to fit all the detected peaks with the given polynomial
1916
            solution.
1917
        fit_deg: int (default: length of the input fit_coefficients)
1918
            Order of polynomial fit with all the detected peaks.
1919

1920
        Returns
1921
        -------
1922
        fit_coeff: list
1923
            List of best fit polynomial fit_coefficient.
1924
        peak_match: numpy 1D array
1925
            Matched peaks
1926
        atlas_match: numpy 1D array
1927
            Corresponding atlas matches
1928
        rms: float
1929
            The root-mean-squared of the residuals
1930
        residual: numpy 1D array
1931
            The difference (NOT absolute) between the data and the best-fit
1932
            solution. * EXPERIMENTAL *
1933
        peak_utilisation: float
1934
            Fraction of detected peaks (pixel) used for calibration [0-1].
1935
        atlas_utilisation: float
1936
            Fraction of supplied arc lines (wavelength) used for
1937
            calibration [0-1].
1938

1939
        """
1940

1941
        if fit_coeff is None:
4✔
1942

1943
            fit_coeff = self.fit_coeff.copy()
×
1944

1945
        if fit_deg is None:
4✔
1946

1947
            fit_deg = len(fit_coeff) - 1
4✔
1948

1949
        if refine:
4✔
1950

1951
            fit_coeff_new = fit_coeff.copy()
4✔
1952

1953
            if n_delta is None:
4✔
1954

1955
                n_delta = len(fit_coeff_new) - 1
4✔
1956

1957
            # fit everything
1958
            fitted_delta = minimize(
4✔
1959
                self._adjust_polyfit,
1960
                fit_coeff_new[: int(n_delta)] * 1e-3,
1961
                args=(fit_coeff, tolerance, min_frac),
1962
                method=method,
1963
                tol=convergence,
1964
                options={"maxiter": 10000},
1965
            ).x
1966

1967
            for i, d in enumerate(fitted_delta):
4✔
1968

1969
                fit_coeff_new[i] += d
4✔
1970

1971
            if np.any(np.isnan(fit_coeff_new)):
4✔
1972

1973
                self.logger.warning(
×
1974
                    "_adjust_polyfit() returns None. "
1975
                    "Input solution is returned."
1976
                )
1977
                return fit_coeff, None, None, None, None, None, None
×
1978

1979
        matched_peaks = []
4✔
1980
        matched_atlas = []
4✔
1981
        residuals = []
4✔
1982

1983
        atlas_lines = self.atlas.get_lines()
4✔
1984

1985
        # Find all Atlas peaks within tolerance
1986
        for p in self.peaks:
4✔
1987

1988
            x = self.polyval(p, fit_coeff)
4✔
1989
            diff = atlas_lines - x
4✔
1990
            diff_abs = np.abs(diff) < tolerance
4✔
1991

1992
            if diff_abs.any():
4✔
1993

1994
                matched_peaks.append(p)
4✔
1995
                matched_atlas.append(list(np.asarray(atlas_lines)[diff_abs]))
4✔
1996
                residuals.append(diff_abs)
4✔
1997

1998
        # Create permutations:
1999
        candidates = [[]]
4✔
2000

2001
        # match is a list
2002
        for match in matched_atlas:
4✔
2003

2004
            if len(match) == 0:
4✔
2005

2006
                continue
×
2007

2008
            self.logger.info("matched: {}".format(match))
4✔
2009

2010
            new_candidates = []
4✔
2011
            # i is an int
2012
            # candidates is a list of list
2013

2014
            for i in range(len(candidates)):
4✔
2015

2016
                # c is a list
2017
                c = candidates[i]
4✔
2018

2019
                if len(match) == 1:
4✔
2020

2021
                    c.extend(match)
4✔
2022

2023
                else:
2024

2025
                    # rep is a list of tuple
2026
                    rep = ~np.in1d(match, c)
×
2027

2028
                    if rep.any():
×
2029

2030
                        for j in np.argwhere(rep):
×
2031

2032
                            new_c = c + [match[j]]
×
2033
                            new_candidates.append(new_c)
×
2034

2035
                # Only add if new_candidates is not an empty list
2036
                if new_candidates != []:
4✔
2037

2038
                    if candidates[0] == []:
×
2039

2040
                        candidates[0] = new_candidates
×
2041

2042
                    else:
2043

2044
                        candidates.append(new_candidates)
×
2045

2046
        if len(candidates) > 1:
4✔
2047

2048
            self.logger.info(
×
2049
                "More than one match solution found, checking permutations."
2050
            )
2051

2052
        self.matched_peaks = np.array(copy.deepcopy(matched_peaks))
4✔
2053

2054
        # Check all candidates
2055
        best_err = 1e9
4✔
2056
        self.matched_atlas = None
4✔
2057
        self.residuals = None
4✔
2058

2059
        for candidate in candidates:
4✔
2060

2061
            matched_atlas = np.array(candidate)
4✔
2062

2063
            fit_coeff = self.polyfit(matched_peaks, matched_atlas, fit_deg)
4✔
2064

2065
            x = self.polyval(matched_peaks, fit_coeff)
4✔
2066
            residuals = np.abs(matched_atlas - x)
4✔
2067
            err = np.sum(residuals)
4✔
2068

2069
            if err < best_err:
4✔
2070

2071
                self.matched_atlas = matched_atlas
4✔
2072
                self.residuals = residuals
4✔
2073

2074
                best_err = err
4✔
2075

2076
        assert self.matched_atlas is not None
4✔
2077
        assert self.residuals is not None
4✔
2078

2079
        self.rms = np.sqrt(
4✔
2080
            np.nansum(self.residuals**2.0) / len(self.residuals)
2081
        )
2082

2083
        self.peak_utilisation = len(self.matched_peaks) / len(self.peaks)
4✔
2084
        self.atlas_utilisation = len(self.matched_atlas) / len(self.atlas)
4✔
2085

2086
        if robust_refit:
4✔
2087

2088
            self.fit_coeff = models.robust_polyfit(
4✔
2089
                np.asarray(self.matched_peaks),
2090
                np.asarray(self.matched_atlas),
2091
                fit_deg,
2092
            )
2093

2094
            if np.any(np.isnan(self.fit_coeff)):
4✔
2095

2096
                self.logger.warning(
×
2097
                    "robust_polyfit() returns None. "
2098
                    "Input solution is returned."
2099
                )
2100
                return (
×
2101
                    fit_coeff,
2102
                    self.matched_peaks,
2103
                    self.matched_atlas,
2104
                    self.rms,
2105
                    self.residuals,
2106
                    self.peak_utilisation,
2107
                    self.atlas_utilisation,
2108
                )
2109

2110
            else:
2111

2112
                self.residuals = self.matched_atlas - self.polyval(
4✔
2113
                    self.matched_peaks, self.fit_coeff
2114
                )
2115
                self.rms = np.sqrt(
4✔
2116
                    np.nansum(self.residuals**2.0) / len(self.residuals)
2117
                )
2118

2119
        else:
2120

2121
            self.fit_coeff = fit_coeff_new
×
2122

2123
        return (
4✔
2124
            self.fit_coeff,
2125
            self.matched_peaks,
2126
            self.matched_atlas,
2127
            self.rms,
2128
            self.residuals,
2129
            self.peak_utilisation,
2130
            self.atlas_utilisation,
2131
        )
2132

2133
    def get_pix_wave_pairs(self):
4✔
2134
        """
2135
        Return the list of matched_peaks and matched_atlas with their
2136
        position in the array.
2137

2138
        Return
2139
        ------
2140
        pw_pairs: list
2141
            List of tuples each containing the array position, peak (pixel)
2142
            and atlas (wavelength).
2143

2144
        """
2145

2146
        pw_pairs = []
4✔
2147

2148
        for i, (p, w) in enumerate(
4✔
2149
            zip(self.matched_peaks, self.matched_atlas)
2150
        ):
2151

2152
            pw_pairs.append((i, p, w))
4✔
2153
            self.logger.info(
4✔
2154
                "Position {}: pixel {} is matched to wavelength {}".format(
2155
                    i, p, w
2156
                )
2157
            )
2158

2159
        return pw_pairs
4✔
2160

2161
    def add_pix_wave_pair(self, pix, wave):
4✔
2162
        """
2163
        Adding extra pixel-wavelength pair to the Calibrator for refitting.
2164
        This DOES NOT work before the Calibrator having fit for a solution
2165
        yet: use set_known_pairs() for that purpose.
2166

2167
        Parameters
2168
        ----------
2169
        pix: float
2170
            pixel position
2171
        wave: float
2172
            wavelength
2173

2174
        """
2175

2176
        arg = np.argwhere(pix > self.matched_peaks)[0]
4✔
2177

2178
        # Only update the lists if both can be inserted
2179
        matched_peaks = np.insert(self.matched_peaks, arg, pix)
4✔
2180
        matched_atlas = np.insert(self.matched_atlas, arg, wave)
4✔
2181

2182
        self.matched_peaks = matched_peaks
4✔
2183
        self.matched_atlas = matched_atlas
4✔
2184

2185
    def remove_pix_wave_pair(self, arg):
4✔
2186
        """
2187
        Remove fitted pixel-wavelength pair from the Calibrator for refitting.
2188
        The positions can be found from get_pix_wave_pairs(). One at a time.
2189

2190
        Parameters
2191
        ----------
2192
        arg: int
2193
            The position of the pairs in the arrays.
2194

2195
        """
2196

2197
        # Only update the lists if both can be deleted
2198
        matched_peaks = np.delete(self.matched_peaks, arg)
4✔
2199
        matched_atlas = np.delete(self.matched_atlas, arg)
4✔
2200

2201
        self.matched_peaks = matched_peaks
4✔
2202
        self.matched_atlas = matched_atlas
4✔
2203

2204
    def manual_refit(
4✔
2205
        self, matched_peaks=None, matched_atlas=None, degree=None, x0=None
2206
    ):
2207
        """
2208
        Perform a refinement of the matched peaks and atlas lines.
2209

2210
        This function takes lists of matched peaks and atlases, along with
2211
        user-specified lists of lines to add/remove from the lists.
2212

2213
        Any given peaks or atlas lines to remove are selected within a
2214
        user-specified tolerance, by default 1 pixel and 5 atlas Angstrom.
2215

2216
        The final set of matching peaks/lines is then matched using a
2217
        robust polyfit of the desired degree. Optionally, an initial
2218
        fit x0 can be provided to condition the optimiser.
2219

2220
        The parameters are identical in the format in the fit() and
2221
        match_peaks() functions, however, with manual changes to the lists of
2222
        peaks and atlas, peak_utilisation and atlas_utilisation are
2223
        meaningless so this function does not return in the same format.
2224

2225
        Parameters
2226
        ----------
2227
        matched_peaks: list
2228
            List of matched peaks
2229
        matched_atlas: list
2230
            List of matched atlas lines
2231
        degree: int
2232
            Polynomial fit degree (Only used if x0 is None)
2233
        x0: list
2234
            Initial fit coefficients
2235

2236
        Returns
2237
        -------
2238
        fit_coeff: list
2239
            List of best fit polynomial coefficients
2240
        matched_peaks: list
2241
            List of matched peaks
2242
        matched_atlas: list
2243
            List of matched atlas lines
2244
        rms: float
2245
            The root-mean-squared of the residuals
2246
        residuals: numpy 1D array
2247
            Residual match error per-peak
2248

2249
        """
2250

2251
        if matched_peaks is None:
4✔
2252

2253
            matched_peaks = self.matched_peaks
×
2254

2255
        if matched_atlas is None:
4✔
2256

2257
            matched_atlas = self.matched_atlas
×
2258

2259
        if (x0 is None) and (degree is None):
4✔
2260

2261
            x0 = self.fit_coeff
4✔
2262
            degree = len(x0) - 1
4✔
2263

2264
        elif (x0 is not None) and (degree is None):
×
2265

2266
            assert isinstance(x0, list)
×
2267
            degree = len(x0) - 1
×
2268

2269
        elif (x0 is None) and (degree is not None):
×
2270

2271
            assert isinstance(degree, int)
×
2272

2273
        else:
2274

2275
            assert isinstance(x0, list)
×
2276
            assert isinstance(degree, int)
×
2277
            assert len(x0) == degree + 1
×
2278

2279
        x = np.asarray(matched_peaks)
4✔
2280
        y = np.asarray(matched_atlas)
4✔
2281

2282
        assert len(x) == len(y)
4✔
2283
        assert len(x) > 0
4✔
2284
        assert degree > 0
4✔
2285
        assert degree <= len(x) - 1
4✔
2286

2287
        # Run robust fitting again
2288
        fit_coeff_new = models.robust_polyfit(x, y, degree, x0)
4✔
2289
        self.logger.info("Input fit_coeff is {}.".format(x0))
4✔
2290
        self.logger.info("Refit fit_coeff is {}.".format(fit_coeff_new))
4✔
2291

2292
        self.fit_coeff = fit_coeff_new
4✔
2293
        self.matched_peaks = copy.deepcopy(matched_peaks)
4✔
2294
        self.matched_atlas = copy.deepcopy(matched_atlas)
4✔
2295
        self.residuals = y - self.polyval(x, fit_coeff_new)
4✔
2296
        self.rms = np.sqrt(
4✔
2297
            np.nansum(self.residuals**2.0) / len(self.residuals)
2298
        )
2299

2300
        return (
4✔
2301
            self.fit_coeff,
2302
            self.matched_peaks,
2303
            self.matched_atlas,
2304
            self.rms,
2305
            self.residuals,
2306
        )
2307

2308
    def plot_arc(
4✔
2309
        self,
2310
        pixel_list=None,
2311
        log_spectrum=False,
2312
        save_fig=False,
2313
        fig_type="png",
2314
        filename=None,
2315
        return_jsonstring=False,
2316
        renderer="default",
2317
        display=True,
2318
    ):
2319
        """
2320
        Plots the 1D spectrum of the extracted arc.
2321

2322
        parameters
2323
        ----------
2324
        pixel_list: array (default: None)
2325
            pixel value of the of the spectrum, this is only needed if the
2326
            spectrum spans multiple detector arrays.
2327
        log_spectrum: boolean (default: False)
2328
            Set to true to display the wavelength calibrated arc spectrum in
2329
            logarithmic space.
2330
        save_fig: boolean (default: False)
2331
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2332
            while the plotly uses the pio.write_html() or pio.write_image().
2333
            The support format types should be provided in fig_type.
2334
        fig_type: string (default: 'png')
2335
            Image type to be saved, choose from:
2336
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2337
        filename: string (default: None)
2338
            Provide a filename or full path. If the extension is not provided
2339
            it is defaulted to png.
2340
        return_jsonstring: boolean (default: False)
2341
            Set to True to return json strings if using plotly as the plotting
2342
            library.
2343
        renderer: string (default: 'default')
2344
            Indicate the Plotly renderer. Nothing gets displayed if json is
2345
            set to True.
2346
        display: boolean (Default: False)
2347
            Set to True to display disgnostic plot.
2348

2349
        Returns
2350
        -------
2351
        Return json strings if using plotly as the plotting library and json
2352
        is True.
2353

2354
        """
2355

2356
        return plotting.plot_arc(
4✔
2357
            self,
2358
            pixel_list=pixel_list,
2359
            log_spectrum=log_spectrum,
2360
            save_fig=save_fig,
2361
            fig_type=fig_type,
2362
            filename=filename,
2363
            return_jsonstring=return_jsonstring,
2364
            renderer=renderer,
2365
            display=display,
2366
        )
2367

2368
    def plot_search_space(
4✔
2369
        self,
2370
        fit_coeff=None,
2371
        top_n_candidate=3,
2372
        weighted=True,
2373
        save_fig=False,
2374
        fig_type="png",
2375
        filename=None,
2376
        return_jsonstring=False,
2377
        renderer="default",
2378
        display=True,
2379
    ):
2380
        """
2381
        Plots the peak/arc line pairs that are considered as potential match
2382
        candidates.
2383

2384
        If fit fit_coefficients are provided, the model solution will be
2385
        overplotted.
2386

2387
        Parameters
2388
        ----------
2389
        fit_coeff: list (default: None)
2390
            List of best polynomial fit_coefficients
2391
        top_n_candidate: int (default: 3)
2392
            Top ranked lines to be fitted.
2393
        weighted: (default: True)
2394
            Draw sample based on the distance from the matched known wavelength
2395
            of the atlas.
2396
        save_fig: boolean (default: False)
2397
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2398
            while the plotly uses the pio.write_html() or pio.write_image().
2399
            The support format types should be provided in fig_type.
2400
        fig_type: string (default: 'png')
2401
            Image type to be saved, choose from:
2402
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2403
        filename: (default: None)
2404
            The destination to save the image.
2405
        return_jsonstring: (default: False)
2406
            Set to True to save the plotly figure as json string. Ignored if
2407
            matplotlib is used.
2408
        renderer: (default: 'default')
2409
            Set the rendered for the plotly display. Ignored if matplotlib is
2410
            used.
2411
        display: boolean (Default: False)
2412
            Set to True to display disgnostic plot.
2413

2414
        Return
2415
        ------
2416
        json object if json is True.
2417

2418
        """
2419

2420
        return plotting.plot_search_space(
4✔
2421
            self,
2422
            fit_coeff=fit_coeff,
2423
            top_n_candidate=top_n_candidate,
2424
            weighted=weighted,
2425
            save_fig=save_fig,
2426
            fig_type=fig_type,
2427
            filename=filename,
2428
            return_jsonstring=return_jsonstring,
2429
            renderer=renderer,
2430
            display=display,
2431
        )
2432

2433
    def plot_fit(
4✔
2434
        self,
2435
        fit_coeff=None,
2436
        spectrum=None,
2437
        tolerance=5.0,
2438
        plot_atlas=True,
2439
        log_spectrum=False,
2440
        save_fig=False,
2441
        fig_type="png",
2442
        filename=None,
2443
        return_jsonstring=False,
2444
        renderer="default",
2445
        display=True,
2446
    ):
2447
        """
2448
        Plots of the wavelength calibrated arc spectrum, the residual and the
2449
        pixel-to-wavelength solution.
2450

2451
        Parameters
2452
        ----------
2453
        fit_coeff: 1D numpy array or list
2454
            Best fit polynomial fit_coefficients
2455
        spectrum: 1D numpy array (N)
2456
            Array of length N pixels
2457
        tolerance: float (default: 5)
2458
            Absolute difference between model and fitted wavelengths in unit
2459
            of angstrom.
2460
        plot_atlas: boolean (default: True)
2461
            Display all the relavent lines available in the atlas library.
2462
        log_spectrum: boolean (default: False)
2463
            Display the arc in log-space if set to True.
2464
        save_fig: boolean (default: False)
2465
            Save an image if set to True. matplotlib uses the pyplot.save_fig()
2466
            while the plotly uses the pio.write_html() or pio.write_image().
2467
            The support format types should be provided in fig_type.
2468
        fig_type: string (default: 'png')
2469
            Image type to be saved, choose from:
2470
            jpg, png, svg, pdf and iframe. Delimiter is '+'.
2471
        filename: string (default: None)
2472
            Provide a filename or full path. If the extension is not provided
2473
            it is defaulted to png.
2474
        return_jsonstring: boolean (default: False)
2475
            Set to True to return json strings if using plotly as the plotting
2476
            library.
2477
        renderer: string (default: 'default')
2478
            Indicate the Plotly renderer. Nothing gets displayed if json is
2479
            set to True.
2480
        display: boolean (Default: False)
2481
            Set to True to display disgnostic plot.
2482

2483
        Returns
2484
        -------
2485
        Return json strings if using plotly as the plotting library and json
2486
        is True.
2487

2488
        """
2489

2490
        if fit_coeff is None:
4✔
2491

2492
            fit_coeff = self.fit_coeff
4✔
2493

2494
        return plotting.plot_fit(
4✔
2495
            self,
2496
            fit_coeff=fit_coeff,
2497
            spectrum=spectrum,
2498
            tolerance=tolerance,
2499
            plot_atlas=plot_atlas,
2500
            log_spectrum=log_spectrum,
2501
            save_fig=save_fig,
2502
            fig_type=fig_type,
2503
            filename=filename,
2504
            return_jsonstring=return_jsonstring,
2505
            renderer=renderer,
2506
            display=display,
2507
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc