• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 4501833058

pending completion
4501833058

Pull #89

github

GitHub
Merge ded7eebc8 into cec48f2a6
Pull Request #89: main catching up to v0.3.9

3 of 8 new or added lines in 3 files covered. (37.5%)

1884 of 2056 relevant lines covered (91.63%)

3.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.08
/src/rascal/plotting.py
1
import logging
4✔
2
import numpy as np
4✔
3
from scipy import signal
4✔
4

5
from .util import load_calibration_lines
4✔
6
from .util import gauss
4✔
7
from .util import filter_intensity
4✔
8
from .util import filter_separation
4✔
9

10
logger = logging.getLogger("plotting")
4✔
11

12

13
def _import_matplotlib():
4✔
14
    """
15
    Call to import matplotlib.
16

17
    """
18

19
    try:
4✔
20

21
        global plt
22
        import matplotlib.pyplot as plt
4✔
23

24
    except ImportError:
×
25

26
        logger.error("matplotlib package not available.")
×
27

28

29
def _import_plotly():
4✔
30
    """
31
    Call to import plotly.
32

33
    """
34

35
    try:
4✔
36

37
        global go
38
        global pio
39
        global psp
40
        import plotly.graph_objects as go
4✔
41
        import plotly.io as pio
4✔
42
        import plotly.subplots as psp
4✔
43

44
        pio.templates["CN"] = go.layout.Template(
4✔
45
            layout_colorway=[
46
                "#1f77b4",
47
                "#ff7f0e",
48
                "#2ca02c",
49
                "#d62728",
50
                "#9467bd",
51
                "#8c564b",
52
                "#e377c2",
53
                "#7f7f7f",
54
                "#bcbd22",
55
                "#17becf",
56
            ]
57
        )
58

59
        # setting Google color palette as default
60
        pio.templates.default = "CN"
4✔
61

62
    except ImportError:
×
63

64
        logger.error("plotly package not available.")
×
65

66

67
def plot_calibration_lines(
4✔
68
    elements=[],
69
    min_atlas_wavelength=3000.0,
70
    max_atlas_wavelength=15000.0,
71
    min_intensity=5.0,
72
    min_distance=0.0,
73
    brightest_n_lines=None,
74
    pixel_scale=1.0,
75
    vacuum=False,
76
    pressure=101325.0,
77
    temperature=273.15,
78
    relative_humidity=0.0,
79
    label=False,
80
    log=False,
81
    save_fig=False,
82
    fig_type="png",
83
    filename=None,
84
    display=True,
85
    linelist="nist",
86
    fig_kwarg={"figsize": (12, 8)},
87
):
88
    """
89
    Plot the expected arc spectrum. Only available in matplotlib at the moment.
90

91
    Parameters
92
    ----------
93
    elements: list
94
        List of short element names, e.g. He as per NIST
95
    linelist: str
96
        Either 'nist' to use the default lines or path to a linelist file.
97
    min_atlas_wavelength: int
98
        Minimum wavelength to search, Angstrom
99
    max_atlas_wavelength: int
100
        Maximum wavelength to search, Angstrom
101
    min_intensity: int
102
        Minimum intensity to search, per NIST
103
    min_distance: int
104
        All ines within this distance from other lines are treated
105
        as unresolved, all of them get removed from the list.
106
    brightest_n_lines: int
107
        Only return the n brightest lines
108
    vacuum: bool
109
        Return vacuum wavelengths
110
    pressure: float
111
        Atmospheric pressure, Pascal
112
    temperature: float
113
        Temperature in Kelvin, default room temp
114
    relative_humidity: float
115
        Relative humidity, percent
116
    log: bool
117
        Plot intensities in log scale
118
    save_fig: boolean (default: False)
119
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
120
        while the plotly uses the pio.write_html() or pio.write_image().
121
        The support format types should be provided in fig_type.
122
    fig_type: string (default: 'png')
123
        Image type to be saved, choose from:
124
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
125
    filename: string (default: None)
126
        Provide a filename or full path. If the extension is not provided
127
        it is defaulted to png.
128
    display: boolean (Default: False)
129
        Set to True to display disgnostic plot.
130
    Returns
131
    -------
132
    fig: matplotlib figure object
133
    """
134

135
    _import_matplotlib()
4✔
136

137
    # the min_intensity and min_distance are set to 0.0 because the
138
    # simulated spectrum would contain them. These arguments only
139
    # affect the labelling.
140
    element_list, wavelength_list, intensity_list = load_calibration_lines(
4✔
141
        elements=elements,
142
        linelist=linelist,
143
        min_atlas_wavelength=min_atlas_wavelength,
144
        max_atlas_wavelength=max_atlas_wavelength,
145
        min_intensity=0.0,
146
        min_distance=0.0,
147
        brightest_n_lines=brightest_n_lines,
148
        vacuum=vacuum,
149
        pressure=pressure,
150
        temperature=temperature,
151
        relative_humidity=relative_humidity,
152
    )
153

154
    # Nyquist sampling rate (2.5) for CCD at seeing of 1 arcsec
155
    sigma = pixel_scale * 2.5 * 1.0
4✔
156
    x = np.arange(-100, 100.001, 0.001)
4✔
157
    gaussian = gauss(x, a=1.0, x0=0.0, sigma=sigma)
4✔
158

159
    # Generate the equally spaced-wavelength array, and the
160
    # corresponding intensity
161
    w = np.around(
4✔
162
        np.arange(min_atlas_wavelength, max_atlas_wavelength + 0.001, 0.001),
163
        decimals=3,
164
    ).astype("float64")
165
    i = np.zeros_like(w)
4✔
166

167
    for e in elements:
4✔
168
        i[
4✔
169
            np.isin(
170
                w, np.around(wavelength_list[element_list == e], decimals=3)
171
            )
172
        ] += intensity_list[element_list == e]
173
    # Convolve to simulate the arc spectrum
174
    model_spectrum = signal.convolve(i, gaussian, mode="same")
4✔
175

176
    # now clean up by min_intensity and min_distance
177
    intensity_mask = filter_intensity(
4✔
178
        elements,
179
        np.column_stack((element_list, wavelength_list, intensity_list)),
180
        min_intensity=min_intensity,
181
    )
182
    wavelength_list = wavelength_list[intensity_mask]
4✔
183
    intensity_list = intensity_list[intensity_mask]
4✔
184
    element_list = element_list[intensity_mask]
4✔
185

186
    distance_mask = filter_separation(
4✔
187
        wavelength_list, min_separation=min_distance
188
    )
189
    wavelength_list = wavelength_list[distance_mask]
4✔
190
    intensity_list = intensity_list[distance_mask]
4✔
191
    element_list = element_list[distance_mask]
4✔
192

193
    fig = plt.figure(**fig_kwarg)
4✔
194

195
    for j, e in enumerate(elements):
4✔
196
        e_mask = element_list == e
4✔
197
        markerline, stemline, _ = plt.stem(
4✔
198
            wavelength_list[e_mask],
199
            intensity_list[e_mask],
200
            label=e,
201
            linefmt="C{}-".format(j),
202
        )
203
        plt.setp(stemline, linewidth=2.0)
4✔
204
        plt.setp(markerline, markersize=2.5, color="C{}".format(j))
4✔
205

206
        if label:
4✔
207

208
            for _w in wavelength_list[e_mask]:
4✔
209

210
                plt.text(
4✔
211
                    _w,
212
                    max(model_spectrum) * 1.05,
213
                    s="{}: {:1.2f}".format(e, _w),
214
                    rotation=90,
215
                    bbox=dict(facecolor="white", alpha=1),
216
                )
217

218
            plt.vlines(
4✔
219
                wavelength_list[e_mask],
220
                intensity_list[e_mask],
221
                max(model_spectrum) * 1.25,
222
                linestyles="dashed",
223
                lw=0.5,
224
                color="grey",
225
            )
226

227
    plt.plot(w, model_spectrum, lw=1.0, c="k", label="Simulated Arc Spectrum")
4✔
228
    if vacuum:
4✔
229
        plt.xlabel("Vacuum Wavelength / A")
×
230
    else:
231
        plt.xlabel("Air Wavelength / A")
4✔
232
    plt.ylabel("NIST intensity")
4✔
233
    plt.grid()
4✔
234
    plt.xlim(min(w), max(w))
4✔
235
    plt.ylim(0, max(model_spectrum) * 1.25)
4✔
236
    plt.legend()
4✔
237
    plt.tight_layout()
4✔
238
    if log:
4✔
239
        plt.ylim(ymin=min_intensity * 0.75)
×
240
        plt.yscale("log")
×
241

242
    if save_fig:
4✔
243

244
        fig_type = fig_type.split("+")
4✔
245

246
        if filename is None:
4✔
247

248
            filename_output = "rascal_arc"
×
249

250
        else:
251

252
            filename_output = filename
4✔
253

254
        for t in fig_type:
4✔
255

256
            if t in ["jpg", "png", "svg", "pdf"]:
4✔
257

258
                plt.savefig(filename_output + "." + t, format=t)
4✔
259

260
    if display:
4✔
261

262
        plt.show()
4✔
263

264
    return fig
4✔
265

266

267
def plot_search_space(
4✔
268
    calibrator,
269
    fit_coeff=None,
270
    top_n_candidate=3,
271
    weighted=True,
272
    save_fig=False,
273
    fig_type="png",
274
    filename=None,
275
    return_jsonstring=False,
276
    renderer="default",
277
    display=True,
278
):
279
    """
280
    Plots the peak/arc line pairs that are considered as potential match
281
    candidates.
282

283
    If fit fit_coefficients are provided, the model solution will be
284
    overplotted.
285

286
    Parameters
287
    ----------
288
    fit_coeff: list (default: None)
289
        List of best polynomial fit_coefficients
290
    top_n_candidate: int (default: 3)
291
        Top ranked lines to be fitted.
292
    weighted: (default: True)
293
        Draw sample based on the distance from the matched known wavelength
294
        of the atlas.
295
    save_fig: boolean (default: False)
296
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
297
        while the plotly uses the pio.write_html() or pio.write_image().
298
        The support format types should be provided in fig_type.
299
    fig_type: string (default: 'png')
300
        Image type to be saved, choose from:
301
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
302
    filename: (default: None)
303
        The destination to save the image.
304
    return_jsonstring: (default: False)
305
        Set to True to save the plotly figure as json string. Ignored if
306
        matplotlib is used.
307
    renderer: (default: 'default')
308
        Set the rendered for the plotly display. Ignored if matplotlib is
309
        used.
310
    display: boolean (Default: False)
311
        Set to True to display disgnostic plot.
312

313
    Return
314
    ------
315
    json object if return_jsonstring is True.
316

317

318
    """
319

320
    # Get top linear estimates and combine
321
    candidate_peak, candidate_arc = calibrator._get_most_common_candidates(
4✔
322
        calibrator.candidates,
323
        top_n_candidate=top_n_candidate,
324
        weighted=weighted,
325
    )
326

327
    # Get the search space boundaries
328
    x = calibrator.pixel_list
4✔
329

330
    m_1 = (
4✔
331
        calibrator.max_wavelength - calibrator.min_wavelength
332
    ) / calibrator.pixel_list.max()
333
    y_1 = m_1 * x + calibrator.min_wavelength
4✔
334

335
    m_2 = (
4✔
336
        calibrator.max_wavelength
337
        + calibrator.range_tolerance
338
        - (calibrator.min_wavelength + calibrator.range_tolerance)
339
    ) / calibrator.pixel_list.max()
340
    y_2 = m_2 * x + calibrator.min_wavelength + calibrator.range_tolerance
4✔
341

342
    m_3 = (
4✔
343
        calibrator.max_wavelength
344
        - calibrator.range_tolerance
345
        - (calibrator.min_wavelength - calibrator.range_tolerance)
346
    ) / calibrator.pixel_list.max()
347
    y_3 = m_3 * x + (calibrator.min_wavelength - calibrator.range_tolerance)
4✔
348

349
    if calibrator.plot_with_matplotlib:
4✔
350
        _import_matplotlib()
4✔
351

352
        fig = plt.figure(figsize=(10, 10))
4✔
353

354
        # Plot all-pairs
355
        plt.scatter(
4✔
356
            *calibrator.pairs.T, alpha=0.2, color="C0", label="All pairs"
357
        )
358

359
        plt.scatter(
4✔
360
            calibrator._merge_candidates(calibrator.candidates)[:, 0],
361
            calibrator._merge_candidates(calibrator.candidates)[:, 1],
362
            alpha=0.2,
363
            color="C1",
364
            label="Candidate Pairs",
365
        )
366

367
        # Tolerance region around the minimum wavelength
368
        plt.text(
4✔
369
            5,
370
            calibrator.min_wavelength + 100,
371
            "Min wavelength (user-supplied)",
372
        )
373
        plt.hlines(
4✔
374
            calibrator.min_wavelength,
375
            0,
376
            calibrator.pixel_list.max(),
377
            color="k",
378
        )
379
        plt.hlines(
4✔
380
            calibrator.min_wavelength + calibrator.range_tolerance,
381
            0,
382
            calibrator.pixel_list.max(),
383
            linestyle="dashed",
384
            alpha=0.5,
385
            color="k",
386
        )
387
        plt.hlines(
4✔
388
            calibrator.min_wavelength - calibrator.range_tolerance,
389
            0,
390
            calibrator.pixel_list.max(),
391
            linestyle="dashed",
392
            alpha=0.5,
393
            color="k",
394
        )
395

396
        # Tolerance region around the maximum wavelength
397
        plt.text(
4✔
398
            5,
399
            calibrator.max_wavelength + 100,
400
            "Max wavelength (user-supplied)",
401
        )
402
        plt.hlines(
4✔
403
            calibrator.max_wavelength,
404
            0,
405
            calibrator.pixel_list.max(),
406
            color="k",
407
        )
408
        plt.hlines(
4✔
409
            calibrator.max_wavelength + calibrator.range_tolerance,
410
            0,
411
            calibrator.pixel_list.max(),
412
            linestyle="dashed",
413
            alpha=0.5,
414
            color="k",
415
        )
416
        plt.hlines(
4✔
417
            calibrator.max_wavelength - calibrator.range_tolerance,
418
            0,
419
            calibrator.pixel_list.max(),
420
            linestyle="dashed",
421
            alpha=0.5,
422
            color="k",
423
        )
424

425
        # The line from (first pixel, minimum wavelength) to
426
        # (last pixel, maximum wavelength), and the two lines defining the
427
        # tolerance region.
428
        plt.plot(x, y_1, label="Linear Fit", color="C3")
4✔
429
        plt.plot(
4✔
430
            x, y_2, linestyle="dashed", label="Tolerance Region", color="C3"
431
        )
432
        plt.plot(x, y_3, linestyle="dashed", color="C3")
4✔
433

434
        if fit_coeff is not None:
4✔
435

436
            plt.scatter(
×
437
                calibrator.peaks,
438
                calibrator.polyval(calibrator.peaks, fit_coeff),
439
                color="C4",
440
                label="Solution",
441
            )
442

443
        plt.scatter(
4✔
444
            candidate_peak,
445
            candidate_arc,
446
            color="C2",
447
            label="Best Candidate Pairs",
448
        )
449

450
        plt.xlim(0, calibrator.pixel_list.max())
4✔
451
        plt.ylim(
4✔
452
            calibrator.min_wavelength - calibrator.range_tolerance,
453
            calibrator.max_wavelength + calibrator.range_tolerance,
454
        )
455

456
        plt.xlabel("Wavelength / A")
4✔
457
        plt.ylabel("Pixel")
4✔
458
        plt.legend()
4✔
459
        plt.grid()
4✔
460
        plt.tight_layout()
4✔
461

462
        if save_fig:
4✔
463

464
            fig_type = fig_type.split("+")
4✔
465

466
            if filename is None:
4✔
467

468
                filename_output = "rascal_hough_search_space"
×
469

470
            else:
471

472
                filename_output = filename
4✔
473

474
            for t in fig_type:
4✔
475

476
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
477

478
                    plt.savefig(filename_output + "." + t, format=t)
4✔
479

480
        if display:
4✔
481

482
            plt.show()
×
483

484
        return fig
4✔
485

486
    elif calibrator.plot_with_plotly:
4✔
487
        _import_plotly()
4✔
488

489
        fig = go.Figure()
4✔
490

491
        # Plot all-pairs
492
        fig.add_trace(
4✔
493
            go.Scatter(
494
                x=calibrator.pairs[:, 0],
495
                y=calibrator.pairs[:, 1],
496
                mode="markers",
497
                name="All Pairs",
498
                marker=dict(
499
                    color=pio.templates["CN"].layout.colorway[0], opacity=0.2
500
                ),
501
            )
502
        )
503

504
        fig.add_trace(
4✔
505
            go.Scatter(
506
                x=calibrator._merge_candidates(calibrator.candidates)[:, 0],
507
                y=calibrator._merge_candidates(calibrator.candidates)[:, 1],
508
                mode="markers",
509
                name="Candidate Pairs",
510
                marker=dict(
511
                    color=pio.templates["CN"].layout.colorway[1], opacity=0.2
512
                ),
513
            )
514
        )
515
        fig.add_trace(
4✔
516
            go.Scatter(
517
                x=candidate_peak,
518
                y=candidate_arc,
519
                mode="markers",
520
                name="Best Candidate Pairs",
521
                marker=dict(color=pio.templates["CN"].layout.colorway[2]),
522
            )
523
        )
524

525
        # Tolerance region around the minimum wavelength
526
        fig.add_trace(
4✔
527
            go.Scatter(
528
                x=[0, calibrator.pixel_list.max()],
529
                y=[calibrator.min_wavelength, calibrator.min_wavelength],
530
                name="Min/Maximum",
531
                mode="lines",
532
                line=dict(color="black"),
533
            )
534
        )
535
        fig.add_trace(
4✔
536
            go.Scatter(
537
                x=[0, calibrator.pixel_list.max()],
538
                y=[
539
                    calibrator.min_wavelength + calibrator.range_tolerance,
540
                    calibrator.min_wavelength + calibrator.range_tolerance,
541
                ],
542
                name="Tolerance Range",
543
                mode="lines",
544
                line=dict(color="black", dash="dash"),
545
            )
546
        )
547
        fig.add_trace(
4✔
548
            go.Scatter(
549
                x=[0, calibrator.pixel_list.max()],
550
                y=[
551
                    calibrator.min_wavelength - calibrator.range_tolerance,
552
                    calibrator.min_wavelength - calibrator.range_tolerance,
553
                ],
554
                showlegend=False,
555
                mode="lines",
556
                line=dict(color="black", dash="dash"),
557
            )
558
        )
559

560
        # Tolerance region around the minimum wavelength
561
        fig.add_trace(
4✔
562
            go.Scatter(
563
                x=[0, calibrator.pixel_list.max()],
564
                y=[calibrator.max_wavelength, calibrator.max_wavelength],
565
                showlegend=False,
566
                mode="lines",
567
                line=dict(color="black"),
568
            )
569
        )
570
        fig.add_trace(
4✔
571
            go.Scatter(
572
                x=[0, calibrator.pixel_list.max()],
573
                y=[
574
                    calibrator.max_wavelength + calibrator.range_tolerance,
575
                    calibrator.max_wavelength + calibrator.range_tolerance,
576
                ],
577
                showlegend=False,
578
                mode="lines",
579
                line=dict(color="black", dash="dash"),
580
            )
581
        )
582
        fig.add_trace(
4✔
583
            go.Scatter(
584
                x=[0, calibrator.pixel_list.max()],
585
                y=[
586
                    calibrator.max_wavelength - calibrator.range_tolerance,
587
                    calibrator.max_wavelength - calibrator.range_tolerance,
588
                ],
589
                showlegend=False,
590
                mode="lines",
591
                line=dict(color="black", dash="dash"),
592
            )
593
        )
594

595
        # The line from (first pixel, minimum wavelength) to
596
        # (last pixel, maximum wavelength), and the two lines defining the
597
        # tolerance region.
598
        fig.add_trace(
4✔
599
            go.Scatter(
600
                x=x,
601
                y=y_1,
602
                mode="lines",
603
                name="Linear Fit",
604
                line=dict(color=pio.templates["CN"].layout.colorway[3]),
605
            )
606
        )
607
        fig.add_trace(
4✔
608
            go.Scatter(
609
                x=x,
610
                y=y_2,
611
                mode="lines",
612
                name="Tolerance Region",
613
                line=dict(
614
                    color=pio.templates["CN"].layout.colorway[3],
615
                    dash="dashdot",
616
                ),
617
            )
618
        )
619
        fig.add_trace(
4✔
620
            go.Scatter(
621
                x=x,
622
                y=y_3,
623
                showlegend=False,
624
                mode="lines",
625
                line=dict(
626
                    color=pio.templates["CN"].layout.colorway[3],
627
                    dash="dashdot",
628
                ),
629
            )
630
        )
631

632
        if fit_coeff is not None:
4✔
633

634
            fig.add_trace(
×
635
                go.Scatter(
636
                    x=calibrator.peaks,
637
                    y=calibrator.polyval(calibrator.peaks, fit_coeff),
638
                    mode="markers",
639
                    name="Solution",
640
                    marker=dict(color=pio.templates["CN"].layout.colorway[4]),
641
                )
642
            )
643

644
        # Layout, Title, Grid config
645
        fig.update_layout(
4✔
646
            autosize=True,
647
            yaxis=dict(
648
                title="Pixel",
649
                range=[
650
                    calibrator.min_wavelength
651
                    - calibrator.range_tolerance * 1.1,
652
                    calibrator.max_wavelength
653
                    + calibrator.range_tolerance * 1.1,
654
                ],
655
                showgrid=True,
656
            ),
657
            xaxis=dict(
658
                title="Wavelength / A",
659
                zeroline=False,
660
                range=[0.0, calibrator.pixel_list.max()],
661
                showgrid=True,
662
            ),
663
            hovermode="closest",
664
            showlegend=True,
665
            height=800,
666
            width=1000,
667
        )
668

669
        if save_fig:
4✔
670

671
            fig_type = fig_type.split("+")
4✔
672

673
            if filename is None:
4✔
674

675
                filename_output = "rascal_hough_search_space"
×
676

677
            else:
678

679
                filename_output = filename
4✔
680

681
            for t in fig_type:
4✔
682

683
                if t == "iframe":
4✔
684

685
                    pio.write_html(fig, filename_output + "." + t)
×
686

687
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
688

689
                    pio.write_image(fig, filename_output + "." + t)
4✔
690

691
        if display:
4✔
692

693
            if renderer == "default":
×
694

695
                fig.show()
×
696

697
            else:
698

699
                fig.show(renderer)
×
700

701
        if return_jsonstring:
4✔
702

703
            return fig.to_json()
×
704

705

706
def plot_fit(
4✔
707
    calibrator,
708
    fit_coeff,
709
    spectrum=None,
710
    tolerance=5.0,
711
    plot_atlas=True,
712
    log_spectrum=False,
713
    save_fig=False,
714
    fig_type="png",
715
    filename=None,
716
    return_jsonstring=False,
717
    renderer="default",
718
    display=True,
719
):
720
    """
721
    Plots of the wavelength calibrated arc spectrum, the residual and the
722
    pixel-to-wavelength solution.
723

724
    Parameters
725
    ----------
726
    fit_coeff: 1D numpy array or list
727
        Best fit polynomail fit_coefficients
728
    spectrum: 1D numpy array (N)
729
        Array of length N pixels
730
    tolerance: float (default: 5)
731
        Absolute difference between model and fitted wavelengths in unit
732
        of angstrom.
733
    plot_atlas: boolean (default: True)
734
        Display all the relavent lines available in the atlas library.
735
    log_spectrum: boolean (default: False)
736
        Display the arc in log-space if set to True.
737
    save_fig: boolean (default: False)
738
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
739
        while the plotly uses the pio.write_html() or pio.write_image().
740
        The support format types should be provided in fig_type.
741
    fig_type: string (default: 'png')
742
        Image type to be saved, choose from:
743
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
744
    filename: string (default: None)
745
        Provide a filename or full path. If the extension is not provided
746
        it is defaulted to png.
747
    return_jsonstring: boolean (default: False)
748
        Set to True to return json strings if using plotly as the plotting
749
        library.
750
    renderer: string (default: 'default')
751
        Indicate the Plotly renderer. Nothing gets displayed if
752
        return_jsonstring is set to True.
753
    display: boolean (Default: False)
754
        Set to True to display disgnostic plot.
755

756
    Returns
757
    -------
758
    Return json strings if using plotly as the plotting library and json
759
    is True.
760

761
    """
762

763
    if spectrum is None:
4✔
764

765
        try:
4✔
766

767
            spectrum = calibrator.spectrum
4✔
768

769
        except Exception as e:
×
770

771
            calibrator.logger.error(e)
×
772
            calibrator.logger.error(
×
773
                "Spectrum is not provided, it cannot be " "plotted."
774
            )
775

776
    if spectrum is not None:
4✔
777

778
        if log_spectrum:
4✔
779

780
            spectrum[spectrum < 0] = 1e-100
4✔
781
            spectrum = np.log10(spectrum)
4✔
782
            vline_max = np.nanmax(spectrum) * 2.0
4✔
783
            text_box_pos = 1.2 * max(spectrum)
4✔
784

785
        else:
786

787
            vline_max = np.nanmax(spectrum) * 1.2
4✔
788
            text_box_pos = 0.8 * max(spectrum)
4✔
789

790
    else:
791

792
        vline_max = 1.0
4✔
793
        text_box_pos = 0.5
4✔
794

795
    wave = calibrator.polyval(calibrator.pixel_list, fit_coeff)
4✔
796

797
    if calibrator.plot_with_matplotlib:
4✔
798
        _import_matplotlib()
4✔
799

800
        fig, (ax1, ax2, ax3) = plt.subplots(
4✔
801
            nrows=3, sharex=True, gridspec_kw={"hspace": 0.0}, figsize=(15, 9)
802
        )
803
        fig.tight_layout()
4✔
804

805
        # Plot fitted spectrum
806
        if spectrum is not None:
4✔
807

808
            ax1.plot(wave, spectrum, label="Arc Spectrum")
4✔
809
            ax1.vlines(
4✔
810
                calibrator.polyval(calibrator.peaks, fit_coeff),
811
                np.array(spectrum)[
812
                    calibrator.pix_to_rawpix(calibrator.peaks).astype("int")
813
                ],
814
                vline_max,
815
                linestyles="dashed",
816
                colors="C1",
817
                label="Detected Peaks",
818
            )
819

820
        # Plot the atlas
821
        if plot_atlas:
4✔
822

823
            # spec = SyntheticSpectrum(
824
            #    fit, model_type='poly', degree=len(fit)-1)
825
            # x_locs = spec.get_pixels(calibrator.atlas)
826
            ax1.vlines(
4✔
827
                calibrator.atlas.get_lines(),
828
                0,
829
                vline_max,
830
                colors="C2",
831
                label="Given Lines",
832
            )
833

834
        fitted_peaks = []
4✔
835
        fitted_diff = []
4✔
836
        all_diff = []
4✔
837

838
        first_one = True
4✔
839
        for p, x in zip(calibrator.matched_peaks, calibrator.matched_atlas):
4✔
840

841
            diff = calibrator.atlas.get_lines() - x
4✔
842
            idx = np.argmin(np.abs(diff))
4✔
843
            all_diff.append(diff[idx])
4✔
844

845
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
846

847
            fitted_peaks.append(p)
4✔
848
            fitted_diff.append(diff[idx])
4✔
849
            calibrator.logger.info(
4✔
850
                "- matched to {} A".format(calibrator.atlas.get_lines()[idx])
851
            )
852

853
            if spectrum is not None:
4✔
854

855
                if first_one:
4✔
856
                    ax1.vlines(
4✔
857
                        calibrator.polyval(p, fit_coeff),
858
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
859
                        vline_max,
860
                        colors="C1",
861
                        label="Fitted Peaks",
862
                    )
863
                    first_one = False
4✔
864

865
                else:
866
                    ax1.vlines(
4✔
867
                        calibrator.polyval(p, fit_coeff),
868
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
869
                        vline_max,
870
                        colors="C1",
871
                    )
872

873
            ax1.text(
4✔
874
                x - 3,
875
                text_box_pos,
876
                s="{}:{:1.2f}".format(
877
                    calibrator.atlas.get_elements()[idx],
878
                    calibrator.atlas.get_lines()[idx],
879
                ),
880
                rotation=90,
881
                bbox=dict(facecolor="white", alpha=1),
882
            )
883

884
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
885

886
        ax1.grid(linestyle=":")
4✔
887
        ax1.set_ylabel("Electron Count / e-")
4✔
888

889
        if spectrum is not None:
4✔
890

891
            if log_spectrum:
4✔
892

893
                ax1.set_ylim(0, vline_max)
4✔
894

895
            else:
896

897
                ax1.set_ylim(np.nanmin(spectrum), vline_max)
4✔
898

899
        ax1.legend(loc="center right")
4✔
900

901
        # Plot the residuals
902
        ax2.scatter(
4✔
903
            calibrator.polyval(fitted_peaks, fit_coeff),
904
            fitted_diff,
905
            marker="+",
906
            color="C1",
907
        )
908
        ax2.hlines(0, wave.min(), wave.max(), linestyles="dashed")
4✔
909
        ax2.hlines(
4✔
910
            rms,
911
            wave.min(),
912
            wave.max(),
913
            linestyles="dashed",
914
            color="k",
915
            label="RMS",
916
        )
917
        ax2.hlines(
4✔
918
            -rms, wave.min(), wave.max(), linestyles="dashed", color="k"
919
        )
920
        ax2.grid(linestyle=":")
4✔
921
        ax2.set_ylabel("Residual / A")
4✔
922
        ax2.legend()
4✔
923
        """
2✔
924
        ax2.text(
925
            min(wave) + np.ptp(wave) * 0.05,
926
            max(spectrum),
927
            'RMS =' + str(rms)[:6]
928
            )
929
        """
930

931
        # Plot the polynomial
932
        ax3.scatter(
4✔
933
            calibrator.polyval(fitted_peaks, fit_coeff),
934
            fitted_peaks,
935
            marker="+",
936
            color="C1",
937
            label="Fitted Peaks",
938
        )
939
        ax3.plot(wave, calibrator.pixel_list, color="C2", label="Solution")
4✔
940
        ax3.grid(linestyle=":")
4✔
941
        ax3.set_xlabel("Wavelength / A")
4✔
942
        ax3.set_ylabel("Pixel")
4✔
943
        ax3.legend(loc="lower right")
4✔
944
        w_min = calibrator.polyval(min(fitted_peaks), fit_coeff)
4✔
945
        w_max = calibrator.polyval(max(fitted_peaks), fit_coeff)
4✔
946
        ax3.set_xlim(w_min * 0.95, w_max * 1.05)
4✔
947

948
        plt.tight_layout()
4✔
949

950
        if save_fig:
4✔
951

952
            fig_type = fig_type.split("+")
4✔
953

954
            if filename is None:
4✔
955

956
                filename_output = "rascal_solution"
×
957

958
            else:
959

960
                filename_output = filename
4✔
961

962
            for t in fig_type:
4✔
963

964
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
965

966
                    plt.savefig(filename_output + "." + t, format=t)
4✔
967

968
        if display:
4✔
969

970
            plt.show()
4✔
971

972
        return fig
4✔
973

974
    elif calibrator.plot_with_plotly:
4✔
975
        _import_plotly()
4✔
976

977
        fig = go.Figure()
4✔
978

979
        # Top plot - arc spectrum and matched peaks
980
        if spectrum is not None:
4✔
981
            fig.add_trace(
4✔
982
                go.Scatter(
983
                    x=wave,
984
                    y=spectrum,
985
                    mode="lines",
986
                    yaxis="y3",
987
                    name="Arc Spectrum",
988
                )
989
            )
990

991
            spec_max = np.nanmax(spectrum) * 1.05
4✔
992

993
        else:
994

NEW
995
            spec_max = vline_max
×
996

997
        fitted_peaks = []
4✔
998
        fitted_peaks_adu = []
4✔
999
        fitted_diff = []
4✔
1000
        all_diff = []
4✔
1001

1002
        for p in calibrator.peaks:
4✔
1003

1004
            x = calibrator.polyval(p, fit_coeff)
4✔
1005

1006
            # Add vlines
1007
            fig.add_shape(
4✔
1008
                type="line",
1009
                xref="x",
1010
                yref="y3",
1011
                x0=x,
1012
                y0=0,
1013
                x1=x,
1014
                y1=spec_max,
1015
                line=dict(
1016
                    color=pio.templates["CN"].layout.colorway[1], width=1
1017
                ),
1018
            )
1019

1020
            diff = calibrator.atlas.get_lines() - x
4✔
1021
            idx = np.argmin(np.abs(diff))
4✔
1022
            all_diff.append(diff[idx])
4✔
1023

1024
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
1025

1026
            if np.abs(diff[idx]) < tolerance:
4✔
1027

1028
                fitted_peaks.append(p)
4✔
1029
                if spectrum is not None:
4✔
1030
                    fitted_peaks_adu.append(
4✔
1031
                        spectrum[int(calibrator.pix_to_rawpix(p))]
1032
                    )
1033
                fitted_diff.append(diff[idx])
4✔
1034
                calibrator.logger.info(
4✔
1035
                    "- matched to {} A".format(
1036
                        calibrator.atlas.get_lines()[idx]
1037
                    )
1038
                )
1039

1040
        x_fitted = calibrator.polyval(fitted_peaks, fit_coeff)
4✔
1041

1042
        fig.add_trace(
4✔
1043
            go.Scatter(
1044
                x=x_fitted,
1045
                y=fitted_peaks_adu,
1046
                mode="markers",
1047
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1048
                yaxis="y3",
1049
                showlegend=False,
1050
            )
1051
        )
1052

1053
        # Middle plot - Residual plot
1054
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
1055
        fig.add_trace(
4✔
1056
            go.Scatter(
1057
                x=x_fitted,
1058
                y=fitted_diff,
1059
                mode="markers",
1060
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1061
                yaxis="y2",
1062
                showlegend=False,
1063
            )
1064
        )
1065
        fig.add_trace(
4✔
1066
            go.Scatter(
1067
                x=[wave.min(), wave.max()],
1068
                y=[0, 0],
1069
                mode="lines",
1070
                line=dict(
1071
                    color=pio.templates["CN"].layout.colorway[0], dash="dash"
1072
                ),
1073
                yaxis="y2",
1074
                showlegend=False,
1075
            )
1076
        )
1077
        fig.add_trace(
4✔
1078
            go.Scatter(
1079
                x=[wave.min(), wave.max()],
1080
                y=[rms, rms],
1081
                mode="lines",
1082
                line=dict(color="black", dash="dash"),
1083
                yaxis="y2",
1084
                showlegend=False,
1085
            )
1086
        )
1087
        fig.add_trace(
4✔
1088
            go.Scatter(
1089
                x=[wave.min(), wave.max()],
1090
                y=[-rms, -rms],
1091
                mode="lines",
1092
                line=dict(color="black", dash="dash"),
1093
                yaxis="y2",
1094
                name="RMS",
1095
            )
1096
        )
1097

1098
        # Bottom plot - Polynomial fit for Pixel to Wavelength
1099
        fig.add_trace(
4✔
1100
            go.Scatter(
1101
                x=x_fitted,
1102
                y=fitted_peaks,
1103
                mode="markers",
1104
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1105
                yaxis="y1",
1106
                name="Fitted Peaks",
1107
            )
1108
        )
1109
        fig.add_trace(
4✔
1110
            go.Scatter(
1111
                x=wave,
1112
                y=calibrator.pixel_list,
1113
                mode="lines",
1114
                line=dict(color=pio.templates["CN"].layout.colorway[2]),
1115
                yaxis="y1",
1116
                name="Solution",
1117
            )
1118
        )
1119

1120
        # Layout, Title, Grid config
1121
        if spectrum is not None:
4✔
1122

1123
            if log_spectrum:
4✔
1124

1125
                fig.update_layout(
×
1126
                    yaxis3=dict(
1127
                        title="Electron Count / e-",
1128
                        range=[
1129
                            np.log10(np.percentile(spectrum, 15)),
1130
                            np.log10(spec_max),
1131
                        ],
1132
                        domain=[0.67, 1.0],
1133
                        showgrid=True,
1134
                        type="log",
1135
                    )
1136
                )
1137

1138
            else:
1139

1140
                fig.update_layout(
4✔
1141
                    yaxis3=dict(
1142
                        title="Electron Count / e-",
1143
                        range=[np.percentile(spectrum, 15), spec_max],
1144
                        domain=[0.67, 1.0],
1145
                        showgrid=True,
1146
                    )
1147
                )
1148

1149
        fig.update_layout(
4✔
1150
            autosize=True,
1151
            yaxis2=dict(
1152
                title="Residual / A",
1153
                range=[min(fitted_diff), max(fitted_diff)],
1154
                domain=[0.33, 0.66],
1155
                showgrid=True,
1156
            ),
1157
            yaxis=dict(
1158
                title="Pixel",
1159
                range=[0.0, max(calibrator.pixel_list)],
1160
                domain=[0.0, 0.32],
1161
                showgrid=True,
1162
            ),
1163
            xaxis=dict(
1164
                title="Wavelength / A",
1165
                zeroline=False,
1166
                range=[
1167
                    calibrator.polyval(min(fitted_peaks), fit_coeff) * 0.95,
1168
                    calibrator.polyval(max(fitted_peaks), fit_coeff) * 1.05,
1169
                ],
1170
                showgrid=True,
1171
            ),
1172
            hovermode="closest",
1173
            showlegend=True,
1174
            height=800,
1175
            width=1000,
1176
        )
1177

1178
        if save_fig:
4✔
1179

1180
            fig_type = fig_type.split("+")
4✔
1181

1182
            if filename is None:
4✔
1183

1184
                filename_output = "rascal_solution"
×
1185

1186
            else:
1187

1188
                filename_output = filename
4✔
1189

1190
            for t in fig_type:
4✔
1191

1192
                if t == "iframe":
4✔
1193

1194
                    pio.write_html(fig, filename_output + "." + t)
×
1195

1196
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1197

1198
                    pio.write_image(fig, filename_output + "." + t)
4✔
1199

1200
        if display:
4✔
1201

1202
            if renderer == "default":
×
1203

1204
                fig.show()
×
1205

1206
            else:
1207

1208
                fig.show(renderer)
×
1209

1210
        if return_jsonstring:
4✔
1211

1212
            return fig.to_json()
×
1213

1214
    else:
1215

1216
        assert calibrator.matplotlib_imported, (
×
1217
            "matplotlib package not available. " + "Plot cannot be generated."
1218
        )
1219
        assert calibrator.plotly_imported, (
×
1220
            "plotly package is not available. " + "Plot cannot be generated."
1221
        )
1222

1223

1224
def plot_arc(
4✔
1225
    calibrator,
1226
    pixel_list=None,
1227
    log_spectrum=False,
1228
    save_fig=False,
1229
    fig_type="png",
1230
    filename=None,
1231
    return_jsonstring=False,
1232
    renderer="default",
1233
    display=True,
1234
):
1235
    """
1236
    Plots the 1D spectrum of the extracted arc.
1237

1238
    parameters
1239
    ----------
1240
    pixel_list: array (default: None)
1241
        pixel value of the of the spectrum, this is only needed if the
1242
        spectrum spans multiple detector arrays.
1243
    log_spectrum: boolean (default: False)
1244
        Set to true to display the wavelength calibrated arc spectrum in
1245
        logarithmic space.
1246
    save_fig: boolean (default: False)
1247
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
1248
        while the plotly uses the pio.write_html() or pio.write_image().
1249
        The support format types should be provided in fig_type.
1250
    fig_type: string (default: 'png')
1251
        Image type to be saved, choose from:
1252
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
1253
    filename: string (default: None)
1254
        Provide a filename or full path. If the extension is not provided
1255
        it is defaulted to png.
1256
    return_jsonstring: boolean (default: False)
1257
        Set to True to return json strings if using plotly as the plotting
1258
        library.
1259
    renderer: string (default: 'default')
1260
        Indicate the Plotly renderer. Nothing gets displayed if
1261
        return_jsonstring is set to True.
1262

1263
    display: boolean (Default: False)
1264
        Set to True to display disgnostic plot.
1265

1266
    Returns
1267
    -------
1268
    Return json strings if using plotly as the plotting library and json
1269
    is True.
1270

1271
    """
1272

1273
    if pixel_list is None:
4✔
1274

1275
        pixel_list = np.arange(len(calibrator.spectrum))
4✔
1276

1277
    if calibrator.plot_with_matplotlib:
4✔
1278
        _import_matplotlib()
4✔
1279

1280
        fig = plt.figure(figsize=(18, 5))
4✔
1281

1282
        if calibrator.spectrum is not None:
4✔
1283
            if log_spectrum:
4✔
1284
                plt.plot(
4✔
1285
                    pixel_list,
1286
                    np.log10(calibrator.spectrum / calibrator.spectrum.max()),
1287
                    label="Arc Spectrum",
1288
                )
1289
                plt.vlines(
4✔
1290
                    calibrator.peaks, -2, 0, label="Detected Peaks", color="C1"
1291
                )
1292
                plt.ylabel("log(Normalised Count)")
4✔
1293
                plt.ylim(-2, 0)
4✔
1294
            else:
1295
                plt.plot(
4✔
1296
                    pixel_list,
1297
                    calibrator.spectrum / calibrator.spectrum.max(),
1298
                    label="Arc Spectrum",
1299
                )
1300
                plt.ylabel("Normalised Count")
4✔
1301
                plt.vlines(
4✔
1302
                    calibrator.peaks,
1303
                    0,
1304
                    1.05,
1305
                    label="Detected Peaks",
1306
                    color="C1",
1307
                )
1308
            plt.title("Number of pixels: " + str(calibrator.spectrum.shape[0]))
4✔
1309
            plt.xlim(0, calibrator.spectrum.shape[0])
4✔
1310
            plt.legend()
4✔
1311

1312
        else:
1313

1314
            plt.xlim(0, max(calibrator.peaks))
×
1315

1316
        plt.xlabel("Pixel (Spectral Direction)")
4✔
1317
        plt.grid()
4✔
1318
        plt.tight_layout()
4✔
1319

1320
        if save_fig:
4✔
1321

1322
            fig_type = fig_type.split("+")
4✔
1323

1324
            if filename is None:
4✔
1325

1326
                filename_output = "rascal_arc"
4✔
1327

1328
            else:
1329

1330
                filename_output = filename
4✔
1331

1332
            for t in fig_type:
4✔
1333

1334
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
1335

1336
                    plt.savefig(filename_output + "." + t, format=t)
4✔
1337

1338
        if display:
4✔
1339

1340
            plt.show()
×
1341

1342
        return fig
4✔
1343

1344
    if calibrator.plot_with_plotly:
4✔
1345

1346
        _import_plotly()
4✔
1347

1348
        fig = go.Figure()
4✔
1349

1350
        if log_spectrum:
4✔
1351

1352
            # Plot all-pairs
1353
            fig.add_trace(
4✔
1354
                go.Scatter(
1355
                    x=list(pixel_list),
1356
                    y=list(
1357
                        np.log10(
1358
                            calibrator.spectrum / calibrator.spectrum.max()
1359
                        )
1360
                    ),
1361
                    mode="lines",
1362
                    name="Arc",
1363
                )
1364
            )
1365
            xmin = min(
4✔
1366
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1367
            )
1368
            xmax = max(
4✔
1369
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1370
            )
1371

1372
        else:
1373

1374
            # Plot all-pairs
1375
            fig.add_trace(
4✔
1376
                go.Scatter(
1377
                    x=list(pixel_list),
1378
                    y=list(calibrator.spectrum / calibrator.spectrum.max()),
1379
                    mode="lines",
1380
                    name="Arc",
1381
                )
1382
            )
1383
            xmin = min(calibrator.spectrum / calibrator.spectrum.max())
4✔
1384
            xmax = max(calibrator.spectrum / calibrator.spectrum.max())
4✔
1385

1386
        # Add vlines
1387
        for i in calibrator.peaks:
4✔
1388
            fig.add_shape(
4✔
1389
                type="line",
1390
                xref="x",
1391
                yref="y",
1392
                x0=i,
1393
                y0=0,
1394
                x1=i,
1395
                y1=1.05,
1396
                line=dict(
1397
                    color=pio.templates["CN"].layout.colorway[1], width=1
1398
                ),
1399
            )
1400

1401
        fig.update_layout(
4✔
1402
            autosize=True,
1403
            yaxis=dict(
1404
                title="Normalised Count", range=[xmin, xmax], showgrid=True
1405
            ),
1406
            xaxis=dict(
1407
                title="Pixel",
1408
                zeroline=False,
1409
                range=[0.0, len(calibrator.spectrum)],
1410
                showgrid=True,
1411
            ),
1412
            hovermode="closest",
1413
            showlegend=True,
1414
            height=800,
1415
            width=1000,
1416
        )
1417

1418
        fig.update_xaxes(
4✔
1419
            showline=True, linewidth=1, linecolor="black", mirror=True
1420
        )
1421

1422
        fig.update_yaxes(
4✔
1423
            showline=True, linewidth=1, linecolor="black", mirror=True
1424
        )
1425

1426
        if save_fig:
4✔
1427

1428
            fig_type = fig_type.split("+")
4✔
1429

1430
            if filename is None:
4✔
1431

1432
                filename_output = "rascal_arc"
4✔
1433

1434
            else:
1435

1436
                filename_output = filename
4✔
1437

1438
            for t in fig_type:
4✔
1439

1440
                if t == "iframe":
4✔
1441

1442
                    pio.write_html(fig, filename_output + "." + t)
×
1443

1444
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1445

1446
                    pio.write_image(fig, filename_output + "." + t)
4✔
1447

1448
        if display:
4✔
1449

1450
            if renderer == "default":
×
1451

1452
                fig.show()
×
1453

1454
            else:
1455

1456
                fig.show(renderer)
×
1457

1458
        if return_jsonstring:
4✔
1459

1460
            return fig.to_json()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc