• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 25683986170

11 May 2026 04:45PM UTC coverage: 86.691% (-0.1%) from 86.801%
25683986170

Pull #361

github

web-flow
Merge 5b7acbf72 into 6cb1b083e
Pull Request #361: Non-detection interface overhaul + misc

251 of 366 new or added lines in 9 files covered. (68.58%)

7 existing lines in 4 files now uncovered.

15229 of 17567 relevant lines covered (86.69%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.2
/redback/plotting.py
1
from __future__ import annotations
1✔
2

3
from os.path import join
1✔
4
from typing import Any, Union
1✔
5

6
import matplotlib
1✔
7
import matplotlib.pyplot as plt
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10

11
import redback
1✔
12
from redback.utils import KwargsAccessorWithDefault, logger
1✔
13

14
class _FilenameGetter(object):
1✔
15
    def __init__(self, suffix: str) -> None:
1✔
16
        self.suffix = suffix
1✔
17

18
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
19
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")
1✔
20

21
    def __set__(self, instance: Plotter, value: object) -> None:
1✔
22
        pass
1✔
23

24

25
class _FilePathGetter(object):
1✔
26

27
    def __init__(self, directory_property: str, filename_property: str) -> None:
1✔
28
        self.directory_property = directory_property
1✔
29
        self.filename_property = filename_property
1✔
30

31
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
32
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))
1✔
33

34

35
class Plotter(object):
1✔
36
    """
37
    Base class for all lightcurve plotting classes in redback.
38
    """
39

40
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
41
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
42
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
43
    band_colors = KwargsAccessorWithDefault("band_colors", None)
1✔
44
    color = KwargsAccessorWithDefault("color", "k")
1✔
45
    band_labels = KwargsAccessorWithDefault("band_labels", None)
1✔
46
    band_scaling = KwargsAccessorWithDefault("band_scaling", {})
1✔
47
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
48
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
49
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "o")
1✔
50
    model = KwargsAccessorWithDefault("model", None)
1✔
51
    ms = KwargsAccessorWithDefault("ms", 5)
1✔
52
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
53

54
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
55
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
56
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
57
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
58
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
59

60
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
61
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
62
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
63

64
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
65
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
66
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
67
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
68

69
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
70
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
71
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
72
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
73
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
74
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
75

76
    plot_others = KwargsAccessorWithDefault("plot_others", True)
1✔
77
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
78
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
79
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
80
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
81
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
82

83
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0)
1✔
84
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5)
1✔
85
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0)
1✔
86
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
87

88
    # Upper limit plot options
89
    upper_limit_marker = KwargsAccessorWithDefault("upper_limit_marker", r"$\downarrow$")
1✔
90
    upper_limit_markersize = KwargsAccessorWithDefault("upper_limit_markersize", 14)
1✔
91
    upper_limit_alpha = KwargsAccessorWithDefault("upper_limit_alpha", 0.7)
1✔
92

93
    # Grid options
94
    show_grid = KwargsAccessorWithDefault("show_grid", False)
1✔
95
    grid_alpha = KwargsAccessorWithDefault("grid_alpha", 0.3)
1✔
96
    grid_color = KwargsAccessorWithDefault("grid_color", "gray")
1✔
97
    grid_linestyle = KwargsAccessorWithDefault("grid_linestyle", "--")
1✔
98
    grid_linewidth = KwargsAccessorWithDefault("grid_linewidth", 0.5)
1✔
99

100
    # Save format and transparency
101
    save_format = KwargsAccessorWithDefault("save_format", "png")
1✔
102
    transparent = KwargsAccessorWithDefault("transparent", False)
1✔
103

104
    # Axis scale options
105
    xscale = KwargsAccessorWithDefault("xscale", None)
1✔
106
    yscale = KwargsAccessorWithDefault("yscale", None)
1✔
107

108
    # Title options
109
    title = KwargsAccessorWithDefault("title", None)
1✔
110
    title_fontsize = KwargsAccessorWithDefault("title_fontsize", 20)
1✔
111

112
    # Line style options
113
    linestyle = KwargsAccessorWithDefault("linestyle", "-")
1✔
114
    max_likelihood_linestyle = KwargsAccessorWithDefault("max_likelihood_linestyle", "-")
1✔
115
    random_sample_linestyle = KwargsAccessorWithDefault("random_sample_linestyle", "-")
1✔
116

117
    # Marker options
118
    markerfillstyle = KwargsAccessorWithDefault("markerfillstyle", "full")
1✔
119
    markeredgecolor = KwargsAccessorWithDefault("markeredgecolor", None)
1✔
120
    markeredgewidth = KwargsAccessorWithDefault("markeredgewidth", 1.0)
1✔
121

122
    # Legend customization
123
    legend_frameon = KwargsAccessorWithDefault("legend_frameon", True)
1✔
124
    legend_shadow = KwargsAccessorWithDefault("legend_shadow", False)
1✔
125
    legend_fancybox = KwargsAccessorWithDefault("legend_fancybox", True)
1✔
126
    legend_framealpha = KwargsAccessorWithDefault("legend_framealpha", 0.8)
1✔
127

128
    # Tick customization
129
    tick_direction = KwargsAccessorWithDefault("tick_direction", "in")
1✔
130
    tick_length = KwargsAccessorWithDefault("tick_length", None)
1✔
131
    tick_width = KwargsAccessorWithDefault("tick_width", None)
1✔
132

133
    # Spine options
134
    show_spines = KwargsAccessorWithDefault("show_spines", True)
1✔
135
    spine_linewidth = KwargsAccessorWithDefault("spine_linewidth", None)
1✔
136

137
    def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None:
1✔
138
        """
139
        :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted.
140
        :param kwargs: Additional kwargs the plotter uses. -------
141
        :keyword capsize: Same as matplotlib capsize.
142
        :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands.
143
        :keyword legend_location: Same as matplotlib legend location.
144
        :keyword legend_cols: Same as matplotlib legend columns.
145
        :keyword color: Color of the data points.
146
        :keyword band_colors: A dictionary with the colors of the bands.
147
        :keyword band_labels: List with the names of the bands.
148
        :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types.
149
        :keyword dpi: Same as matplotlib dpi.
150
        :keyword elinewidth: same as matplotlib elinewidth
151
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
152
        :keyword model: str or callable, the model to plot.
153
        :keyword ms: Same as matplotlib markersize.
154
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
155
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
156
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
157
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
158
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
159
        :keyword random_sample_color: Color of the random sample curves.
160
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
161
        :keyword linewidth: Same as matplotlib linewidth
162
        :keyword zorder: Same as matplotlib zorder
163
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
164
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
165
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
166
        :keyword annotation_size: `size` argument of of `ax.annotate`.
167
        :keyword fontsize_axes: Font size of the x and y labels.
168
        :keyword fontsize_legend: Font size of the legend.
169
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
170
                                  Used on `supxlabel` and `supylabel`.
171
        :keyword fontsize_ticks: Font size of the axis ticks.
172
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
173
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
174
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
175
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
176
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
177
                                   'credible_intervals': Plot a credible interval that is calculated based
178
                                   on the available parameter sets.
179
        :keyword reference_mjd_date: Date to use as reference point for the x axis.
180
                                    Default is the first date in the data.
181
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
182
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
183
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
184
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
185
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
186
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
187
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
188
        :keyword show_grid: Whether to show grid lines. Default is False.
189
        :keyword grid_alpha: Transparency of grid lines. Default is 0.3.
190
        :keyword grid_color: Color of grid lines. Default is 'gray'.
191
        :keyword grid_linestyle: Line style of grid lines. Default is '--'.
192
        :keyword grid_linewidth: Line width of grid lines. Default is 0.5.
193
        :keyword save_format: Format for saving plots (e.g., 'png', 'pdf', 'svg', 'eps'). Default is 'png'.
194
        :keyword transparent: Whether to save plots with transparent background. Default is False.
195
        :keyword xscale: X-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
196
        :keyword yscale: Y-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
197
        :keyword title: Title for the plot. Default is None (no title).
198
        :keyword title_fontsize: Font size for the title. Default is 20.
199
        :keyword linestyle: Line style for model curves. Default is '-'.
200
        :keyword max_likelihood_linestyle: Line style for max likelihood curve. Default is '-'.
201
        :keyword random_sample_linestyle: Line style for random sample curves. Default is '-'.
202
        :keyword markerfillstyle: Fill style for markers ('full', 'left', 'right', 'bottom', 'top', 'none'). Default is 'full'.
203
        :keyword markeredgecolor: Edge color for markers. Default is None (same as face color).
204
        :keyword markeredgewidth: Edge width for markers. Default is 1.0.
205
        :keyword legend_frameon: Whether to draw a frame around the legend. Default is True.
206
        :keyword legend_shadow: Whether to draw a shadow behind the legend. Default is False.
207
        :keyword legend_fancybox: Whether to use rounded corners for legend frame. Default is True.
208
        :keyword legend_framealpha: Transparency of legend frame. Default is 0.8.
209
        :keyword tick_direction: Direction of tick marks ('in', 'out', 'inout'). Default is 'in'.
210
        :keyword tick_length: Length of tick marks. Default is None (matplotlib default).
211
        :keyword tick_width: Width of tick marks. Default is None (matplotlib default).
212
        :keyword show_spines: Whether to show plot spines (borders). Default is True.
213
        :keyword spine_linewidth: Width of plot spines. Default is None (matplotlib default).
214
        """
215
        self.transient = transient
1✔
216
        self.kwargs = kwargs or dict()
1✔
217
        self._posterior_sorted = False
1✔
218

219
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
220

221
    def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
222
        """
223
        :param axes: The axes used in the plotting procedure.
224
        :type axes: matplotlib.axes.Axes
225

226
        :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot.
227
        :rtype: np.ndarray
228
        """
229
        if isinstance(axes, np.ndarray):
1✔
230
            ax = axes[0]
1✔
231
        else:
232
            ax = axes
1✔
233

234
        if ax.get_yscale() == 'linear':
1✔
235
            times = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
236
        else:
237
            times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
238

239
        if self.transient.use_phase_model:
1✔
240
            times = times + self._reference_mjd_date
1✔
241
        return times
1✔
242

243
    @property
1✔
244
    def _xlim_low(self) -> float:
1✔
245
        default = self.xlim_low_multiplier * self.transient.x[0]
1✔
246
        if default == 0:
1✔
247
            default += 1e-3
×
248
        return self.kwargs.get("xlim_low", default)
1✔
249

250
    @property
1✔
251
    def _xlim_high(self) -> float:
1✔
252
        if self._x_err is None:
1✔
253
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
254
        else:
255
            default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1])
×
256
        return self.kwargs.get("xlim_high", default)
1✔
257

258
    @property
1✔
259
    def _ylim_low(self) -> float:
1✔
260
        y_valid = self.transient.y[np.isfinite(self.transient.y)]
1✔
261
        if len(y_valid) == 0:
1✔
NEW
262
            return 0
×
263
        default = self.ylim_low_multiplier * min(y_valid)
1✔
264
        return self.kwargs.get("ylim_low", default)
1✔
265

266
    @property
1✔
267
    def _ylim_high(self) -> float:
1✔
268
        y_valid = self.transient.y[np.isfinite(self.transient.y)]
1✔
269
        if len(y_valid) == 0:
1✔
NEW
270
            return 1
×
271
        default = self.ylim_high_multiplier * np.max(y_valid)
1✔
272
        return self.kwargs.get("ylim_high", default)
1✔
273

274
    @property
1✔
275
    def _x_err(self) -> Union[np.ndarray, None]:
1✔
276
        if self.transient.x_err is not None:
1✔
277
            return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]])
×
278
        else:
279
            return None
1✔
280

281
    @property
1✔
282
    def _y_err(self) -> np.ndarray:
1✔
283
        if self.transient.y_err.ndim > 1.:
1✔
284
            return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]])
×
285
        else:
286
            return np.array([np.abs(self.transient.y_err)])
1✔
287
    @property
1✔
288
    def _lightcurve_plot_outdir(self) -> str:
1✔
289
        return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__))
1✔
290

291
    @property
1✔
292
    def _data_plot_outdir(self) -> str:
1✔
293
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
294

295
    def _get_outdir(self, default: str) -> str:
1✔
296
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
297

298
    def get_filename(self, default: str) -> str:
1✔
299
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
300

301
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
302
        return self.kwargs.get(kwarg, default) or default
1✔
303

304
    @property
1✔
305
    def _model_kwargs(self) -> dict:
1✔
306
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
307

308
    @property
1✔
309
    def _posterior(self) -> pd.DataFrame:
1✔
310
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
311
        if not self._posterior_sorted and posterior is not None:
1✔
312
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
313
            self._posterior_sorted = True
1✔
314
        return posterior
1✔
315

316
    @property
1✔
317
    def _max_like_params(self) -> pd.core.series.Series:
1✔
318
        return self._posterior.iloc[-1]
1✔
319

320
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
321
        integers = np.arange(len(self._posterior))
1✔
322
        indices = np.random.choice(integers, size=self.random_models)
1✔
323
        return [self._posterior.iloc[idx] for idx in indices]
1✔
324

325
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
326
    _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve")
1✔
327
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
328
    _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data")
1✔
329
    _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve")
1✔
330

331
    _data_plot_filepath = _FilePathGetter(
1✔
332
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
333
    _lightcurve_plot_filepath = _FilePathGetter(
1✔
334
        directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename")
335
    _residual_plot_filepath = _FilePathGetter(
1✔
336
        directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename")
337
    _multiband_data_plot_filepath = _FilePathGetter(
1✔
338
        directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename")
339
    _multiband_lightcurve_plot_filepath = _FilePathGetter(
1✔
340
        directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename")
341

342
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
343
        plt.tight_layout()
1✔
344
        if save:
1✔
345
            # Update filepath extension if save_format is specified
346
            if '.' in filepath:
1✔
347
                filepath = filepath.rsplit('.', 1)[0] + f'.{self.save_format}'
1✔
348
            else:
349
                filepath = f'{filepath}.{self.save_format}'
1✔
350

351
            facecolor = 'none' if self.transparent else 'white'
1✔
352
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches,
1✔
353
                       transparent=self.transparent, facecolor=facecolor)
354
        if show:
1✔
355
            plt.show()
1✔
356

357
    def _apply_axis_customizations(self, ax: matplotlib.axes.Axes) -> None:
1✔
358
        """Apply common axis customizations like grid, title, ticks, and spines."""
359
        # Grid
360
        if self.show_grid:
1✔
361
            ax.grid(True, alpha=self.grid_alpha, color=self.grid_color,
1✔
362
                   linestyle=self.grid_linestyle, linewidth=self.grid_linewidth)
363

364
        # Title
365
        if self.title is not None:
1✔
366
            ax.set_title(self.title, fontsize=self.title_fontsize)
1✔
367

368
        # Tick customization
369
        tick_params = {'axis': 'both', 'which': 'both',
1✔
370
                      'pad': self.axis_tick_params_pad,
371
                      'labelsize': self.fontsize_ticks,
372
                      'direction': self.tick_direction}
373
        if self.tick_length is not None:
1✔
374
            tick_params['length'] = self.tick_length
1✔
375
        if self.tick_width is not None:
1✔
376
            tick_params['width'] = self.tick_width
1✔
377
        ax.tick_params(**tick_params)
1✔
378

379
        # Spine customization
380
        if not self.show_spines:
1✔
381
            for spine in ax.spines.values():
1✔
382
                spine.set_visible(False)
1✔
383
        elif self.spine_linewidth is not None:
1✔
384
            for spine in ax.spines.values():
1✔
385
                spine.set_linewidth(self.spine_linewidth)
1✔
386

387
class SpecPlotter(object):
1✔
388
    """
389
    Base class for all lightcurve plotting classes in redback.
390
    """
391

392
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
393
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
394
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
395
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
396
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
397
    color = KwargsAccessorWithDefault("color", "k")
1✔
398
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
399
    model = KwargsAccessorWithDefault("model", None)
1✔
400
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
401
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
402

403
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
404
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
405
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
406
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
407
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
408

409
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
410
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
411
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
412
    yscale = KwargsAccessorWithDefault("yscale", "linear")
1✔
413

414
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
415
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
416
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
417
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
418

419
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
420
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
421
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
422
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
423
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
424
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
425

426
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
427
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
428
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
429
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
430
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
431

432
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
1✔
433
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
434
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
1✔
435
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
436

437
    # Upper limit plot options
438
    upper_limit_marker = KwargsAccessorWithDefault("upper_limit_marker", r"$\downarrow$")
1✔
439
    upper_limit_markersize = KwargsAccessorWithDefault("upper_limit_markersize", 14)
1✔
440
    upper_limit_alpha = KwargsAccessorWithDefault("upper_limit_alpha", 0.7)
1✔
441

442
    # Grid options
443
    show_grid = KwargsAccessorWithDefault("show_grid", False)
1✔
444
    grid_alpha = KwargsAccessorWithDefault("grid_alpha", 0.3)
1✔
445
    grid_color = KwargsAccessorWithDefault("grid_color", "gray")
1✔
446
    grid_linestyle = KwargsAccessorWithDefault("grid_linestyle", "--")
1✔
447
    grid_linewidth = KwargsAccessorWithDefault("grid_linewidth", 0.5)
1✔
448

449
    # Save format and transparency
450
    save_format = KwargsAccessorWithDefault("save_format", "png")
1✔
451
    transparent = KwargsAccessorWithDefault("transparent", False)
1✔
452

453
    # Axis scale options (xscale can be customized too)
454
    xscale = KwargsAccessorWithDefault("xscale", None)
1✔
455

456
    # Title options
457
    title = KwargsAccessorWithDefault("title", None)
1✔
458
    title_fontsize = KwargsAccessorWithDefault("title_fontsize", 20)
1✔
459

460
    # Line style options
461
    linestyle = KwargsAccessorWithDefault("linestyle", "-")
1✔
462
    max_likelihood_linestyle = KwargsAccessorWithDefault("max_likelihood_linestyle", "-")
1✔
463
    random_sample_linestyle = KwargsAccessorWithDefault("random_sample_linestyle", "-")
1✔
464

465
    # Marker options
466
    markerfillstyle = KwargsAccessorWithDefault("markerfillstyle", "full")
1✔
467
    markeredgecolor = KwargsAccessorWithDefault("markeredgecolor", None)
1✔
468
    markeredgewidth = KwargsAccessorWithDefault("markeredgewidth", 1.0)
1✔
469

470
    # Legend customization
471
    legend_frameon = KwargsAccessorWithDefault("legend_frameon", True)
1✔
472
    legend_shadow = KwargsAccessorWithDefault("legend_shadow", False)
1✔
473
    legend_fancybox = KwargsAccessorWithDefault("legend_fancybox", True)
1✔
474
    legend_framealpha = KwargsAccessorWithDefault("legend_framealpha", 0.8)
1✔
475

476
    # Tick customization
477
    tick_direction = KwargsAccessorWithDefault("tick_direction", "in")
1✔
478
    tick_length = KwargsAccessorWithDefault("tick_length", None)
1✔
479
    tick_width = KwargsAccessorWithDefault("tick_width", None)
1✔
480

481
    # Spine options
482
    show_spines = KwargsAccessorWithDefault("show_spines", True)
1✔
483
    spine_linewidth = KwargsAccessorWithDefault("spine_linewidth", None)
1✔
484

485
    def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
1✔
486
        """
487
        :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
488
        :param kwargs: Additional kwargs the plotter uses. -------
489
        :keyword capsize: Same as matplotlib capsize.
490
        :keyword elinewidth: same as matplotlib elinewidth
491
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
492
        :keyword ms: Same as matplotlib markersize.
493
        :keyword legend_location: Same as matplotlib legend location.
494
        :keyword legend_cols: Same as matplotlib legend columns.
495
        :keyword color: Color of the data points.
496
        :keyword dpi: Same as matplotlib dpi.
497
        :keyword model: str or callable, the model to plot.
498
        :keyword ms: Same as matplotlib markersize.
499
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
500
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
501
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
502
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
503
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
504
        :keyword random_sample_color: Color of the random sample curves.
505
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
506
        :keyword linewidth: Same as matplotlib linewidth
507
        :keyword zorder: Same as matplotlib zorder
508
        :keyword yscale: Same as matplotlib yscale, default is linear
509
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
510
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
511
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
512
        :keyword annotation_size: `size` argument of of `ax.annotate`.
513
        :keyword fontsize_axes: Font size of the x and y labels.
514
        :keyword fontsize_legend: Font size of the legend.
515
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
516
                                  Used on `supxlabel` and `supylabel`.
517
        :keyword fontsize_ticks: Font size of the axis ticks.
518
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
519
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
520
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
521
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
522
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
523
                                   'credible_intervals': Plot a credible interval that is calculated based
524
                                   on the available parameter sets.
525
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
526
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
527
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
528
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
529
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
530
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
531
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
532
        :keyword show_grid: Whether to show grid lines. Default is False.
533
        :keyword grid_alpha: Transparency of grid lines. Default is 0.3.
534
        :keyword grid_color: Color of grid lines. Default is 'gray'.
535
        :keyword grid_linestyle: Line style of grid lines. Default is '--'.
536
        :keyword grid_linewidth: Line width of grid lines. Default is 0.5.
537
        :keyword save_format: Format for saving plots (e.g., 'png', 'pdf', 'svg', 'eps'). Default is 'png'.
538
        :keyword transparent: Whether to save plots with transparent background. Default is False.
539
        :keyword xscale: X-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
540
        :keyword title: Title for the plot. Default is None (no title).
541
        :keyword title_fontsize: Font size for the title. Default is 20.
542
        :keyword linestyle: Line style for model curves. Default is '-'.
543
        :keyword max_likelihood_linestyle: Line style for max likelihood curve. Default is '-'.
544
        :keyword random_sample_linestyle: Line style for random sample curves. Default is '-'.
545
        :keyword markerfillstyle: Fill style for markers ('full', 'left', 'right', 'bottom', 'top', 'none'). Default is 'full'.
546
        :keyword markeredgecolor: Edge color for markers. Default is None (same as face color).
547
        :keyword markeredgewidth: Edge width for markers. Default is 1.0.
548
        :keyword legend_frameon: Whether to draw a frame around the legend. Default is True.
549
        :keyword legend_shadow: Whether to draw a shadow behind the legend. Default is False.
550
        :keyword legend_fancybox: Whether to use rounded corners for legend frame. Default is True.
551
        :keyword legend_framealpha: Transparency of legend frame. Default is 0.8.
552
        :keyword tick_direction: Direction of tick marks ('in', 'out', 'inout'). Default is 'in'.
553
        :keyword tick_length: Length of tick marks. Default is None (matplotlib default).
554
        :keyword tick_width: Width of tick marks. Default is None (matplotlib default).
555
        :keyword show_spines: Whether to show plot spines (borders). Default is True.
556
        :keyword spine_linewidth: Width of plot spines. Default is None (matplotlib default).
557
        """
558
        self.transient = spectrum
1✔
559
        self.kwargs = kwargs or dict()
1✔
560
        self._posterior_sorted = False
1✔
561

562
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
563

564
    def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
565
        """
566
        :param axes: The axes used in the plotting procedure.
567
        :type axes: matplotlib.axes.Axes
568

569
        :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
570
        :rtype: np.ndarray
571
        """
572
        if isinstance(axes, np.ndarray):
1✔
573
            ax = axes[0]
×
574
        else:
575
            ax = axes
1✔
576

577
        if ax.get_yscale() == 'linear':
1✔
578
            angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
579
        else:
580
            angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
581

582
        return angstroms
1✔
583

584
    @property
1✔
585
    def _xlim_low(self) -> float:
1✔
586
        default = self.xlim_low_multiplier * self.transient.angstroms[0]
1✔
587
        if default == 0:
1✔
588
            default += 1e-3
×
589
        return self.kwargs.get("xlim_low", default)
1✔
590

591
    @property
1✔
592
    def _xlim_high(self) -> float:
1✔
593
        default = self.xlim_high_multiplier * self.transient.angstroms[-1]
1✔
594
        return self.kwargs.get("xlim_high", default)
1✔
595

596
    @property
1✔
597
    def _ylim_low(self) -> float:
1✔
598
        default = self.ylim_low_multiplier * min(self.transient.flux_density)
1✔
599
        return self.kwargs.get("ylim_low", default/1e-17)
1✔
600

601
    @property
1✔
602
    def _ylim_high(self) -> float:
1✔
603
        default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
1✔
604
        return self.kwargs.get("ylim_high", default/1e-17)
1✔
605

606
    @property
1✔
607
    def _y_err(self) -> np.ndarray:
1✔
608
        return np.array([np.abs(self.transient.flux_density_err)])
1✔
609

610
    @property
1✔
611
    def _data_plot_outdir(self) -> str:
1✔
612
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
613

614
    def _get_outdir(self, default: str) -> str:
1✔
615
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
616

617
    def get_filename(self, default: str) -> str:
1✔
618
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
619

620
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
621
        return self.kwargs.get(kwarg, default) or default
1✔
622

623
    @property
1✔
624
    def _model_kwargs(self) -> dict:
1✔
625
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
626

627
    @property
1✔
628
    def _posterior(self) -> pd.DataFrame:
1✔
629
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
630
        if not self._posterior_sorted and posterior is not None:
1✔
631
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
632
            self._posterior_sorted = True
1✔
633
        return posterior
1✔
634

635
    @property
1✔
636
    def _max_like_params(self) -> pd.core.series.Series:
1✔
637
        return self._posterior.iloc[-1]
1✔
638

639
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
640
        integers = np.arange(len(self._posterior))
1✔
641
        indices = np.random.choice(integers, size=self.random_models)
1✔
642
        return [self._posterior.iloc[idx] for idx in indices]
1✔
643

644
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
645
    _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
1✔
646
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
647

648
    _data_plot_filepath = _FilePathGetter(
1✔
649
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
650
    _spectrum_ppd_plot_filepath = _FilePathGetter(
1✔
651
        directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
652
    _residual_plot_filepath = _FilePathGetter(
1✔
653
        directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
654

655
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
656
        plt.tight_layout()
1✔
657
        if save:
1✔
658
            # Update filepath extension if save_format is specified
659
            if '.' in filepath:
1✔
660
                filepath = filepath.rsplit('.', 1)[0] + f'.{self.save_format}'
1✔
661
            else:
662
                filepath = f'{filepath}.{self.save_format}'
1✔
663

664
            facecolor = 'none' if self.transparent else 'white'
1✔
665
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches,
1✔
666
                       transparent=self.transparent, facecolor=facecolor)
667
        if show:
1✔
668
            plt.show()
1✔
669

670
    def _apply_axis_customizations(self, ax: matplotlib.axes.Axes) -> None:
1✔
671
        """Apply common axis customizations like grid, title, ticks, and spines."""
672
        # Grid
673
        if self.show_grid:
1✔
674
            ax.grid(True, alpha=self.grid_alpha, color=self.grid_color,
1✔
675
                   linestyle=self.grid_linestyle, linewidth=self.grid_linewidth)
676

677
        # Title
678
        if self.title is not None:
1✔
679
            ax.set_title(self.title, fontsize=self.title_fontsize)
1✔
680

681
        # Tick customization
682
        tick_params = {'axis': 'both', 'which': 'both',
1✔
683
                      'pad': self.axis_tick_params_pad,
684
                      'labelsize': self.fontsize_ticks,
685
                      'direction': self.tick_direction}
686
        if self.tick_length is not None:
1✔
687
            tick_params['length'] = self.tick_length
1✔
688
        if self.tick_width is not None:
1✔
689
            tick_params['width'] = self.tick_width
1✔
690
        ax.tick_params(**tick_params)
1✔
691

692
        # Spine customization
693
        if not self.show_spines:
1✔
694
            for spine in ax.spines.values():
1✔
695
                spine.set_visible(False)
1✔
696
        elif self.spine_linewidth is not None:
1✔
697
            for spine in ax.spines.values():
1✔
698
                spine.set_linewidth(self.spine_linewidth)
1✔
699

700

701
class IntegratedFluxPlotter(Plotter):
1✔
702

703
    @property
1✔
704
    def _xlabel(self) -> str:
1✔
705
        return r"Time since burst [s]"
1✔
706

707
    @property
1✔
708
    def _ylabel(self) -> str:
1✔
709
        return self.transient.ylabel
1✔
710

711
    def plot_data(
1✔
712
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
713
        """Plots the Integrated flux data and returns Axes.
714

715
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
716
        :type axes: Union[matplotlib.axes.Axes, None], optional
717
        :param save: Whether to save the plot. (Default value = True)
718
        :type save: bool
719
        :param show: Whether to show the plot. (Default value = True)
720
        :type show: bool
721

722
        :return: The axes with the plot.
723
        :rtype: matplotlib.axes.Axes
724
        """
725
        ax = axes or plt.gca()
1✔
726

727
        ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err,
1✔
728
                    fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
729
                    fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
730
                    markeredgewidth=self.markeredgewidth)
731

732
        # Apply custom scales if specified, otherwise use defaults
733
        ax.set_xscale(self.xscale if self.xscale is not None else 'log')
1✔
734
        ax.set_yscale(self.yscale if self.yscale is not None else 'log')
1✔
735

736
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
737
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
738
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
739
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
740

741
        ax.annotate(
1✔
742
            self.transient.name, xy=self.xy, xycoords=self.xycoords,
743
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
744

745
        # Apply new customizations
746
        self._apply_axis_customizations(ax)
1✔
747

748
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
749
        return ax
1✔
750

751
    def plot_lightcurve(
1✔
752
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
753
        """Plots the Integrated flux data and the lightcurve and returns Axes.
754

755
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
756
        :type axes: Union[matplotlib.axes.Axes, None], optional
757
        :param save: Whether to save the plot. (Default value = True)
758
        :type save: bool
759
        :param show: Whether to show the plot. (Default value = True)
760
        :type show: bool
761

762
        :return: The axes with the plot.
763
        :rtype: matplotlib.axes.Axes
764
        """
765
        
766
        axes = axes or plt.gca()
1✔
767

768
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
769
        times = self._get_times(axes)
1✔
770

771
        self._plot_lightcurves(axes, times)
1✔
772

773
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
774
        return axes
1✔
775

776
    def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None:
1✔
777
        if self.plot_max_likelihood:
1✔
778
            ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
779
            axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
780
                     lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
781

782
        random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
783
                          for random_params in self._get_random_parameters()]
784
        if self.uncertainty_mode == "random_models":
1✔
785
            for ys in random_ys_list:
1✔
786
                axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
787
                         lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
788
        elif self.uncertainty_mode == "credible_intervals":
1✔
789
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
1✔
790
            axes.fill_between(
1✔
791
                times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
792

793
    def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None:
1✔
794
        ys = self.model(times, **params, **self._model_kwargs)
×
795
        axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha,
×
796
                 lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
797

798
    def plot_residuals(
1✔
799
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
800
        """Plots the residual of the Integrated flux data returns Axes.
801

802
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
803
        :param save: Whether to save the plot. (Default value = True)
804
        :param show: Whether to show the plot. (Default value = True)
805

806
        :return: The axes with the plot.
807
        :rtype: matplotlib.axes.Axes
808
        """
809
        if axes is None:
1✔
810
            fig, axes = plt.subplots(
1✔
811
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
812

813
        axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False)
1✔
814
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
815
        axes[0].set_xlabel("")
1✔
816
        ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs)
1✔
817
        axes[1].errorbar(
1✔
818
            self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err,
819
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
820
            fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
821
            markeredgewidth=self.markeredgewidth)
822
        axes[1].set_yscale("log")
1✔
823
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
824

825
        # Apply new customizations
826
        self._apply_axis_customizations(axes[1])
1✔
827

828
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
829
        return axes
1✔
830

831

832
class LuminosityOpticalPlotter(IntegratedFluxPlotter):
1✔
833

834
    @property
1✔
835
    def _xlabel(self) -> str:
1✔
836
        return r"Time since explosion [days]"
1✔
837

838
    @property
1✔
839
    def _ylabel(self) -> str:
1✔
840
        return r"L$_{\rm bol}$ [$10^{50}$ erg s$^{-1}$]"
1✔
841

842
class LuminosityPlotter(IntegratedFluxPlotter):
1✔
843
    pass
1✔
844

845

846
class MagnitudePlotter(Plotter):
1✔
847

848
    xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
849
    xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
1✔
850
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
1✔
851
    ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.95)
1✔
852
    ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.05)
1✔
853
    ncols = KwargsAccessorWithDefault("ncols", 2)
1✔
854

855
    @property
1✔
856
    def _colors(self) -> str:
1✔
857
        return self.kwargs.get("colors", self.transient.get_colors(self._filters))
1✔
858

859
    @property
1✔
860
    def _xlabel(self) -> str:
1✔
861
        if self.transient.use_phase_model:
1✔
862
            default = f"Time since {self._reference_mjd_date} MJD [days]"
1✔
863
        else:
864
            default = self.transient.xlabel
1✔
865
        return self.kwargs.get("xlabel", default)
1✔
866

867
    @property
1✔
868
    def _ylabel(self) -> str:
1✔
869
        return self.kwargs.get("ylabel", self.transient.ylabel)
1✔
870

871
    @property
1✔
872
    def _get_bands_to_plot(self) -> list[str]:
1✔
873
        return self.kwargs.get("bands_to_plot", self.transient.active_bands)
1✔
874

875
    @property
1✔
876
    def _xlim_low(self) -> float:
1✔
877
        if self.transient.use_phase_model:
1✔
878
            default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier
1✔
879
        else:
880
            default = self.xlim_low_multiplier * self.transient.x[0]
1✔
881
        if default == 0:
1✔
882
            default += 1e-3
1✔
883
        return self.kwargs.get("xlim_low", default)
1✔
884

885
    @property
1✔
886
    def _xlim_high(self) -> float:
1✔
887
        if self.transient.use_phase_model:
1✔
888
            default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier
1✔
889
        else:
890
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
891
        return self.kwargs.get("xlim_high", default)
1✔
892

893
    @property
1✔
894
    def _ylim_low_magnitude(self) -> float:
1✔
895
        y_valid = self.transient.y[np.isfinite(self.transient.y)]
1✔
896
        if len(y_valid) == 0:
1✔
NEW
897
            return 0
×
898
        default = self.ylim_low_magnitude_multiplier * min(y_valid)
1✔
899
        return self.kwargs.get("ylim_low", default)
1✔
900

901
    @property
1✔
902
    def _ylim_high_magnitude(self) -> float:
1✔
903
        y_valid = self.transient.y[np.isfinite(self.transient.y)]
1✔
904
        if len(y_valid) == 0:
1✔
NEW
905
            return 1
×
906
        default = self.ylim_high_magnitude_multiplier * np.max(y_valid)
1✔
907
        return self.kwargs.get("ylim_high", default)
1✔
908

909
    def _get_ylim_low_with_indices(self, indices: list) -> float:
1✔
910
        y_valid = self.transient.y[indices][np.isfinite(self.transient.y[indices])]
1✔
911
        if len(y_valid) == 0:
1✔
NEW
912
            return 0
×
913
        return self.ylim_low_multiplier * min(y_valid)
1✔
914

915
    def _get_ylim_high_with_indices(self, indices: list) -> float:
1✔
916
        y_valid = self.transient.y[indices][np.isfinite(self.transient.y[indices])]
1✔
917
        if len(y_valid) == 0:
1✔
NEW
918
            return 1
×
919
        return self.ylim_high_multiplier * np.max(y_valid)
1✔
920

921
    def _get_x_err(self, indices: list) -> np.ndarray:
1✔
922
        return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err
1✔
923

924
    def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None:
1✔
925
        if self.transient.magnitude_data:
1✔
926
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
927
            ax.invert_yaxis()
1✔
928
            # Apply custom yscale if specified, otherwise use default
929
            ax.set_yscale(self.yscale if self.yscale is not None else 'linear')
1✔
930
        else:
931
            ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
932
            # Apply custom yscale if specified, otherwise use default
933
            ax.set_yscale(self.yscale if self.yscale is not None else "log")
1✔
934

935
    def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None:
1✔
936
        if self.transient.magnitude_data:
1✔
937
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
938
            ax.invert_yaxis()
1✔
939
            # Apply custom yscale if specified, otherwise use default
940
            ax.set_yscale(self.yscale if self.yscale is not None else 'linear')
1✔
941
        else:
942
            ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
1✔
943
                        self._get_ylim_high_with_indices(indices=indices))
944
            # Apply custom yscale if specified, otherwise use default
945
            ax.set_yscale(self.yscale if self.yscale is not None else "log")
1✔
946

947
    def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None:
1✔
948
        # Apply custom xscale if specified, otherwise use default behavior
949
        if self.xscale is not None:
1✔
950
            axes.set_xscale(self.xscale)
1✔
951
        elif self.transient.use_phase_model:
1✔
952
            axes.set_xscale("linear")  # Keep master's default for phase model
1✔
953
        axes.set_xlim(self._xlim_low, self._xlim_high)
1✔
954

955
    @property
1✔
956
    def _nrows(self) -> int:
1✔
957
        default = int(np.ceil(len(self._filters) / 2))
1✔
958
        return self._get_kwarg_with_default("nrows", default=default)
1✔
959

960
    @property
1✔
961
    def _npanels(self) -> int:
1✔
962
        npanels = self._nrows * self.ncols
×
963
        if npanels < len(self._filters):
×
964
            raise ValueError(f"Insufficient number of panels. {npanels} panels were given "
×
965
                             f"but {len(self._filters)} panels are needed.")
966
        return npanels
×
967

968
    @property
1✔
969
    def _figsize(self) -> tuple:
1✔
970
        default = (4 + 4 * self.ncols, 2 + 2 * self._nrows)
1✔
971
        return self._get_kwarg_with_default("figsize", default=default)
1✔
972

973
    @property
1✔
974
    def _reference_mjd_date(self) -> int:
1✔
975
        if self.transient.use_phase_model:
1✔
976
            return self.kwargs.get("reference_mjd_date", int(self.transient.x[0]))
1✔
977
        return 0
1✔
978

979
    @property
1✔
980
    def band_label_generator(self):
1✔
981
        if self.band_labels is not None:
1✔
982
            return (bl for bl in self.band_labels)
×
983

984
    def _get_data_plot_label(self, band):
1✔
985
        if band in self.band_scaling:
1✔
NEW
986
            label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band))
×
NEW
987
            if self.band_scaling.get("type") == 'x':
×
NEW
988
                if self.band_scaling.get(band) == 1:
×
NEW
989
                    label = band
×
NEW
990
            elif self.band_scaling.get("type") == '+':
×
NEW
991
                if self.band_scaling.get(band) == 0:
×
NEW
992
                    label = band
×
993
        else:
994
            label = band
1✔
995
        if isinstance(label, float):
1✔
NEW
996
            label = f"{label:.2e}"
×
997
        return label
1✔
998

999
    def _get_data_plot_labels(self) -> dict:
1✔
1000
        band_label_generator = self.band_label_generator
1✔
1001
        band_labels = dict()
1✔
1002
        for band in self.transient.unique_bands:
1✔
1003
            if band in self._filters:
1✔
1004
                if band_label_generator is None:
1✔
1005
                    band_labels[band] = self._get_data_plot_label(band)
1✔
1006
                else:
NEW
1007
                    band_labels[band] = next(band_label_generator)
×
1008
            elif self.plot_others:
1✔
1009
                band_labels[band] = None
1✔
1010
        return band_labels
1✔
1011

1012
    def plot_data(
1✔
1013
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1014
        """Plots the Magnitude data and returns Axes.
1015

1016
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1017
        :type axes: Union[matplotlib.axes.Axes, None], optional
1018
        :param save: Whether to save the plot. (Default value = True)
1019
        :type save: bool
1020
        :param show: Whether to show the plot. (Default value = True)
1021
        :type show: bool
1022

1023
        :return: The axes with the plot.
1024
        :rtype: matplotlib.axes.Axes
1025
        """
1026
        ax = axes or plt.gca()
1✔
1027

1028
        band_labels = self._get_data_plot_labels()
1✔
1029
        labeled_bands = set()
1✔
1030

1031
        for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
1032
            if band not in band_labels:
1✔
NEW
1033
                continue
×
1034
            # Filter out upper limit indices from detection data
1035
            if self.transient.has_upper_limits is True:
1✔
1036
                detection_mask = ~self.transient.upper_limits
1✔
1037
                indices = np.asarray(indices)
1✔
1038
                indices = indices[detection_mask[indices]]
1✔
1039
                if len(indices) == 0:
1✔
1040
                    continue
1✔
1041
            label = band_labels[band]
1✔
1042
            if band in self._filters:
1✔
1043
                color = self._colors[list(self._filters).index(band)]
1✔
1044
            elif self.plot_others:
1✔
1045
                color = "black"
1✔
1046
            else:
UNCOV
1047
                continue
×
1048
            if self.band_colors is not None:
1✔
1049
                color = self.band_colors[band]
1✔
1050
            if band in self.band_scaling:
1✔
1051
                if self.band_scaling.get("type") == 'x':
×
1052
                    ax.errorbar(
×
1053
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band),
1054
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band),
1055
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
1056
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1057
                        fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1058
                        markeredgewidth=self.markeredgewidth)
1059
                elif self.band_scaling.get("type") == '+':
×
1060
                    ax.errorbar(
×
1061
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band),
1062
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
1063
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
1064
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1065
                        fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1066
                        markeredgewidth=self.markeredgewidth)
1067
            else:
1068
                ax.errorbar(
1✔
1069
                    self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices],
1070
                    xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
1071
                    fmt=self.errorbar_fmt, ms=self.ms, color=color,
1072
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1073
                    fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1074
                    markeredgewidth=self.markeredgewidth)
1075
            if label is not None:
1✔
1076
                labeled_bands.add(band)
1✔
1077

1078
        # Plot upper limits if present
1079
        if self.transient.has_upper_limits is True:
1✔
1080
            ul_boolean = self.transient.upper_limits
1✔
1081
            ul_marker = self.upper_limit_marker
1✔
1082
            for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
1083
                if band not in band_labels:
1✔
NEW
1084
                    continue
×
1085
                ul_band_indices = np.array([i for i in indices if ul_boolean[i]])
1✔
1086
                if len(ul_band_indices) == 0:
1✔
1087
                    continue
1✔
1088
                # Upper limits must have finite y-values to be plotted at a position
1089
                finite_mask = np.isfinite(self.transient.y[ul_band_indices])
1✔
1090
                if not np.all(finite_mask):
1✔
NEW
1091
                    n_nan = int(np.sum(~finite_mask))
×
NEW
1092
                    logger.warning(
×
1093
                        f"{n_nan} upper limit(s) in band '{band}' have NaN y-values and "
1094
                        f"will not be plotted. Provide finite upper limit values (e.g. the "
1095
                        f"limiting magnitude or flux) to plot them."
1096
                    )
NEW
1097
                    ul_band_indices = ul_band_indices[finite_mask]
×
1098
                if len(ul_band_indices) == 0:
1✔
NEW
1099
                    continue
×
1100
                if self.band_colors is not None and band in self.band_colors:
1✔
1101
                    color = self.band_colors[band]
1✔
NEW
1102
                elif band in self._filters:
×
NEW
1103
                    color = self._colors[list(self._filters).index(band)]
×
NEW
1104
                elif self.plot_others:
×
NEW
1105
                    color = "black"
×
1106
                else:
NEW
1107
                    continue
×
1108
                ul_y = self.transient.y[ul_band_indices]
1✔
1109
                if band in self.band_scaling:
1✔
NEW
1110
                    if self.band_scaling.get("type") == 'x':
×
NEW
1111
                        ul_y = ul_y * self.band_scaling.get(band)
×
NEW
1112
                    elif self.band_scaling.get("type") == '+':
×
NEW
1113
                        ul_y = ul_y + self.band_scaling.get(band)
×
1114
                label = band_labels[band] if band not in labeled_bands else None
1✔
1115
                ax.plot(self.transient.x[ul_band_indices] - self._reference_mjd_date,
1✔
1116
                        ul_y, marker=ul_marker, ms=self.upper_limit_markersize,
1117
                        color=color, linestyle='none', alpha=self.upper_limit_alpha,
1118
                        label=label, zorder=3)
1119

1120
        self._set_x_axis(axes=ax)
1✔
1121
        self._set_y_axis_data(ax)
1✔
1122

1123
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1124
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1125

1126
        # Apply new customizations
1127
        self._apply_axis_customizations(ax)
1✔
1128

1129
        # Legend with new customization options
1130
        ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend,
1✔
1131
                 frameon=self.legend_frameon, shadow=self.legend_shadow,
1132
                 fancybox=self.legend_fancybox, framealpha=self.legend_framealpha)
1133

1134
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1135
        return ax
1✔
1136

1137
    def plot_lightcurve(
1✔
1138
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\
1139
            -> matplotlib.axes.Axes:
1140
        """Plots the Magnitude data and returns Axes.
1141

1142
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1143
        :type axes: Union[matplotlib.axes.Axes, None], optional
1144
        :param save: Whether to save the plot. (Default value = True)
1145
        :type save: bool
1146
        :param show: Whether to show the plot. (Default value = True)
1147
        :type show: bool
1148

1149
        :return: The axes with the plot.
1150
        :rtype: matplotlib.axes.Axes
1151
        """
1152
        axes = axes or plt.gca()
1✔
1153

1154
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
1155

1156
        times = self._get_times(axes)
1✔
1157
        bands_to_plot = self._get_bands_to_plot
1✔
1158

1159
        color_max = self.max_likelihood_color
1✔
1160
        color_sample = self.random_sample_color
1✔
1161
        for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)):
1✔
1162
            if self.set_same_color_per_subplot is True:
1✔
1163
                if self.band_colors is not None:
1✔
1164
                    color = self.band_colors[band]
×
1165
                color_max = color
1✔
1166
                color_sample = color
1✔
1167
            sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band])
1✔
1168
            self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))]
1✔
1169
            if isinstance(band, str):
1✔
1170
                frequency = redback.utils.bands_to_frequency([band])
1✔
1171
            else:
1172
                frequency = band
×
1173
            self._model_kwargs['frequency'] = np.ones(len(times)) * frequency
1✔
1174
            if self.plot_max_likelihood:
1✔
1175
                ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
1176
                if band in self.band_scaling:
1✔
1177
                    if self.band_scaling.get("type") == 'x':
×
1178
                        axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max,
×
1179
                                 alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1180
                    elif self.band_scaling.get("type") == '+':
×
1181
                        axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max,
×
1182
                                 alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1183
                else:
1184
                    axes.plot(times - self._reference_mjd_date, ys, color=color_max,
1✔
1185
                             alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1186

1187
            random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
1188
                              for random_params in self._get_random_parameters()]
1189
            if self.uncertainty_mode == "random_models":
1✔
1190
                for ys in random_ys_list:
1✔
1191
                    if band in self.band_scaling:
1✔
1192
                        if self.band_scaling.get("type") == 'x':
×
1193
                            axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample,
×
1194
                                     alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1195
                        elif self.band_scaling.get("type") == '+':
×
1196
                            axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample,
×
1197
                                     alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1198
                    else:
1199
                        axes.plot(times - self._reference_mjd_date, ys, color=color_sample,
1✔
1200
                                 alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1201
            elif self.uncertainty_mode == "credible_intervals":
×
1202
                if band in self.band_scaling:
×
1203
                    if self.band_scaling.get("type") == 'x':
×
1204
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
×
1205
                    elif self.band_scaling.get("type") == '+':
×
1206
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
×
1207
                else:
1208
                    lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
×
1209
                axes.fill_between(
×
1210
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1211
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1212

1213
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
1214
        return axes
1✔
1215

1216
    def _check_valid_multiband_data_mode(self) -> bool:
1✔
1217
        if self.transient.luminosity_data:
1✔
1218
            redback.utils.logger.warning(
×
1219
                f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.")
1220
            return False
×
1221
        return True
1✔
1222

1223
    def plot_multiband(
1✔
1224
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True,
1225
            show: bool = True) -> matplotlib.axes.Axes:
1226
        """Plots the Magnitude multiband data and returns Axes.
1227

1228
        :param figure: Matplotlib figure to plot the data into.
1229
        :type figure: matplotlib.figure.Figure
1230
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1231
        :type axes: Union[matplotlib.axes.Axes, None], optional
1232
        :param save: Whether to save the plot. (Default value = True)
1233
        :type save: bool
1234
        :param show: Whether to show the plot. (Default value = True)
1235
        :type show: bool
1236

1237
        :return: The axes with the plot.
1238
        :rtype: matplotlib.axes.Axes
1239
        """
1240
        if not self._check_valid_multiband_data_mode():
1✔
1241
            return
×
1242

1243
        if figure is None or axes is None:
1✔
1244
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
×
1245
        axes = axes.ravel()
1✔
1246

1247
        band_label_generator = self.band_label_generator
1✔
1248

1249
        ii = 0
1✔
1250
        for indices, band, freq in zip(
1✔
1251
                self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies):
1252
            if band not in self._filters:
1✔
1253
                continue
×
1254

1255
            x_err = self._get_x_err(indices)
1✔
1256
            color = self._colors[list(self._filters).index(band)]
1✔
1257
            if self.band_colors is not None:
1✔
1258
                color = self.band_colors[band]
1✔
1259
            if band_label_generator is None:
1✔
1260
                label = self._get_multiband_plot_label(band, freq)
1✔
1261
            else:
1262
                label = next(band_label_generator)
×
1263

1264
            # Separate detections and upper limits for this band
1265
            if self.transient.has_upper_limits is True:
1✔
1266
                det_mask = ~self.transient.upper_limits[indices]
1✔
1267
                det_indices = indices[det_mask]
1✔
1268
                ul_indices_band = indices[~det_mask]
1✔
1269
            else:
1270
                det_indices = indices
1✔
1271
                ul_indices_band = np.array([], dtype=int)
1✔
1272

1273
            # Plot detections with error bars
1274
            if len(det_indices) > 0:
1✔
1275
                axes[ii].errorbar(
1✔
1276
                    self.transient.x[det_indices] - self._reference_mjd_date,
1277
                    self.transient.y[det_indices], xerr=self._get_x_err(det_indices),
1278
                    yerr=self.transient.y_err[det_indices], fmt=self.errorbar_fmt, ms=self.ms, color=color,
1279
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1280
                    fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1281
                    markeredgewidth=self.markeredgewidth)
1282

1283
            # Plot upper limits with limit markers
1284
            if len(ul_indices_band) > 0:
1✔
1285
                finite_mask = np.isfinite(self.transient.y[ul_indices_band])
1✔
1286
                if not np.all(finite_mask):
1✔
NEW
1287
                    n_nan = int(np.sum(~finite_mask))
×
NEW
1288
                    logger.warning(
×
1289
                        f"{n_nan} upper limit(s) in band '{band}' have NaN y-values and "
1290
                        f"will not be plotted."
1291
                    )
NEW
1292
                    ul_indices_band = ul_indices_band[finite_mask]
×
1293
                if len(ul_indices_band) > 0:
1✔
1294
                    ul_y = self.transient.y[ul_indices_band]
1✔
1295
                    if band in self.band_scaling:
1✔
NEW
1296
                        if self.band_scaling.get("type") == 'x':
×
NEW
1297
                            ul_y = ul_y * self.band_scaling.get(band)
×
NEW
1298
                        elif self.band_scaling.get("type") == '+':
×
NEW
1299
                            ul_y = ul_y + self.band_scaling.get(band)
×
1300
                    axes[ii].plot(
1✔
1301
                        self.transient.x[ul_indices_band] - self._reference_mjd_date,
1302
                        ul_y, marker=self.upper_limit_marker, ms=self.upper_limit_markersize,
1303
                        color=color, linestyle='none', alpha=self.upper_limit_alpha,
1304
                        label=label if len(det_indices) == 0 else None, zorder=3)
1305

1306
            self._set_x_axis(axes[ii])
1✔
1307
            self._set_y_axis_multiband_data(axes[ii], indices)
1✔
1308

1309
            # Apply new customizations
1310
            self._apply_axis_customizations(axes[ii])
1✔
1311

1312
            # Legend with new customization options
1313
            axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend,
1✔
1314
                           frameon=self.legend_frameon, shadow=self.legend_shadow,
1315
                           fancybox=self.legend_fancybox, framealpha=self.legend_framealpha)
1316
            ii += 1
1✔
1317

1318
        figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure)
1✔
1319
        figure.supylabel(self._ylabel, fontsize=self.fontsize_figure)
1✔
1320
        plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace)
1✔
1321

1322
        self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show)
1✔
1323
        return axes
1✔
1324

1325
    @staticmethod
1✔
1326
    def _get_multiband_plot_label(band: str, freq: float) -> str:
1✔
1327
        if isinstance(band, str):
1✔
1328
            if 1e10 < float(freq) < 1e16:
1✔
1329
                label = band
1✔
1330
            else:
1331
                label = f"{freq:.2e}"
×
1332
        else:
1333
            label = f"{band:.2e}"
×
1334
        return label
1✔
1335

1336
    @property
1✔
1337
    def _filters(self) -> list[str]:
1✔
1338
        filters = self.kwargs.get("filters", self.transient.active_bands)
1✔
1339
        if 'bands_to_plot' in self.kwargs:
1✔
1340
            filters = self.kwargs['bands_to_plot']
×
1341
        if filters is None:
1✔
1342
            return self.transient.active_bands
×
1343
        elif str(filters) == 'default':
1✔
1344
            return self.transient.default_filters
×
1345
        return filters
1✔
1346

1347
    def plot_multiband_lightcurve(
1✔
1348
        self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1349
        """Plots the Magnitude multiband lightcurve and returns Axes.
1350

1351
        :param figure: Matplotlib figure to plot the data into.
1352
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1353
        :type axes: Union[matplotlib.axes.Axes, None], optional
1354
        :param save: Whether to save the plot. (Default value = True)
1355
        :type save: bool
1356
        :param show: Whether to show the plot. (Default value = True)
1357
        :type show: bool
1358

1359
        :return: The axes with the plot.
1360
        :rtype: matplotlib.axes.Axes
1361
        """
1362
        if not self._check_valid_multiband_data_mode():
1✔
1363
            return
×
1364

1365
        if figure is None or axes is None:
1✔
1366
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
1✔
1367

1368
        axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False)
1✔
1369
        times = self._get_times(axes)
1✔
1370

1371
        ii = 0
1✔
1372
        color_max = self.max_likelihood_color
1✔
1373
        color_sample = self.random_sample_color
1✔
1374
        for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies):
1✔
1375
            if band not in self._filters:
1✔
1376
                continue
×
1377
            new_model_kwargs = self._model_kwargs.copy()
1✔
1378
            new_model_kwargs['frequency'] = freq
1✔
1379
            new_model_kwargs['bands'] = redback.utils.sncosmo_bandname_from_band([band])
1✔
1380
            new_model_kwargs['bands'] = [new_model_kwargs['bands'][0] for _ in range(len(times))]
1✔
1381
            
1382
            if self.set_same_color_per_subplot is True:
1✔
1383
                color = self._colors[list(self._filters).index(band)]
1✔
1384
                if self.band_colors is not None:
1✔
1385
                    color = self.band_colors[band]
×
1386
                color_max = color
1✔
1387
                color_sample = color
1✔
1388

1389
            if self.plot_max_likelihood:
1✔
1390
                ys = self.model(times, **self._max_like_params, **new_model_kwargs)
×
1391
                axes[ii].plot(
×
1392
                    times - self._reference_mjd_date, ys, color=color_max,
1393
                    alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1394
            random_ys_list = [self.model(times, **random_params, **new_model_kwargs)
1✔
1395
                              for random_params in self._get_random_parameters()]
1396
            if self.uncertainty_mode == "random_models":
1✔
1397
                for random_ys in random_ys_list:
1✔
1398
                    axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
1✔
1399
                                  alpha=self.random_sample_alpha, lw=self.linewidth,
1400
                                  linestyle=self.random_sample_linestyle, zorder=self.zorder)
1401
            elif self.uncertainty_mode == "credible_intervals":
×
1402
                lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
1403
                axes[ii].fill_between(
×
1404
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1405
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1406
            ii += 1
1✔
1407

1408
        self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show)
1✔
1409
        return axes
1✔
1410

1411

1412
class FluxDensityPlotter(MagnitudePlotter):
1✔
1413
    pass
1✔
1414

1415
class IntegratedFluxOpticalPlotter(MagnitudePlotter):
1✔
1416
    pass
1✔
1417

1418
class SpectrumPlotter(SpecPlotter):
1✔
1419
    @property
1✔
1420
    def _xlabel(self) -> str:
1✔
1421
        return self.transient.xlabel
1✔
1422

1423
    @property
1✔
1424
    def _ylabel(self) -> str:
1✔
1425
        return self.transient.ylabel
1✔
1426

1427
    def plot_data(
1✔
1428
            self, plot_format: str = 'standard',
1429
            axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1430
        """Plots the spectrum data and returns Axes.
1431

1432
        :param plot_format: The format to plot the data in. Options are 'standard' and 'errorbar'. (Default value = 'standard')
1433
        :type plot_format: str
1434
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1435
        :type axes: Union[matplotlib.axes.Axes, None], optional
1436
        :param save: Whether to save the plot. (Default value = True)
1437
        :type save: bool
1438
        :param show: Whether to show the plot. (Default value = True)
1439
        :type show: bool
1440

1441
        :return: The axes with the plot.
1442
        :rtype: matplotlib.axes.Axes
1443
        """
1444
        ax = axes or plt.gca()
1✔
1445

1446
        if self.transient.plot_with_time_label:
1✔
1447
            label = self.transient.time
1✔
1448
        else:
1449
            label = self.transient.name
1✔
1450
        # Apply custom scales if specified, otherwise use defaults
1451
        ax.set_xscale(self.xscale if self.xscale is not None else 'linear')
1✔
1452
        if plot_format == 'standard':
1✔
1453
            ax.plot(
1✔
1454
                self.transient.angstroms,
1455
                self.transient.flux_density / 1e-17,
1456
                color=self.color,
1457
                lw=self.linewidth,
1458
                linestyle=self.linestyle,
1459
            )
1460
        else:
1461
            ax.errorbar(
×
1462
                self.transient.angstroms,
1463
                self.transient.flux_density / 1e-17,
1464
                yerr=self.transient.flux_density_err / 1e-17,
1465
                color=self.color,
1466
                fmt=self.errorbar_fmt,
1467
                ms=self.ms,
1468
                elinewidth=self.elinewidth,
1469
                capsize=self.capsize,
1470
            )
1471
        ax.set_yscale(self.yscale)
1✔
1472

1473
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
1474
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
1475
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1476
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1477

1478
        ax.annotate(
1✔
1479
            label, xy=self.xy, xycoords=self.xycoords,
1480
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1481

1482
        # Apply new customizations
1483
        self._apply_axis_customizations(ax)
1✔
1484

1485
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1486
        return ax
1✔
1487

1488
    def plot_spectrum(
1✔
1489
            self, plot_format: str = 'standard',
1490
            axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1491
        """Plots the spectrum data and the fit and returns Axes.
1492

1493
        :param plot_format: The format to plot the data in. Options are 'standard' and 'errorbar'. (Default value = 'standard')
1494
        :type plot_format: str
1495
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1496
        :type axes: Union[matplotlib.axes.Axes, None], optional
1497
        :param save: Whether to save the plot. (Default value = True)
1498
        :type save: bool
1499
        :param show: Whether to show the plot. (Default value = True)
1500
        :type show: bool
1501

1502
        :return: The axes with the plot.
1503
        :rtype: matplotlib.axes.Axes
1504
        """
1505

1506
        axes = axes or plt.gca()
1✔
1507

1508
        axes = self.plot_data(axes=axes, save=False, show=False, plot_format=plot_format)
1✔
1509
        angstroms = self._get_angstroms(axes)
1✔
1510

1511
        self._plot_spectrums(axes, angstroms)
1✔
1512

1513
        self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
1✔
1514
        return axes
1✔
1515

1516
    def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1✔
1517
        if self.plot_max_likelihood:
1✔
1518
            ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1519
            axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
1520
                      lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1521

1522
        random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
1✔
1523
                          for random_params in self._get_random_parameters()]
1524
        if self.uncertainty_mode == "random_models":
1✔
1525
            for ys in random_ys_list:
1✔
1526
                axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
1527
                          lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
1528
        elif self.uncertainty_mode == "credible_intervals":
1✔
1529
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
1✔
1530
                                                                                interval=self.credible_interval_level)
1531
            axes.fill_between(
1✔
1532
                angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1533

1534
    def plot_residuals(
1✔
1535
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1536
        """Plots the residual of the Integrated flux data returns Axes.
1537

1538
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1539
        :param save: Whether to save the plot. (Default value = True)
1540
        :param show: Whether to show the plot. (Default value = True)
1541

1542
        :return: The axes with the plot.
1543
        :rtype: matplotlib.axes.Axes
1544
        """
1545
        if axes is None:
1✔
1546
            fig, axes = plt.subplots(
1✔
1547
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1548

1549
        axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
1✔
1550
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
1551
        axes[0].set_xlabel("")
1✔
1552
        ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1553
        axes[1].errorbar(
1✔
1554
            self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1555
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
1556
            fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1557
            markeredgewidth=self.markeredgewidth)
1558
        axes[1].set_yscale('linear')
1✔
1559
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
1560

1561
        # Apply new customizations
1562
        self._apply_axis_customizations(axes[1])
1✔
1563

1564
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
1565
        return axes
1✔
1566

1567

1568
def get_plotter_kwargs_docs(plotter_class=None):
1✔
1569
    """Return a formatted string documenting all kwarg options for a Plotter class.
1570

1571
    Introspects the class for KwargsAccessorWithDefault descriptors and prints
1572
    each option with its default value. This is the canonical way to discover
1573
    what can be passed as **kwargs to transient.plot_data(), plot_multiband(),
1574
    plot_lightcurve() etc.
1575

1576
    :param plotter_class: A Plotter subclass to inspect. Defaults to Plotter.
1577
    :type plotter_class: type, optional
1578
    :return: Formatted documentation string.
1579
    :rtype: str
1580

1581
    **Example**::
1582

1583
        import redback.plotting
1584
        print(redback.plotting.get_plotter_kwargs_docs())
1585
        # or for MagnitudePlotter specifically:
1586
        print(redback.plotting.get_plotter_kwargs_docs(redback.plotting.MagnitudePlotter))
1587

1588
    """
1589
    if plotter_class is None:
1✔
1590
        plotter_class = Plotter
1✔
1591

1592
    lines = [f"Keyword arguments accepted by {plotter_class.__name__} (pass via **kwargs):",
1✔
1593
             "=" * 65]
1594

1595
    # Walk the MRO so subclass descriptors appear after base-class ones
1596
    seen = set()
1✔
1597
    for cls in reversed(plotter_class.__mro__):
1✔
1598
        for name, obj in vars(cls).items():
1✔
1599
            if isinstance(obj, KwargsAccessorWithDefault) and name not in seen:
1✔
1600
                seen.add(name)
1✔
1601
                default = repr(obj.default)
1✔
1602
                lines.append(f"  {name} (default: {default})")
1✔
1603

1604
    lines.append("")
1✔
1605
    lines.append("Pass any of these as keyword arguments, e.g.:")
1✔
1606
    lines.append("  transient.plot_data(ms=8, fontsize_axes=16, band_colors={'g': 'green'})")
1✔
1607
    return "\n".join(lines)
1✔
1608

1609

1610
def plot_binned_count_lightcurve(
1✔
1611
        binned=None, time_bins=None, counts=None, background=None, selection=None,
1612
        rate=None, error=None,
1613
        axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None,
1614
        save: bool = True, show: bool = True, color: str = "tab:blue", marker: str = "o",
1615
        markersize: float = 4.0, xscale: str = "linear", yscale: str = "linear",
1616
        min_counts: int = None, annotate_min_counts: bool = True) -> matplotlib.axes.Axes:
1617
    """
1618
    Plot count-rate light curve (counts/s vs time).
1619

1620
    Inputs (ThreeML-like):
1621
    - time_bins: bin edges
1622
    - counts: counts per bin
1623
    - background: background counts per bin (optional)
1624
    - selection: boolean mask for bins (optional)
1625
    Or provide a DataFrame via `binned` with columns:
1626
      time_start/time_end or time_center + dt, and counts.
1627
    """
1628
    ax = axes or plt.gca()
1✔
1629
    logger.info("Plotting binned lightcurve (min_counts=%s)", str(min_counts))
1✔
1630

1631
    rate_err = None
1✔
1632
    rate_in = rate
1✔
1633
    if binned is not None:
1✔
1634
        if "dt" in binned:
×
1635
            dt = binned["dt"].to_numpy()
×
1636
        else:
1637
            dt = (binned["time_end"] - binned["time_start"]).to_numpy()
×
1638
        if "time_center" in binned:
×
1639
            t = binned["time_center"].to_numpy()
×
1640
        else:
1641
            t = 0.5 * (binned["time_start"] + binned["time_end"]).to_numpy()
×
1642
        if "counts" in binned:
×
1643
            cts = binned["counts"].to_numpy()
×
1644
        else:
1645
            cts = (binned["count_rate"] * dt).to_numpy()
×
1646
        if "count_rate_err" in binned:
×
1647
            rate_err = binned["count_rate_err"].to_numpy()
×
1648
        elif "rate_err" in binned:
×
1649
            rate_err = binned["rate_err"].to_numpy()
×
1650
    else:
1651
        if time_bins is None or counts is None:
1✔
1652
            raise ValueError("Provide either `binned` or (`time_bins` and `counts`).")
×
1653
        t = 0.5 * (time_bins[:-1] + time_bins[1:])
1✔
1654
        dt = (time_bins[1:] - time_bins[:-1])
1✔
1655
        cts = counts
1✔
1656
        if rate_in is not None:
1✔
1657
            rate_in = np.asarray(rate_in, dtype=float)
1✔
1658
            if error is not None:
1✔
1659
                rate_err = np.asarray(error, dtype=float)
1✔
1660

1661
    if selection is not None:
1✔
1662
        t = t[selection]
×
1663
        dt = dt[selection]
×
1664
        cts = cts[selection]
×
1665
        if background is not None:
×
1666
            background = background[selection]
×
1667
        if rate_in is not None:
×
1668
            rate_in = rate_in[selection]
×
1669
        if rate_err is not None:
×
1670
            rate_err = rate_err[selection]
×
1671

1672
    if min_counts is not None and min_counts > 0:
1✔
1673
        grouped_t = []
×
1674
        grouped_dt = []
×
1675
        grouped_cts = []
×
1676
        grouped_bkg = [] if background is not None else None
×
1677
        grouped_err2 = [] if rate_err is not None else None
×
1678
        grouped_rate = [] if rate_in is not None else None
×
1679
        acc_cts = 0.0
×
1680
        acc_t = 0.0
×
1681
        acc_dt = 0.0
×
1682
        acc_bkg = 0.0
×
1683
        acc_err2 = 0.0
×
1684
        acc_rate = 0.0
×
1685
        for i in range(len(cts)):
×
1686
            acc_cts += float(cts[i])
×
1687
            acc_t += float(t[i]) * float(dt[i])
×
1688
            acc_dt += float(dt[i])
×
1689
            if background is not None:
×
1690
                acc_bkg += float(background[i])
×
1691
            if rate_err is not None:
×
1692
                acc_err2 += float(rate_err[i]) ** 2 * float(dt[i]) ** 2
×
1693
            if rate_in is not None:
×
1694
                acc_rate += float(rate_in[i]) * float(dt[i])
×
1695
            if acc_cts >= min_counts:
×
1696
                grouped_t.append(acc_t / acc_dt)
×
1697
                grouped_dt.append(acc_dt)
×
1698
                grouped_cts.append(acc_cts)
×
1699
                if background is not None:
×
1700
                    grouped_bkg.append(acc_bkg)
×
1701
                if rate_err is not None:
×
1702
                    grouped_err2.append(acc_err2)
×
1703
                if rate_in is not None:
×
1704
                    grouped_rate.append(acc_rate / acc_dt)
×
1705
                acc_cts = 0.0
×
1706
                acc_t = 0.0
×
1707
                acc_dt = 0.0
×
1708
                acc_bkg = 0.0
×
1709
                acc_err2 = 0.0
×
1710
                acc_rate = 0.0
×
1711
        if acc_dt > 0:
×
1712
            grouped_t.append(acc_t / acc_dt)
×
1713
            grouped_dt.append(acc_dt)
×
1714
            grouped_cts.append(acc_cts)
×
1715
            if background is not None:
×
1716
                grouped_bkg.append(acc_bkg)
×
1717
            if rate_err is not None:
×
1718
                grouped_err2.append(acc_err2)
×
1719
            if rate_in is not None:
×
1720
                grouped_rate.append(acc_rate / acc_dt)
×
1721

1722
        t = np.asarray(grouped_t, dtype=float)
×
1723
        dt = np.asarray(grouped_dt, dtype=float)
×
1724
        cts = np.asarray(grouped_cts, dtype=float)
×
1725
        if background is not None:
×
1726
            background = np.asarray(grouped_bkg, dtype=float)
×
1727
        if rate_err is not None:
×
1728
            rate_err = np.sqrt(np.asarray(grouped_err2, dtype=float)) / dt
×
1729
        if rate_in is not None:
×
1730
            rate_in = np.asarray(grouped_rate, dtype=float)
×
1731

1732
    if rate_in is None:
1✔
1733
        rate = cts / dt
×
1734
        rate_err = np.sqrt(np.maximum(cts, 0.0)) / dt
×
1735
    else:
1736
        rate = rate_in
1✔
1737
        rate_err = rate_err if rate_err is not None else np.sqrt(np.maximum(cts, 0.0)) / dt
1✔
1738

1739
    ax.errorbar(
1✔
1740
        t, rate, yerr=rate_err, fmt=marker, markersize=markersize,
1741
        color=color, elinewidth=1.0, capsize=2, label="count rate"
1742
    )
1743

1744
    if background is not None:
1✔
1745
        bkg_rate = background / dt
×
1746
        bkg_err = np.sqrt(np.maximum(background, 0.0)) / dt
×
1747
        ax.errorbar(
×
1748
            t, bkg_rate, yerr=bkg_err, fmt=marker, markersize=markersize,
1749
            color="0.5", elinewidth=1.0, capsize=2, label="background"
1750
        )
1751

1752
    ax.set_xscale(xscale)
1✔
1753
    ax.set_yscale(yscale)
1✔
1754
    ax.set_xlabel("Time (s)")
1✔
1755
    ax.set_ylabel("Counts/s")
1✔
1756
    ax.legend()
1✔
1757

1758
    if annotate_min_counts and min_counts is not None:
1✔
1759
        text = f"min counts/bin: {min_counts}"
×
1760
        ax.text(
×
1761
            0.02, 0.98, text,
1762
            transform=ax.transAxes,
1763
            ha="left", va="top",
1764
            fontsize=9,
1765
            bbox=dict(boxstyle="round,pad=0.2", facecolor="white", alpha=0.6, edgecolor="none")
1766
        )
1767

1768
    if save and filename is not None:
1✔
1769
        path = filename if outdir is None else f"{outdir}/{filename}"
×
1770
        plt.savefig(path, dpi=150, bbox_inches="tight")
×
1771
    if show:
1✔
1772
        plt.show()
×
1773
    return ax
1✔
1774

1775

1776
def plot_spectrum_data(dataset, **kwargs):
1✔
1777
    """
1778
    Wrapper for SpectralDataset.plot_spectrum_data to keep plotting API consistent.
1779
    """
1780
    return dataset.plot_spectrum_data(**kwargs)
×
1781

1782

1783
def plot_spectrum_fit(dataset, **kwargs):
1✔
1784
    """
1785
    Wrapper for SpectralDataset.plot_spectrum_fit to keep plotting API consistent.
1786
    """
1787
    return dataset.plot_spectrum_fit(**kwargs)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc