• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 18631349511

19 Oct 2025 01:49PM UTC coverage: 91.691%. First build
18631349511

Pull #95

github

web-flow
Merge 8bdeaec78 into c6da64f2f
Pull Request #95: switch to uv

157 of 174 new or added lines in 14 files covered. (90.23%)

1876 of 2046 relevant lines covered (91.69%)

3.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.08
/src/rascal/plotting.py
1
import logging
4✔
2
import numpy as np
4✔
3
from scipy import signal
4✔
4

5
from .util import load_calibration_lines
4✔
6
from .util import gauss
4✔
7
from .util import filter_intensity
4✔
8
from .util import filter_separation
4✔
9

10
logger = logging.getLogger("plotting")
4✔
11

12

13
def _import_matplotlib():
4✔
14
    """
15
    Call to import matplotlib.
16

17
    """
18

19
    try:
4✔
20
        global plt
21
        import matplotlib.pyplot as plt
4✔
22

23
    except ImportError:
×
24
        logger.error("matplotlib package not available.")
×
25

26

27
def _import_plotly():
4✔
28
    """
29
    Call to import plotly.
30

31
    """
32

33
    try:
4✔
34
        global go
35
        global pio
36
        global psp
37
        import plotly.graph_objects as go
4✔
38
        import plotly.io as pio
4✔
39
        import plotly.subplots as psp
4✔
40

41
        pio.templates["CN"] = go.layout.Template(
4✔
42
            layout_colorway=[
43
                "#1f77b4",
44
                "#ff7f0e",
45
                "#2ca02c",
46
                "#d62728",
47
                "#9467bd",
48
                "#8c564b",
49
                "#e377c2",
50
                "#7f7f7f",
51
                "#bcbd22",
52
                "#17becf",
53
            ]
54
        )
55

56
        # setting Google color palette as default
57
        pio.templates.default = "CN"
4✔
58

59
    except ImportError:
×
60
        logger.error("plotly package not available.")
×
61

62

63
def plot_calibration_lines(
4✔
64
    elements=[],
65
    min_atlas_wavelength=3000.0,
66
    max_atlas_wavelength=15000.0,
67
    min_intensity=5.0,
68
    min_distance=0.0,
69
    brightest_n_lines=None,
70
    pixel_scale=1.0,
71
    vacuum=False,
72
    pressure=101325.0,
73
    temperature=273.15,
74
    relative_humidity=0.0,
75
    label=False,
76
    log=False,
77
    save_fig=False,
78
    fig_type="png",
79
    filename=None,
80
    display=True,
81
    linelist="nist",
82
    fig_kwarg={"figsize": (12, 8)},
83
):
84
    """
85
    Plot the expected arc spectrum. Only available in matplotlib at the moment.
86

87
    Parameters
88
    ----------
89
    elements: list
90
        List of short element names, e.g. He as per NIST
91
    linelist: str
92
        Either 'nist' to use the default lines or path to a linelist file.
93
    min_atlas_wavelength: int
94
        Minimum wavelength to search, Angstrom
95
    max_atlas_wavelength: int
96
        Maximum wavelength to search, Angstrom
97
    min_intensity: int
98
        Minimum intensity to search, per NIST
99
    min_distance: int
100
        All ines within this distance from other lines are treated
101
        as unresolved, all of them get removed from the list.
102
    brightest_n_lines: int
103
        Only return the n brightest lines
104
    vacuum: bool
105
        Return vacuum wavelengths
106
    pressure: float
107
        Atmospheric pressure, Pascal
108
    temperature: float
109
        Temperature in Kelvin, default room temp
110
    relative_humidity: float
111
        Relative humidity, percent
112
    log: bool
113
        Plot intensities in log scale
114
    save_fig: boolean (default: False)
115
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
116
        while the plotly uses the pio.write_html() or pio.write_image().
117
        The support format types should be provided in fig_type.
118
    fig_type: string (default: 'png')
119
        Image type to be saved, choose from:
120
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
121
    filename: string (default: None)
122
        Provide a filename or full path. If the extension is not provided
123
        it is defaulted to png.
124
    display: boolean (Default: False)
125
        Set to True to display disgnostic plot.
126
    Returns
127
    -------
128
    fig: matplotlib figure object
129
    """
130

131
    _import_matplotlib()
4✔
132

133
    # the min_intensity and min_distance are set to 0.0 because the
134
    # simulated spectrum would contain them. These arguments only
135
    # affect the labelling.
136
    element_list, wavelength_list, intensity_list = load_calibration_lines(
4✔
137
        elements=elements,
138
        linelist=linelist,
139
        min_atlas_wavelength=min_atlas_wavelength,
140
        max_atlas_wavelength=max_atlas_wavelength,
141
        min_intensity=0.0,
142
        min_distance=0.0,
143
        brightest_n_lines=brightest_n_lines,
144
        vacuum=vacuum,
145
        pressure=pressure,
146
        temperature=temperature,
147
        relative_humidity=relative_humidity,
148
    )
149

150
    # Nyquist sampling rate (2.5) for CCD at seeing of 1 arcsec
151
    sigma = pixel_scale * 2.5 * 1.0
4✔
152
    x = np.arange(-100, 100.001, 0.001)
4✔
153
    gaussian = gauss(x, a=1.0, x0=0.0, sigma=sigma)
4✔
154

155
    # Generate the equally spaced-wavelength array, and the
156
    # corresponding intensity
157
    w = np.around(
4✔
158
        np.arange(min_atlas_wavelength, max_atlas_wavelength + 0.001, 0.001),
159
        decimals=3,
160
    ).astype("float64")
161
    i = np.zeros_like(w)
4✔
162

163
    for e in elements:
4✔
164
        i[np.isin(w, np.around(wavelength_list[element_list == e], decimals=3))] += intensity_list[element_list == e]
4✔
165
    # Convolve to simulate the arc spectrum
166
    model_spectrum = signal.convolve(i, gaussian, mode="same")
4✔
167

168
    # now clean up by min_intensity and min_distance
169
    intensity_mask = filter_intensity(
4✔
170
        elements,
171
        np.column_stack((element_list, wavelength_list, intensity_list)),
172
        min_intensity=min_intensity,
173
    )
174
    wavelength_list = wavelength_list[intensity_mask]
4✔
175
    intensity_list = intensity_list[intensity_mask]
4✔
176
    element_list = element_list[intensity_mask]
4✔
177

178
    distance_mask = filter_separation(wavelength_list, min_separation=min_distance)
4✔
179
    wavelength_list = wavelength_list[distance_mask]
4✔
180
    intensity_list = intensity_list[distance_mask]
4✔
181
    element_list = element_list[distance_mask]
4✔
182

183
    fig = plt.figure(**fig_kwarg)
4✔
184

185
    for j, e in enumerate(elements):
4✔
186
        e_mask = element_list == e
4✔
187
        markerline, stemline, _ = plt.stem(
4✔
188
            wavelength_list[e_mask],
189
            intensity_list[e_mask],
190
            label=e,
191
            linefmt="C{}-".format(j),
192
        )
193
        plt.setp(stemline, linewidth=2.0)
4✔
194
        plt.setp(markerline, markersize=2.5, color="C{}".format(j))
4✔
195

196
        if label:
4✔
197
            for _w in wavelength_list[e_mask]:
4✔
198
                plt.text(
4✔
199
                    _w,
200
                    max(model_spectrum) * 1.05,
201
                    s="{}: {:1.2f}".format(e, _w),
202
                    rotation=90,
203
                    bbox=dict(facecolor="white", alpha=1),
204
                )
205

206
            plt.vlines(
4✔
207
                wavelength_list[e_mask],
208
                intensity_list[e_mask],
209
                max(model_spectrum) * 1.25,
210
                linestyles="dashed",
211
                lw=0.5,
212
                color="grey",
213
            )
214

215
    plt.plot(w, model_spectrum, lw=1.0, c="k", label="Simulated Arc Spectrum")
4✔
216
    if vacuum:
4✔
217
        plt.xlabel("Vacuum Wavelength / A")
×
218
    else:
219
        plt.xlabel("Air Wavelength / A")
4✔
220
    plt.ylabel("NIST intensity")
4✔
221
    plt.grid()
4✔
222
    plt.xlim(min(w), max(w))
4✔
223
    plt.ylim(0, max(model_spectrum) * 1.25)
4✔
224
    plt.legend()
4✔
225
    plt.tight_layout()
4✔
226
    if log:
4✔
227
        plt.ylim(ymin=min_intensity * 0.75)
×
228
        plt.yscale("log")
×
229

230
    if save_fig:
4✔
231
        fig_type = fig_type.split("+")
4✔
232

233
        if filename is None:
4✔
234
            filename_output = "rascal_arc"
×
235

236
        else:
237
            filename_output = filename
4✔
238

239
        for t in fig_type:
4✔
240
            if t in ["jpg", "png", "svg", "pdf"]:
4✔
241
                plt.savefig(filename_output + "." + t, format=t)
4✔
242

243
    if display:
4✔
244
        plt.show()
4✔
245

246
    return fig
4✔
247

248

249
def plot_search_space(
4✔
250
    calibrator,
251
    fit_coeff=None,
252
    top_n_candidate=3,
253
    weighted=True,
254
    save_fig=False,
255
    fig_type="png",
256
    filename=None,
257
    return_jsonstring=False,
258
    renderer="default",
259
    display=True,
260
):
261
    """
262
    Plots the peak/arc line pairs that are considered as potential match
263
    candidates.
264

265
    If fit fit_coefficients are provided, the model solution will be
266
    overplotted.
267

268
    Parameters
269
    ----------
270
    fit_coeff: list (default: None)
271
        List of best polynomial fit_coefficients
272
    top_n_candidate: int (default: 3)
273
        Top ranked lines to be fitted.
274
    weighted: (default: True)
275
        Draw sample based on the distance from the matched known wavelength
276
        of the atlas.
277
    save_fig: boolean (default: False)
278
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
279
        while the plotly uses the pio.write_html() or pio.write_image().
280
        The support format types should be provided in fig_type.
281
    fig_type: string (default: 'png')
282
        Image type to be saved, choose from:
283
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
284
    filename: (default: None)
285
        The destination to save the image.
286
    return_jsonstring: (default: False)
287
        Set to True to save the plotly figure as json string. Ignored if
288
        matplotlib is used.
289
    renderer: (default: 'default')
290
        Set the rendered for the plotly display. Ignored if matplotlib is
291
        used.
292
    display: boolean (Default: False)
293
        Set to True to display disgnostic plot.
294

295
    Return
296
    ------
297
    json object if return_jsonstring is True.
298

299

300
    """
301

302
    # Get top linear estimates and combine
303
    candidate_peak, candidate_arc = calibrator._get_most_common_candidates(
4✔
304
        calibrator.candidates,
305
        top_n_candidate=top_n_candidate,
306
        weighted=weighted,
307
    )
308

309
    # Get the search space boundaries
310
    x = calibrator.pixel_list
4✔
311

312
    m_1 = (calibrator.max_wavelength - calibrator.min_wavelength) / calibrator.pixel_list.max()
4✔
313
    y_1 = m_1 * x + calibrator.min_wavelength
4✔
314

315
    m_2 = (
4✔
316
        calibrator.max_wavelength + calibrator.range_tolerance - (calibrator.min_wavelength + calibrator.range_tolerance)
317
    ) / calibrator.pixel_list.max()
318
    y_2 = m_2 * x + calibrator.min_wavelength + calibrator.range_tolerance
4✔
319

320
    m_3 = (
4✔
321
        calibrator.max_wavelength - calibrator.range_tolerance - (calibrator.min_wavelength - calibrator.range_tolerance)
322
    ) / calibrator.pixel_list.max()
323
    y_3 = m_3 * x + (calibrator.min_wavelength - calibrator.range_tolerance)
4✔
324

325
    if calibrator.plot_with_matplotlib:
4✔
326
        _import_matplotlib()
4✔
327

328
        fig = plt.figure(figsize=(10, 10))
4✔
329

330
        # Plot all-pairs
331
        plt.scatter(*calibrator.pairs.T, alpha=0.2, color="C0", label="All pairs")
4✔
332

333
        plt.scatter(
4✔
334
            calibrator._merge_candidates(calibrator.candidates)[:, 0],
335
            calibrator._merge_candidates(calibrator.candidates)[:, 1],
336
            alpha=0.2,
337
            color="C1",
338
            label="Candidate Pairs",
339
        )
340

341
        # Tolerance region around the minimum wavelength
342
        plt.text(
4✔
343
            5,
344
            calibrator.min_wavelength + 100,
345
            "Min wavelength (user-supplied)",
346
        )
347
        plt.hlines(
4✔
348
            calibrator.min_wavelength,
349
            0,
350
            calibrator.pixel_list.max(),
351
            color="k",
352
        )
353
        plt.hlines(
4✔
354
            calibrator.min_wavelength + calibrator.range_tolerance,
355
            0,
356
            calibrator.pixel_list.max(),
357
            linestyle="dashed",
358
            alpha=0.5,
359
            color="k",
360
        )
361
        plt.hlines(
4✔
362
            calibrator.min_wavelength - calibrator.range_tolerance,
363
            0,
364
            calibrator.pixel_list.max(),
365
            linestyle="dashed",
366
            alpha=0.5,
367
            color="k",
368
        )
369

370
        # Tolerance region around the maximum wavelength
371
        plt.text(
4✔
372
            5,
373
            calibrator.max_wavelength + 100,
374
            "Max wavelength (user-supplied)",
375
        )
376
        plt.hlines(
4✔
377
            calibrator.max_wavelength,
378
            0,
379
            calibrator.pixel_list.max(),
380
            color="k",
381
        )
382
        plt.hlines(
4✔
383
            calibrator.max_wavelength + calibrator.range_tolerance,
384
            0,
385
            calibrator.pixel_list.max(),
386
            linestyle="dashed",
387
            alpha=0.5,
388
            color="k",
389
        )
390
        plt.hlines(
4✔
391
            calibrator.max_wavelength - calibrator.range_tolerance,
392
            0,
393
            calibrator.pixel_list.max(),
394
            linestyle="dashed",
395
            alpha=0.5,
396
            color="k",
397
        )
398

399
        # The line from (first pixel, minimum wavelength) to
400
        # (last pixel, maximum wavelength), and the two lines defining the
401
        # tolerance region.
402
        plt.plot(x, y_1, label="Linear Fit", color="C3")
4✔
403
        plt.plot(x, y_2, linestyle="dashed", label="Tolerance Region", color="C3")
4✔
404
        plt.plot(x, y_3, linestyle="dashed", color="C3")
4✔
405

406
        if fit_coeff is not None:
4✔
407
            plt.scatter(
×
408
                calibrator.peaks,
409
                calibrator.polyval(calibrator.peaks, fit_coeff),
410
                color="C4",
411
                label="Solution",
412
            )
413

414
        plt.scatter(
4✔
415
            candidate_peak,
416
            candidate_arc,
417
            color="C2",
418
            label="Best Candidate Pairs",
419
        )
420

421
        plt.xlim(0, calibrator.pixel_list.max())
4✔
422
        plt.ylim(
4✔
423
            calibrator.min_wavelength - calibrator.range_tolerance,
424
            calibrator.max_wavelength + calibrator.range_tolerance,
425
        )
426

427
        plt.xlabel("Wavelength / A")
4✔
428
        plt.ylabel("Pixel")
4✔
429
        plt.legend()
4✔
430
        plt.grid()
4✔
431
        plt.tight_layout()
4✔
432

433
        if save_fig:
4✔
434
            fig_type = fig_type.split("+")
4✔
435

436
            if filename is None:
4✔
437
                filename_output = "rascal_hough_search_space"
×
438

439
            else:
440
                filename_output = filename
4✔
441

442
            for t in fig_type:
4✔
443
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
444
                    plt.savefig(filename_output + "." + t, format=t)
4✔
445

446
        if display:
4✔
447
            plt.show()
×
448

449
        return fig
4✔
450

451
    elif calibrator.plot_with_plotly:
4✔
452
        _import_plotly()
4✔
453

454
        fig = go.Figure()
4✔
455

456
        # Plot all-pairs
457
        fig.add_trace(
4✔
458
            go.Scatter(
459
                x=calibrator.pairs[:, 0],
460
                y=calibrator.pairs[:, 1],
461
                mode="markers",
462
                name="All Pairs",
463
                marker=dict(color=pio.templates["CN"].layout.colorway[0], opacity=0.2),
464
            )
465
        )
466

467
        fig.add_trace(
4✔
468
            go.Scatter(
469
                x=calibrator._merge_candidates(calibrator.candidates)[:, 0],
470
                y=calibrator._merge_candidates(calibrator.candidates)[:, 1],
471
                mode="markers",
472
                name="Candidate Pairs",
473
                marker=dict(color=pio.templates["CN"].layout.colorway[1], opacity=0.2),
474
            )
475
        )
476
        fig.add_trace(
4✔
477
            go.Scatter(
478
                x=candidate_peak,
479
                y=candidate_arc,
480
                mode="markers",
481
                name="Best Candidate Pairs",
482
                marker=dict(color=pio.templates["CN"].layout.colorway[2]),
483
            )
484
        )
485

486
        # Tolerance region around the minimum wavelength
487
        fig.add_trace(
4✔
488
            go.Scatter(
489
                x=[0, calibrator.pixel_list.max()],
490
                y=[calibrator.min_wavelength, calibrator.min_wavelength],
491
                name="Min/Maximum",
492
                mode="lines",
493
                line=dict(color="black"),
494
            )
495
        )
496
        fig.add_trace(
4✔
497
            go.Scatter(
498
                x=[0, calibrator.pixel_list.max()],
499
                y=[
500
                    calibrator.min_wavelength + calibrator.range_tolerance,
501
                    calibrator.min_wavelength + calibrator.range_tolerance,
502
                ],
503
                name="Tolerance Range",
504
                mode="lines",
505
                line=dict(color="black", dash="dash"),
506
            )
507
        )
508
        fig.add_trace(
4✔
509
            go.Scatter(
510
                x=[0, calibrator.pixel_list.max()],
511
                y=[
512
                    calibrator.min_wavelength - calibrator.range_tolerance,
513
                    calibrator.min_wavelength - calibrator.range_tolerance,
514
                ],
515
                showlegend=False,
516
                mode="lines",
517
                line=dict(color="black", dash="dash"),
518
            )
519
        )
520

521
        # Tolerance region around the minimum wavelength
522
        fig.add_trace(
4✔
523
            go.Scatter(
524
                x=[0, calibrator.pixel_list.max()],
525
                y=[calibrator.max_wavelength, calibrator.max_wavelength],
526
                showlegend=False,
527
                mode="lines",
528
                line=dict(color="black"),
529
            )
530
        )
531
        fig.add_trace(
4✔
532
            go.Scatter(
533
                x=[0, calibrator.pixel_list.max()],
534
                y=[
535
                    calibrator.max_wavelength + calibrator.range_tolerance,
536
                    calibrator.max_wavelength + calibrator.range_tolerance,
537
                ],
538
                showlegend=False,
539
                mode="lines",
540
                line=dict(color="black", dash="dash"),
541
            )
542
        )
543
        fig.add_trace(
4✔
544
            go.Scatter(
545
                x=[0, calibrator.pixel_list.max()],
546
                y=[
547
                    calibrator.max_wavelength - calibrator.range_tolerance,
548
                    calibrator.max_wavelength - calibrator.range_tolerance,
549
                ],
550
                showlegend=False,
551
                mode="lines",
552
                line=dict(color="black", dash="dash"),
553
            )
554
        )
555

556
        # The line from (first pixel, minimum wavelength) to
557
        # (last pixel, maximum wavelength), and the two lines defining the
558
        # tolerance region.
559
        fig.add_trace(
4✔
560
            go.Scatter(
561
                x=x,
562
                y=y_1,
563
                mode="lines",
564
                name="Linear Fit",
565
                line=dict(color=pio.templates["CN"].layout.colorway[3]),
566
            )
567
        )
568
        fig.add_trace(
4✔
569
            go.Scatter(
570
                x=x,
571
                y=y_2,
572
                mode="lines",
573
                name="Tolerance Region",
574
                line=dict(
575
                    color=pio.templates["CN"].layout.colorway[3],
576
                    dash="dashdot",
577
                ),
578
            )
579
        )
580
        fig.add_trace(
4✔
581
            go.Scatter(
582
                x=x,
583
                y=y_3,
584
                showlegend=False,
585
                mode="lines",
586
                line=dict(
587
                    color=pio.templates["CN"].layout.colorway[3],
588
                    dash="dashdot",
589
                ),
590
            )
591
        )
592

593
        if fit_coeff is not None:
4✔
594
            fig.add_trace(
×
595
                go.Scatter(
596
                    x=calibrator.peaks,
597
                    y=calibrator.polyval(calibrator.peaks, fit_coeff),
598
                    mode="markers",
599
                    name="Solution",
600
                    marker=dict(color=pio.templates["CN"].layout.colorway[4]),
601
                )
602
            )
603

604
        # Layout, Title, Grid config
605
        fig.update_layout(
4✔
606
            autosize=True,
607
            yaxis=dict(
608
                title="Pixel",
609
                range=[
610
                    calibrator.min_wavelength - calibrator.range_tolerance * 1.1,
611
                    calibrator.max_wavelength + calibrator.range_tolerance * 1.1,
612
                ],
613
                showgrid=True,
614
            ),
615
            xaxis=dict(
616
                title="Wavelength / A",
617
                zeroline=False,
618
                range=[0.0, calibrator.pixel_list.max()],
619
                showgrid=True,
620
            ),
621
            hovermode="closest",
622
            showlegend=True,
623
            height=800,
624
            width=1000,
625
        )
626

627
        if save_fig:
4✔
628
            fig_type = fig_type.split("+")
4✔
629

630
            if filename is None:
4✔
631
                filename_output = "rascal_hough_search_space"
×
632

633
            else:
634
                filename_output = filename
4✔
635

636
            for t in fig_type:
4✔
637
                if t == "iframe":
4✔
638
                    pio.write_html(fig, filename_output + "." + t)
×
639

640
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
641
                    pio.write_image(fig, filename_output + "." + t)
4✔
642

643
        if display:
4✔
644
            if renderer == "default":
×
645
                fig.show()
×
646

647
            else:
648
                fig.show(renderer)
×
649

650
        if return_jsonstring:
4✔
651
            return fig.to_json()
×
652

653

654
def plot_fit(
4✔
655
    calibrator,
656
    fit_coeff,
657
    spectrum=None,
658
    tolerance=5.0,
659
    plot_atlas=True,
660
    log_spectrum=False,
661
    save_fig=False,
662
    fig_type="png",
663
    filename=None,
664
    return_jsonstring=False,
665
    renderer="default",
666
    display=True,
667
):
668
    """
669
    Plots of the wavelength calibrated arc spectrum, the residual and the
670
    pixel-to-wavelength solution.
671

672
    Parameters
673
    ----------
674
    fit_coeff: 1D numpy array or list
675
        Best fit polynomail fit_coefficients
676
    spectrum: 1D numpy array (N)
677
        Array of length N pixels
678
    tolerance: float (default: 5)
679
        Absolute difference between model and fitted wavelengths in unit
680
        of angstrom.
681
    plot_atlas: boolean (default: True)
682
        Display all the relavent lines available in the atlas library.
683
    log_spectrum: boolean (default: False)
684
        Display the arc in log-space if set to True.
685
    save_fig: boolean (default: False)
686
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
687
        while the plotly uses the pio.write_html() or pio.write_image().
688
        The support format types should be provided in fig_type.
689
    fig_type: string (default: 'png')
690
        Image type to be saved, choose from:
691
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
692
    filename: string (default: None)
693
        Provide a filename or full path. If the extension is not provided
694
        it is defaulted to png.
695
    return_jsonstring: boolean (default: False)
696
        Set to True to return json strings if using plotly as the plotting
697
        library.
698
    renderer: string (default: 'default')
699
        Indicate the Plotly renderer. Nothing gets displayed if
700
        return_jsonstring is set to True.
701
    display: boolean (Default: False)
702
        Set to True to display disgnostic plot.
703

704
    Returns
705
    -------
706
    Return json strings if using plotly as the plotting library and json
707
    is True.
708

709
    """
710

711
    if spectrum is None:
4✔
712
        try:
4✔
713
            spectrum = calibrator.spectrum
4✔
714

715
        except Exception as e:
×
716
            calibrator.logger.error(e)
×
NEW
717
            calibrator.logger.error("Spectrum is not provided, it cannot be plotted.")
×
718

719
    if spectrum is not None:
4✔
720
        if log_spectrum:
4✔
721
            spectrum[spectrum < 0] = 1e-100
4✔
722
            spectrum = np.log10(spectrum)
4✔
723
            vline_max = np.nanmax(spectrum) * 2.0
4✔
724
            text_box_pos = 1.2 * max(spectrum)
4✔
725

726
        else:
727
            vline_max = np.nanmax(spectrum) * 1.2
4✔
728
            text_box_pos = 0.8 * max(spectrum)
4✔
729

730
    else:
731
        vline_max = 1.0
4✔
732
        text_box_pos = 0.5
4✔
733

734
    wave = calibrator.polyval(calibrator.pixel_list, fit_coeff)
4✔
735

736
    if calibrator.plot_with_matplotlib:
4✔
737
        _import_matplotlib()
4✔
738

739
        fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, sharex=True, gridspec_kw={"hspace": 0.0}, figsize=(15, 9))
4✔
740
        fig.tight_layout()
4✔
741

742
        # Plot fitted spectrum
743
        if spectrum is not None:
4✔
744
            ax1.plot(wave, spectrum, label="Arc Spectrum")
4✔
745
            ax1.vlines(
4✔
746
                calibrator.polyval(calibrator.peaks, fit_coeff),
747
                np.array(spectrum)[calibrator.pix_to_rawpix(calibrator.peaks).astype("int")],
748
                vline_max,
749
                linestyles="dashed",
750
                colors="C1",
751
                label="Detected Peaks",
752
            )
753

754
        # Plot the atlas
755
        if plot_atlas:
4✔
756
            # spec = SyntheticSpectrum(
757
            #    fit, model_type='poly', degree=len(fit)-1)
758
            # x_locs = spec.get_pixels(calibrator.atlas)
759
            ax1.vlines(
4✔
760
                calibrator.atlas.get_lines(),
761
                0,
762
                vline_max,
763
                colors="C2",
764
                label="Given Lines",
765
            )
766

767
        fitted_peaks = []
4✔
768
        fitted_diff = []
4✔
769
        all_diff = []
4✔
770

771
        first_one = True
4✔
772
        for p, x in zip(calibrator.matched_peaks, calibrator.matched_atlas):
4✔
773
            diff = calibrator.atlas.get_lines() - x
4✔
774
            idx = np.argmin(np.abs(diff))
4✔
775
            all_diff.append(diff[idx])
4✔
776

777
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
778

779
            fitted_peaks.append(p)
4✔
780
            fitted_diff.append(diff[idx])
4✔
781
            calibrator.logger.info("- matched to {} A".format(calibrator.atlas.get_lines()[idx]))
4✔
782

783
            if spectrum is not None:
4✔
784
                if first_one:
4✔
785
                    ax1.vlines(
4✔
786
                        calibrator.polyval(p, fit_coeff),
787
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
788
                        vline_max,
789
                        colors="C1",
790
                        label="Fitted Peaks",
791
                    )
792
                    first_one = False
4✔
793

794
                else:
795
                    ax1.vlines(
4✔
796
                        calibrator.polyval(p, fit_coeff),
797
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
798
                        vline_max,
799
                        colors="C1",
800
                    )
801

802
            ax1.text(
4✔
803
                x - 3,
804
                text_box_pos,
805
                s="{}:{:1.2f}".format(
806
                    calibrator.atlas.get_elements()[idx],
807
                    calibrator.atlas.get_lines()[idx],
808
                ),
809
                rotation=90,
810
                bbox=dict(facecolor="white", alpha=1),
811
            )
812

813
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
814

815
        ax1.grid(linestyle=":")
4✔
816
        ax1.set_ylabel("Electron Count / e-")
4✔
817

818
        if spectrum is not None:
4✔
819
            if log_spectrum:
4✔
820
                ax1.set_ylim(0, vline_max)
4✔
821

822
            else:
823
                ax1.set_ylim(np.nanmin(spectrum), vline_max)
4✔
824

825
        ax1.legend(loc="center right")
4✔
826

827
        # Plot the residuals
828
        ax2.scatter(
4✔
829
            calibrator.polyval(fitted_peaks, fit_coeff),
830
            fitted_diff,
831
            marker="+",
832
            color="C1",
833
        )
834
        ax2.hlines(0, wave.min(), wave.max(), linestyles="dashed")
4✔
835
        ax2.hlines(
4✔
836
            rms,
837
            wave.min(),
838
            wave.max(),
839
            linestyles="dashed",
840
            color="k",
841
            label="RMS",
842
        )
843
        ax2.hlines(-rms, wave.min(), wave.max(), linestyles="dashed", color="k")
4✔
844
        ax2.grid(linestyle=":")
4✔
845
        ax2.set_ylabel("Residual / A")
4✔
846
        ax2.legend()
4✔
847
        """
4✔
848
        ax2.text(
849
            min(wave) + np.ptp(wave) * 0.05,
850
            max(spectrum),
851
            'RMS =' + str(rms)[:6]
852
            )
853
        """
854

855
        # Plot the polynomial
856
        ax3.scatter(
4✔
857
            calibrator.polyval(fitted_peaks, fit_coeff),
858
            fitted_peaks,
859
            marker="+",
860
            color="C1",
861
            label="Fitted Peaks",
862
        )
863
        ax3.plot(wave, calibrator.pixel_list, color="C2", label="Solution")
4✔
864
        ax3.grid(linestyle=":")
4✔
865
        ax3.set_xlabel("Wavelength / A")
4✔
866
        ax3.set_ylabel("Pixel")
4✔
867
        ax3.legend(loc="lower right")
4✔
868
        w_min = calibrator.polyval(min(fitted_peaks), fit_coeff)
4✔
869
        w_max = calibrator.polyval(max(fitted_peaks), fit_coeff)
4✔
870
        ax3.set_xlim(w_min * 0.95, w_max * 1.05)
4✔
871

872
        plt.tight_layout()
4✔
873

874
        if save_fig:
4✔
875
            fig_type = fig_type.split("+")
4✔
876

877
            if filename is None:
4✔
878
                filename_output = "rascal_solution"
×
879

880
            else:
881
                filename_output = filename
4✔
882

883
            for t in fig_type:
4✔
884
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
885
                    plt.savefig(filename_output + "." + t, format=t)
4✔
886

887
        if display:
4✔
888
            plt.show()
4✔
889

890
        return fig
4✔
891

892
    elif calibrator.plot_with_plotly:
4✔
893
        _import_plotly()
4✔
894

895
        fig = go.Figure()
4✔
896

897
        # Top plot - arc spectrum and matched peaks
898
        if spectrum is not None:
4✔
899
            fig.add_trace(
4✔
900
                go.Scatter(
901
                    x=wave,
902
                    y=spectrum,
903
                    mode="lines",
904
                    yaxis="y3",
905
                    name="Arc Spectrum",
906
                )
907
            )
908

909
            spec_max = np.nanmax(spectrum) * 1.05
4✔
910

911
        else:
912
            spec_max = vline_max
×
913

914
        fitted_peaks = []
4✔
915
        fitted_peaks_adu = []
4✔
916
        fitted_diff = []
4✔
917
        all_diff = []
4✔
918

919
        for p in calibrator.peaks:
4✔
920
            x = calibrator.polyval(p, fit_coeff)
4✔
921

922
            # Add vlines
923
            fig.add_shape(
4✔
924
                type="line",
925
                xref="x",
926
                yref="y3",
927
                x0=x,
928
                y0=0,
929
                x1=x,
930
                y1=spec_max,
931
                line=dict(color=pio.templates["CN"].layout.colorway[1], width=1),
932
            )
933

934
            diff = calibrator.atlas.get_lines() - x
4✔
935
            idx = np.argmin(np.abs(diff))
4✔
936
            all_diff.append(diff[idx])
4✔
937

938
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
939

940
            if np.abs(diff[idx]) < tolerance:
4✔
941
                fitted_peaks.append(p)
4✔
942
                if spectrum is not None:
4✔
943
                    fitted_peaks_adu.append(spectrum[int(calibrator.pix_to_rawpix(p))])
4✔
944
                fitted_diff.append(diff[idx])
4✔
945
                calibrator.logger.info("- matched to {} A".format(calibrator.atlas.get_lines()[idx]))
4✔
946

947
        x_fitted = calibrator.polyval(fitted_peaks, fit_coeff)
4✔
948

949
        fig.add_trace(
4✔
950
            go.Scatter(
951
                x=x_fitted,
952
                y=fitted_peaks_adu,
953
                mode="markers",
954
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
955
                yaxis="y3",
956
                showlegend=False,
957
            )
958
        )
959

960
        # Middle plot - Residual plot
961
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
962
        fig.add_trace(
4✔
963
            go.Scatter(
964
                x=x_fitted,
965
                y=fitted_diff,
966
                mode="markers",
967
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
968
                yaxis="y2",
969
                showlegend=False,
970
            )
971
        )
972
        fig.add_trace(
4✔
973
            go.Scatter(
974
                x=[wave.min(), wave.max()],
975
                y=[0, 0],
976
                mode="lines",
977
                line=dict(color=pio.templates["CN"].layout.colorway[0], dash="dash"),
978
                yaxis="y2",
979
                showlegend=False,
980
            )
981
        )
982
        fig.add_trace(
4✔
983
            go.Scatter(
984
                x=[wave.min(), wave.max()],
985
                y=[rms, rms],
986
                mode="lines",
987
                line=dict(color="black", dash="dash"),
988
                yaxis="y2",
989
                showlegend=False,
990
            )
991
        )
992
        fig.add_trace(
4✔
993
            go.Scatter(
994
                x=[wave.min(), wave.max()],
995
                y=[-rms, -rms],
996
                mode="lines",
997
                line=dict(color="black", dash="dash"),
998
                yaxis="y2",
999
                name="RMS",
1000
            )
1001
        )
1002

1003
        # Bottom plot - Polynomial fit for Pixel to Wavelength
1004
        fig.add_trace(
4✔
1005
            go.Scatter(
1006
                x=x_fitted,
1007
                y=fitted_peaks,
1008
                mode="markers",
1009
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1010
                yaxis="y1",
1011
                name="Fitted Peaks",
1012
            )
1013
        )
1014
        fig.add_trace(
4✔
1015
            go.Scatter(
1016
                x=wave,
1017
                y=calibrator.pixel_list,
1018
                mode="lines",
1019
                line=dict(color=pio.templates["CN"].layout.colorway[2]),
1020
                yaxis="y1",
1021
                name="Solution",
1022
            )
1023
        )
1024

1025
        # Layout, Title, Grid config
1026
        if spectrum is not None:
4✔
1027
            if log_spectrum:
4✔
1028
                fig.update_layout(
×
1029
                    yaxis3=dict(
1030
                        title="Electron Count / e-",
1031
                        range=[
1032
                            np.log10(np.percentile(spectrum, 15)),
1033
                            np.log10(spec_max),
1034
                        ],
1035
                        domain=[0.67, 1.0],
1036
                        showgrid=True,
1037
                        type="log",
1038
                    )
1039
                )
1040

1041
            else:
1042
                fig.update_layout(
4✔
1043
                    yaxis3=dict(
1044
                        title="Electron Count / e-",
1045
                        range=[np.percentile(spectrum, 15), spec_max],
1046
                        domain=[0.67, 1.0],
1047
                        showgrid=True,
1048
                    )
1049
                )
1050

1051
        fig.update_layout(
4✔
1052
            autosize=True,
1053
            yaxis2=dict(
1054
                title="Residual / A",
1055
                range=[min(fitted_diff), max(fitted_diff)],
1056
                domain=[0.33, 0.66],
1057
                showgrid=True,
1058
            ),
1059
            yaxis=dict(
1060
                title="Pixel",
1061
                range=[0.0, max(calibrator.pixel_list)],
1062
                domain=[0.0, 0.32],
1063
                showgrid=True,
1064
            ),
1065
            xaxis=dict(
1066
                title="Wavelength / A",
1067
                zeroline=False,
1068
                range=[
1069
                    calibrator.polyval(min(fitted_peaks), fit_coeff) * 0.95,
1070
                    calibrator.polyval(max(fitted_peaks), fit_coeff) * 1.05,
1071
                ],
1072
                showgrid=True,
1073
            ),
1074
            hovermode="closest",
1075
            showlegend=True,
1076
            height=800,
1077
            width=1000,
1078
        )
1079

1080
        if save_fig:
4✔
1081
            fig_type = fig_type.split("+")
4✔
1082

1083
            if filename is None:
4✔
1084
                filename_output = "rascal_solution"
×
1085

1086
            else:
1087
                filename_output = filename
4✔
1088

1089
            for t in fig_type:
4✔
1090
                if t == "iframe":
4✔
1091
                    pio.write_html(fig, filename_output + "." + t)
×
1092

1093
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1094
                    pio.write_image(fig, filename_output + "." + t)
4✔
1095

1096
        if display:
4✔
1097
            if renderer == "default":
×
1098
                fig.show()
×
1099

1100
            else:
1101
                fig.show(renderer)
×
1102

1103
        if return_jsonstring:
4✔
1104
            return fig.to_json()
×
1105

1106
    else:
NEW
1107
        assert calibrator.matplotlib_imported, "matplotlib package not available. " + "Plot cannot be generated."
×
NEW
1108
        assert calibrator.plotly_imported, "plotly package is not available. " + "Plot cannot be generated."
×
1109

1110

1111
def plot_arc(
4✔
1112
    calibrator,
1113
    pixel_list=None,
1114
    log_spectrum=False,
1115
    save_fig=False,
1116
    fig_type="png",
1117
    filename=None,
1118
    return_jsonstring=False,
1119
    renderer="default",
1120
    display=True,
1121
):
1122
    """
1123
    Plots the 1D spectrum of the extracted arc.
1124

1125
    parameters
1126
    ----------
1127
    pixel_list: array (default: None)
1128
        pixel value of the of the spectrum, this is only needed if the
1129
        spectrum spans multiple detector arrays.
1130
    log_spectrum: boolean (default: False)
1131
        Set to true to display the wavelength calibrated arc spectrum in
1132
        logarithmic space.
1133
    save_fig: boolean (default: False)
1134
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
1135
        while the plotly uses the pio.write_html() or pio.write_image().
1136
        The support format types should be provided in fig_type.
1137
    fig_type: string (default: 'png')
1138
        Image type to be saved, choose from:
1139
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
1140
    filename: string (default: None)
1141
        Provide a filename or full path. If the extension is not provided
1142
        it is defaulted to png.
1143
    return_jsonstring: boolean (default: False)
1144
        Set to True to return json strings if using plotly as the plotting
1145
        library.
1146
    renderer: string (default: 'default')
1147
        Indicate the Plotly renderer. Nothing gets displayed if
1148
        return_jsonstring is set to True.
1149

1150
    display: boolean (Default: False)
1151
        Set to True to display disgnostic plot.
1152

1153
    Returns
1154
    -------
1155
    Return json strings if using plotly as the plotting library and json
1156
    is True.
1157

1158
    """
1159

1160
    if pixel_list is None:
4✔
1161
        pixel_list = np.arange(len(calibrator.spectrum))
4✔
1162

1163
    if calibrator.plot_with_matplotlib:
4✔
1164
        _import_matplotlib()
4✔
1165

1166
        fig = plt.figure(figsize=(18, 5))
4✔
1167

1168
        if calibrator.spectrum is not None:
4✔
1169
            if log_spectrum:
4✔
1170
                plt.plot(
4✔
1171
                    pixel_list,
1172
                    np.log10(calibrator.spectrum / calibrator.spectrum.max()),
1173
                    label="Arc Spectrum",
1174
                )
1175
                plt.vlines(calibrator.peaks, -2, 0, label="Detected Peaks", color="C1")
4✔
1176
                plt.ylabel("log(Normalised Count)")
4✔
1177
                plt.ylim(-2, 0)
4✔
1178
            else:
1179
                plt.plot(
4✔
1180
                    pixel_list,
1181
                    calibrator.spectrum / calibrator.spectrum.max(),
1182
                    label="Arc Spectrum",
1183
                )
1184
                plt.ylabel("Normalised Count")
4✔
1185
                plt.vlines(
4✔
1186
                    calibrator.peaks,
1187
                    0,
1188
                    1.05,
1189
                    label="Detected Peaks",
1190
                    color="C1",
1191
                )
1192
            plt.title("Number of pixels: " + str(calibrator.spectrum.shape[0]))
4✔
1193
            plt.xlim(0, calibrator.spectrum.shape[0])
4✔
1194
            plt.legend()
4✔
1195

1196
        else:
1197
            plt.xlim(0, max(calibrator.peaks))
×
1198

1199
        plt.xlabel("Pixel (Spectral Direction)")
4✔
1200
        plt.grid()
4✔
1201
        plt.tight_layout()
4✔
1202

1203
        if save_fig:
4✔
1204
            fig_type = fig_type.split("+")
4✔
1205

1206
            if filename is None:
4✔
1207
                filename_output = "rascal_arc"
4✔
1208

1209
            else:
1210
                filename_output = filename
4✔
1211

1212
            for t in fig_type:
4✔
1213
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
1214
                    plt.savefig(filename_output + "." + t, format=t)
4✔
1215

1216
        if display:
4✔
1217
            plt.show()
×
1218

1219
        return fig
4✔
1220

1221
    if calibrator.plot_with_plotly:
4✔
1222
        _import_plotly()
4✔
1223

1224
        fig = go.Figure()
4✔
1225

1226
        if log_spectrum:
4✔
1227
            # Plot all-pairs
1228
            fig.add_trace(
4✔
1229
                go.Scatter(
1230
                    x=list(pixel_list),
1231
                    y=list(np.log10(calibrator.spectrum / calibrator.spectrum.max())),
1232
                    mode="lines",
1233
                    name="Arc",
1234
                )
1235
            )
1236
            xmin = min(np.log10(calibrator.spectrum / calibrator.spectrum.max()))
4✔
1237
            xmax = max(np.log10(calibrator.spectrum / calibrator.spectrum.max()))
4✔
1238

1239
        else:
1240
            # Plot all-pairs
1241
            fig.add_trace(
4✔
1242
                go.Scatter(
1243
                    x=list(pixel_list),
1244
                    y=list(calibrator.spectrum / calibrator.spectrum.max()),
1245
                    mode="lines",
1246
                    name="Arc",
1247
                )
1248
            )
1249
            xmin = min(calibrator.spectrum / calibrator.spectrum.max())
4✔
1250
            xmax = max(calibrator.spectrum / calibrator.spectrum.max())
4✔
1251

1252
        # Add vlines
1253
        for i in calibrator.peaks:
4✔
1254
            fig.add_shape(
4✔
1255
                type="line",
1256
                xref="x",
1257
                yref="y",
1258
                x0=i,
1259
                y0=0,
1260
                x1=i,
1261
                y1=1.05,
1262
                line=dict(color=pio.templates["CN"].layout.colorway[1], width=1),
1263
            )
1264

1265
        fig.update_layout(
4✔
1266
            autosize=True,
1267
            yaxis=dict(title="Normalised Count", range=[xmin, xmax], showgrid=True),
1268
            xaxis=dict(
1269
                title="Pixel",
1270
                zeroline=False,
1271
                range=[0.0, len(calibrator.spectrum)],
1272
                showgrid=True,
1273
            ),
1274
            hovermode="closest",
1275
            showlegend=True,
1276
            height=800,
1277
            width=1000,
1278
        )
1279

1280
        fig.update_xaxes(showline=True, linewidth=1, linecolor="black", mirror=True)
4✔
1281

1282
        fig.update_yaxes(showline=True, linewidth=1, linecolor="black", mirror=True)
4✔
1283

1284
        if save_fig:
4✔
1285
            fig_type = fig_type.split("+")
4✔
1286

1287
            if filename is None:
4✔
1288
                filename_output = "rascal_arc"
4✔
1289

1290
            else:
1291
                filename_output = filename
4✔
1292

1293
            for t in fig_type:
4✔
1294
                if t == "iframe":
4✔
1295
                    pio.write_html(fig, filename_output + "." + t)
×
1296

1297
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1298
                    pio.write_image(fig, filename_output + "." + t)
4✔
1299

1300
        if display:
4✔
1301
            if renderer == "default":
×
1302
                fig.show()
×
1303

1304
            else:
1305
                fig.show(renderer)
×
1306

1307
        if return_jsonstring:
4✔
1308
            return fig.to_json()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc