• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

jveitchmichaelis / rascal / 4515862454

pending completion
4515862454

push

github

cylammarco
added remarklint to pre-commit. linted everything.

1884 of 2056 relevant lines covered (91.63%)

3.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.08
/src/rascal/plotting.py
1
import logging
4✔
2
import numpy as np
4✔
3
from scipy import signal
4✔
4

5
from .util import load_calibration_lines
4✔
6
from .util import gauss
4✔
7
from .util import filter_intensity
4✔
8
from .util import filter_separation
4✔
9

10
logger = logging.getLogger("plotting")
4✔
11

12

13
def _import_matplotlib():
4✔
14
    """
15
    Call to import matplotlib.
16

17
    """
18

19
    try:
4✔
20
        global plt
21
        import matplotlib.pyplot as plt
4✔
22

23
    except ImportError:
×
24
        logger.error("matplotlib package not available.")
×
25

26

27
def _import_plotly():
4✔
28
    """
29
    Call to import plotly.
30

31
    """
32

33
    try:
4✔
34
        global go
35
        global pio
36
        global psp
37
        import plotly.graph_objects as go
4✔
38
        import plotly.io as pio
4✔
39
        import plotly.subplots as psp
4✔
40

41
        pio.templates["CN"] = go.layout.Template(
4✔
42
            layout_colorway=[
43
                "#1f77b4",
44
                "#ff7f0e",
45
                "#2ca02c",
46
                "#d62728",
47
                "#9467bd",
48
                "#8c564b",
49
                "#e377c2",
50
                "#7f7f7f",
51
                "#bcbd22",
52
                "#17becf",
53
            ]
54
        )
55

56
        # setting Google color palette as default
57
        pio.templates.default = "CN"
4✔
58

59
    except ImportError:
×
60
        logger.error("plotly package not available.")
×
61

62

63
def plot_calibration_lines(
4✔
64
    elements=[],
65
    min_atlas_wavelength=3000.0,
66
    max_atlas_wavelength=15000.0,
67
    min_intensity=5.0,
68
    min_distance=0.0,
69
    brightest_n_lines=None,
70
    pixel_scale=1.0,
71
    vacuum=False,
72
    pressure=101325.0,
73
    temperature=273.15,
74
    relative_humidity=0.0,
75
    label=False,
76
    log=False,
77
    save_fig=False,
78
    fig_type="png",
79
    filename=None,
80
    display=True,
81
    linelist="nist",
82
    fig_kwarg={"figsize": (12, 8)},
83
):
84
    """
85
    Plot the expected arc spectrum. Only available in matplotlib at the moment.
86

87
    Parameters
88
    ----------
89
    elements: list
90
        List of short element names, e.g. He as per NIST
91
    linelist: str
92
        Either 'nist' to use the default lines or path to a linelist file.
93
    min_atlas_wavelength: int
94
        Minimum wavelength to search, Angstrom
95
    max_atlas_wavelength: int
96
        Maximum wavelength to search, Angstrom
97
    min_intensity: int
98
        Minimum intensity to search, per NIST
99
    min_distance: int
100
        All ines within this distance from other lines are treated
101
        as unresolved, all of them get removed from the list.
102
    brightest_n_lines: int
103
        Only return the n brightest lines
104
    vacuum: bool
105
        Return vacuum wavelengths
106
    pressure: float
107
        Atmospheric pressure, Pascal
108
    temperature: float
109
        Temperature in Kelvin, default room temp
110
    relative_humidity: float
111
        Relative humidity, percent
112
    log: bool
113
        Plot intensities in log scale
114
    save_fig: boolean (default: False)
115
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
116
        while the plotly uses the pio.write_html() or pio.write_image().
117
        The support format types should be provided in fig_type.
118
    fig_type: string (default: 'png')
119
        Image type to be saved, choose from:
120
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
121
    filename: string (default: None)
122
        Provide a filename or full path. If the extension is not provided
123
        it is defaulted to png.
124
    display: boolean (Default: False)
125
        Set to True to display disgnostic plot.
126
    Returns
127
    -------
128
    fig: matplotlib figure object
129
    """
130

131
    _import_matplotlib()
4✔
132

133
    # the min_intensity and min_distance are set to 0.0 because the
134
    # simulated spectrum would contain them. These arguments only
135
    # affect the labelling.
136
    element_list, wavelength_list, intensity_list = load_calibration_lines(
4✔
137
        elements=elements,
138
        linelist=linelist,
139
        min_atlas_wavelength=min_atlas_wavelength,
140
        max_atlas_wavelength=max_atlas_wavelength,
141
        min_intensity=0.0,
142
        min_distance=0.0,
143
        brightest_n_lines=brightest_n_lines,
144
        vacuum=vacuum,
145
        pressure=pressure,
146
        temperature=temperature,
147
        relative_humidity=relative_humidity,
148
    )
149

150
    # Nyquist sampling rate (2.5) for CCD at seeing of 1 arcsec
151
    sigma = pixel_scale * 2.5 * 1.0
4✔
152
    x = np.arange(-100, 100.001, 0.001)
4✔
153
    gaussian = gauss(x, a=1.0, x0=0.0, sigma=sigma)
4✔
154

155
    # Generate the equally spaced-wavelength array, and the
156
    # corresponding intensity
157
    w = np.around(
4✔
158
        np.arange(min_atlas_wavelength, max_atlas_wavelength + 0.001, 0.001),
159
        decimals=3,
160
    ).astype("float64")
161
    i = np.zeros_like(w)
4✔
162

163
    for e in elements:
4✔
164
        i[
4✔
165
            np.isin(
166
                w, np.around(wavelength_list[element_list == e], decimals=3)
167
            )
168
        ] += intensity_list[element_list == e]
169
    # Convolve to simulate the arc spectrum
170
    model_spectrum = signal.convolve(i, gaussian, mode="same")
4✔
171

172
    # now clean up by min_intensity and min_distance
173
    intensity_mask = filter_intensity(
4✔
174
        elements,
175
        np.column_stack((element_list, wavelength_list, intensity_list)),
176
        min_intensity=min_intensity,
177
    )
178
    wavelength_list = wavelength_list[intensity_mask]
4✔
179
    intensity_list = intensity_list[intensity_mask]
4✔
180
    element_list = element_list[intensity_mask]
4✔
181

182
    distance_mask = filter_separation(
4✔
183
        wavelength_list, min_separation=min_distance
184
    )
185
    wavelength_list = wavelength_list[distance_mask]
4✔
186
    intensity_list = intensity_list[distance_mask]
4✔
187
    element_list = element_list[distance_mask]
4✔
188

189
    fig = plt.figure(**fig_kwarg)
4✔
190

191
    for j, e in enumerate(elements):
4✔
192
        e_mask = element_list == e
4✔
193
        markerline, stemline, _ = plt.stem(
4✔
194
            wavelength_list[e_mask],
195
            intensity_list[e_mask],
196
            label=e,
197
            linefmt="C{}-".format(j),
198
        )
199
        plt.setp(stemline, linewidth=2.0)
4✔
200
        plt.setp(markerline, markersize=2.5, color="C{}".format(j))
4✔
201

202
        if label:
4✔
203
            for _w in wavelength_list[e_mask]:
4✔
204
                plt.text(
4✔
205
                    _w,
206
                    max(model_spectrum) * 1.05,
207
                    s="{}: {:1.2f}".format(e, _w),
208
                    rotation=90,
209
                    bbox=dict(facecolor="white", alpha=1),
210
                )
211

212
            plt.vlines(
4✔
213
                wavelength_list[e_mask],
214
                intensity_list[e_mask],
215
                max(model_spectrum) * 1.25,
216
                linestyles="dashed",
217
                lw=0.5,
218
                color="grey",
219
            )
220

221
    plt.plot(w, model_spectrum, lw=1.0, c="k", label="Simulated Arc Spectrum")
4✔
222
    if vacuum:
4✔
223
        plt.xlabel("Vacuum Wavelength / A")
×
224
    else:
225
        plt.xlabel("Air Wavelength / A")
4✔
226
    plt.ylabel("NIST intensity")
4✔
227
    plt.grid()
4✔
228
    plt.xlim(min(w), max(w))
4✔
229
    plt.ylim(0, max(model_spectrum) * 1.25)
4✔
230
    plt.legend()
4✔
231
    plt.tight_layout()
4✔
232
    if log:
4✔
233
        plt.ylim(ymin=min_intensity * 0.75)
×
234
        plt.yscale("log")
×
235

236
    if save_fig:
4✔
237
        fig_type = fig_type.split("+")
4✔
238

239
        if filename is None:
4✔
240
            filename_output = "rascal_arc"
×
241

242
        else:
243
            filename_output = filename
4✔
244

245
        for t in fig_type:
4✔
246
            if t in ["jpg", "png", "svg", "pdf"]:
4✔
247
                plt.savefig(filename_output + "." + t, format=t)
4✔
248

249
    if display:
4✔
250
        plt.show()
4✔
251

252
    return fig
4✔
253

254

255
def plot_search_space(
4✔
256
    calibrator,
257
    fit_coeff=None,
258
    top_n_candidate=3,
259
    weighted=True,
260
    save_fig=False,
261
    fig_type="png",
262
    filename=None,
263
    return_jsonstring=False,
264
    renderer="default",
265
    display=True,
266
):
267
    """
268
    Plots the peak/arc line pairs that are considered as potential match
269
    candidates.
270

271
    If fit fit_coefficients are provided, the model solution will be
272
    overplotted.
273

274
    Parameters
275
    ----------
276
    fit_coeff: list (default: None)
277
        List of best polynomial fit_coefficients
278
    top_n_candidate: int (default: 3)
279
        Top ranked lines to be fitted.
280
    weighted: (default: True)
281
        Draw sample based on the distance from the matched known wavelength
282
        of the atlas.
283
    save_fig: boolean (default: False)
284
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
285
        while the plotly uses the pio.write_html() or pio.write_image().
286
        The support format types should be provided in fig_type.
287
    fig_type: string (default: 'png')
288
        Image type to be saved, choose from:
289
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
290
    filename: (default: None)
291
        The destination to save the image.
292
    return_jsonstring: (default: False)
293
        Set to True to save the plotly figure as json string. Ignored if
294
        matplotlib is used.
295
    renderer: (default: 'default')
296
        Set the rendered for the plotly display. Ignored if matplotlib is
297
        used.
298
    display: boolean (Default: False)
299
        Set to True to display disgnostic plot.
300

301
    Return
302
    ------
303
    json object if return_jsonstring is True.
304

305

306
    """
307

308
    # Get top linear estimates and combine
309
    candidate_peak, candidate_arc = calibrator._get_most_common_candidates(
4✔
310
        calibrator.candidates,
311
        top_n_candidate=top_n_candidate,
312
        weighted=weighted,
313
    )
314

315
    # Get the search space boundaries
316
    x = calibrator.pixel_list
4✔
317

318
    m_1 = (
4✔
319
        calibrator.max_wavelength - calibrator.min_wavelength
320
    ) / calibrator.pixel_list.max()
321
    y_1 = m_1 * x + calibrator.min_wavelength
4✔
322

323
    m_2 = (
4✔
324
        calibrator.max_wavelength
325
        + calibrator.range_tolerance
326
        - (calibrator.min_wavelength + calibrator.range_tolerance)
327
    ) / calibrator.pixel_list.max()
328
    y_2 = m_2 * x + calibrator.min_wavelength + calibrator.range_tolerance
4✔
329

330
    m_3 = (
4✔
331
        calibrator.max_wavelength
332
        - calibrator.range_tolerance
333
        - (calibrator.min_wavelength - calibrator.range_tolerance)
334
    ) / calibrator.pixel_list.max()
335
    y_3 = m_3 * x + (calibrator.min_wavelength - calibrator.range_tolerance)
4✔
336

337
    if calibrator.plot_with_matplotlib:
4✔
338
        _import_matplotlib()
4✔
339

340
        fig = plt.figure(figsize=(10, 10))
4✔
341

342
        # Plot all-pairs
343
        plt.scatter(
4✔
344
            *calibrator.pairs.T, alpha=0.2, color="C0", label="All pairs"
345
        )
346

347
        plt.scatter(
4✔
348
            calibrator._merge_candidates(calibrator.candidates)[:, 0],
349
            calibrator._merge_candidates(calibrator.candidates)[:, 1],
350
            alpha=0.2,
351
            color="C1",
352
            label="Candidate Pairs",
353
        )
354

355
        # Tolerance region around the minimum wavelength
356
        plt.text(
4✔
357
            5,
358
            calibrator.min_wavelength + 100,
359
            "Min wavelength (user-supplied)",
360
        )
361
        plt.hlines(
4✔
362
            calibrator.min_wavelength,
363
            0,
364
            calibrator.pixel_list.max(),
365
            color="k",
366
        )
367
        plt.hlines(
4✔
368
            calibrator.min_wavelength + calibrator.range_tolerance,
369
            0,
370
            calibrator.pixel_list.max(),
371
            linestyle="dashed",
372
            alpha=0.5,
373
            color="k",
374
        )
375
        plt.hlines(
4✔
376
            calibrator.min_wavelength - calibrator.range_tolerance,
377
            0,
378
            calibrator.pixel_list.max(),
379
            linestyle="dashed",
380
            alpha=0.5,
381
            color="k",
382
        )
383

384
        # Tolerance region around the maximum wavelength
385
        plt.text(
4✔
386
            5,
387
            calibrator.max_wavelength + 100,
388
            "Max wavelength (user-supplied)",
389
        )
390
        plt.hlines(
4✔
391
            calibrator.max_wavelength,
392
            0,
393
            calibrator.pixel_list.max(),
394
            color="k",
395
        )
396
        plt.hlines(
4✔
397
            calibrator.max_wavelength + calibrator.range_tolerance,
398
            0,
399
            calibrator.pixel_list.max(),
400
            linestyle="dashed",
401
            alpha=0.5,
402
            color="k",
403
        )
404
        plt.hlines(
4✔
405
            calibrator.max_wavelength - calibrator.range_tolerance,
406
            0,
407
            calibrator.pixel_list.max(),
408
            linestyle="dashed",
409
            alpha=0.5,
410
            color="k",
411
        )
412

413
        # The line from (first pixel, minimum wavelength) to
414
        # (last pixel, maximum wavelength), and the two lines defining the
415
        # tolerance region.
416
        plt.plot(x, y_1, label="Linear Fit", color="C3")
4✔
417
        plt.plot(
4✔
418
            x, y_2, linestyle="dashed", label="Tolerance Region", color="C3"
419
        )
420
        plt.plot(x, y_3, linestyle="dashed", color="C3")
4✔
421

422
        if fit_coeff is not None:
4✔
423
            plt.scatter(
×
424
                calibrator.peaks,
425
                calibrator.polyval(calibrator.peaks, fit_coeff),
426
                color="C4",
427
                label="Solution",
428
            )
429

430
        plt.scatter(
4✔
431
            candidate_peak,
432
            candidate_arc,
433
            color="C2",
434
            label="Best Candidate Pairs",
435
        )
436

437
        plt.xlim(0, calibrator.pixel_list.max())
4✔
438
        plt.ylim(
4✔
439
            calibrator.min_wavelength - calibrator.range_tolerance,
440
            calibrator.max_wavelength + calibrator.range_tolerance,
441
        )
442

443
        plt.xlabel("Wavelength / A")
4✔
444
        plt.ylabel("Pixel")
4✔
445
        plt.legend()
4✔
446
        plt.grid()
4✔
447
        plt.tight_layout()
4✔
448

449
        if save_fig:
4✔
450
            fig_type = fig_type.split("+")
4✔
451

452
            if filename is None:
4✔
453
                filename_output = "rascal_hough_search_space"
×
454

455
            else:
456
                filename_output = filename
4✔
457

458
            for t in fig_type:
4✔
459
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
460
                    plt.savefig(filename_output + "." + t, format=t)
4✔
461

462
        if display:
4✔
463
            plt.show()
×
464

465
        return fig
4✔
466

467
    elif calibrator.plot_with_plotly:
4✔
468
        _import_plotly()
4✔
469

470
        fig = go.Figure()
4✔
471

472
        # Plot all-pairs
473
        fig.add_trace(
4✔
474
            go.Scatter(
475
                x=calibrator.pairs[:, 0],
476
                y=calibrator.pairs[:, 1],
477
                mode="markers",
478
                name="All Pairs",
479
                marker=dict(
480
                    color=pio.templates["CN"].layout.colorway[0], opacity=0.2
481
                ),
482
            )
483
        )
484

485
        fig.add_trace(
4✔
486
            go.Scatter(
487
                x=calibrator._merge_candidates(calibrator.candidates)[:, 0],
488
                y=calibrator._merge_candidates(calibrator.candidates)[:, 1],
489
                mode="markers",
490
                name="Candidate Pairs",
491
                marker=dict(
492
                    color=pio.templates["CN"].layout.colorway[1], opacity=0.2
493
                ),
494
            )
495
        )
496
        fig.add_trace(
4✔
497
            go.Scatter(
498
                x=candidate_peak,
499
                y=candidate_arc,
500
                mode="markers",
501
                name="Best Candidate Pairs",
502
                marker=dict(color=pio.templates["CN"].layout.colorway[2]),
503
            )
504
        )
505

506
        # Tolerance region around the minimum wavelength
507
        fig.add_trace(
4✔
508
            go.Scatter(
509
                x=[0, calibrator.pixel_list.max()],
510
                y=[calibrator.min_wavelength, calibrator.min_wavelength],
511
                name="Min/Maximum",
512
                mode="lines",
513
                line=dict(color="black"),
514
            )
515
        )
516
        fig.add_trace(
4✔
517
            go.Scatter(
518
                x=[0, calibrator.pixel_list.max()],
519
                y=[
520
                    calibrator.min_wavelength + calibrator.range_tolerance,
521
                    calibrator.min_wavelength + calibrator.range_tolerance,
522
                ],
523
                name="Tolerance Range",
524
                mode="lines",
525
                line=dict(color="black", dash="dash"),
526
            )
527
        )
528
        fig.add_trace(
4✔
529
            go.Scatter(
530
                x=[0, calibrator.pixel_list.max()],
531
                y=[
532
                    calibrator.min_wavelength - calibrator.range_tolerance,
533
                    calibrator.min_wavelength - calibrator.range_tolerance,
534
                ],
535
                showlegend=False,
536
                mode="lines",
537
                line=dict(color="black", dash="dash"),
538
            )
539
        )
540

541
        # Tolerance region around the minimum wavelength
542
        fig.add_trace(
4✔
543
            go.Scatter(
544
                x=[0, calibrator.pixel_list.max()],
545
                y=[calibrator.max_wavelength, calibrator.max_wavelength],
546
                showlegend=False,
547
                mode="lines",
548
                line=dict(color="black"),
549
            )
550
        )
551
        fig.add_trace(
4✔
552
            go.Scatter(
553
                x=[0, calibrator.pixel_list.max()],
554
                y=[
555
                    calibrator.max_wavelength + calibrator.range_tolerance,
556
                    calibrator.max_wavelength + calibrator.range_tolerance,
557
                ],
558
                showlegend=False,
559
                mode="lines",
560
                line=dict(color="black", dash="dash"),
561
            )
562
        )
563
        fig.add_trace(
4✔
564
            go.Scatter(
565
                x=[0, calibrator.pixel_list.max()],
566
                y=[
567
                    calibrator.max_wavelength - calibrator.range_tolerance,
568
                    calibrator.max_wavelength - calibrator.range_tolerance,
569
                ],
570
                showlegend=False,
571
                mode="lines",
572
                line=dict(color="black", dash="dash"),
573
            )
574
        )
575

576
        # The line from (first pixel, minimum wavelength) to
577
        # (last pixel, maximum wavelength), and the two lines defining the
578
        # tolerance region.
579
        fig.add_trace(
4✔
580
            go.Scatter(
581
                x=x,
582
                y=y_1,
583
                mode="lines",
584
                name="Linear Fit",
585
                line=dict(color=pio.templates["CN"].layout.colorway[3]),
586
            )
587
        )
588
        fig.add_trace(
4✔
589
            go.Scatter(
590
                x=x,
591
                y=y_2,
592
                mode="lines",
593
                name="Tolerance Region",
594
                line=dict(
595
                    color=pio.templates["CN"].layout.colorway[3],
596
                    dash="dashdot",
597
                ),
598
            )
599
        )
600
        fig.add_trace(
4✔
601
            go.Scatter(
602
                x=x,
603
                y=y_3,
604
                showlegend=False,
605
                mode="lines",
606
                line=dict(
607
                    color=pio.templates["CN"].layout.colorway[3],
608
                    dash="dashdot",
609
                ),
610
            )
611
        )
612

613
        if fit_coeff is not None:
4✔
614
            fig.add_trace(
×
615
                go.Scatter(
616
                    x=calibrator.peaks,
617
                    y=calibrator.polyval(calibrator.peaks, fit_coeff),
618
                    mode="markers",
619
                    name="Solution",
620
                    marker=dict(color=pio.templates["CN"].layout.colorway[4]),
621
                )
622
            )
623

624
        # Layout, Title, Grid config
625
        fig.update_layout(
4✔
626
            autosize=True,
627
            yaxis=dict(
628
                title="Pixel",
629
                range=[
630
                    calibrator.min_wavelength
631
                    - calibrator.range_tolerance * 1.1,
632
                    calibrator.max_wavelength
633
                    + calibrator.range_tolerance * 1.1,
634
                ],
635
                showgrid=True,
636
            ),
637
            xaxis=dict(
638
                title="Wavelength / A",
639
                zeroline=False,
640
                range=[0.0, calibrator.pixel_list.max()],
641
                showgrid=True,
642
            ),
643
            hovermode="closest",
644
            showlegend=True,
645
            height=800,
646
            width=1000,
647
        )
648

649
        if save_fig:
4✔
650
            fig_type = fig_type.split("+")
4✔
651

652
            if filename is None:
4✔
653
                filename_output = "rascal_hough_search_space"
×
654

655
            else:
656
                filename_output = filename
4✔
657

658
            for t in fig_type:
4✔
659
                if t == "iframe":
4✔
660
                    pio.write_html(fig, filename_output + "." + t)
×
661

662
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
663
                    pio.write_image(fig, filename_output + "." + t)
4✔
664

665
        if display:
4✔
666
            if renderer == "default":
×
667
                fig.show()
×
668

669
            else:
670
                fig.show(renderer)
×
671

672
        if return_jsonstring:
4✔
673
            return fig.to_json()
×
674

675

676
def plot_fit(
4✔
677
    calibrator,
678
    fit_coeff,
679
    spectrum=None,
680
    tolerance=5.0,
681
    plot_atlas=True,
682
    log_spectrum=False,
683
    save_fig=False,
684
    fig_type="png",
685
    filename=None,
686
    return_jsonstring=False,
687
    renderer="default",
688
    display=True,
689
):
690
    """
691
    Plots of the wavelength calibrated arc spectrum, the residual and the
692
    pixel-to-wavelength solution.
693

694
    Parameters
695
    ----------
696
    fit_coeff: 1D numpy array or list
697
        Best fit polynomail fit_coefficients
698
    spectrum: 1D numpy array (N)
699
        Array of length N pixels
700
    tolerance: float (default: 5)
701
        Absolute difference between model and fitted wavelengths in unit
702
        of angstrom.
703
    plot_atlas: boolean (default: True)
704
        Display all the relavent lines available in the atlas library.
705
    log_spectrum: boolean (default: False)
706
        Display the arc in log-space if set to True.
707
    save_fig: boolean (default: False)
708
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
709
        while the plotly uses the pio.write_html() or pio.write_image().
710
        The support format types should be provided in fig_type.
711
    fig_type: string (default: 'png')
712
        Image type to be saved, choose from:
713
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
714
    filename: string (default: None)
715
        Provide a filename or full path. If the extension is not provided
716
        it is defaulted to png.
717
    return_jsonstring: boolean (default: False)
718
        Set to True to return json strings if using plotly as the plotting
719
        library.
720
    renderer: string (default: 'default')
721
        Indicate the Plotly renderer. Nothing gets displayed if
722
        return_jsonstring is set to True.
723
    display: boolean (Default: False)
724
        Set to True to display disgnostic plot.
725

726
    Returns
727
    -------
728
    Return json strings if using plotly as the plotting library and json
729
    is True.
730

731
    """
732

733
    if spectrum is None:
4✔
734
        try:
4✔
735
            spectrum = calibrator.spectrum
4✔
736

737
        except Exception as e:
×
738
            calibrator.logger.error(e)
×
739
            calibrator.logger.error(
×
740
                "Spectrum is not provided, it cannot be " "plotted."
741
            )
742

743
    if spectrum is not None:
4✔
744
        if log_spectrum:
4✔
745
            spectrum[spectrum < 0] = 1e-100
4✔
746
            spectrum = np.log10(spectrum)
4✔
747
            vline_max = np.nanmax(spectrum) * 2.0
4✔
748
            text_box_pos = 1.2 * max(spectrum)
4✔
749

750
        else:
751
            vline_max = np.nanmax(spectrum) * 1.2
4✔
752
            text_box_pos = 0.8 * max(spectrum)
4✔
753

754
    else:
755
        vline_max = 1.0
4✔
756
        text_box_pos = 0.5
4✔
757

758
    wave = calibrator.polyval(calibrator.pixel_list, fit_coeff)
4✔
759

760
    if calibrator.plot_with_matplotlib:
4✔
761
        _import_matplotlib()
4✔
762

763
        fig, (ax1, ax2, ax3) = plt.subplots(
4✔
764
            nrows=3, sharex=True, gridspec_kw={"hspace": 0.0}, figsize=(15, 9)
765
        )
766
        fig.tight_layout()
4✔
767

768
        # Plot fitted spectrum
769
        if spectrum is not None:
4✔
770
            ax1.plot(wave, spectrum, label="Arc Spectrum")
4✔
771
            ax1.vlines(
4✔
772
                calibrator.polyval(calibrator.peaks, fit_coeff),
773
                np.array(spectrum)[
774
                    calibrator.pix_to_rawpix(calibrator.peaks).astype("int")
775
                ],
776
                vline_max,
777
                linestyles="dashed",
778
                colors="C1",
779
                label="Detected Peaks",
780
            )
781

782
        # Plot the atlas
783
        if plot_atlas:
4✔
784
            # spec = SyntheticSpectrum(
785
            #    fit, model_type='poly', degree=len(fit)-1)
786
            # x_locs = spec.get_pixels(calibrator.atlas)
787
            ax1.vlines(
4✔
788
                calibrator.atlas.get_lines(),
789
                0,
790
                vline_max,
791
                colors="C2",
792
                label="Given Lines",
793
            )
794

795
        fitted_peaks = []
4✔
796
        fitted_diff = []
4✔
797
        all_diff = []
4✔
798

799
        first_one = True
4✔
800
        for p, x in zip(calibrator.matched_peaks, calibrator.matched_atlas):
4✔
801
            diff = calibrator.atlas.get_lines() - x
4✔
802
            idx = np.argmin(np.abs(diff))
4✔
803
            all_diff.append(diff[idx])
4✔
804

805
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
806

807
            fitted_peaks.append(p)
4✔
808
            fitted_diff.append(diff[idx])
4✔
809
            calibrator.logger.info(
4✔
810
                "- matched to {} A".format(calibrator.atlas.get_lines()[idx])
811
            )
812

813
            if spectrum is not None:
4✔
814
                if first_one:
4✔
815
                    ax1.vlines(
4✔
816
                        calibrator.polyval(p, fit_coeff),
817
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
818
                        vline_max,
819
                        colors="C1",
820
                        label="Fitted Peaks",
821
                    )
822
                    first_one = False
4✔
823

824
                else:
825
                    ax1.vlines(
4✔
826
                        calibrator.polyval(p, fit_coeff),
827
                        spectrum[calibrator.pix_to_rawpix(p).astype("int")],
828
                        vline_max,
829
                        colors="C1",
830
                    )
831

832
            ax1.text(
4✔
833
                x - 3,
834
                text_box_pos,
835
                s="{}:{:1.2f}".format(
836
                    calibrator.atlas.get_elements()[idx],
837
                    calibrator.atlas.get_lines()[idx],
838
                ),
839
                rotation=90,
840
                bbox=dict(facecolor="white", alpha=1),
841
            )
842

843
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
844

845
        ax1.grid(linestyle=":")
4✔
846
        ax1.set_ylabel("Electron Count / e-")
4✔
847

848
        if spectrum is not None:
4✔
849
            if log_spectrum:
4✔
850
                ax1.set_ylim(0, vline_max)
4✔
851

852
            else:
853
                ax1.set_ylim(np.nanmin(spectrum), vline_max)
4✔
854

855
        ax1.legend(loc="center right")
4✔
856

857
        # Plot the residuals
858
        ax2.scatter(
4✔
859
            calibrator.polyval(fitted_peaks, fit_coeff),
860
            fitted_diff,
861
            marker="+",
862
            color="C1",
863
        )
864
        ax2.hlines(0, wave.min(), wave.max(), linestyles="dashed")
4✔
865
        ax2.hlines(
4✔
866
            rms,
867
            wave.min(),
868
            wave.max(),
869
            linestyles="dashed",
870
            color="k",
871
            label="RMS",
872
        )
873
        ax2.hlines(
4✔
874
            -rms, wave.min(), wave.max(), linestyles="dashed", color="k"
875
        )
876
        ax2.grid(linestyle=":")
4✔
877
        ax2.set_ylabel("Residual / A")
4✔
878
        ax2.legend()
4✔
879
        """
2✔
880
        ax2.text(
881
            min(wave) + np.ptp(wave) * 0.05,
882
            max(spectrum),
883
            'RMS =' + str(rms)[:6]
884
            )
885
        """
886

887
        # Plot the polynomial
888
        ax3.scatter(
4✔
889
            calibrator.polyval(fitted_peaks, fit_coeff),
890
            fitted_peaks,
891
            marker="+",
892
            color="C1",
893
            label="Fitted Peaks",
894
        )
895
        ax3.plot(wave, calibrator.pixel_list, color="C2", label="Solution")
4✔
896
        ax3.grid(linestyle=":")
4✔
897
        ax3.set_xlabel("Wavelength / A")
4✔
898
        ax3.set_ylabel("Pixel")
4✔
899
        ax3.legend(loc="lower right")
4✔
900
        w_min = calibrator.polyval(min(fitted_peaks), fit_coeff)
4✔
901
        w_max = calibrator.polyval(max(fitted_peaks), fit_coeff)
4✔
902
        ax3.set_xlim(w_min * 0.95, w_max * 1.05)
4✔
903

904
        plt.tight_layout()
4✔
905

906
        if save_fig:
4✔
907
            fig_type = fig_type.split("+")
4✔
908

909
            if filename is None:
4✔
910
                filename_output = "rascal_solution"
×
911

912
            else:
913
                filename_output = filename
4✔
914

915
            for t in fig_type:
4✔
916
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
917
                    plt.savefig(filename_output + "." + t, format=t)
4✔
918

919
        if display:
4✔
920
            plt.show()
4✔
921

922
        return fig
4✔
923

924
    elif calibrator.plot_with_plotly:
4✔
925
        _import_plotly()
4✔
926

927
        fig = go.Figure()
4✔
928

929
        # Top plot - arc spectrum and matched peaks
930
        if spectrum is not None:
4✔
931
            fig.add_trace(
4✔
932
                go.Scatter(
933
                    x=wave,
934
                    y=spectrum,
935
                    mode="lines",
936
                    yaxis="y3",
937
                    name="Arc Spectrum",
938
                )
939
            )
940

941
            spec_max = np.nanmax(spectrum) * 1.05
4✔
942

943
        else:
944
            spec_max = vline_max
×
945

946
        fitted_peaks = []
4✔
947
        fitted_peaks_adu = []
4✔
948
        fitted_diff = []
4✔
949
        all_diff = []
4✔
950

951
        for p in calibrator.peaks:
4✔
952
            x = calibrator.polyval(p, fit_coeff)
4✔
953

954
            # Add vlines
955
            fig.add_shape(
4✔
956
                type="line",
957
                xref="x",
958
                yref="y3",
959
                x0=x,
960
                y0=0,
961
                x1=x,
962
                y1=spec_max,
963
                line=dict(
964
                    color=pio.templates["CN"].layout.colorway[1], width=1
965
                ),
966
            )
967

968
            diff = calibrator.atlas.get_lines() - x
4✔
969
            idx = np.argmin(np.abs(diff))
4✔
970
            all_diff.append(diff[idx])
4✔
971

972
            calibrator.logger.info("Peak at: {} A".format(x))
4✔
973

974
            if np.abs(diff[idx]) < tolerance:
4✔
975
                fitted_peaks.append(p)
4✔
976
                if spectrum is not None:
4✔
977
                    fitted_peaks_adu.append(
4✔
978
                        spectrum[int(calibrator.pix_to_rawpix(p))]
979
                    )
980
                fitted_diff.append(diff[idx])
4✔
981
                calibrator.logger.info(
4✔
982
                    "- matched to {} A".format(
983
                        calibrator.atlas.get_lines()[idx]
984
                    )
985
                )
986

987
        x_fitted = calibrator.polyval(fitted_peaks, fit_coeff)
4✔
988

989
        fig.add_trace(
4✔
990
            go.Scatter(
991
                x=x_fitted,
992
                y=fitted_peaks_adu,
993
                mode="markers",
994
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
995
                yaxis="y3",
996
                showlegend=False,
997
            )
998
        )
999

1000
        # Middle plot - Residual plot
1001
        rms = np.sqrt(np.mean(np.array(fitted_diff) ** 2.0))
4✔
1002
        fig.add_trace(
4✔
1003
            go.Scatter(
1004
                x=x_fitted,
1005
                y=fitted_diff,
1006
                mode="markers",
1007
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1008
                yaxis="y2",
1009
                showlegend=False,
1010
            )
1011
        )
1012
        fig.add_trace(
4✔
1013
            go.Scatter(
1014
                x=[wave.min(), wave.max()],
1015
                y=[0, 0],
1016
                mode="lines",
1017
                line=dict(
1018
                    color=pio.templates["CN"].layout.colorway[0], dash="dash"
1019
                ),
1020
                yaxis="y2",
1021
                showlegend=False,
1022
            )
1023
        )
1024
        fig.add_trace(
4✔
1025
            go.Scatter(
1026
                x=[wave.min(), wave.max()],
1027
                y=[rms, rms],
1028
                mode="lines",
1029
                line=dict(color="black", dash="dash"),
1030
                yaxis="y2",
1031
                showlegend=False,
1032
            )
1033
        )
1034
        fig.add_trace(
4✔
1035
            go.Scatter(
1036
                x=[wave.min(), wave.max()],
1037
                y=[-rms, -rms],
1038
                mode="lines",
1039
                line=dict(color="black", dash="dash"),
1040
                yaxis="y2",
1041
                name="RMS",
1042
            )
1043
        )
1044

1045
        # Bottom plot - Polynomial fit for Pixel to Wavelength
1046
        fig.add_trace(
4✔
1047
            go.Scatter(
1048
                x=x_fitted,
1049
                y=fitted_peaks,
1050
                mode="markers",
1051
                marker=dict(color=pio.templates["CN"].layout.colorway[1]),
1052
                yaxis="y1",
1053
                name="Fitted Peaks",
1054
            )
1055
        )
1056
        fig.add_trace(
4✔
1057
            go.Scatter(
1058
                x=wave,
1059
                y=calibrator.pixel_list,
1060
                mode="lines",
1061
                line=dict(color=pio.templates["CN"].layout.colorway[2]),
1062
                yaxis="y1",
1063
                name="Solution",
1064
            )
1065
        )
1066

1067
        # Layout, Title, Grid config
1068
        if spectrum is not None:
4✔
1069
            if log_spectrum:
4✔
1070
                fig.update_layout(
×
1071
                    yaxis3=dict(
1072
                        title="Electron Count / e-",
1073
                        range=[
1074
                            np.log10(np.percentile(spectrum, 15)),
1075
                            np.log10(spec_max),
1076
                        ],
1077
                        domain=[0.67, 1.0],
1078
                        showgrid=True,
1079
                        type="log",
1080
                    )
1081
                )
1082

1083
            else:
1084
                fig.update_layout(
4✔
1085
                    yaxis3=dict(
1086
                        title="Electron Count / e-",
1087
                        range=[np.percentile(spectrum, 15), spec_max],
1088
                        domain=[0.67, 1.0],
1089
                        showgrid=True,
1090
                    )
1091
                )
1092

1093
        fig.update_layout(
4✔
1094
            autosize=True,
1095
            yaxis2=dict(
1096
                title="Residual / A",
1097
                range=[min(fitted_diff), max(fitted_diff)],
1098
                domain=[0.33, 0.66],
1099
                showgrid=True,
1100
            ),
1101
            yaxis=dict(
1102
                title="Pixel",
1103
                range=[0.0, max(calibrator.pixel_list)],
1104
                domain=[0.0, 0.32],
1105
                showgrid=True,
1106
            ),
1107
            xaxis=dict(
1108
                title="Wavelength / A",
1109
                zeroline=False,
1110
                range=[
1111
                    calibrator.polyval(min(fitted_peaks), fit_coeff) * 0.95,
1112
                    calibrator.polyval(max(fitted_peaks), fit_coeff) * 1.05,
1113
                ],
1114
                showgrid=True,
1115
            ),
1116
            hovermode="closest",
1117
            showlegend=True,
1118
            height=800,
1119
            width=1000,
1120
        )
1121

1122
        if save_fig:
4✔
1123
            fig_type = fig_type.split("+")
4✔
1124

1125
            if filename is None:
4✔
1126
                filename_output = "rascal_solution"
×
1127

1128
            else:
1129
                filename_output = filename
4✔
1130

1131
            for t in fig_type:
4✔
1132
                if t == "iframe":
4✔
1133
                    pio.write_html(fig, filename_output + "." + t)
×
1134

1135
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1136
                    pio.write_image(fig, filename_output + "." + t)
4✔
1137

1138
        if display:
4✔
1139
            if renderer == "default":
×
1140
                fig.show()
×
1141

1142
            else:
1143
                fig.show(renderer)
×
1144

1145
        if return_jsonstring:
4✔
1146
            return fig.to_json()
×
1147

1148
    else:
1149
        assert calibrator.matplotlib_imported, (
×
1150
            "matplotlib package not available. " + "Plot cannot be generated."
1151
        )
1152
        assert calibrator.plotly_imported, (
×
1153
            "plotly package is not available. " + "Plot cannot be generated."
1154
        )
1155

1156

1157
def plot_arc(
4✔
1158
    calibrator,
1159
    pixel_list=None,
1160
    log_spectrum=False,
1161
    save_fig=False,
1162
    fig_type="png",
1163
    filename=None,
1164
    return_jsonstring=False,
1165
    renderer="default",
1166
    display=True,
1167
):
1168
    """
1169
    Plots the 1D spectrum of the extracted arc.
1170

1171
    parameters
1172
    ----------
1173
    pixel_list: array (default: None)
1174
        pixel value of the of the spectrum, this is only needed if the
1175
        spectrum spans multiple detector arrays.
1176
    log_spectrum: boolean (default: False)
1177
        Set to true to display the wavelength calibrated arc spectrum in
1178
        logarithmic space.
1179
    save_fig: boolean (default: False)
1180
        Save an image if set to True. matplotlib uses the pyplot.save_fig()
1181
        while the plotly uses the pio.write_html() or pio.write_image().
1182
        The support format types should be provided in fig_type.
1183
    fig_type: string (default: 'png')
1184
        Image type to be saved, choose from:
1185
        jpg, png, svg, pdf and iframe. Delimiter is '+'.
1186
    filename: string (default: None)
1187
        Provide a filename or full path. If the extension is not provided
1188
        it is defaulted to png.
1189
    return_jsonstring: boolean (default: False)
1190
        Set to True to return json strings if using plotly as the plotting
1191
        library.
1192
    renderer: string (default: 'default')
1193
        Indicate the Plotly renderer. Nothing gets displayed if
1194
        return_jsonstring is set to True.
1195

1196
    display: boolean (Default: False)
1197
        Set to True to display disgnostic plot.
1198

1199
    Returns
1200
    -------
1201
    Return json strings if using plotly as the plotting library and json
1202
    is True.
1203

1204
    """
1205

1206
    if pixel_list is None:
4✔
1207
        pixel_list = np.arange(len(calibrator.spectrum))
4✔
1208

1209
    if calibrator.plot_with_matplotlib:
4✔
1210
        _import_matplotlib()
4✔
1211

1212
        fig = plt.figure(figsize=(18, 5))
4✔
1213

1214
        if calibrator.spectrum is not None:
4✔
1215
            if log_spectrum:
4✔
1216
                plt.plot(
4✔
1217
                    pixel_list,
1218
                    np.log10(calibrator.spectrum / calibrator.spectrum.max()),
1219
                    label="Arc Spectrum",
1220
                )
1221
                plt.vlines(
4✔
1222
                    calibrator.peaks, -2, 0, label="Detected Peaks", color="C1"
1223
                )
1224
                plt.ylabel("log(Normalised Count)")
4✔
1225
                plt.ylim(-2, 0)
4✔
1226
            else:
1227
                plt.plot(
4✔
1228
                    pixel_list,
1229
                    calibrator.spectrum / calibrator.spectrum.max(),
1230
                    label="Arc Spectrum",
1231
                )
1232
                plt.ylabel("Normalised Count")
4✔
1233
                plt.vlines(
4✔
1234
                    calibrator.peaks,
1235
                    0,
1236
                    1.05,
1237
                    label="Detected Peaks",
1238
                    color="C1",
1239
                )
1240
            plt.title("Number of pixels: " + str(calibrator.spectrum.shape[0]))
4✔
1241
            plt.xlim(0, calibrator.spectrum.shape[0])
4✔
1242
            plt.legend()
4✔
1243

1244
        else:
1245
            plt.xlim(0, max(calibrator.peaks))
×
1246

1247
        plt.xlabel("Pixel (Spectral Direction)")
4✔
1248
        plt.grid()
4✔
1249
        plt.tight_layout()
4✔
1250

1251
        if save_fig:
4✔
1252
            fig_type = fig_type.split("+")
4✔
1253

1254
            if filename is None:
4✔
1255
                filename_output = "rascal_arc"
4✔
1256

1257
            else:
1258
                filename_output = filename
4✔
1259

1260
            for t in fig_type:
4✔
1261
                if t in ["jpg", "png", "svg", "pdf"]:
4✔
1262
                    plt.savefig(filename_output + "." + t, format=t)
4✔
1263

1264
        if display:
4✔
1265
            plt.show()
×
1266

1267
        return fig
4✔
1268

1269
    if calibrator.plot_with_plotly:
4✔
1270
        _import_plotly()
4✔
1271

1272
        fig = go.Figure()
4✔
1273

1274
        if log_spectrum:
4✔
1275
            # Plot all-pairs
1276
            fig.add_trace(
4✔
1277
                go.Scatter(
1278
                    x=list(pixel_list),
1279
                    y=list(
1280
                        np.log10(
1281
                            calibrator.spectrum / calibrator.spectrum.max()
1282
                        )
1283
                    ),
1284
                    mode="lines",
1285
                    name="Arc",
1286
                )
1287
            )
1288
            xmin = min(
4✔
1289
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1290
            )
1291
            xmax = max(
4✔
1292
                np.log10(calibrator.spectrum / calibrator.spectrum.max())
1293
            )
1294

1295
        else:
1296
            # Plot all-pairs
1297
            fig.add_trace(
4✔
1298
                go.Scatter(
1299
                    x=list(pixel_list),
1300
                    y=list(calibrator.spectrum / calibrator.spectrum.max()),
1301
                    mode="lines",
1302
                    name="Arc",
1303
                )
1304
            )
1305
            xmin = min(calibrator.spectrum / calibrator.spectrum.max())
4✔
1306
            xmax = max(calibrator.spectrum / calibrator.spectrum.max())
4✔
1307

1308
        # Add vlines
1309
        for i in calibrator.peaks:
4✔
1310
            fig.add_shape(
4✔
1311
                type="line",
1312
                xref="x",
1313
                yref="y",
1314
                x0=i,
1315
                y0=0,
1316
                x1=i,
1317
                y1=1.05,
1318
                line=dict(
1319
                    color=pio.templates["CN"].layout.colorway[1], width=1
1320
                ),
1321
            )
1322

1323
        fig.update_layout(
4✔
1324
            autosize=True,
1325
            yaxis=dict(
1326
                title="Normalised Count", range=[xmin, xmax], showgrid=True
1327
            ),
1328
            xaxis=dict(
1329
                title="Pixel",
1330
                zeroline=False,
1331
                range=[0.0, len(calibrator.spectrum)],
1332
                showgrid=True,
1333
            ),
1334
            hovermode="closest",
1335
            showlegend=True,
1336
            height=800,
1337
            width=1000,
1338
        )
1339

1340
        fig.update_xaxes(
4✔
1341
            showline=True, linewidth=1, linecolor="black", mirror=True
1342
        )
1343

1344
        fig.update_yaxes(
4✔
1345
            showline=True, linewidth=1, linecolor="black", mirror=True
1346
        )
1347

1348
        if save_fig:
4✔
1349
            fig_type = fig_type.split("+")
4✔
1350

1351
            if filename is None:
4✔
1352
                filename_output = "rascal_arc"
4✔
1353

1354
            else:
1355
                filename_output = filename
4✔
1356

1357
            for t in fig_type:
4✔
1358
                if t == "iframe":
4✔
1359
                    pio.write_html(fig, filename_output + "." + t)
×
1360

1361
                elif t in ["jpg", "png", "svg", "pdf"]:
4✔
1362
                    pio.write_image(fig, filename_output + "." + t)
4✔
1363

1364
        if display:
4✔
1365
            if renderer == "default":
×
1366
                fig.show()
×
1367

1368
            else:
1369
                fig.show(renderer)
×
1370

1371
        if return_jsonstring:
4✔
1372
            return fig.to_json()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc