• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 3907653473

pending completion
3907653473

push

github

cylammarco
Extended the config to store peaks_effective. Fixed a bestfit inliners comparison. Fixed multiple plotting bugs. Improved automatic allocation of num_pix and effective_pixel.

109 of 120 new or added lines in 17 files covered. (90.83%)

2708 of 2927 relevant lines covered (92.52%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.23
/src/rascal/plotting.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3

4
"""
1✔
5
Some plotting functions for diagnostic and inspection.
6

7
"""
8

9
import logging
1✔
10
from typing import Union
1✔
11

12
import numpy as np
1✔
13
from rascal import calibrator, util
1✔
14
from scipy import signal
1✔
15

16
logger = logging.getLogger("plotting")
1✔
17

18

19
def _import_matplotlib():
1✔
20
    """
21
    Call to import matplotlib.
22

23
    """
24

25
    try:
1✔
26

27
        global plt
28
        import matplotlib.pyplot as plt
1✔
29

30
    except ImportError:
×
31

32
        logger.error("matplotlib package not available.")
×
33

34

35
def _import_plotly():
1✔
36
    """
37
    Call to import plotly.
38

39
    """
40

41
    try:
1✔
42

43
        global go
44
        global pio
45
        global psp
46
        global pio_color
47
        import plotly.graph_objects as go
1✔
48
        import plotly.io as pio
1✔
49
        import plotly.subplots as psp
1✔
50

51
        pio.templates["CN"] = go.layout.Template(
1✔
52
            layout_colorway=[
53
                "#1f77b4",
54
                "#ff7f0e",
55
                "#2ca02c",
56
                "#d62728",
57
                "#9467bd",
58
                "#8c564b",
59
                "#e377c2",
60
                "#7f7f7f",
61
                "#bcbd22",
62
                "#17becf",
63
            ]
64
        )
65

66
        # setting Google color palette as default
67
        pio.templates.default = "CN"
1✔
68
        pio_color = pio.templates["CN"].layout.colorway
1✔
69

70
    except ImportError:
×
71

72
        logger.error("plotly package not available.")
×
73

74

75
def plot_search_space(
1✔
76
    calibrator: "calibrator.Calibrator",
77
    fit_coeff: Union[list, np.ndarray] = None,
78
    top_n_candidate: int = 3,
79
    weighted: bool = True,
80
    save_fig: bool = False,
81
    fig_type: str = "png",
82
    filename: str = None,
83
    return_jsonstring: bool = False,
84
    renderer: str = "default",
85
    display: bool = True,
86
):
87
    """
88
    Plots the peak/arc line pairs that are considered as potential match
89
    candidates.
90

91
    If fit fit_coefficients are provided, the model solution will be
92
    overplotted.
93

94
    Parameters
95
    ----------
96
    fit_coeff: list (default: None)
97
        List of best polynomial fit_coefficients
98
    top_n_candidate: int (default: 3)
99
        Top ranked lines to be fitted.
100
    weighted: (default: True)
101
        Draw sample based on the distance from the matched known wavelength
102
        of the atlas.
103
    save_fig: boolean (default: False)
104
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
105
        while the plotly uses the pio.write_html() or pio.write_image().
106
        The support format types should be provided in fig_type.
107
    fig_type: string (default: 'png')
108
        Image type to be saved, choose from:
109
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
110
    filename: (default: None)
111
        The destination to save the image.
112
    return_jsonstring: (default: False)
113
        Set to True to save the plotly figure as json string. Ignored if
114
        matplotlib is used.
115
    renderer: (default: 'default')
116
        Set the rendered for the plotly display. Ignored if matplotlib is
117
        used.
118
    display: boolean (Default: False)
119
        Set to True to display disgnostic plot.
120

121
    Return
122
    ------
123
    json object if return_jsonstring is True.
124

125

126
    """
127

128
    # Get top linear estimates and combine
129
    candidate_peak, candidate_arc = calibrator._get_most_common_candidates(
1✔
130
        calibrator.candidates,
131
        top_n_candidate=top_n_candidate,
132
        weighted=weighted,
133
    )
134

135
    # Get the search space boundaries
136
    x = calibrator.effective_pixel
1✔
137

138
    m_1 = (
1✔
139
        calibrator.max_wavelength - calibrator.min_wavelength
140
    ) / calibrator.effective_pixel.max()
141
    y_1 = m_1 * x + calibrator.min_wavelength
1✔
142

143
    m_2 = (
1✔
144
        calibrator.max_wavelength
145
        + calibrator.range_tolerance
146
        - (calibrator.min_wavelength + calibrator.range_tolerance)
147
    ) / calibrator.effective_pixel.max()
148
    y_2 = m_2 * x + calibrator.min_wavelength + calibrator.range_tolerance
1✔
149

150
    m_3 = (
1✔
151
        calibrator.max_wavelength
152
        - calibrator.range_tolerance
153
        - (calibrator.min_wavelength - calibrator.range_tolerance)
154
    ) / calibrator.effective_pixel.max()
155
    y_3 = m_3 * x + (calibrator.min_wavelength - calibrator.range_tolerance)
1✔
156

157
    if calibrator.plot_with_matplotlib:
1✔
158
        _import_matplotlib()
1✔
159

160
        fig = plt.figure(figsize=(10, 10))
1✔
161

162
        # Plot all-pairs
163
        plt.scatter(
1✔
164
            *calibrator.pairs.T, alpha=0.2, color="C0", label="All pairs"
165
        )
166

167
        plt.scatter(
1✔
168
            calibrator._merge_candidates(calibrator.candidates)[:, 0],
169
            calibrator._merge_candidates(calibrator.candidates)[:, 1],
170
            alpha=0.2,
171
            color="C1",
172
            label="Candidate Pairs",
173
        )
174

175
        # Tolerance region around the minimum wavelength
176
        plt.text(
1✔
177
            5,
178
            calibrator.min_wavelength + 100,
179
            "Min wavelength (user-supplied)",
180
        )
181
        plt.hlines(
1✔
182
            calibrator.min_wavelength,
183
            0,
184
            calibrator.effective_pixel.max(),
185
            color="k",
186
        )
187
        plt.hlines(
1✔
188
            calibrator.min_wavelength + calibrator.range_tolerance,
189
            0,
190
            calibrator.effective_pixel.max(),
191
            linestyle="dashed",
192
            alpha=0.5,
193
            color="k",
194
        )
195
        plt.hlines(
1✔
196
            calibrator.min_wavelength - calibrator.range_tolerance,
197
            0,
198
            calibrator.effective_pixel.max(),
199
            linestyle="dashed",
200
            alpha=0.5,
201
            color="k",
202
        )
203

204
        # Tolerance region around the maximum wavelength
205
        plt.text(
1✔
206
            5,
207
            calibrator.max_wavelength + 100,
208
            "Max wavelength (user-supplied)",
209
        )
210
        plt.hlines(
1✔
211
            calibrator.max_wavelength,
212
            0,
213
            calibrator.effective_pixel.max(),
214
            color="k",
215
        )
216
        plt.hlines(
1✔
217
            calibrator.max_wavelength + calibrator.range_tolerance,
218
            0,
219
            calibrator.effective_pixel.max(),
220
            linestyle="dashed",
221
            alpha=0.5,
222
            color="k",
223
        )
224
        plt.hlines(
1✔
225
            calibrator.max_wavelength - calibrator.range_tolerance,
226
            0,
227
            calibrator.effective_pixel.max(),
228
            linestyle="dashed",
229
            alpha=0.5,
230
            color="k",
231
        )
232

233
        # The line from (first pixel, minimum wavelength) to
234
        # (last pixel, maximum wavelength), and the two lines defining the
235
        # tolerance region.
236
        plt.plot(x, y_1, label="Linear Fit", color="C3")
1✔
237
        plt.plot(
1✔
238
            x, y_2, linestyle="dashed", label="Tolerance Region", color="C3"
239
        )
240
        plt.plot(x, y_3, linestyle="dashed", color="C3")
1✔
241

242
        if fit_coeff is not None:
1✔
243

244
            plt.scatter(
×
245
                calibrator.peaks,
246
                calibrator.polyval(calibrator.peaks, fit_coeff),
247
                color="C4",
248
                label="Solution",
249
            )
250

251
        plt.scatter(
1✔
252
            candidate_peak,
253
            candidate_arc,
254
            color="C2",
255
            label="Best Candidate Pairs",
256
        )
257

258
        plt.xlim(0, calibrator.effective_pixel.max())
1✔
259
        plt.ylim(
1✔
260
            calibrator.min_wavelength - calibrator.range_tolerance,
261
            calibrator.max_wavelength + calibrator.range_tolerance,
262
        )
263

264
        plt.ylabel("Wavelength / A")
1✔
265
        plt.xlabel("Pixel")
1✔
266
        plt.legend()
1✔
267
        plt.grid()
1✔
268
        plt.tight_layout()
1✔
269

270
        if save_fig:
1✔
271

272
            fig_type = fig_type.split("+")
1✔
273

274
            if filename is None:
1✔
275

276
                filename_output = "rascal_hough_search_space"
×
277

278
            else:
279

280
                filename_output = filename
1✔
281

282
            for t in fig_type:
1✔
283

284
                if t in ["jpg", "png", "svg", "pdf"]:
1✔
285

286
                    plt.savefig(filename_output + "." + t, format=t)
1✔
287

288
        if display:
1✔
289

290
            plt.show()
×
291

292
        return fig
1✔
293

294
    elif calibrator.plot_with_plotly:
×
295
        _import_plotly()
×
296

297
        fig = go.Figure()
×
298

299
        # Plot all-pairs
300
        fig.add_trace(
×
301
            go.Scatter(
302
                x=calibrator.pairs[:, 0],
303
                y=calibrator.pairs[:, 1],
304
                mode="markers",
305
                name="All Pairs",
306
                marker=dict(color=pio_color[0], opacity=0.2),
307
            )
308
        )
309

310
        fig.add_trace(
×
311
            go.Scatter(
312
                x=calibrator._merge_candidates(calibrator.candidates)[:, 0],
313
                y=calibrator._merge_candidates(calibrator.candidates)[:, 1],
314
                mode="markers",
315
                name="Candidate Pairs",
316
                marker=dict(color=pio_color[1], opacity=0.2),
317
            )
318
        )
319
        fig.add_trace(
×
320
            go.Scatter(
321
                x=candidate_peak,
322
                y=candidate_arc,
323
                mode="markers",
324
                name="Best Candidate Pairs",
325
                marker=dict(color=pio_color[2]),
326
            )
327
        )
328

329
        # Tolerance region around the minimum wavelength
330
        fig.add_trace(
×
331
            go.Scatter(
332
                x=[0, calibrator.effective_pixel.max()],
333
                y=[calibrator.min_wavelength, calibrator.min_wavelength],
334
                name="Min/Maximum",
335
                mode="lines",
336
                line=dict(color="black"),
337
            )
338
        )
339
        fig.add_trace(
×
340
            go.Scatter(
341
                x=[0, calibrator.effective_pixel.max()],
342
                y=[
343
                    calibrator.min_wavelength + calibrator.range_tolerance,
344
                    calibrator.min_wavelength + calibrator.range_tolerance,
345
                ],
346
                name="Tolerance Range",
347
                mode="lines",
348
                line=dict(color="black", dash="dash"),
349
            )
350
        )
351
        fig.add_trace(
×
352
            go.Scatter(
353
                x=[0, calibrator.effective_pixel.max()],
354
                y=[
355
                    calibrator.min_wavelength - calibrator.range_tolerance,
356
                    calibrator.min_wavelength - calibrator.range_tolerance,
357
                ],
358
                showlegend=False,
359
                mode="lines",
360
                line=dict(color="black", dash="dash"),
361
            )
362
        )
363

364
        # Tolerance region around the minimum wavelength
365
        fig.add_trace(
×
366
            go.Scatter(
367
                x=[0, calibrator.effective_pixel.max()],
368
                y=[calibrator.max_wavelength, calibrator.max_wavelength],
369
                showlegend=False,
370
                mode="lines",
371
                line=dict(color="black"),
372
            )
373
        )
374
        fig.add_trace(
×
375
            go.Scatter(
376
                x=[0, calibrator.effective_pixel.max()],
377
                y=[
378
                    calibrator.max_wavelength + calibrator.range_tolerance,
379
                    calibrator.max_wavelength + calibrator.range_tolerance,
380
                ],
381
                showlegend=False,
382
                mode="lines",
383
                line=dict(color="black", dash="dash"),
384
            )
385
        )
386
        fig.add_trace(
×
387
            go.Scatter(
388
                x=[0, calibrator.effective_pixel.max()],
389
                y=[
390
                    calibrator.max_wavelength - calibrator.range_tolerance,
391
                    calibrator.max_wavelength - calibrator.range_tolerance,
392
                ],
393
                showlegend=False,
394
                mode="lines",
395
                line=dict(color="black", dash="dash"),
396
            )
397
        )
398

399
        # The line from (first pixel, minimum wavelength) to
400
        # (last pixel, maximum wavelength), and the two lines defining the
401
        # tolerance region.
402
        fig.add_trace(
×
403
            go.Scatter(
404
                x=x,
405
                y=y_1,
406
                mode="lines",
407
                name="Linear Fit",
408
                line=dict(color=pio_color[3]),
409
            )
410
        )
411
        fig.add_trace(
×
412
            go.Scatter(
413
                x=x,
414
                y=y_2,
415
                mode="lines",
416
                name="Tolerance Region",
417
                line=dict(
418
                    color=pio_color[3],
419
                    dash="dashdot",
420
                ),
421
            )
422
        )
423
        fig.add_trace(
×
424
            go.Scatter(
425
                x=x,
426
                y=y_3,
427
                showlegend=False,
428
                mode="lines",
429
                line=dict(
430
                    color=pio_color[3],
431
                    dash="dashdot",
432
                ),
433
            )
434
        )
435

436
        if fit_coeff is not None:
×
437

438
            fig.add_trace(
×
439
                go.Scatter(
440
                    x=calibrator.peaks,
441
                    y=calibrator.polyval(calibrator.peaks, fit_coeff),
442
                    mode="markers",
443
                    name="Solution",
444
                    marker=dict(color=pio_color[4]),
445
                )
446
            )
447

448
        # Layout, Title, Grid config
449
        fig.update_layout(
×
450
            autosize=True,
451
            yaxis=dict(
452
                title="Wavelength / A",
453
                range=[
454
                    calibrator.min_wavelength
455
                    - calibrator.range_tolerance * 1.1,
456
                    calibrator.max_wavelength
457
                    + calibrator.range_tolerance * 1.1,
458
                ],
459
                showgrid=True,
460
            ),
461
            xaxis=dict(
462
                title="Pixel",
463
                zeroline=False,
464
                range=[0.0, calibrator.effective_pixel.max()],
465
                showgrid=True,
466
            ),
467
            hovermode="closest",
468
            showlegend=True,
469
            height=800,
470
            width=1000,
471
        )
472

473
        if save_fig:
×
474

475
            fig_type = fig_type.split("+")
×
476

477
            if filename is None:
×
478

479
                filename_output = "rascal_hough_search_space"
×
480

481
            else:
482

483
                filename_output = filename
×
484

485
            for t in fig_type:
×
486

487
                if t == "iframe":
×
488

489
                    pio.write_html(fig, filename_output + "." + t)
×
490

491
                elif t in ["jpg", "png", "svg", "pdf"]:
×
492

493
                    pio.write_image(fig, filename_output + "." + t)
×
494

495
        if display:
×
496

497
            if renderer == "default":
×
498

499
                fig.show()
×
500

501
            else:
502

503
                fig.show(renderer)
×
504

505
        if return_jsonstring:
×
506

507
            return fig.to_json()
×
508

509

510
def plot_fit(
1✔
511
    calibrator: "calibrator.Calibrator",
512
    fit_coeff: Union[list, np.ndarray],
513
    spectrum: Union[list, np.ndarray] = None,
514
    plot_atlas: bool = True,
515
    log_spectrum: bool = False,
516
    save_fig: bool = False,
517
    fig_type: str = "png",
518
    filename: str = None,
519
    return_jsonstring: bool = False,
520
    renderer: str = "default",
521
    display: bool = True,
522
):
523
    """
524
    Plots of the wavelength calibrated arc spectrum, the residual and the
525
    pixel-to-wavelength solution.
526

527
    Parameters
528
    ----------
529
    fit_coeff: 1D numpy array or list
530
        Best fit polynomail fit_coefficients
531
    spectrum: 1D numpy array (N)
532
        Array of length N pixels
533
    plot_atlas: boolean (default: True)
534
        Display all the relavent lines available in the atlas library.
535
    log_spectrum: boolean (default: False)
536
        Display the arc in log-space if set to True.
537
    save_fig: boolean (default: False)
538
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
539
        while the plotly uses the pio.write_html() or pio.write_image().
540
        The support format types should be provided in fig_type.
541
    fig_type: string (default: 'png')
542
        Image type to be saved, choose from:
543
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
544
    filename: string (default: None)
545
        Provide a filename or full path. If the extension is not provided
546
        it is defaulted to png.
547
    return_jsonstring: boolean (default: False)
548
        Set to True to return json strings if using plotly as the plotting
549
        library.
550
    renderer: string (default: 'default')
551
        Indicate the Plotly renderer. Nothing gets displayed if
552
        return_jsonstring is set to True.
553
    display: boolean (Default: False)
554
        Set to True to display disgnostic plot.
555

556
    Returns
557
    -------
558
    Return json strings if using plotly as the plotting library and json
559
    is True.
560

561
    """
562

563
    if spectrum is None:
1✔
564

565
        try:
1✔
566

567
            spectrum = calibrator.spectrum
1✔
568

569
        except Exception as e:
×
570

571
            calibrator.logger.error(e)
×
572
            calibrator.logger.error(
×
573
                "Spectrum is not provided, it cannot be plotted."
574
            )
575

576
    if spectrum is not None:
1✔
577

578
        if log_spectrum:
1✔
579

580
            spectrum[spectrum < 0] = 1e-100
1✔
581
            spectrum = np.log10(spectrum)
1✔
582
            vline_max = np.nanmax(spectrum) * 2.0
1✔
583
            text_box_pos = 1.2 * max(spectrum)
1✔
584

585
        else:
586

587
            vline_max = np.nanmax(spectrum) * 1.2
1✔
588
            text_box_pos = 0.8 * max(spectrum)
1✔
589

590
    else:
591

592
        vline_max = 1.0
1✔
593
        text_box_pos = 0.5
1✔
594

595
    wave = calibrator.polyval(calibrator.effective_pixel, fit_coeff)
1✔
596

597
    fitted_diff = []
1✔
598

599
    for p, x in zip(calibrator.matched_peaks, calibrator.matched_atlas):
1✔
600

601
        diff = calibrator.atlas.get_lines() - calibrator.polyval(p, fit_coeff)
1✔
602
        idx = np.argmin(np.abs(diff))
1✔
603

604
        calibrator.logger.info(f"Peak at: {x} A")
1✔
605

606
        fitted_diff.append(diff[idx])
1✔
607
        calibrator.logger.info(
1✔
608
            f"- matched to {calibrator.atlas.get_lines()[idx]} A"
609
        )
610

611
    if calibrator.plot_with_matplotlib:
1✔
612

613
        _import_matplotlib()
1✔
614

615
        fig, (ax1, ax2, ax3) = plt.subplots(
1✔
616
            nrows=3, sharex=True, gridspec_kw={"hspace": 0.0}, figsize=(15, 9)
617
        )
618
        fig.tight_layout()
1✔
619

620
        # Plot fitted spectrum
621
        if spectrum is not None:
1✔
622

623
            ax1.plot(wave, spectrum, label="Arc Spectrum")
1✔
624
            ax1.vlines(
1✔
625
                calibrator.polyval(calibrator.peaks_effective, fit_coeff),
626
                np.array(spectrum)[calibrator.peaks.astype("int")],
627
                vline_max,
628
                linestyles="dashed",
629
                colors="C1",
630
                label="Detected Peaks",
631
            )
632

633
        # Plot the atlas
634
        if plot_atlas:
1✔
635

636
            # spec = SyntheticSpectrum(
637
            #    fit, model_type='poly', degree=len(fit)-1)
638
            # x_locs = spec.get_pixels(calibrator.atlas)
639
            ax1.vlines(
1✔
640
                calibrator.atlas.get_lines(),
641
                0,
642
                vline_max,
643
                colors="C2",
644
                label="Given Lines",
645
            )
646

647
        first_one = True
1✔
648
        for p, x in zip(calibrator.matched_peaks, calibrator.matched_atlas):
1✔
649

650
            p_idx = int(
1✔
651
                calibrator.peaks[np.where(calibrator.peaks_effective == p)[0]]
652
            )
653

654
            if spectrum is not None:
1✔
655

656
                if first_one:
1✔
657
                    ax1.vlines(
1✔
658
                        calibrator.polyval(p, fit_coeff),
659
                        spectrum[p_idx],
660
                        vline_max,
661
                        colors="C1",
662
                        label="Fitted Peaks",
663
                    )
664
                    first_one = False
1✔
665

666
                else:
667
                    ax1.vlines(
1✔
668
                        calibrator.polyval(p, fit_coeff),
669
                        spectrum[p_idx],
670
                        vline_max,
671
                        colors="C1",
672
                    )
673

674
            ax1.text(
1✔
675
                x - 3,
676
                text_box_pos,
677
                s=(
678
                    f"{calibrator.atlas.get_elements()[idx]}:"
679
                    + f"{calibrator.atlas.get_lines()[idx]:1.2f}"
680
                ),
681
                rotation=90,
682
                bbox=dict(facecolor="white", alpha=1),
683
            )
684

685
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
1✔
686

687
        ax1.grid(linestyle=":")
1✔
688
        ax1.set_ylabel("Electron Count / e-")
1✔
689

690
        if spectrum is not None:
1✔
691

692
            if log_spectrum:
1✔
693

694
                ax1.set_ylim(0, vline_max)
1✔
695

696
            else:
697

698
                ax1.set_ylim(np.nanmin(spectrum), vline_max)
1✔
699

700
        ax1.legend(loc="center right")
1✔
701

702
        # Plot the residuals
703
        ax2.scatter(
1✔
704
            calibrator.polyval(calibrator.matched_peaks, fit_coeff),
705
            fitted_diff,
706
            marker="+",
707
            color="C1",
708
        )
709
        ax2.hlines(0, wave.min(), wave.max(), linestyles="dashed")
1✔
710
        ax2.hlines(
1✔
711
            rms,
712
            wave.min(),
713
            wave.max(),
714
            linestyles="dashed",
715
            color="k",
716
            label="RMS",
717
        )
718
        ax2.hlines(
1✔
719
            -rms, wave.min(), wave.max(), linestyles="dashed", color="k"
720
        )
721
        ax2.grid(linestyle=":")
1✔
722
        ax2.set_ylabel("Residual / A")
1✔
723
        ax2.legend()
1✔
724

725
        # Plot the polynomial
726
        ax3.scatter(
1✔
727
            calibrator.polyval(calibrator.matched_peaks, fit_coeff),
728
            calibrator.matched_peaks,
729
            marker="+",
730
            color="C1",
731
            label="Fitted Peaks",
732
        )
733
        ax3.plot(
1✔
734
            wave, calibrator.effective_pixel, color="C2", label="Solution"
735
        )
736
        ax3.grid(linestyle=":")
1✔
737
        ax3.set_xlabel("Wavelength / A")
1✔
738
        ax3.set_ylabel("Pixel")
1✔
739
        ax3.legend(loc="lower right")
1✔
740
        w_min = calibrator.polyval(min(calibrator.matched_peaks), fit_coeff)
1✔
741
        w_max = calibrator.polyval(max(calibrator.matched_peaks), fit_coeff)
1✔
742
        ax3.set_xlim(w_min * 0.95, w_max * 1.05)
1✔
743

744
        plt.tight_layout()
1✔
745

746
        if save_fig:
1✔
747

748
            fig_type = fig_type.split("+")
1✔
749

750
            if filename is None:
1✔
751

752
                filename_output = "rascal_solution"
×
753

754
            else:
755

756
                filename_output = filename
1✔
757

758
            for t in fig_type:
1✔
759

760
                if t in ["jpg", "png", "svg", "pdf"]:
1✔
761

762
                    plt.savefig(filename_output + "." + t, format=t)
1✔
763

764
        if display:
1✔
765

766
            fig.show()
×
767

768
        return fig
1✔
769

770
    elif calibrator.plot_with_plotly:
1✔
771

772
        _import_plotly()
1✔
773

774
        fig = go.Figure()
1✔
775

776
        # Top plot - arc spectrum and matched peaks
777
        if spectrum is not None:
1✔
778
            fig.add_trace(
1✔
779
                go.Scatter(
780
                    x=wave,
781
                    y=spectrum,
782
                    mode="lines",
783
                    yaxis="y3",
784
                    name="Arc Spectrum",
785
                )
786
            )
787

788
            spec_max = np.nanmax(spectrum) * 1.05
1✔
789

790
        else:
791

792
            spec_max = vline_max
×
793

794
        y_fitted = []
1✔
795

796
        for p in calibrator.peaks_effective:
1✔
797

798
            x = calibrator.polyval(p, fit_coeff)
1✔
799

800
            p_idx = int(
1✔
801
                calibrator.peaks[np.where(calibrator.peaks_effective == p)[0]]
802
            )
803

804
            # Add vlines
805
            fig.add_shape(
1✔
806
                type="line",
807
                xref="x",
808
                yref="y3",
809
                x0=x,
810
                y0=spectrum[p_idx],
811
                x1=x,
812
                y1=spec_max,
813
                line=dict(color=pio_color[1], width=1),
814
            )
815

816
            if p in calibrator.matched_peaks:
1✔
817

818
                y_fitted.append(spectrum[p_idx])
1✔
819

820
        x_fitted = calibrator.polyval(calibrator.matched_peaks, fit_coeff)
1✔
821

822
        fig.add_trace(
1✔
823
            go.Scatter(
824
                x=x_fitted,
825
                y=y_fitted,
826
                mode="markers",
827
                marker=dict(color=pio_color[1]),
828
                yaxis="y3",
829
                showlegend=False,
830
            )
831
        )
832

833
        # Middle plot - Residual plot
834
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
1✔
835
        fig.add_trace(
1✔
836
            go.Scatter(
837
                x=x_fitted,
838
                y=fitted_diff,
839
                mode="markers",
840
                marker=dict(color=pio_color[1]),
841
                yaxis="y2",
842
                showlegend=False,
843
            )
844
        )
845
        fig.add_trace(
1✔
846
            go.Scatter(
847
                x=[wave.min(), wave.max()],
848
                y=[0, 0],
849
                mode="lines",
850
                line=dict(color=pio_color[0], dash="dash"),
851
                yaxis="y2",
852
                showlegend=False,
853
            )
854
        )
855
        fig.add_trace(
1✔
856
            go.Scatter(
857
                x=[wave.min(), wave.max()],
858
                y=[rms, rms],
859
                mode="lines",
860
                line=dict(color="black", dash="dash"),
861
                yaxis="y2",
862
                showlegend=False,
863
            )
864
        )
865
        fig.add_trace(
1✔
866
            go.Scatter(
867
                x=[wave.min(), wave.max()],
868
                y=[-rms, -rms],
869
                mode="lines",
870
                line=dict(color="black", dash="dash"),
871
                yaxis="y2",
872
                name="RMS",
873
            )
874
        )
875

876
        # Bottom plot - Polynomial fit for Pixel to Wavelength
877
        fig.add_trace(
1✔
878
            go.Scatter(
879
                x=x_fitted,
880
                y=calibrator.matched_peaks,
881
                mode="markers",
882
                marker=dict(color=pio_color[1]),
883
                yaxis="y1",
884
                name="Fitted Peaks",
885
            )
886
        )
887
        fig.add_trace(
1✔
888
            go.Scatter(
889
                x=wave,
890
                y=calibrator.effective_pixel,
891
                mode="lines",
892
                line=dict(color=pio_color[2]),
893
                yaxis="y1",
894
                name="Solution",
895
            )
896
        )
897

898
        # Layout, Title, Grid config
899
        if spectrum is not None:
1✔
900

901
            if log_spectrum:
1✔
902

903
                fig.update_layout(
×
904
                    yaxis3=dict(
905
                        title="Electron Count / e-",
906
                        range=[
907
                            np.log10(np.percentile(spectrum, 15)),
908
                            np.log10(spec_max),
909
                        ],
910
                        domain=[0.67, 1.0],
911
                        showgrid=True,
912
                        type="log",
913
                    )
914
                )
915

916
            else:
917

918
                fig.update_layout(
1✔
919
                    yaxis3=dict(
920
                        title="Electron Count / e-",
921
                        range=[np.percentile(spectrum, 15), spec_max],
922
                        domain=[0.67, 1.0],
923
                        showgrid=True,
924
                    )
925
                )
926

927
        fig.update_layout(
1✔
928
            autosize=True,
929
            yaxis2=dict(
930
                title="Residual / A",
931
                range=[min(fitted_diff), max(fitted_diff)],
932
                domain=[0.33, 0.66],
933
                showgrid=True,
934
            ),
935
            yaxis=dict(
936
                title="Pixel",
937
                range=[0.0, max(calibrator.effective_pixel)],
938
                domain=[0.0, 0.32],
939
                showgrid=True,
940
            ),
941
            xaxis=dict(
942
                title="Wavelength / A",
943
                zeroline=False,
944
                range=[
945
                    calibrator.polyval(
946
                        min(calibrator.matched_peaks), fit_coeff
947
                    )
948
                    * 0.95,
949
                    calibrator.polyval(
950
                        max(calibrator.matched_peaks), fit_coeff
951
                    )
952
                    * 1.05,
953
                ],
954
                showgrid=True,
955
            ),
956
            hovermode="closest",
957
            showlegend=True,
958
            height=800,
959
            width=1000,
960
        )
961

962
        if save_fig:
1✔
963

964
            fig_type = fig_type.split("+")
1✔
965

966
            if filename is None:
1✔
967

968
                filename_output = "rascal_solution"
×
969

970
            else:
971

972
                filename_output = filename
1✔
973

974
            for t in fig_type:
1✔
975

976
                if t == "iframe":
1✔
977

978
                    pio.write_html(fig, filename_output + "." + t)
×
979

980
                elif t in ["jpg", "png", "svg", "pdf"]:
1✔
981

982
                    pio.write_image(fig, filename_output + "." + t)
1✔
983

984
        if display:
1✔
985

986
            if renderer == "default":
×
987

988
                fig.show()
×
989

990
            else:
991

992
                fig.show(renderer)
×
993

994
        if return_jsonstring:
1✔
995

996
            return fig.to_json()
×
997

998
    else:
999

1000
        assert (
×
1001
            calibrator.matplotlib_imported
1002
        ), "matplotlib package not available. Plot cannot be generated."
1003
        assert (
×
1004
            calibrator.plotly_imported
1005
        ), "plotly package is not available. Plot cannot be generated."
1006

1007

1008
def plot_arc(
1✔
1009
    calibrator: "calibrator.Calibrator",
1010
    effective_pixel: Union[list, np.ndarray] = None,
1011
    log_spectrum: Union[list, np.ndarray] = False,
1012
    save_fig: bool = False,
1013
    fig_type: str = "png",
1014
    filename: str = None,
1015
    return_jsonstring: bool = False,
1016
    renderer: str = "default",
1017
    display: bool = True,
1018
):
1019
    """
1020
    Plots the 1D spectrum of the extracted arc.
1021

1022
    parameters
1023
    ----------
1024
    effective_pixel: array (default: None)
1025
        pixel value of the of the spectrum, this is only needed if the
1026
        spectrum spans multiple detector arrays.
1027
    log_spectrum: boolean (default: False)
1028
        Set to true to display the wavelength calibrated arc spectrum in
1029
        logarithmic space.
1030
    save_fig: boolean (default: False)
1031
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
1032
        while the plotly uses the pio.write_html() or pio.write_image().
1033
        The support format types should be provided in fig_type.
1034
    fig_type: string (default: 'png')
1035
        Image type to be saved, choose from:
1036
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
1037
    filename: string (default: None)
1038
        Provide a filename or full path. If the extension is not provided
1039
        it is defaulted to png.
1040
    return_jsonstring: boolean (default: False)
1041
        Set to True to return json strings if using plotly as the plotting
1042
        library.
1043
    renderer: string (default: 'default')
1044
        Indicate the Plotly renderer. Nothing gets displayed if
1045
        return_jsonstring is set to True.
1046
    display: boolean (Default: False)
1047
        Set to True to display disgnostic plot.
1048

1049
    Returns
1050
    -------
1051
    Return json strings if using plotly as the plotting library and json
1052
    is True.
1053

1054
    """
1055

1056
    if effective_pixel is None:
1✔
1057

NEW
1058
        effective_pixel = calibrator.effective_pixel
×
1059

1060
    if calibrator.plot_with_matplotlib:
1✔
1061

1062
        _import_matplotlib()
1✔
1063

1064
        fig = plt.figure(figsize=(18, 5))
1✔
1065

1066
        if calibrator.spectrum is not None:
1✔
1067
            if log_spectrum:
1✔
1068
                plt.plot(
1✔
1069
                    effective_pixel,
1070
                    np.log10(calibrator.spectrum / calibrator.spectrum.max()),
1071
                    label="Arc Spectrum",
1072
                )
1073
                plt.vlines(
1✔
1074
                    calibrator.peaks_effective,
1075
                    -2,
1076
                    0,
1077
                    label="Detected Peaks",
1078
                    color="C1",
1079
                )
1080
                plt.ylabel("log(Normalised Count)")
1✔
1081
                plt.ylim(-2, 0)
1✔
1082
            else:
1083
                plt.plot(
1✔
1084
                    effective_pixel,
1085
                    calibrator.spectrum / calibrator.spectrum.max(),
1086
                    label="Arc Spectrum",
1087
                )
1088
                plt.ylabel("Normalised Count")
1✔
1089
                plt.vlines(
1✔
1090
                    calibrator.peaks_effective,
1091
                    0,
1092
                    1.05,
1093
                    label="Detected Peaks",
1094
                    color="C1",
1095
                )
1096
            plt.title("Number of pixels: " + str(calibrator.spectrum.shape[0]))
1✔
1097
            plt.xlim(0, calibrator.spectrum.shape[0])
1✔
1098
            plt.legend()
1✔
1099

1100
        else:
1101

NEW
1102
            plt.xlim(0, max(calibrator.peaks_effective))
×
1103

1104
        plt.xlabel("Pixel (Spectral Direction)")
1✔
1105
        plt.grid()
1✔
1106
        plt.tight_layout()
1✔
1107

1108
        if save_fig:
1✔
1109

1110
            fig_type = fig_type.split("+")
1✔
1111

1112
            if filename is None:
1✔
1113

1114
                filename_output = "rascal_arc"
1✔
1115

1116
            else:
1117

1118
                filename_output = filename
1✔
1119

1120
            for t in fig_type:
1✔
1121

1122
                if t in ["jpg", "png", "svg", "pdf"]:
1✔
1123

1124
                    plt.savefig(filename_output + "." + t, format=t)
1✔
1125

1126
        if display:
1✔
1127

1128
            plt.show()
×
1129

1130
        return fig
1✔
1131

1132
    if calibrator.plot_with_plotly:
1✔
1133

1134
        _import_plotly()
1✔
1135

1136
        fig = go.Figure()
1✔
1137

1138
        if log_spectrum:
1✔
1139

1140
            # Plot all-pairs
1141
            fig.add_trace(
1✔
1142
                go.Scatter(
1143
                    x=list(effective_pixel),
1144
                    y=list(
1145
                        np.log10(
1146
                            calibrator.spectrum / calibrator.spectrum.max()
1147
                        )
1148
                    ),
1149
                    mode="lines",
1150
                    name="Arc",
1151
                )
1152
            )
1153
            xmin = min(
1✔
1154
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1155
            )
1156
            xmax = max(
1✔
1157
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1158
            )
1159

1160
        else:
1161

1162
            # Plot all-pairs
1163
            fig.add_trace(
1✔
1164
                go.Scatter(
1165
                    x=list(effective_pixel),
1166
                    y=list(calibrator.spectrum / calibrator.spectrum.max()),
1167
                    mode="lines",
1168
                    name="Arc",
1169
                )
1170
            )
1171
            xmin = min(calibrator.spectrum / calibrator.spectrum.max())
1✔
1172
            xmax = max(calibrator.spectrum / calibrator.spectrum.max())
1✔
1173

1174
        # Add vlines
1175
        for i in calibrator.peaks_effective:
1✔
1176
            fig.add_shape(
1✔
1177
                type="line",
1178
                xref="x",
1179
                yref="y",
1180
                x0=i,
1181
                y0=0,
1182
                x1=i,
1183
                y1=1.05,
1184
                line=dict(color=pio_color[1], width=1),
1185
            )
1186

1187
        fig.update_layout(
1✔
1188
            autosize=True,
1189
            yaxis=dict(
1190
                title="Normalised Count", range=[xmin, xmax], showgrid=True
1191
            ),
1192
            xaxis=dict(
1193
                title="Pixel",
1194
                zeroline=False,
1195
                range=[0.0, len(calibrator.spectrum)],
1196
                showgrid=True,
1197
            ),
1198
            hovermode="closest",
1199
            showlegend=True,
1200
            height=800,
1201
            width=1000,
1202
        )
1203

1204
        fig.update_xaxes(
1✔
1205
            showline=True, linewidth=1, linecolor="black", mirror=True
1206
        )
1207

1208
        fig.update_yaxes(
1✔
1209
            showline=True, linewidth=1, linecolor="black", mirror=True
1210
        )
1211

1212
        if save_fig:
1✔
1213

1214
            fig_type = fig_type.split("+")
1✔
1215

1216
            if filename is None:
1✔
1217

1218
                filename_output = "rascal_arc"
1✔
1219

1220
            else:
1221

1222
                filename_output = filename
1✔
1223

1224
            for t in fig_type:
1✔
1225

1226
                if t == "iframe":
1✔
1227

1228
                    pio.write_html(fig, filename_output + "." + t)
×
1229

1230
                elif t in ["jpg", "png", "svg", "pdf"]:
1✔
1231

1232
                    pio.write_image(fig, filename_output + "." + t)
1✔
1233

1234
        if display:
1✔
1235

1236
            if renderer == "default":
×
1237

1238
                fig.show()
×
1239

1240
            else:
1241

1242
                fig.show(renderer)
×
1243

1244
        if return_jsonstring:
1✔
1245

1246
            return fig.to_json()
×
1247

1248

1249
def plot_calibration_lines(
1✔
1250
    elements: Union[list, np.ndarray] = [],
1251
    linelist: str = "nist",
1252
    min_atlas_wavelength: float = 3000.0,
1253
    max_atlas_wavelength: float = 15000.0,
1254
    min_intensity: float = 5.0,
1255
    min_distance: float = 0.0,
1256
    brightest_n_lines: int = None,
1257
    pixel_scale: float = 1.0,
1258
    vacuum: bool = False,
1259
    pressure: float = 101325.0,
1260
    temperature: float = 273.15,
1261
    relative_humidity: float = 0.0,
1262
    label: bool = False,
1263
    log: bool = False,
1264
    save_fig: bool = False,
1265
    fig_type: str = "png",
1266
    filename: str = None,
1267
    display: bool = True,
1268
    fig_kwarg: dict = {"figsize": (12, 8)},
1269
):
1270
    """
1271
    Plot the expected arc spectrum. Currently only available with matplotlib.
1272

1273
    Parameters
1274
    ----------
1275
    elements: list
1276
        List of short element names, e.g. He as per NIST
1277
    linelist: str
1278
        Either 'nist' to use the default lines or path to a linelist file.
1279
    min_atlas_wavelength: int
1280
        Minimum wavelength to search, Angstrom
1281
    max_atlas_wavelength: int
1282
        Maximum wavelength to search, Angstrom
1283
    min_intensity: int
1284
        Minimum intensity to search, per NIST
1285
    min_distance: int
1286
        All ines within this distance from other lines are treated
1287
        as unresolved, all of them get removed from the list.
1288
    brightest_n_lines: int
1289
        Only return the n brightest lines
1290
    vacuum: bool
1291
        Return vacuum wavelengths
1292
    pressure: float
1293
        Atmospheric pressure, Pascal
1294
    temperature: float
1295
        Temperature in Kelvin, default room temp
1296
    relative_humidity: float
1297
        Relative humidity, percent
1298
    log: bool
1299
        Plot intensities in log scale
1300
    save_fig: boolean (default: False)
1301
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
1302
        while the plotly uses the pio.write_html() or pio.write_image().
1303
        The support format types should be provided in fig_type.
1304
    fig_type: string (default: 'png')
1305
        Image type to be saved, choose from:
1306
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
1307
    filename: string (default: None)
1308
        Provide a filename or full path. If the extension is not provided
1309
        it is defaulted to png.
1310
    display: boolean (Default: False)
1311
        Set to True to display disgnostic plot.
1312

1313
    Returns
1314
    -------
1315
    fig: matplotlib figure object
1316

1317
    """
1318

1319
    _import_matplotlib()
1✔
1320

1321
    # the min_intensity and min_distance are set to 0.0 because the
1322
    # simulated spectrum would contain them. These arguments only
1323
    # affect the labelling.
1324
    (
1✔
1325
        element_list,
1326
        wavelength_list,
1327
        intensity_list,
1328
    ) = util.load_calibration_lines(
1329
        elements=elements,
1330
        linelist=linelist,
1331
        min_atlas_wavelength=min_atlas_wavelength,
1332
        max_atlas_wavelength=max_atlas_wavelength,
1333
        min_intensity=0.0,
1334
        min_distance=0.0,
1335
        brightest_n_lines=brightest_n_lines,
1336
        vacuum=vacuum,
1337
        pressure=pressure,
1338
        temperature=temperature,
1339
        relative_humidity=relative_humidity,
1340
    )
1341

1342
    # Nyquist sampling rate (2.5) for CCD at seeing of 1 arcsec
1343
    sigma = pixel_scale * 2.5 * 1.0
1✔
1344
    x = np.arange(-100, 100.001, 0.001)
1✔
1345
    gaussian = util.gauss(x, a=1.0, x0=0.0, sigma=sigma)
1✔
1346

1347
    # Generate the equally spaced-wavelength array, and the
1348
    # corresponding intensity
1349
    w = np.around(
1✔
1350
        np.arange(min_atlas_wavelength, max_atlas_wavelength + 0.001, 0.001),
1351
        decimals=3,
1352
    ).astype("float64")
1353
    i = np.zeros_like(w)
1✔
1354

1355
    for e in elements:
1✔
1356
        i[
1✔
1357
            np.isin(
1358
                w, np.around(wavelength_list[element_list == e], decimals=3)
1359
            )
1360
        ] += intensity_list[element_list == e]
1361
    # Convolve to simulate the arc spectrum
1362
    model_spectrum = signal.convolve(i, gaussian, mode="same")
1✔
1363

1364
    # now clean up by min_intensity and min_distance
1365
    intensity_mask = util.filter_intensity(
1✔
1366
        elements,
1367
        np.column_stack((element_list, wavelength_list, intensity_list)),
1368
        min_intensity=min_intensity,
1369
    )
1370
    wavelength_list = wavelength_list[intensity_mask]
1✔
1371
    intensity_list = intensity_list[intensity_mask]
1✔
1372
    element_list = element_list[intensity_mask]
1✔
1373

1374
    distance_mask = util.filter_distance(
1✔
1375
        wavelength_list, min_distance=min_distance
1376
    )
1377
    wavelength_list = wavelength_list[distance_mask]
1✔
1378
    intensity_list = intensity_list[distance_mask]
1✔
1379
    element_list = element_list[distance_mask]
1✔
1380

1381
    fig = plt.figure(**fig_kwarg)
1✔
1382

1383
    for j, e in enumerate(elements):
1✔
1384
        e_mask = element_list == e
1✔
1385
        markerline, stemline, _ = plt.stem(
1✔
1386
            wavelength_list[e_mask],
1387
            intensity_list[e_mask],
1388
            label=e,
1389
            linefmt=f"C{j}-",
1390
        )
1391
        plt.setp(stemline, linewidth=2.0)
1✔
1392
        plt.setp(markerline, markersize=2.5, color=f"C{j}")
1✔
1393

1394
        if label:
1✔
1395

1396
            for _w in wavelength_list[e_mask]:
1✔
1397

1398
                plt.text(
1✔
1399
                    _w,
1400
                    max(model_spectrum) * 1.05,
1401
                    s=f"{e}: {_w:1.2f}",
1402
                    rotation=90,
1403
                    bbox=dict(facecolor="white", alpha=1),
1404
                )
1405

1406
            plt.vlines(
1✔
1407
                wavelength_list[e_mask],
1408
                intensity_list[e_mask],
1409
                max(model_spectrum) * 1.25,
1410
                linestyles="dashed",
1411
                lw=0.5,
1412
                color="grey",
1413
            )
1414

1415
    plt.plot(w, model_spectrum, lw=1.0, c="k", label="Simulated Arc Spectrum")
1✔
1416
    if vacuum:
1✔
1417
        plt.xlabel("Vacuum Wavelength / A")
×
1418
    else:
1419
        plt.xlabel("Air Wavelength / A")
1✔
1420
    plt.ylabel("NIST intensity")
1✔
1421
    plt.grid()
1✔
1422
    plt.xlim(min(w), max(w))
1✔
1423
    plt.ylim(0, max(model_spectrum) * 1.25)
1✔
1424
    plt.legend()
1✔
1425
    plt.tight_layout()
1✔
1426
    if log:
1✔
1427
        plt.ylim(ymin=min_intensity * 0.75)
×
1428
        plt.yscale("log")
×
1429

1430
    if save_fig:
1✔
1431

1432
        fig_type = fig_type.split("+")
1✔
1433

1434
        if filename is None:
1✔
1435

1436
            filename_output = "rascal_arc"
×
1437

1438
        else:
1439

1440
            filename_output = filename
1✔
1441

1442
        for t in fig_type:
1✔
1443

1444
            if t in ["jpg", "png", "svg", "pdf"]:
1✔
1445

1446
                plt.savefig(filename_output + "." + t, format=t)
1✔
1447

1448
    if display:
1✔
1449

1450
        plt.show()
1✔
1451

1452
    return fig
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc