• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 19598860677

22 Nov 2025 05:27PM UTC coverage: 87.099% (+0.2%) from 86.869%
19598860677

Pull #316

github

web-flow
Merge 865af20b2 into 11d71ec81
Pull Request #316: [WIP] - Add comprehensive customization options to plotter classes

118 of 123 new or added lines in 1 file covered. (95.93%)

6 existing lines in 2 files now uncovered.

10755 of 12348 relevant lines covered (87.1%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.12
/redback/plotting.py
1
from __future__ import annotations
1✔
2

3
from os.path import join
1✔
4
from typing import Any, Union
1✔
5

6
import matplotlib
1✔
7
import matplotlib.pyplot as plt
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10

11
import redback
1✔
12
from redback.utils import KwargsAccessorWithDefault
1✔
13

14
class _FilenameGetter(object):
1✔
15
    def __init__(self, suffix: str) -> None:
1✔
16
        self.suffix = suffix
1✔
17

18
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
19
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")
1✔
20

21
    def __set__(self, instance: Plotter, value: object) -> None:
1✔
22
        pass
1✔
23

24

25
class _FilePathGetter(object):
1✔
26

27
    def __init__(self, directory_property: str, filename_property: str) -> None:
1✔
28
        self.directory_property = directory_property
1✔
29
        self.filename_property = filename_property
1✔
30

31
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
32
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))
1✔
33

34

35
class Plotter(object):
1✔
36
    """
37
    Base class for all lightcurve plotting classes in redback.
38
    """
39

40
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
41
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
42
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
43
    band_colors = KwargsAccessorWithDefault("band_colors", None)
1✔
44
    color = KwargsAccessorWithDefault("color", "k")
1✔
45
    band_labels = KwargsAccessorWithDefault("band_labels", None)
1✔
46
    band_scaling = KwargsAccessorWithDefault("band_scaling", {})
1✔
47
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
48
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
49
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "o")
1✔
50
    model = KwargsAccessorWithDefault("model", None)
1✔
51
    ms = KwargsAccessorWithDefault("ms", 5)
1✔
52
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
53

54
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
55
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
56
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
57
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
58
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
59

60
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
61
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
62
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
63

64
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
65
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
66
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
67
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
68

69
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
70
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
71
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
72
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
73
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
74
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
75

76
    plot_others = KwargsAccessorWithDefault("plot_others", True)
1✔
77
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
78
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
79
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
80
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
81
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
82

83
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0)
1✔
84
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5)
1✔
85
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0)
1✔
86
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
87

88
    # Grid options
89
    show_grid = KwargsAccessorWithDefault("show_grid", False)
1✔
90
    grid_alpha = KwargsAccessorWithDefault("grid_alpha", 0.3)
1✔
91
    grid_color = KwargsAccessorWithDefault("grid_color", "gray")
1✔
92
    grid_linestyle = KwargsAccessorWithDefault("grid_linestyle", "--")
1✔
93
    grid_linewidth = KwargsAccessorWithDefault("grid_linewidth", 0.5)
1✔
94

95
    # Save format and transparency
96
    save_format = KwargsAccessorWithDefault("save_format", "png")
1✔
97
    transparent = KwargsAccessorWithDefault("transparent", False)
1✔
98

99
    # Axis scale options
100
    xscale = KwargsAccessorWithDefault("xscale", None)
1✔
101
    yscale = KwargsAccessorWithDefault("yscale", None)
1✔
102

103
    # Title options
104
    title = KwargsAccessorWithDefault("title", None)
1✔
105
    title_fontsize = KwargsAccessorWithDefault("title_fontsize", 20)
1✔
106

107
    # Line style options
108
    linestyle = KwargsAccessorWithDefault("linestyle", "-")
1✔
109
    max_likelihood_linestyle = KwargsAccessorWithDefault("max_likelihood_linestyle", "-")
1✔
110
    random_sample_linestyle = KwargsAccessorWithDefault("random_sample_linestyle", "-")
1✔
111

112
    # Marker options
113
    markerfillstyle = KwargsAccessorWithDefault("markerfillstyle", "full")
1✔
114
    markeredgecolor = KwargsAccessorWithDefault("markeredgecolor", None)
1✔
115
    markeredgewidth = KwargsAccessorWithDefault("markeredgewidth", 1.0)
1✔
116

117
    # Legend customization
118
    legend_frameon = KwargsAccessorWithDefault("legend_frameon", True)
1✔
119
    legend_shadow = KwargsAccessorWithDefault("legend_shadow", False)
1✔
120
    legend_fancybox = KwargsAccessorWithDefault("legend_fancybox", True)
1✔
121
    legend_framealpha = KwargsAccessorWithDefault("legend_framealpha", 0.8)
1✔
122

123
    # Tick customization
124
    tick_direction = KwargsAccessorWithDefault("tick_direction", "in")
1✔
125
    tick_length = KwargsAccessorWithDefault("tick_length", None)
1✔
126
    tick_width = KwargsAccessorWithDefault("tick_width", None)
1✔
127

128
    # Spine options
129
    show_spines = KwargsAccessorWithDefault("show_spines", True)
1✔
130
    spine_linewidth = KwargsAccessorWithDefault("spine_linewidth", None)
1✔
131

132
    def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None:
1✔
133
        """
134
        :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted.
135
        :param kwargs: Additional kwargs the plotter uses. -------
136
        :keyword capsize: Same as matplotlib capsize.
137
        :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands.
138
        :keyword legend_location: Same as matplotlib legend location.
139
        :keyword legend_cols: Same as matplotlib legend columns.
140
        :keyword color: Color of the data points.
141
        :keyword band_colors: A dictionary with the colors of the bands.
142
        :keyword band_labels: List with the names of the bands.
143
        :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types.
144
        :keyword dpi: Same as matplotlib dpi.
145
        :keyword elinewidth: same as matplotlib elinewidth
146
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
147
        :keyword model: str or callable, the model to plot.
148
        :keyword ms: Same as matplotlib markersize.
149
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
150
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
151
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
152
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
153
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
154
        :keyword random_sample_color: Color of the random sample curves.
155
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
156
        :keyword linewidth: Same as matplotlib linewidth
157
        :keyword zorder: Same as matplotlib zorder
158
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
159
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
160
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
161
        :keyword annotation_size: `size` argument of of `ax.annotate`.
162
        :keyword fontsize_axes: Font size of the x and y labels.
163
        :keyword fontsize_legend: Font size of the legend.
164
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
165
                                  Used on `supxlabel` and `supylabel`.
166
        :keyword fontsize_ticks: Font size of the axis ticks.
167
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
168
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
169
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
170
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
171
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
172
                                   'credible_intervals': Plot a credible interval that is calculated based
173
                                   on the available parameter sets.
174
        :keyword reference_mjd_date: Date to use as reference point for the x axis.
175
                                    Default is the first date in the data.
176
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
177
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
178
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
179
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
180
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
181
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
182
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
183
        :keyword show_grid: Whether to show grid lines. Default is False.
184
        :keyword grid_alpha: Transparency of grid lines. Default is 0.3.
185
        :keyword grid_color: Color of grid lines. Default is 'gray'.
186
        :keyword grid_linestyle: Line style of grid lines. Default is '--'.
187
        :keyword grid_linewidth: Line width of grid lines. Default is 0.5.
188
        :keyword save_format: Format for saving plots (e.g., 'png', 'pdf', 'svg', 'eps'). Default is 'png'.
189
        :keyword transparent: Whether to save plots with transparent background. Default is False.
190
        :keyword xscale: X-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
191
        :keyword yscale: Y-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
192
        :keyword title: Title for the plot. Default is None (no title).
193
        :keyword title_fontsize: Font size for the title. Default is 20.
194
        :keyword linestyle: Line style for model curves. Default is '-'.
195
        :keyword max_likelihood_linestyle: Line style for max likelihood curve. Default is '-'.
196
        :keyword random_sample_linestyle: Line style for random sample curves. Default is '-'.
197
        :keyword markerfillstyle: Fill style for markers ('full', 'left', 'right', 'bottom', 'top', 'none'). Default is 'full'.
198
        :keyword markeredgecolor: Edge color for markers. Default is None (same as face color).
199
        :keyword markeredgewidth: Edge width for markers. Default is 1.0.
200
        :keyword legend_frameon: Whether to draw a frame around the legend. Default is True.
201
        :keyword legend_shadow: Whether to draw a shadow behind the legend. Default is False.
202
        :keyword legend_fancybox: Whether to use rounded corners for legend frame. Default is True.
203
        :keyword legend_framealpha: Transparency of legend frame. Default is 0.8.
204
        :keyword tick_direction: Direction of tick marks ('in', 'out', 'inout'). Default is 'in'.
205
        :keyword tick_length: Length of tick marks. Default is None (matplotlib default).
206
        :keyword tick_width: Width of tick marks. Default is None (matplotlib default).
207
        :keyword show_spines: Whether to show plot spines (borders). Default is True.
208
        :keyword spine_linewidth: Width of plot spines. Default is None (matplotlib default).
209
        """
210
        self.transient = transient
1✔
211
        self.kwargs = kwargs or dict()
1✔
212
        self._posterior_sorted = False
1✔
213

214
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
215

216
    def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
217
        """
218
        :param axes: The axes used in the plotting procedure.
219
        :type axes: matplotlib.axes.Axes
220

221
        :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot.
222
        :rtype: np.ndarray
223
        """
224
        if isinstance(axes, np.ndarray):
1✔
225
            ax = axes[0]
1✔
226
        else:
227
            ax = axes
1✔
228

229
        if ax.get_yscale() == 'linear':
1✔
230
            times = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
231
        else:
232
            times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
233

234
        if self.transient.use_phase_model:
1✔
235
            times = times + self._reference_mjd_date
1✔
236
        return times
1✔
237

238
    @property
1✔
239
    def _xlim_low(self) -> float:
1✔
240
        default = self.xlim_low_multiplier * self.transient.x[0]
1✔
241
        if default == 0:
1✔
242
            default += 1e-3
×
243
        return self.kwargs.get("xlim_low", default)
1✔
244

245
    @property
1✔
246
    def _xlim_high(self) -> float:
1✔
247
        if self._x_err is None:
1✔
248
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
249
        else:
250
            default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1])
×
251
        return self.kwargs.get("xlim_high", default)
1✔
252

253
    @property
1✔
254
    def _ylim_low(self) -> float:
1✔
255
        default = self.ylim_low_multiplier * min(self.transient.y)
1✔
256
        return self.kwargs.get("ylim_low", default)
1✔
257

258
    @property
1✔
259
    def _ylim_high(self) -> float:
1✔
260
        default = self.ylim_high_multiplier * np.max(self.transient.y)
1✔
261
        return self.kwargs.get("ylim_high", default)
1✔
262

263
    @property
1✔
264
    def _x_err(self) -> Union[np.ndarray, None]:
1✔
265
        if self.transient.x_err is not None:
1✔
266
            return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]])
×
267
        else:
268
            return None
1✔
269

270
    @property
1✔
271
    def _y_err(self) -> np.ndarray:
1✔
272
        if self.transient.y_err.ndim > 1.:
1✔
273
            return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]])
×
274
        else:
275
            return np.array([np.abs(self.transient.y_err)])
1✔
276
    @property
1✔
277
    def _lightcurve_plot_outdir(self) -> str:
1✔
278
        return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__))
1✔
279

280
    @property
1✔
281
    def _data_plot_outdir(self) -> str:
1✔
282
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
283

284
    def _get_outdir(self, default: str) -> str:
1✔
285
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
286

287
    def get_filename(self, default: str) -> str:
1✔
288
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
289

290
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
291
        return self.kwargs.get(kwarg, default) or default
1✔
292

293
    @property
1✔
294
    def _model_kwargs(self) -> dict:
1✔
295
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
296

297
    @property
1✔
298
    def _posterior(self) -> pd.DataFrame:
1✔
299
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
300
        if not self._posterior_sorted and posterior is not None:
1✔
301
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
302
            self._posterior_sorted = True
1✔
303
        return posterior
1✔
304

305
    @property
1✔
306
    def _max_like_params(self) -> pd.core.series.Series:
1✔
307
        return self._posterior.iloc[-1]
1✔
308

309
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
310
        integers = np.arange(len(self._posterior))
1✔
311
        indices = np.random.choice(integers, size=self.random_models)
1✔
312
        return [self._posterior.iloc[idx] for idx in indices]
1✔
313

314
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
315
    _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve")
1✔
316
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
317
    _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data")
1✔
318
    _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve")
1✔
319

320
    _data_plot_filepath = _FilePathGetter(
1✔
321
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
322
    _lightcurve_plot_filepath = _FilePathGetter(
1✔
323
        directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename")
324
    _residual_plot_filepath = _FilePathGetter(
1✔
325
        directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename")
326
    _multiband_data_plot_filepath = _FilePathGetter(
1✔
327
        directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename")
328
    _multiband_lightcurve_plot_filepath = _FilePathGetter(
1✔
329
        directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename")
330

331
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
332
        plt.tight_layout()
1✔
333
        if save:
1✔
334
            # Update filepath extension if save_format is specified
335
            if '.' in filepath:
1✔
336
                filepath = filepath.rsplit('.', 1)[0] + f'.{self.save_format}'
1✔
337
            else:
338
                filepath = f'{filepath}.{self.save_format}'
1✔
339

340
            facecolor = 'none' if self.transparent else 'white'
1✔
341
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches,
1✔
342
                       transparent=self.transparent, facecolor=facecolor)
343
        if show:
1✔
344
            plt.show()
1✔
345

346
    def _apply_axis_customizations(self, ax: matplotlib.axes.Axes) -> None:
1✔
347
        """Apply common axis customizations like grid, title, ticks, and spines."""
348
        # Grid
349
        if self.show_grid:
1✔
350
            ax.grid(True, alpha=self.grid_alpha, color=self.grid_color,
1✔
351
                   linestyle=self.grid_linestyle, linewidth=self.grid_linewidth)
352

353
        # Title
354
        if self.title is not None:
1✔
355
            ax.set_title(self.title, fontsize=self.title_fontsize)
1✔
356

357
        # Tick customization
358
        tick_params = {'axis': 'both', 'which': 'both',
1✔
359
                      'pad': self.axis_tick_params_pad,
360
                      'labelsize': self.fontsize_ticks,
361
                      'direction': self.tick_direction}
362
        if self.tick_length is not None:
1✔
363
            tick_params['length'] = self.tick_length
1✔
364
        if self.tick_width is not None:
1✔
365
            tick_params['width'] = self.tick_width
1✔
366
        ax.tick_params(**tick_params)
1✔
367

368
        # Spine customization
369
        if not self.show_spines:
1✔
370
            for spine in ax.spines.values():
1✔
371
                spine.set_visible(False)
1✔
372
        elif self.spine_linewidth is not None:
1✔
373
            for spine in ax.spines.values():
1✔
374
                spine.set_linewidth(self.spine_linewidth)
1✔
375

376
class SpecPlotter(object):
1✔
377
    """
378
    Base class for all lightcurve plotting classes in redback.
379
    """
380

381
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
382
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
383
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
384
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
385
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
386
    color = KwargsAccessorWithDefault("color", "k")
1✔
387
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
388
    model = KwargsAccessorWithDefault("model", None)
1✔
389
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
390
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
391

392
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
393
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
394
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
395
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
396
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
397

398
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
399
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
400
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
401
    yscale = KwargsAccessorWithDefault("yscale", "linear")
1✔
402

403
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
404
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
405
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
406
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
407

408
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
409
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
410
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
411
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
412
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
413
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
414

415
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
416
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
417
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
418
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
419
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
420

421
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
1✔
422
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
423
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
1✔
424
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
425

426
    # Grid options
427
    show_grid = KwargsAccessorWithDefault("show_grid", False)
1✔
428
    grid_alpha = KwargsAccessorWithDefault("grid_alpha", 0.3)
1✔
429
    grid_color = KwargsAccessorWithDefault("grid_color", "gray")
1✔
430
    grid_linestyle = KwargsAccessorWithDefault("grid_linestyle", "--")
1✔
431
    grid_linewidth = KwargsAccessorWithDefault("grid_linewidth", 0.5)
1✔
432

433
    # Save format and transparency
434
    save_format = KwargsAccessorWithDefault("save_format", "png")
1✔
435
    transparent = KwargsAccessorWithDefault("transparent", False)
1✔
436

437
    # Axis scale options (xscale can be customized too)
438
    xscale = KwargsAccessorWithDefault("xscale", None)
1✔
439

440
    # Title options
441
    title = KwargsAccessorWithDefault("title", None)
1✔
442
    title_fontsize = KwargsAccessorWithDefault("title_fontsize", 20)
1✔
443

444
    # Line style options
445
    linestyle = KwargsAccessorWithDefault("linestyle", "-")
1✔
446
    max_likelihood_linestyle = KwargsAccessorWithDefault("max_likelihood_linestyle", "-")
1✔
447
    random_sample_linestyle = KwargsAccessorWithDefault("random_sample_linestyle", "-")
1✔
448

449
    # Marker options
450
    markerfillstyle = KwargsAccessorWithDefault("markerfillstyle", "full")
1✔
451
    markeredgecolor = KwargsAccessorWithDefault("markeredgecolor", None)
1✔
452
    markeredgewidth = KwargsAccessorWithDefault("markeredgewidth", 1.0)
1✔
453

454
    # Legend customization
455
    legend_frameon = KwargsAccessorWithDefault("legend_frameon", True)
1✔
456
    legend_shadow = KwargsAccessorWithDefault("legend_shadow", False)
1✔
457
    legend_fancybox = KwargsAccessorWithDefault("legend_fancybox", True)
1✔
458
    legend_framealpha = KwargsAccessorWithDefault("legend_framealpha", 0.8)
1✔
459

460
    # Tick customization
461
    tick_direction = KwargsAccessorWithDefault("tick_direction", "in")
1✔
462
    tick_length = KwargsAccessorWithDefault("tick_length", None)
1✔
463
    tick_width = KwargsAccessorWithDefault("tick_width", None)
1✔
464

465
    # Spine options
466
    show_spines = KwargsAccessorWithDefault("show_spines", True)
1✔
467
    spine_linewidth = KwargsAccessorWithDefault("spine_linewidth", None)
1✔
468

469
    def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
1✔
470
        """
471
        :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
472
        :param kwargs: Additional kwargs the plotter uses. -------
473
        :keyword capsize: Same as matplotlib capsize.
474
        :keyword elinewidth: same as matplotlib elinewidth
475
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
476
        :keyword ms: Same as matplotlib markersize.
477
        :keyword legend_location: Same as matplotlib legend location.
478
        :keyword legend_cols: Same as matplotlib legend columns.
479
        :keyword color: Color of the data points.
480
        :keyword dpi: Same as matplotlib dpi.
481
        :keyword model: str or callable, the model to plot.
482
        :keyword ms: Same as matplotlib markersize.
483
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
484
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
485
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
486
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
487
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
488
        :keyword random_sample_color: Color of the random sample curves.
489
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
490
        :keyword linewidth: Same as matplotlib linewidth
491
        :keyword zorder: Same as matplotlib zorder
492
        :keyword yscale: Same as matplotlib yscale, default is linear
493
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
494
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
495
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
496
        :keyword annotation_size: `size` argument of of `ax.annotate`.
497
        :keyword fontsize_axes: Font size of the x and y labels.
498
        :keyword fontsize_legend: Font size of the legend.
499
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
500
                                  Used on `supxlabel` and `supylabel`.
501
        :keyword fontsize_ticks: Font size of the axis ticks.
502
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
503
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
504
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
505
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
506
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
507
                                   'credible_intervals': Plot a credible interval that is calculated based
508
                                   on the available parameter sets.
509
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
510
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
511
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
512
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
513
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
514
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
515
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
516
        :keyword show_grid: Whether to show grid lines. Default is False.
517
        :keyword grid_alpha: Transparency of grid lines. Default is 0.3.
518
        :keyword grid_color: Color of grid lines. Default is 'gray'.
519
        :keyword grid_linestyle: Line style of grid lines. Default is '--'.
520
        :keyword grid_linewidth: Line width of grid lines. Default is 0.5.
521
        :keyword save_format: Format for saving plots (e.g., 'png', 'pdf', 'svg', 'eps'). Default is 'png'.
522
        :keyword transparent: Whether to save plots with transparent background. Default is False.
523
        :keyword xscale: X-axis scale ('linear', 'log', 'symlog', 'logit'). Default is None (auto-determined).
524
        :keyword title: Title for the plot. Default is None (no title).
525
        :keyword title_fontsize: Font size for the title. Default is 20.
526
        :keyword linestyle: Line style for model curves. Default is '-'.
527
        :keyword max_likelihood_linestyle: Line style for max likelihood curve. Default is '-'.
528
        :keyword random_sample_linestyle: Line style for random sample curves. Default is '-'.
529
        :keyword markerfillstyle: Fill style for markers ('full', 'left', 'right', 'bottom', 'top', 'none'). Default is 'full'.
530
        :keyword markeredgecolor: Edge color for markers. Default is None (same as face color).
531
        :keyword markeredgewidth: Edge width for markers. Default is 1.0.
532
        :keyword legend_frameon: Whether to draw a frame around the legend. Default is True.
533
        :keyword legend_shadow: Whether to draw a shadow behind the legend. Default is False.
534
        :keyword legend_fancybox: Whether to use rounded corners for legend frame. Default is True.
535
        :keyword legend_framealpha: Transparency of legend frame. Default is 0.8.
536
        :keyword tick_direction: Direction of tick marks ('in', 'out', 'inout'). Default is 'in'.
537
        :keyword tick_length: Length of tick marks. Default is None (matplotlib default).
538
        :keyword tick_width: Width of tick marks. Default is None (matplotlib default).
539
        :keyword show_spines: Whether to show plot spines (borders). Default is True.
540
        :keyword spine_linewidth: Width of plot spines. Default is None (matplotlib default).
541
        """
542
        self.transient = spectrum
1✔
543
        self.kwargs = kwargs or dict()
1✔
544
        self._posterior_sorted = False
1✔
545

546
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
547

548
    def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
549
        """
550
        :param axes: The axes used in the plotting procedure.
551
        :type axes: matplotlib.axes.Axes
552

553
        :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
554
        :rtype: np.ndarray
555
        """
556
        if isinstance(axes, np.ndarray):
1✔
557
            ax = axes[0]
×
558
        else:
559
            ax = axes
1✔
560

561
        if ax.get_yscale() == 'linear':
1✔
562
            angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
563
        else:
564
            angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
565

566
        return angstroms
1✔
567

568
    @property
1✔
569
    def _xlim_low(self) -> float:
1✔
570
        default = self.xlim_low_multiplier * self.transient.angstroms[0]
1✔
571
        if default == 0:
1✔
572
            default += 1e-3
×
573
        return self.kwargs.get("xlim_low", default)
1✔
574

575
    @property
1✔
576
    def _xlim_high(self) -> float:
1✔
577
        default = self.xlim_high_multiplier * self.transient.angstroms[-1]
1✔
578
        return self.kwargs.get("xlim_high", default)
1✔
579

580
    @property
1✔
581
    def _ylim_low(self) -> float:
1✔
582
        default = self.ylim_low_multiplier * min(self.transient.flux_density)
1✔
583
        return self.kwargs.get("ylim_low", default/1e-17)
1✔
584

585
    @property
1✔
586
    def _ylim_high(self) -> float:
1✔
587
        default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
1✔
588
        return self.kwargs.get("ylim_high", default/1e-17)
1✔
589

590
    @property
1✔
591
    def _y_err(self) -> np.ndarray:
1✔
592
        return np.array([np.abs(self.transient.flux_density_err)])
1✔
593

594
    @property
1✔
595
    def _data_plot_outdir(self) -> str:
1✔
596
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
597

598
    def _get_outdir(self, default: str) -> str:
1✔
599
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
600

601
    def get_filename(self, default: str) -> str:
1✔
602
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
603

604
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
605
        return self.kwargs.get(kwarg, default) or default
1✔
606

607
    @property
1✔
608
    def _model_kwargs(self) -> dict:
1✔
609
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
610

611
    @property
1✔
612
    def _posterior(self) -> pd.DataFrame:
1✔
613
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
614
        if not self._posterior_sorted and posterior is not None:
1✔
615
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
616
            self._posterior_sorted = True
1✔
617
        return posterior
1✔
618

619
    @property
1✔
620
    def _max_like_params(self) -> pd.core.series.Series:
1✔
621
        return self._posterior.iloc[-1]
1✔
622

623
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
624
        integers = np.arange(len(self._posterior))
1✔
625
        indices = np.random.choice(integers, size=self.random_models)
1✔
626
        return [self._posterior.iloc[idx] for idx in indices]
1✔
627

628
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
629
    _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
1✔
630
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
631

632
    _data_plot_filepath = _FilePathGetter(
1✔
633
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
634
    _spectrum_ppd_plot_filepath = _FilePathGetter(
1✔
635
        directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
636
    _residual_plot_filepath = _FilePathGetter(
1✔
637
        directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
638

639
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
640
        plt.tight_layout()
1✔
641
        if save:
1✔
642
            # Update filepath extension if save_format is specified
643
            if '.' in filepath:
1✔
644
                filepath = filepath.rsplit('.', 1)[0] + f'.{self.save_format}'
1✔
645
            else:
646
                filepath = f'{filepath}.{self.save_format}'
1✔
647

648
            facecolor = 'none' if self.transparent else 'white'
1✔
649
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches,
1✔
650
                       transparent=self.transparent, facecolor=facecolor)
651
        if show:
1✔
652
            plt.show()
1✔
653

654
    def _apply_axis_customizations(self, ax: matplotlib.axes.Axes) -> None:
1✔
655
        """Apply common axis customizations like grid, title, ticks, and spines."""
656
        # Grid
657
        if self.show_grid:
1✔
658
            ax.grid(True, alpha=self.grid_alpha, color=self.grid_color,
1✔
659
                   linestyle=self.grid_linestyle, linewidth=self.grid_linewidth)
660

661
        # Title
662
        if self.title is not None:
1✔
663
            ax.set_title(self.title, fontsize=self.title_fontsize)
1✔
664

665
        # Tick customization
666
        tick_params = {'axis': 'both', 'which': 'both',
1✔
667
                      'pad': self.axis_tick_params_pad,
668
                      'labelsize': self.fontsize_ticks,
669
                      'direction': self.tick_direction}
670
        if self.tick_length is not None:
1✔
671
            tick_params['length'] = self.tick_length
1✔
672
        if self.tick_width is not None:
1✔
673
            tick_params['width'] = self.tick_width
1✔
674
        ax.tick_params(**tick_params)
1✔
675

676
        # Spine customization
677
        if not self.show_spines:
1✔
678
            for spine in ax.spines.values():
1✔
679
                spine.set_visible(False)
1✔
680
        elif self.spine_linewidth is not None:
1✔
681
            for spine in ax.spines.values():
1✔
682
                spine.set_linewidth(self.spine_linewidth)
1✔
683

684

685
class IntegratedFluxPlotter(Plotter):
1✔
686

687
    @property
1✔
688
    def _xlabel(self) -> str:
1✔
689
        return r"Time since burst [s]"
1✔
690

691
    @property
1✔
692
    def _ylabel(self) -> str:
1✔
693
        return self.transient.ylabel
1✔
694

695
    def plot_data(
1✔
696
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
697
        """Plots the Integrated flux data and returns Axes.
698

699
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
700
        :type axes: Union[matplotlib.axes.Axes, None], optional
701
        :param save: Whether to save the plot. (Default value = True)
702
        :type save: bool
703
        :param show: Whether to show the plot. (Default value = True)
704
        :type show: bool
705

706
        :return: The axes with the plot.
707
        :rtype: matplotlib.axes.Axes
708
        """
709
        ax = axes or plt.gca()
1✔
710

711
        ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err,
1✔
712
                    fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
713
                    fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
714
                    markeredgewidth=self.markeredgewidth)
715

716
        # Apply custom scales if specified, otherwise use defaults
717
        ax.set_xscale(self.xscale if self.xscale is not None else 'log')
1✔
718
        ax.set_yscale(self.yscale if self.yscale is not None else 'log')
1✔
719

720
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
721
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
722
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
723
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
724

725
        ax.annotate(
1✔
726
            self.transient.name, xy=self.xy, xycoords=self.xycoords,
727
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
728

729
        # Apply new customizations
730
        self._apply_axis_customizations(ax)
1✔
731

732
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
733
        return ax
1✔
734

735
    def plot_lightcurve(
1✔
736
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
737
        """Plots the Integrated flux data and the lightcurve and returns Axes.
738

739
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
740
        :type axes: Union[matplotlib.axes.Axes, None], optional
741
        :param save: Whether to save the plot. (Default value = True)
742
        :type save: bool
743
        :param show: Whether to show the plot. (Default value = True)
744
        :type show: bool
745

746
        :return: The axes with the plot.
747
        :rtype: matplotlib.axes.Axes
748
        """
749
        
750
        axes = axes or plt.gca()
1✔
751

752
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
753
        times = self._get_times(axes)
1✔
754

755
        self._plot_lightcurves(axes, times)
1✔
756

757
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
758
        return axes
1✔
759

760
    def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None:
1✔
761
        if self.plot_max_likelihood:
1✔
762
            ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
763
            axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
764
                     lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
765

766
        random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
767
                          for random_params in self._get_random_parameters()]
768
        if self.uncertainty_mode == "random_models":
1✔
769
            for ys in random_ys_list:
1✔
770
                axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
771
                         lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
772
        elif self.uncertainty_mode == "credible_intervals":
1✔
773
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
1✔
774
            axes.fill_between(
1✔
775
                times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
776

777
    def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None:
1✔
778
        ys = self.model(times, **params, **self._model_kwargs)
×
NEW
779
        axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha,
×
780
                 lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
781

782
    def plot_residuals(
1✔
783
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
784
        """Plots the residual of the Integrated flux data returns Axes.
785

786
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
787
        :param save: Whether to save the plot. (Default value = True)
788
        :param show: Whether to show the plot. (Default value = True)
789

790
        :return: The axes with the plot.
791
        :rtype: matplotlib.axes.Axes
792
        """
793
        if axes is None:
1✔
794
            fig, axes = plt.subplots(
1✔
795
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
796

797
        axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False)
1✔
798
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
799
        axes[0].set_xlabel("")
1✔
800
        ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs)
1✔
801
        axes[1].errorbar(
1✔
802
            self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err,
803
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
804
            fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
805
            markeredgewidth=self.markeredgewidth)
806
        axes[1].set_yscale("log")
1✔
807
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
808

809
        # Apply new customizations
810
        self._apply_axis_customizations(axes[1])
1✔
811

812
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
813
        return axes
1✔
814

815

816
class LuminosityOpticalPlotter(IntegratedFluxPlotter):
1✔
817

818
    @property
1✔
819
    def _xlabel(self) -> str:
1✔
820
        return r"Time since explosion [days]"
1✔
821

822
    @property
1✔
823
    def _ylabel(self) -> str:
1✔
824
        return r"L$_{\rm bol}$ [$10^{50}$ erg s$^{-1}$]"
1✔
825

826
class LuminosityPlotter(IntegratedFluxPlotter):
1✔
827
    pass
1✔
828

829

830
class MagnitudePlotter(Plotter):
1✔
831

832
    xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
833
    xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
1✔
834
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
1✔
835
    ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.95)
1✔
836
    ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.05)
1✔
837
    ncols = KwargsAccessorWithDefault("ncols", 2)
1✔
838

839
    @property
1✔
840
    def _colors(self) -> str:
1✔
841
        return self.kwargs.get("colors", self.transient.get_colors(self._filters))
1✔
842

843
    @property
1✔
844
    def _xlabel(self) -> str:
1✔
845
        if self.transient.use_phase_model:
1✔
846
            default = f"Time since {self._reference_mjd_date} MJD [days]"
1✔
847
        else:
848
            default = self.transient.xlabel
1✔
849
        return self.kwargs.get("xlabel", default)
1✔
850

851
    @property
1✔
852
    def _ylabel(self) -> str:
1✔
853
        return self.kwargs.get("ylabel", self.transient.ylabel)
1✔
854

855
    @property
1✔
856
    def _get_bands_to_plot(self) -> list[str]:
1✔
857
        return self.kwargs.get("bands_to_plot", self.transient.active_bands)
1✔
858

859
    @property
1✔
860
    def _xlim_low(self) -> float:
1✔
861
        if self.transient.use_phase_model:
1✔
862
            default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier
1✔
863
        else:
864
            default = self.xlim_low_multiplier * self.transient.x[0]
1✔
865
        if default == 0:
1✔
866
            default += 1e-3
1✔
867
        return self.kwargs.get("xlim_low", default)
1✔
868

869
    @property
1✔
870
    def _xlim_high(self) -> float:
1✔
871
        if self.transient.use_phase_model:
1✔
872
            default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier
1✔
873
        else:
874
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
875
        return self.kwargs.get("xlim_high", default)
1✔
876

877
    @property
1✔
878
    def _ylim_low_magnitude(self) -> float:
1✔
879
        return self.ylim_low_magnitude_multiplier * min(self.transient.y)
1✔
880

881
    @property
1✔
882
    def _ylim_high_magnitude(self) -> float:
1✔
883
        return self.ylim_high_magnitude_multiplier * np.max(self.transient.y)
1✔
884

885
    def _get_ylim_low_with_indices(self, indices: list) -> float:
1✔
886
        return self.ylim_low_multiplier * min(self.transient.y[indices])
1✔
887

888
    def _get_ylim_high_with_indices(self, indices: list) -> float:
1✔
889
        return self.ylim_high_multiplier * np.max(self.transient.y[indices])
1✔
890

891
    def _get_x_err(self, indices: list) -> np.ndarray:
1✔
892
        return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err
1✔
893

894
    def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None:
1✔
895
        if self.transient.magnitude_data:
1✔
896
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
897
            ax.invert_yaxis()
1✔
898
            # Apply custom yscale if specified, otherwise use default
899
            ax.set_yscale(self.yscale if self.yscale is not None else 'linear')
1✔
900
        else:
901
            ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
902
            # Apply custom yscale if specified, otherwise use default
903
            ax.set_yscale(self.yscale if self.yscale is not None else "log")
1✔
904

905
    def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None:
1✔
906
        if self.transient.magnitude_data:
1✔
907
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
908
            ax.invert_yaxis()
1✔
909
            # Apply custom yscale if specified, otherwise use default
910
            ax.set_yscale(self.yscale if self.yscale is not None else 'linear')
1✔
911
        else:
912
            ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
1✔
913
                        self._get_ylim_high_with_indices(indices=indices))
914
            # Apply custom yscale if specified, otherwise use default
915
            ax.set_yscale(self.yscale if self.yscale is not None else "log")
1✔
916

917
    def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None:
1✔
918
        # Apply custom xscale if specified, otherwise use default behavior
919
        if self.xscale is not None:
1✔
920
            axes.set_xscale(self.xscale)
1✔
921
        elif self.transient.use_phase_model:
1✔
922
            axes.set_xscale("linear")  # Keep master's default for phase model
1✔
923
        axes.set_xlim(self._xlim_low, self._xlim_high)
1✔
924

925
    @property
1✔
926
    def _nrows(self) -> int:
1✔
927
        default = int(np.ceil(len(self._filters) / 2))
1✔
928
        return self._get_kwarg_with_default("nrows", default=default)
1✔
929

930
    @property
1✔
931
    def _npanels(self) -> int:
1✔
932
        npanels = self._nrows * self.ncols
×
933
        if npanels < len(self._filters):
×
934
            raise ValueError(f"Insufficient number of panels. {npanels} panels were given "
×
935
                             f"but {len(self._filters)} panels are needed.")
936
        return npanels
×
937

938
    @property
1✔
939
    def _figsize(self) -> tuple:
1✔
940
        default = (4 + 4 * self.ncols, 2 + 2 * self._nrows)
1✔
941
        return self._get_kwarg_with_default("figsize", default=default)
1✔
942

943
    @property
1✔
944
    def _reference_mjd_date(self) -> int:
1✔
945
        if self.transient.use_phase_model:
1✔
946
            return self.kwargs.get("reference_mjd_date", int(self.transient.x[0]))
1✔
947
        return 0
1✔
948

949
    @property
1✔
950
    def band_label_generator(self):
1✔
951
        if self.band_labels is not None:
1✔
952
            return (bl for bl in self.band_labels)
×
953

954
    def plot_data(
1✔
955
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
956
        """Plots the Magnitude data and returns Axes.
957

958
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
959
        :type axes: Union[matplotlib.axes.Axes, None], optional
960
        :param save: Whether to save the plot. (Default value = True)
961
        :type save: bool
962
        :param show: Whether to show the plot. (Default value = True)
963
        :type show: bool
964

965
        :return: The axes with the plot.
966
        :rtype: matplotlib.axes.Axes
967
        """
968
        ax = axes or plt.gca()
1✔
969

970
        band_label_generator = self.band_label_generator
1✔
971

972
        for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
973
            if band in self._filters:
1✔
974
                color = self._colors[list(self._filters).index(band)]
1✔
975
                if band_label_generator is None:
1✔
976
                    if band in self.band_scaling:
1✔
977
                        label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band))
×
978
                        if self.band_scaling.get("type") == 'x':
×
979
                            if self.band_scaling.get(band) == 1:
×
980
                                label = band
×
981
                        elif self.band_scaling.get("type") == '+':
×
982
                            if self.band_scaling.get(band) == 0:
×
983
                                label = band
×
984
                    else:
985
                        label = band   
1✔
986
                else:
987
                    label = next(band_label_generator)
×
988
            elif self.plot_others:
1✔
989
                color = "black"
1✔
990
                label = None
1✔
991
            else:
992
                continue
×
993
            if isinstance(label, float):
1✔
994
                label = f"{label:.2e}"
×
995
            if self.band_colors is not None:
1✔
996
                color = self.band_colors[band]
×
997
            if band in self.band_scaling:
1✔
998
                if self.band_scaling.get("type") == 'x':
×
999
                    ax.errorbar(
×
1000
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band),
1001
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band),
1002
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
1003
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1004
                        fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1005
                        markeredgewidth=self.markeredgewidth)
1006
                elif self.band_scaling.get("type") == '+':
×
1007
                    ax.errorbar(
×
1008
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band),
1009
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
1010
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
1011
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1012
                        fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1013
                        markeredgewidth=self.markeredgewidth)
1014
            else:
1015
                ax.errorbar(
1✔
1016
                    self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices],
1017
                    xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
1018
                    fmt=self.errorbar_fmt, ms=self.ms, color=color,
1019
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1020
                    fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1021
                    markeredgewidth=self.markeredgewidth)
1022

1023
        self._set_x_axis(axes=ax)
1✔
1024
        self._set_y_axis_data(ax)
1✔
1025

1026
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1027
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1028

1029
        # Apply new customizations
1030
        self._apply_axis_customizations(ax)
1✔
1031

1032
        # Legend with new customization options
1033
        ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend,
1✔
1034
                 frameon=self.legend_frameon, shadow=self.legend_shadow,
1035
                 fancybox=self.legend_fancybox, framealpha=self.legend_framealpha)
1036

1037
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1038
        return ax
1✔
1039

1040
    def plot_lightcurve(
1✔
1041
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\
1042
            -> matplotlib.axes.Axes:
1043
        """Plots the Magnitude data and returns Axes.
1044

1045
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1046
        :type axes: Union[matplotlib.axes.Axes, None], optional
1047
        :param save: Whether to save the plot. (Default value = True)
1048
        :type save: bool
1049
        :param show: Whether to show the plot. (Default value = True)
1050
        :type show: bool
1051

1052
        :return: The axes with the plot.
1053
        :rtype: matplotlib.axes.Axes
1054
        """
1055
        axes = axes or plt.gca()
1✔
1056

1057
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
1058

1059
        times = self._get_times(axes)
1✔
1060
        bands_to_plot = self._get_bands_to_plot
1✔
1061

1062
        color_max = self.max_likelihood_color
1✔
1063
        color_sample = self.random_sample_color
1✔
1064
        for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)):
1✔
1065
            if self.set_same_color_per_subplot is True:
1✔
1066
                if self.band_colors is not None:
1✔
1067
                    color = self.band_colors[band]
×
1068
                color_max = color
1✔
1069
                color_sample = color
1✔
1070
            sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band])
1✔
1071
            self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))]
1✔
1072
            if isinstance(band, str):
1✔
1073
                frequency = redback.utils.bands_to_frequency([band])
1✔
1074
            else:
1075
                frequency = band
×
1076
            self._model_kwargs['frequency'] = np.ones(len(times)) * frequency
1✔
1077
            if self.plot_max_likelihood:
1✔
1078
                ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
1079
                if band in self.band_scaling:
1✔
1080
                    if self.band_scaling.get("type") == 'x':
×
NEW
1081
                        axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max,
×
1082
                                 alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1083
                    elif self.band_scaling.get("type") == '+':
×
NEW
1084
                        axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max,
×
1085
                                 alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1086
                else:
1087
                    axes.plot(times - self._reference_mjd_date, ys, color=color_max,
1✔
1088
                             alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1089

1090
            random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
1091
                              for random_params in self._get_random_parameters()]
1092
            if self.uncertainty_mode == "random_models":
1✔
1093
                for ys in random_ys_list:
1✔
1094
                    if band in self.band_scaling:
1✔
1095
                        if self.band_scaling.get("type") == 'x':
×
NEW
1096
                            axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample,
×
1097
                                     alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1098
                        elif self.band_scaling.get("type") == '+':
×
NEW
1099
                            axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample,
×
1100
                                     alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1101
                    else:
1102
                        axes.plot(times - self._reference_mjd_date, ys, color=color_sample,
1✔
1103
                                 alpha=self.random_sample_alpha, lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=-1)
1104
            elif self.uncertainty_mode == "credible_intervals":
×
1105
                if band in self.band_scaling:
×
1106
                    if self.band_scaling.get("type") == 'x':
×
1107
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
×
1108
                    elif self.band_scaling.get("type") == '+':
×
1109
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
×
1110
                else:
1111
                    lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
×
1112
                axes.fill_between(
×
1113
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1114
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1115

1116
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
1117
        return axes
1✔
1118

1119
    def _check_valid_multiband_data_mode(self) -> bool:
1✔
1120
        if self.transient.luminosity_data:
1✔
1121
            redback.utils.logger.warning(
×
1122
                f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.")
1123
            return False
×
1124
        return True
1✔
1125

1126
    def plot_multiband(
1✔
1127
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True,
1128
            show: bool = True) -> matplotlib.axes.Axes:
1129
        """Plots the Magnitude multiband data and returns Axes.
1130

1131
        :param figure: Matplotlib figure to plot the data into.
1132
        :type figure: matplotlib.figure.Figure
1133
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1134
        :type axes: Union[matplotlib.axes.Axes, None], optional
1135
        :param save: Whether to save the plot. (Default value = True)
1136
        :type save: bool
1137
        :param show: Whether to show the plot. (Default value = True)
1138
        :type show: bool
1139

1140
        :return: The axes with the plot.
1141
        :rtype: matplotlib.axes.Axes
1142
        """
1143
        if not self._check_valid_multiband_data_mode():
1✔
1144
            return
×
1145

1146
        if figure is None or axes is None:
1✔
1147
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
×
1148
        axes = axes.ravel()
1✔
1149

1150
        band_label_generator = self.band_label_generator
1✔
1151

1152
        ii = 0
1✔
1153
        for indices, band, freq in zip(
1✔
1154
                self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies):
1155
            if band not in self._filters:
1✔
1156
                continue
×
1157

1158
            x_err = self._get_x_err(indices)
1✔
1159
            color = self._colors[list(self._filters).index(band)]
1✔
1160
            if self.band_colors is not None:
1✔
1161
                color = self.band_colors[band]
×
1162
            if band_label_generator is None:
1✔
1163
                label = self._get_multiband_plot_label(band, freq)
1✔
1164
            else:
1165
                label = next(band_label_generator)
×
1166

1167
            axes[ii].errorbar(
1✔
1168
                self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=x_err,
1169
                yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color,
1170
                elinewidth=self.elinewidth, capsize=self.capsize, label=label,
1171
                fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1172
                markeredgewidth=self.markeredgewidth)
1173

1174
            self._set_x_axis(axes[ii])
1✔
1175
            self._set_y_axis_multiband_data(axes[ii], indices)
1✔
1176

1177
            # Apply new customizations
1178
            self._apply_axis_customizations(axes[ii])
1✔
1179

1180
            # Legend with new customization options
1181
            axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend,
1✔
1182
                           frameon=self.legend_frameon, shadow=self.legend_shadow,
1183
                           fancybox=self.legend_fancybox, framealpha=self.legend_framealpha)
1184
            ii += 1
1✔
1185

1186
        figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure)
1✔
1187
        figure.supylabel(self._ylabel, fontsize=self.fontsize_figure)
1✔
1188
        plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace)
1✔
1189

1190
        self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show)
1✔
1191
        return axes
1✔
1192

1193
    @staticmethod
1✔
1194
    def _get_multiband_plot_label(band: str, freq: float) -> str:
1✔
1195
        if isinstance(band, str):
1✔
1196
            if 1e10 < float(freq) < 1e16:
1✔
1197
                label = band
1✔
1198
            else:
1199
                label = f"{freq:.2e}"
×
1200
        else:
1201
            label = f"{band:.2e}"
×
1202
        return label
1✔
1203

1204
    @property
1✔
1205
    def _filters(self) -> list[str]:
1✔
1206
        filters = self.kwargs.get("filters", self.transient.active_bands)
1✔
1207
        if 'bands_to_plot' in self.kwargs:
1✔
1208
            filters = self.kwargs['bands_to_plot']
×
1209
        if filters is None:
1✔
1210
            return self.transient.active_bands
×
1211
        elif str(filters) == 'default':
1✔
1212
            return self.transient.default_filters
×
1213
        return filters
1✔
1214

1215
    def plot_multiband_lightcurve(
1✔
1216
        self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1217
        """Plots the Magnitude multiband lightcurve and returns Axes.
1218

1219
        :param figure: Matplotlib figure to plot the data into.
1220
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1221
        :type axes: Union[matplotlib.axes.Axes, None], optional
1222
        :param save: Whether to save the plot. (Default value = True)
1223
        :type save: bool
1224
        :param show: Whether to show the plot. (Default value = True)
1225
        :type show: bool
1226

1227
        :return: The axes with the plot.
1228
        :rtype: matplotlib.axes.Axes
1229
        """
1230
        if not self._check_valid_multiband_data_mode():
1✔
1231
            return
×
1232

1233
        if figure is None or axes is None:
1✔
1234
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
1✔
1235

1236
        axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False)
1✔
1237
        times = self._get_times(axes)
1✔
1238

1239
        ii = 0
1✔
1240
        color_max = self.max_likelihood_color
1✔
1241
        color_sample = self.random_sample_color
1✔
1242
        for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies):
1✔
1243
            if band not in self._filters:
1✔
1244
                continue
×
1245
            new_model_kwargs = self._model_kwargs.copy()
1✔
1246
            new_model_kwargs['frequency'] = freq
1✔
1247
            new_model_kwargs['bands'] = redback.utils.sncosmo_bandname_from_band([band])
1✔
1248
            new_model_kwargs['bands'] = [new_model_kwargs['bands'][0] for _ in range(len(times))]
1✔
1249
            
1250
            if self.set_same_color_per_subplot is True:
1✔
1251
                color = self._colors[list(self._filters).index(band)]
1✔
1252
                if self.band_colors is not None:
1✔
1253
                    color = self.band_colors[band]
×
1254
                color_max = color
1✔
1255
                color_sample = color
1✔
1256

1257
            if self.plot_max_likelihood:
1✔
1258
                ys = self.model(times, **self._max_like_params, **new_model_kwargs)
×
1259
                axes[ii].plot(
×
1260
                    times - self._reference_mjd_date, ys, color=color_max,
1261
                    alpha=self.max_likelihood_alpha, lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1262
            random_ys_list = [self.model(times, **random_params, **new_model_kwargs)
1✔
1263
                              for random_params in self._get_random_parameters()]
1264
            if self.uncertainty_mode == "random_models":
1✔
1265
                for random_ys in random_ys_list:
1✔
1266
                    axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
1✔
1267
                                  alpha=self.random_sample_alpha, lw=self.linewidth,
1268
                                  linestyle=self.random_sample_linestyle, zorder=self.zorder)
1269
            elif self.uncertainty_mode == "credible_intervals":
×
1270
                lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
1271
                axes[ii].fill_between(
×
1272
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1273
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1274
            ii += 1
1✔
1275

1276
        self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show)
1✔
1277
        return axes
1✔
1278

1279

1280
class FluxDensityPlotter(MagnitudePlotter):
1✔
1281
    pass
1✔
1282

1283
class IntegratedFluxOpticalPlotter(MagnitudePlotter):
1✔
1284
    pass
1✔
1285

1286
class SpectrumPlotter(SpecPlotter):
1✔
1287
    @property
1✔
1288
    def _xlabel(self) -> str:
1✔
1289
        return self.transient.xlabel
1✔
1290

1291
    @property
1✔
1292
    def _ylabel(self) -> str:
1✔
1293
        return self.transient.ylabel
1✔
1294

1295
    def plot_data(
1✔
1296
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1297
        """Plots the spectrum data and returns Axes.
1298

1299
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1300
        :type axes: Union[matplotlib.axes.Axes, None], optional
1301
        :param save: Whether to save the plot. (Default value = True)
1302
        :type save: bool
1303
        :param show: Whether to show the plot. (Default value = True)
1304
        :type show: bool
1305

1306
        :return: The axes with the plot.
1307
        :rtype: matplotlib.axes.Axes
1308
        """
1309
        ax = axes or plt.gca()
1✔
1310

1311
        if self.transient.plot_with_time_label:
1✔
1312
            label = self.transient.time
1✔
1313
        else:
1314
            label = self.transient.name
1✔
1315
        ax.plot(self.transient.angstroms, self.transient.flux_density/1e-17, color=self.color,
1✔
1316
                lw=self.linewidth, linestyle=self.linestyle)
1317

1318
        # Apply custom scales if specified, otherwise use defaults
1319
        ax.set_xscale(self.xscale if self.xscale is not None else 'linear')
1✔
1320
        ax.set_yscale(self.yscale)
1✔
1321

1322
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
1323
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
1324
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1325
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1326

1327
        ax.annotate(
1✔
1328
            label, xy=self.xy, xycoords=self.xycoords,
1329
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1330

1331
        # Apply new customizations
1332
        self._apply_axis_customizations(ax)
1✔
1333

1334
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1335
        return ax
1✔
1336

1337
    def plot_spectrum(
1✔
1338
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1339
        """Plots the spectrum data and the fit and returns Axes.
1340

1341
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1342
        :type axes: Union[matplotlib.axes.Axes, None], optional
1343
        :param save: Whether to save the plot. (Default value = True)
1344
        :type save: bool
1345
        :param show: Whether to show the plot. (Default value = True)
1346
        :type show: bool
1347

1348
        :return: The axes with the plot.
1349
        :rtype: matplotlib.axes.Axes
1350
        """
1351

1352
        axes = axes or plt.gca()
1✔
1353

1354
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
1355
        angstroms = self._get_angstroms(axes)
1✔
1356

1357
        self._plot_spectrums(axes, angstroms)
1✔
1358

1359
        self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
1✔
1360
        return axes
1✔
1361

1362
    def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1✔
1363
        if self.plot_max_likelihood:
1✔
1364
            ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1365
            axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
1366
                      lw=self.linewidth, linestyle=self.max_likelihood_linestyle)
1367

1368
        random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
1✔
1369
                          for random_params in self._get_random_parameters()]
1370
        if self.uncertainty_mode == "random_models":
1✔
1371
            for ys in random_ys_list:
1✔
1372
                axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
1373
                          lw=self.linewidth, linestyle=self.random_sample_linestyle, zorder=self.zorder)
1374
        elif self.uncertainty_mode == "credible_intervals":
1✔
1375
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
1✔
1376
                                                                                interval=self.credible_interval_level)
1377
            axes.fill_between(
1✔
1378
                angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1379

1380
    def plot_residuals(
1✔
1381
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1382
        """Plots the residual of the Integrated flux data returns Axes.
1383

1384
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1385
        :param save: Whether to save the plot. (Default value = True)
1386
        :param show: Whether to show the plot. (Default value = True)
1387

1388
        :return: The axes with the plot.
1389
        :rtype: matplotlib.axes.Axes
1390
        """
1391
        if axes is None:
1✔
1392
            fig, axes = plt.subplots(
1✔
1393
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1394

1395
        axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
1✔
1396
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
1397
        axes[0].set_xlabel("")
1✔
1398
        ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1399
        axes[1].errorbar(
1✔
1400
            self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1401
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize,
1402
            fillstyle=self.markerfillstyle, markeredgecolor=self.markeredgecolor,
1403
            markeredgewidth=self.markeredgewidth)
1404
        axes[1].set_yscale('linear')
1✔
1405
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
1406

1407
        # Apply new customizations
1408
        self._apply_axis_customizations(axes[1])
1✔
1409

1410
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
1411
        return axes
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc