• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 19571845612

21 Nov 2025 01:22PM UTC coverage: 86.583% (+0.02%) from 86.568%
19571845612

push

github

web-flow
Merge pull request #327 from nikhil-sarin/update_for_plotting

Change some defaults in the plotting and remove some matplotlib style…

3 of 9 new or added lines in 1 file covered. (33.33%)

1 existing line in 1 file now uncovered.

10564 of 12201 relevant lines covered (86.58%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.57
/redback/plotting.py
1
from __future__ import annotations
1✔
2

3
from os.path import join
1✔
4
from typing import Any, Union
1✔
5

6
import matplotlib
1✔
7
import matplotlib.pyplot as plt
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10

11
import redback
1✔
12
from redback.utils import KwargsAccessorWithDefault
1✔
13

14
class _FilenameGetter(object):
1✔
15
    def __init__(self, suffix: str) -> None:
1✔
16
        self.suffix = suffix
1✔
17

18
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
19
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")
1✔
20

21
    def __set__(self, instance: Plotter, value: object) -> None:
1✔
22
        pass
1✔
23

24

25
class _FilePathGetter(object):
1✔
26

27
    def __init__(self, directory_property: str, filename_property: str) -> None:
1✔
28
        self.directory_property = directory_property
1✔
29
        self.filename_property = filename_property
1✔
30

31
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
32
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))
1✔
33

34

35
class Plotter(object):
1✔
36
    """
37
    Base class for all lightcurve plotting classes in redback.
38
    """
39

40
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
41
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
42
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
43
    band_colors = KwargsAccessorWithDefault("band_colors", None)
1✔
44
    color = KwargsAccessorWithDefault("color", "k")
1✔
45
    band_labels = KwargsAccessorWithDefault("band_labels", None)
1✔
46
    band_scaling = KwargsAccessorWithDefault("band_scaling", {})
1✔
47
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
48
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
49
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "o")
1✔
50
    model = KwargsAccessorWithDefault("model", None)
1✔
51
    ms = KwargsAccessorWithDefault("ms", 5)
1✔
52
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
53

54
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
55
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
56
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
57
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
58
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
59

60
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
61
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
62
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
63

64
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
65
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
66
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
67
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
68

69
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
70
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
71
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
72
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
73
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
74
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
75

76
    plot_others = KwargsAccessorWithDefault("plot_others", True)
1✔
77
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
78
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
79
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
80
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
81
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
82

83
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0)
1✔
84
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5)
1✔
85
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0)
1✔
86
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
87

88
    def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None:
1✔
89
        """
90
        :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted.
91
        :param kwargs: Additional kwargs the plotter uses. -------
92
        :keyword capsize: Same as matplotlib capsize.
93
        :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands.
94
        :keyword legend_location: Same as matplotlib legend location.
95
        :keyword legend_cols: Same as matplotlib legend columns.
96
        :keyword color: Color of the data points.
97
        :keyword band_colors: A dictionary with the colors of the bands.
98
        :keyword band_labels: List with the names of the bands.
99
        :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types.
100
        :keyword dpi: Same as matplotlib dpi.
101
        :keyword elinewidth: same as matplotlib elinewidth
102
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
103
        :keyword model: str or callable, the model to plot.
104
        :keyword ms: Same as matplotlib markersize.
105
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
106
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
107
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
108
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
109
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
110
        :keyword random_sample_color: Color of the random sample curves.
111
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
112
        :keyword linewidth: Same as matplotlib linewidth
113
        :keyword zorder: Same as matplotlib zorder
114
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
115
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
116
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
117
        :keyword annotation_size: `size` argument of of `ax.annotate`.
118
        :keyword fontsize_axes: Font size of the x and y labels.
119
        :keyword fontsize_legend: Font size of the legend.
120
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
121
                                  Used on `supxlabel` and `supylabel`.
122
        :keyword fontsize_ticks: Font size of the axis ticks.
123
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
124
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
125
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
126
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
127
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
128
                                   'credible_intervals': Plot a credible interval that is calculated based
129
                                   on the available parameter sets.
130
        :keyword reference_mjd_date: Date to use as reference point for the x axis.
131
                                    Default is the first date in the data.
132
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
133
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
134
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
135
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
136
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
137
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
138
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
139
        """
140
        self.transient = transient
1✔
141
        self.kwargs = kwargs or dict()
1✔
142
        self._posterior_sorted = False
1✔
143

144
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
145

146
    def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
147
        """
148
        :param axes: The axes used in the plotting procedure.
149
        :type axes: matplotlib.axes.Axes
150

151
        :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot.
152
        :rtype: np.ndarray
153
        """
154
        if isinstance(axes, np.ndarray):
1✔
155
            ax = axes[0]
1✔
156
        else:
157
            ax = axes
1✔
158

159
        if ax.get_yscale() == 'linear':
1✔
160
            times = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
161
        else:
162
            times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
163

164
        if self.transient.use_phase_model:
1✔
165
            times = times + self._reference_mjd_date
1✔
166
        return times
1✔
167

168
    @property
1✔
169
    def _xlim_low(self) -> float:
1✔
170
        default = self.xlim_low_multiplier * self.transient.x[0]
1✔
171
        if default == 0:
1✔
172
            default += 1e-3
×
173
        return self.kwargs.get("xlim_low", default)
1✔
174

175
    @property
1✔
176
    def _xlim_high(self) -> float:
1✔
177
        if self._x_err is None:
1✔
178
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
179
        else:
180
            default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1])
×
181
        return self.kwargs.get("xlim_high", default)
1✔
182

183
    @property
1✔
184
    def _ylim_low(self) -> float:
1✔
185
        default = self.ylim_low_multiplier * min(self.transient.y)
1✔
186
        return self.kwargs.get("ylim_low", default)
1✔
187

188
    @property
1✔
189
    def _ylim_high(self) -> float:
1✔
190
        default = self.ylim_high_multiplier * np.max(self.transient.y)
1✔
191
        return self.kwargs.get("ylim_high", default)
1✔
192

193
    @property
1✔
194
    def _x_err(self) -> Union[np.ndarray, None]:
1✔
195
        if self.transient.x_err is not None:
1✔
196
            return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]])
×
197
        else:
198
            return None
1✔
199

200
    @property
1✔
201
    def _y_err(self) -> np.ndarray:
1✔
202
        if self.transient.y_err.ndim > 1.:
1✔
203
            return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]])
×
204
        else:
205
            return np.array([np.abs(self.transient.y_err)])
1✔
206
    @property
1✔
207
    def _lightcurve_plot_outdir(self) -> str:
1✔
208
        return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__))
1✔
209

210
    @property
1✔
211
    def _data_plot_outdir(self) -> str:
1✔
212
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
213

214
    def _get_outdir(self, default: str) -> str:
1✔
215
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
216

217
    def get_filename(self, default: str) -> str:
1✔
218
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
219

220
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
221
        return self.kwargs.get(kwarg, default) or default
1✔
222

223
    @property
1✔
224
    def _model_kwargs(self) -> dict:
1✔
225
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
226

227
    @property
1✔
228
    def _posterior(self) -> pd.DataFrame:
1✔
229
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
230
        if not self._posterior_sorted and posterior is not None:
1✔
231
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
232
            self._posterior_sorted = True
1✔
233
        return posterior
1✔
234

235
    @property
1✔
236
    def _max_like_params(self) -> pd.core.series.Series:
1✔
237
        return self._posterior.iloc[-1]
1✔
238

239
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
240
        integers = np.arange(len(self._posterior))
1✔
241
        indices = np.random.choice(integers, size=self.random_models)
1✔
242
        return [self._posterior.iloc[idx] for idx in indices]
1✔
243

244
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
245
    _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve")
1✔
246
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
247
    _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data")
1✔
248
    _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve")
1✔
249

250
    _data_plot_filepath = _FilePathGetter(
1✔
251
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
252
    _lightcurve_plot_filepath = _FilePathGetter(
1✔
253
        directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename")
254
    _residual_plot_filepath = _FilePathGetter(
1✔
255
        directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename")
256
    _multiband_data_plot_filepath = _FilePathGetter(
1✔
257
        directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename")
258
    _multiband_lightcurve_plot_filepath = _FilePathGetter(
1✔
259
        directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename")
260

261
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
262
        plt.tight_layout()
1✔
263
        if save:
1✔
264
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
265
        if show:
1✔
266
            plt.show()
×
267

268
class SpecPlotter(object):
1✔
269
    """
270
    Base class for all lightcurve plotting classes in redback.
271
    """
272

273
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
274
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
275
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
276
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
277
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
278
    color = KwargsAccessorWithDefault("color", "k")
1✔
279
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
280
    model = KwargsAccessorWithDefault("model", None)
1✔
281
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
282
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
283

284
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
285
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
286
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
287
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
288
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
289

290
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
291
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
292
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
293
    yscale = KwargsAccessorWithDefault("yscale", "linear")
1✔
294

295
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
296
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
297
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
298
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
299

300
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
301
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
302
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
303
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
304
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
305
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
306

307
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
308
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
309
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
310
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
311
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
312

313
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
1✔
314
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
315
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
1✔
316
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
317

318
    def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
1✔
319
        """
320
        :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
321
        :param kwargs: Additional kwargs the plotter uses. -------
322
        :keyword capsize: Same as matplotlib capsize.
323
        :keyword elinewidth: same as matplotlib elinewidth
324
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
325
        :keyword ms: Same as matplotlib markersize.
326
        :keyword legend_location: Same as matplotlib legend location.
327
        :keyword legend_cols: Same as matplotlib legend columns.
328
        :keyword color: Color of the data points.
329
        :keyword dpi: Same as matplotlib dpi.
330
        :keyword model: str or callable, the model to plot.
331
        :keyword ms: Same as matplotlib markersize.
332
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
333
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
334
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
335
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
336
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
337
        :keyword random_sample_color: Color of the random sample curves.
338
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
339
        :keyword linewidth: Same as matplotlib linewidth
340
        :keyword zorder: Same as matplotlib zorder
341
        :keyword yscale: Same as matplotlib yscale, default is linear
342
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
343
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
344
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
345
        :keyword annotation_size: `size` argument of of `ax.annotate`.
346
        :keyword fontsize_axes: Font size of the x and y labels.
347
        :keyword fontsize_legend: Font size of the legend.
348
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
349
                                  Used on `supxlabel` and `supylabel`.
350
        :keyword fontsize_ticks: Font size of the axis ticks.
351
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
352
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
353
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
354
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
355
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
356
                                   'credible_intervals': Plot a credible interval that is calculated based
357
                                   on the available parameter sets.
358
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
359
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
360
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
361
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
362
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
363
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
364
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
365
        """
366
        self.transient = spectrum
1✔
367
        self.kwargs = kwargs or dict()
1✔
368
        self._posterior_sorted = False
1✔
369

370
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
371

372
    def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
373
        """
374
        :param axes: The axes used in the plotting procedure.
375
        :type axes: matplotlib.axes.Axes
376

377
        :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
378
        :rtype: np.ndarray
379
        """
380
        if isinstance(axes, np.ndarray):
1✔
381
            ax = axes[0]
×
382
        else:
383
            ax = axes
1✔
384

385
        if ax.get_yscale() == 'linear':
1✔
386
            angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
387
        else:
388
            angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
389

390
        return angstroms
1✔
391

392
    @property
1✔
393
    def _xlim_low(self) -> float:
1✔
394
        default = self.xlim_low_multiplier * self.transient.angstroms[0]
1✔
395
        if default == 0:
1✔
396
            default += 1e-3
×
397
        return self.kwargs.get("xlim_low", default)
1✔
398

399
    @property
1✔
400
    def _xlim_high(self) -> float:
1✔
401
        default = self.xlim_high_multiplier * self.transient.angstroms[-1]
1✔
402
        return self.kwargs.get("xlim_high", default)
1✔
403

404
    @property
1✔
405
    def _ylim_low(self) -> float:
1✔
406
        default = self.ylim_low_multiplier * min(self.transient.flux_density)
1✔
407
        return self.kwargs.get("ylim_low", default/1e-17)
1✔
408

409
    @property
1✔
410
    def _ylim_high(self) -> float:
1✔
411
        default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
1✔
412
        return self.kwargs.get("ylim_high", default/1e-17)
1✔
413

414
    @property
1✔
415
    def _y_err(self) -> np.ndarray:
1✔
416
        return np.array([np.abs(self.transient.flux_density_err)])
1✔
417

418
    @property
1✔
419
    def _data_plot_outdir(self) -> str:
1✔
420
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
421

422
    def _get_outdir(self, default: str) -> str:
1✔
423
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
424

425
    def get_filename(self, default: str) -> str:
1✔
426
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
427

428
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
429
        return self.kwargs.get(kwarg, default) or default
1✔
430

431
    @property
1✔
432
    def _model_kwargs(self) -> dict:
1✔
433
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
434

435
    @property
1✔
436
    def _posterior(self) -> pd.DataFrame:
1✔
437
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
438
        if not self._posterior_sorted and posterior is not None:
1✔
439
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
440
            self._posterior_sorted = True
1✔
441
        return posterior
1✔
442

443
    @property
1✔
444
    def _max_like_params(self) -> pd.core.series.Series:
1✔
445
        return self._posterior.iloc[-1]
1✔
446

447
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
448
        integers = np.arange(len(self._posterior))
1✔
449
        indices = np.random.choice(integers, size=self.random_models)
1✔
450
        return [self._posterior.iloc[idx] for idx in indices]
1✔
451

452
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
453
    _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
1✔
454
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
455

456
    _data_plot_filepath = _FilePathGetter(
1✔
457
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
458
    _spectrum_ppd_plot_filepath = _FilePathGetter(
1✔
459
        directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
460
    _residual_plot_filepath = _FilePathGetter(
1✔
461
        directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
462

463
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
464
        plt.tight_layout()
1✔
465
        if save:
1✔
466
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
467
        if show:
1✔
468
            plt.show()
×
469

470

471
class IntegratedFluxPlotter(Plotter):
1✔
472

473
    @property
1✔
474
    def _xlabel(self) -> str:
1✔
475
        return r"Time since burst [s]"
1✔
476

477
    @property
1✔
478
    def _ylabel(self) -> str:
1✔
479
        return self.transient.ylabel
1✔
480

481
    def plot_data(
1✔
482
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
483
        """Plots the Integrated flux data and returns Axes.
484

485
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
486
        :type axes: Union[matplotlib.axes.Axes, None], optional
487
        :param save: Whether to save the plot. (Default value = True)
488
        :type save: bool
489
        :param show: Whether to show the plot. (Default value = True)
490
        :type show: bool
491

492
        :return: The axes with the plot.
493
        :rtype: matplotlib.axes.Axes
494
        """
495
        ax = axes or plt.gca()
1✔
496

497
        ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err,
1✔
498
                    fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
499

500
        ax.set_xscale('log')
1✔
501
        ax.set_yscale('log')
1✔
502

503
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
504
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
505
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
506
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
507

508
        ax.annotate(
1✔
509
            self.transient.name, xy=self.xy, xycoords=self.xycoords,
510
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
511

512
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
513

514
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
515
        return ax
1✔
516

517
    def plot_lightcurve(
1✔
518
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
519
        """Plots the Integrated flux data and the lightcurve and returns Axes.
520

521
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
522
        :type axes: Union[matplotlib.axes.Axes, None], optional
523
        :param save: Whether to save the plot. (Default value = True)
524
        :type save: bool
525
        :param show: Whether to show the plot. (Default value = True)
526
        :type show: bool
527

528
        :return: The axes with the plot.
529
        :rtype: matplotlib.axes.Axes
530
        """
531
        
532
        axes = axes or plt.gca()
1✔
533

534
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
535
        times = self._get_times(axes)
1✔
536

537
        self._plot_lightcurves(axes, times)
1✔
538

539
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
540
        return axes
1✔
541

542
    def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None:
1✔
543
        if self.plot_max_likelihood:
×
544
            ys = self.model(times, **self._max_like_params, **self._model_kwargs)
×
545
            axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
546

547
        random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
×
548
                          for random_params in self._get_random_parameters()]
549
        if self.uncertainty_mode == "random_models":
×
550
            for ys in random_ys_list:
×
551
                axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
552
                          zorder=self.zorder)
553
        elif self.uncertainty_mode == "credible_intervals":
×
554
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
555
            axes.fill_between(
×
556
                times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
557

558
    def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None:
1✔
559
        ys = self.model(times, **params, **self._model_kwargs)
×
560
        axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
561
                  zorder=self.zorder)
562

563
    def plot_residuals(
1✔
564
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
565
        """Plots the residual of the Integrated flux data returns Axes.
566

567
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
568
        :param save: Whether to save the plot. (Default value = True)
569
        :param show: Whether to show the plot. (Default value = True)
570

571
        :return: The axes with the plot.
572
        :rtype: matplotlib.axes.Axes
573
        """
574
        if axes is None:
1✔
575
            fig, axes = plt.subplots(
1✔
576
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
577

578
        axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False)
1✔
579
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
580
        axes[0].set_xlabel("")
1✔
581
        ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs)
1✔
582
        axes[1].errorbar(
1✔
583
            self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err,
584
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
585
        axes[1].set_yscale("log")
1✔
586
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
587
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
588

589
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
590
        return axes
1✔
591

592

593
class LuminosityOpticalPlotter(IntegratedFluxPlotter):
1✔
594

595
    @property
1✔
596
    def _xlabel(self) -> str:
1✔
597
        return r"Time since explosion [days]"
1✔
598

599
    @property
1✔
600
    def _ylabel(self) -> str:
1✔
601
        return r"L$_{\rm bol}$ [$10^{50}$ erg s$^{-1}$]"
1✔
602

603
class LuminosityPlotter(IntegratedFluxPlotter):
1✔
604
    pass
1✔
605

606

607
class MagnitudePlotter(Plotter):
1✔
608

609
    xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
610
    xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
1✔
611
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
1✔
612
    ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.95)
1✔
613
    ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.05)
1✔
614
    ncols = KwargsAccessorWithDefault("ncols", 2)
1✔
615

616
    @property
1✔
617
    def _colors(self) -> str:
1✔
618
        return self.kwargs.get("colors", self.transient.get_colors(self._filters))
1✔
619

620
    @property
1✔
621
    def _xlabel(self) -> str:
1✔
622
        if self.transient.use_phase_model:
1✔
623
            default = f"Time since {self._reference_mjd_date} MJD [days]"
1✔
624
        else:
625
            default = self.transient.xlabel
1✔
626
        return self.kwargs.get("xlabel", default)
1✔
627

628
    @property
1✔
629
    def _ylabel(self) -> str:
1✔
630
        return self.kwargs.get("ylabel", self.transient.ylabel)
1✔
631

632
    @property
1✔
633
    def _get_bands_to_plot(self) -> list[str]:
1✔
634
        return self.kwargs.get("bands_to_plot", self.transient.active_bands)
1✔
635

636
    @property
1✔
637
    def _xlim_low(self) -> float:
1✔
638
        if self.transient.use_phase_model:
1✔
639
            default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier
1✔
640
        else:
641
            default = self.xlim_low_multiplier * self.transient.x[0]
1✔
642
        if default == 0:
1✔
643
            default += 1e-3
1✔
644
        return self.kwargs.get("xlim_low", default)
1✔
645

646
    @property
1✔
647
    def _xlim_high(self) -> float:
1✔
648
        if self.transient.use_phase_model:
1✔
649
            default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier
1✔
650
        else:
651
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
652
        return self.kwargs.get("xlim_high", default)
1✔
653

654
    @property
1✔
655
    def _ylim_low_magnitude(self) -> float:
1✔
656
        return self.ylim_low_magnitude_multiplier * min(self.transient.y)
1✔
657

658
    @property
1✔
659
    def _ylim_high_magnitude(self) -> float:
1✔
660
        return self.ylim_high_magnitude_multiplier * np.max(self.transient.y)
1✔
661

662
    def _get_ylim_low_with_indices(self, indices: list) -> float:
1✔
663
        return self.ylim_low_multiplier * min(self.transient.y[indices])
1✔
664

665
    def _get_ylim_high_with_indices(self, indices: list) -> float:
1✔
666
        return self.ylim_high_multiplier * np.max(self.transient.y[indices])
1✔
667

668
    def _get_x_err(self, indices: list) -> np.ndarray:
1✔
669
        return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err
1✔
670

671
    def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None:
1✔
672
        if self.transient.magnitude_data:
1✔
673
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
674
            ax.invert_yaxis()
1✔
675
            ax.set_yscale('linear')
1✔
676
        else:
677
            ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
678
            ax.set_yscale("log")
1✔
679

680
    def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None:
1✔
681
        if self.transient.magnitude_data:
1✔
682
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
683
            ax.invert_yaxis()
1✔
684
            ax.set_yscale('linear')
1✔
685
        else:
686
            ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
1✔
687
                        self._get_ylim_high_with_indices(indices=indices))
688
            ax.set_yscale("log")
1✔
689

690
    def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None:
1✔
691
        if self.transient.use_phase_model:
1✔
692
            axes.set_xscale("linear")
1✔
693
        axes.set_xlim(self._xlim_low, self._xlim_high)
1✔
694

695
    @property
1✔
696
    def _nrows(self) -> int:
1✔
697
        default = int(np.ceil(len(self._filters) / 2))
1✔
698
        return self._get_kwarg_with_default("nrows", default=default)
1✔
699

700
    @property
1✔
701
    def _npanels(self) -> int:
1✔
702
        npanels = self._nrows * self.ncols
×
703
        if npanels < len(self._filters):
×
704
            raise ValueError(f"Insufficient number of panels. {npanels} panels were given "
×
705
                             f"but {len(self._filters)} panels are needed.")
706
        return npanels
×
707

708
    @property
1✔
709
    def _figsize(self) -> tuple:
1✔
710
        default = (4 + 4 * self.ncols, 2 + 2 * self._nrows)
1✔
711
        return self._get_kwarg_with_default("figsize", default=default)
1✔
712

713
    @property
1✔
714
    def _reference_mjd_date(self) -> int:
1✔
715
        if self.transient.use_phase_model:
1✔
716
            return self.kwargs.get("reference_mjd_date", int(self.transient.x[0]))
1✔
717
        return 0
1✔
718

719
    @property
1✔
720
    def band_label_generator(self):
1✔
721
        if self.band_labels is not None:
1✔
722
            return (bl for bl in self.band_labels)
×
723

724
    def plot_data(
1✔
725
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
726
        """Plots the Magnitude data and returns Axes.
727

728
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
729
        :type axes: Union[matplotlib.axes.Axes, None], optional
730
        :param save: Whether to save the plot. (Default value = True)
731
        :type save: bool
732
        :param show: Whether to show the plot. (Default value = True)
733
        :type show: bool
734

735
        :return: The axes with the plot.
736
        :rtype: matplotlib.axes.Axes
737
        """
738
        ax = axes or plt.gca()
1✔
739

740
        band_label_generator = self.band_label_generator
1✔
741

742
        for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
743
            if band in self._filters:
1✔
744
                color = self._colors[list(self._filters).index(band)]
1✔
745
                if band_label_generator is None:
1✔
746
                    if band in self.band_scaling:
1✔
747
                        label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band))
×
NEW
748
                        if self.band_scaling.get("type") == 'x':
×
NEW
749
                            if self.band_scaling.get(band) == 1:
×
NEW
750
                                label = band
×
NEW
751
                        elif self.band_scaling.get("type") == '+':
×
NEW
752
                            if self.band_scaling.get(band) == 0:
×
NEW
753
                                label = band
×
754
                    else:
755
                        label = band   
1✔
756
                else:
757
                    label = next(band_label_generator)
×
758
            elif self.plot_others:
×
759
                color = "black"
×
760
                label = None
×
761
            else:
762
                continue
×
763
            if isinstance(label, float):
1✔
764
                label = f"{label:.2e}"
×
765
            if self.band_colors is not None:
1✔
766
                color = self.band_colors[band]
×
767
            if band in self.band_scaling:
1✔
768
                if self.band_scaling.get("type") == 'x':
×
769
                    ax.errorbar(
×
770
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band),
771
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band),
772
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
773
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
774
                elif self.band_scaling.get("type") == '+':
×
775
                    ax.errorbar(
×
776
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band),
777
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
778
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
779
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
780
            else:
781
                ax.errorbar(
1✔
782
                    self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices],
783
                    xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
784
                    fmt=self.errorbar_fmt, ms=self.ms, color=color,
785
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label)
786

787
        self._set_x_axis(axes=ax)
1✔
788
        self._set_y_axis_data(ax)
1✔
789

790
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
791
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
792

793
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
794
        ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
795

796
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
797
        return ax
1✔
798

799
    def plot_lightcurve(
1✔
800
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\
801
            -> matplotlib.axes.Axes:
802
        """Plots the Magnitude data and returns Axes.
803

804
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
805
        :type axes: Union[matplotlib.axes.Axes, None], optional
806
        :param save: Whether to save the plot. (Default value = True)
807
        :type save: bool
808
        :param show: Whether to show the plot. (Default value = True)
809
        :type show: bool
810

811
        :return: The axes with the plot.
812
        :rtype: matplotlib.axes.Axes
813
        """
814
        axes = axes or plt.gca()
1✔
815

816
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
817

818
        times = self._get_times(axes)
1✔
819
        bands_to_plot = self._get_bands_to_plot
1✔
820

821
        color_max = self.max_likelihood_color
1✔
822
        color_sample = self.random_sample_color
1✔
823
        for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)):
1✔
824
            if self.set_same_color_per_subplot is True:
1✔
825
                if self.band_colors is not None:
1✔
826
                    color = self.band_colors[band]
×
827
                color_max = color
1✔
828
                color_sample = color
1✔
829
            sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band])
1✔
830
            self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))]
1✔
831
            if isinstance(band, str):
1✔
832
                frequency = redback.utils.bands_to_frequency([band])
1✔
833
            else:
834
                frequency = band
×
835
            self._model_kwargs['frequency'] = np.ones(len(times)) * frequency
1✔
836
            if self.plot_max_likelihood:
1✔
837
                ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
838
                if band in self.band_scaling:
1✔
839
                    if self.band_scaling.get("type") == 'x':
×
840
                        axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
841
                    elif self.band_scaling.get("type") == '+':
×
842
                        axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
843
                else:        
844
                    axes.plot(times - self._reference_mjd_date, ys, color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
1✔
845

846
            random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
847
                              for random_params in self._get_random_parameters()]
848
            if self.uncertainty_mode == "random_models":
1✔
849
                for ys in random_ys_list:
1✔
850
                    if band in self.band_scaling:
1✔
851
                        if self.band_scaling.get("type") == 'x':
×
852
                            axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
853
                        elif self.band_scaling.get("type") == '+':
×
854
                            axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
855
                    else:
856
                        axes.plot(times - self._reference_mjd_date, ys, color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
1✔
857
            elif self.uncertainty_mode == "credible_intervals":
×
858
                if band in self.band_scaling:
×
859
                    if self.band_scaling.get("type") == 'x':
×
860
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
×
861
                    elif self.band_scaling.get("type") == '+':
×
862
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
×
863
                else:
864
                    lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
×
865
                axes.fill_between(
×
866
                    times - self._reference_mjd_date, lower_bound, upper_bound,
867
                    alpha=self.uncertainty_band_alpha, color=color_sample)
868

869
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
870
        return axes
1✔
871

872
    def _check_valid_multiband_data_mode(self) -> bool:
1✔
873
        if self.transient.luminosity_data:
1✔
874
            redback.utils.logger.warning(
×
875
                f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.")
876
            return False
×
877
        return True
1✔
878

879
    def plot_multiband(
1✔
880
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True,
881
            show: bool = True) -> matplotlib.axes.Axes:
882
        """Plots the Magnitude multiband data and returns Axes.
883

884
        :param figure: Matplotlib figure to plot the data into.
885
        :type figure: matplotlib.figure.Figure
886
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
887
        :type axes: Union[matplotlib.axes.Axes, None], optional
888
        :param save: Whether to save the plot. (Default value = True)
889
        :type save: bool
890
        :param show: Whether to show the plot. (Default value = True)
891
        :type show: bool
892

893
        :return: The axes with the plot.
894
        :rtype: matplotlib.axes.Axes
895
        """
896
        if not self._check_valid_multiband_data_mode():
1✔
897
            return
×
898

899
        if figure is None or axes is None:
1✔
900
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
×
901
        axes = axes.ravel()
1✔
902

903
        band_label_generator = self.band_label_generator
1✔
904

905
        ii = 0
1✔
906
        for indices, band, freq in zip(
1✔
907
                self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies):
908
            if band not in self._filters:
1✔
909
                continue
×
910

911
            x_err = self._get_x_err(indices)
1✔
912
            color = self._colors[list(self._filters).index(band)]
1✔
913
            if self.band_colors is not None:
1✔
914
                color = self.band_colors[band]
×
915
            if band_label_generator is None:
1✔
916
                label = self._get_multiband_plot_label(band, freq)
1✔
917
            else:
918
                label = next(band_label_generator)
×
919

920
            axes[ii].errorbar(
1✔
921
                self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=x_err,
922
                yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color,
923
                elinewidth=self.elinewidth, capsize=self.capsize,
924
                label=label)
925

926
            self._set_x_axis(axes[ii])
1✔
927
            self._set_y_axis_multiband_data(axes[ii], indices)
1✔
928
            axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
929
            axes[ii].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
930
            ii += 1
1✔
931

932
        figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure)
1✔
933
        figure.supylabel(self._ylabel, fontsize=self.fontsize_figure)
1✔
934
        plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace)
1✔
935

936
        self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show)
1✔
937
        return axes
1✔
938

939
    @staticmethod
1✔
940
    def _get_multiband_plot_label(band: str, freq: float) -> str:
1✔
941
        if isinstance(band, str):
1✔
942
            if 1e10 < float(freq) < 1e16:
1✔
943
                label = band
1✔
944
            else:
945
                label = f"{freq:.2e}"
×
946
        else:
947
            label = f"{band:.2e}"
×
948
        return label
1✔
949

950
    @property
1✔
951
    def _filters(self) -> list[str]:
1✔
952
        filters = self.kwargs.get("filters", self.transient.active_bands)
1✔
953
        if 'bands_to_plot' in self.kwargs:
1✔
954
            filters = self.kwargs['bands_to_plot']
×
955
        if filters is None:
1✔
956
            return self.transient.active_bands
×
957
        elif str(filters) == 'default':
1✔
958
            return self.transient.default_filters
×
959
        return filters
1✔
960

961
    def plot_multiband_lightcurve(
1✔
962
        self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
963
        """Plots the Magnitude multiband lightcurve and returns Axes.
964

965
        :param figure: Matplotlib figure to plot the data into.
966
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
967
        :type axes: Union[matplotlib.axes.Axes, None], optional
968
        :param save: Whether to save the plot. (Default value = True)
969
        :type save: bool
970
        :param show: Whether to show the plot. (Default value = True)
971
        :type show: bool
972

973
        :return: The axes with the plot.
974
        :rtype: matplotlib.axes.Axes
975
        """
976
        if not self._check_valid_multiband_data_mode():
1✔
977
            return
×
978

979
        if figure is None or axes is None:
1✔
980
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
1✔
981

982
        axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False)
1✔
983
        times = self._get_times(axes)
1✔
984

985
        ii = 0
1✔
986
        color_max = self.max_likelihood_color
1✔
987
        color_sample = self.random_sample_color
1✔
988
        for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies):
1✔
989
            if band not in self._filters:
1✔
990
                continue
×
991
            new_model_kwargs = self._model_kwargs.copy()
1✔
992
            new_model_kwargs['frequency'] = freq
1✔
993
            new_model_kwargs['bands'] = redback.utils.sncosmo_bandname_from_band([band])
1✔
994
            new_model_kwargs['bands'] = [new_model_kwargs['bands'][0] for _ in range(len(times))]
1✔
995
            
996
            if self.set_same_color_per_subplot is True:
1✔
997
                color = self._colors[list(self._filters).index(band)]
1✔
998
                if self.band_colors is not None:
1✔
999
                    color = self.band_colors[band]
×
1000
                color_max = color
1✔
1001
                color_sample = color
1✔
1002

1003
            if self.plot_max_likelihood:
1✔
1004
                ys = self.model(times, **self._max_like_params, **new_model_kwargs)
×
1005
                axes[ii].plot(
×
1006
                    times - self._reference_mjd_date, ys, color=color_max,
1007
                    alpha=self.max_likelihood_alpha, lw=self.linewidth)
1008
            random_ys_list = [self.model(times, **random_params, **new_model_kwargs)
1✔
1009
                              for random_params in self._get_random_parameters()]
1010
            if self.uncertainty_mode == "random_models":
1✔
1011
                for random_ys in random_ys_list:
1✔
1012
                    axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
1✔
1013
                                  alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder)
1014
            elif self.uncertainty_mode == "credible_intervals":
×
1015
                lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
1016
                axes[ii].fill_between(
×
1017
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1018
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1019
            ii += 1
1✔
1020

1021
        self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show)
1✔
1022
        return axes
1✔
1023

1024

1025
class FluxDensityPlotter(MagnitudePlotter):
1✔
1026
    pass
1✔
1027

1028
class IntegratedFluxOpticalPlotter(MagnitudePlotter):
1✔
1029
    pass
1✔
1030

1031
class SpectrumPlotter(SpecPlotter):
1✔
1032
    @property
1✔
1033
    def _xlabel(self) -> str:
1✔
1034
        return self.transient.xlabel
1✔
1035

1036
    @property
1✔
1037
    def _ylabel(self) -> str:
1✔
1038
        return self.transient.ylabel
1✔
1039

1040
    def plot_data(
1✔
1041
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1042
        """Plots the spectrum data and returns Axes.
1043

1044
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1045
        :type axes: Union[matplotlib.axes.Axes, None], optional
1046
        :param save: Whether to save the plot. (Default value = True)
1047
        :type save: bool
1048
        :param show: Whether to show the plot. (Default value = True)
1049
        :type show: bool
1050

1051
        :return: The axes with the plot.
1052
        :rtype: matplotlib.axes.Axes
1053
        """
1054
        ax = axes or plt.gca()
1✔
1055

1056
        if self.transient.plot_with_time_label:
1✔
1057
            label = self.transient.time
1✔
1058
        else:
1059
            label = self.transient.name
×
1060
        ax.plot(self.transient.angstroms, self.transient.flux_density/1e-17, color=self.color,
1✔
1061
                lw=self.linewidth)
1062
        ax.set_xscale('linear')
1✔
1063
        ax.set_yscale(self.yscale)
1✔
1064

1065
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
1066
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
1067
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1068
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1069

1070
        ax.annotate(
1✔
1071
            label, xy=self.xy, xycoords=self.xycoords,
1072
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1073

1074
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
1075

1076
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1077
        return ax
1✔
1078

1079
    def plot_spectrum(
1✔
1080
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1081
        """Plots the spectrum data and the fit and returns Axes.
1082

1083
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1084
        :type axes: Union[matplotlib.axes.Axes, None], optional
1085
        :param save: Whether to save the plot. (Default value = True)
1086
        :type save: bool
1087
        :param show: Whether to show the plot. (Default value = True)
1088
        :type show: bool
1089

1090
        :return: The axes with the plot.
1091
        :rtype: matplotlib.axes.Axes
1092
        """
1093

1094
        axes = axes or plt.gca()
1✔
1095

1096
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
1097
        angstroms = self._get_angstroms(axes)
1✔
1098

1099
        self._plot_spectrums(axes, angstroms)
1✔
1100

1101
        self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
1✔
1102
        return axes
1✔
1103

1104
    def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1✔
1105
        if self.plot_max_likelihood:
1✔
1106
            ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1107
            axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
1108
                      lw=self.linewidth)
1109

1110
        random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
1✔
1111
                          for random_params in self._get_random_parameters()]
1112
        if self.uncertainty_mode == "random_models":
1✔
1113
            for ys in random_ys_list:
1✔
1114
                axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
1115
                          lw=self.linewidth, zorder=self.zorder)
1116
        elif self.uncertainty_mode == "credible_intervals":
×
1117
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
×
1118
                                                                                interval=self.credible_interval_level)
1119
            axes.fill_between(
×
1120
                angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1121

1122
    def plot_residuals(
1✔
1123
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1124
        """Plots the residual of the Integrated flux data returns Axes.
1125

1126
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1127
        :param save: Whether to save the plot. (Default value = True)
1128
        :param show: Whether to show the plot. (Default value = True)
1129

1130
        :return: The axes with the plot.
1131
        :rtype: matplotlib.axes.Axes
1132
        """
1133
        if axes is None:
1✔
1134
            fig, axes = plt.subplots(
1✔
1135
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1136

1137
        axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
1✔
1138
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
1139
        axes[0].set_xlabel("")
1✔
1140
        ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1141
        axes[1].errorbar(
1✔
1142
            self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1143
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
1144
        axes[1].set_yscale('linear')
1✔
1145
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
1146
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
1147

1148
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
1149
        return axes
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc