• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 15047800526

15 May 2025 02:35PM UTC coverage: 86.893% (-0.05%) from 86.941%
15047800526

push

github

web-flow
Merge pull request #271 from nikhil-sarin/bugfix_eff_widths_and_freqs

Fix for effective widths

33 of 47 new or added lines in 5 files covered. (70.21%)

2 existing lines in 1 file now uncovered.

12755 of 14679 relevant lines covered (86.89%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.34
/redback/plotting.py
1
from __future__ import annotations
1✔
2

3
from os.path import join
1✔
4
from typing import Any, Union
1✔
5

6
import matplotlib
1✔
7
import matplotlib.pyplot as plt
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10

11
import redback
1✔
12
from redback.utils import KwargsAccessorWithDefault
1✔
13

14
class _FilenameGetter(object):
1✔
15
    def __init__(self, suffix: str) -> None:
1✔
16
        self.suffix = suffix
1✔
17

18
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
19
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")
1✔
20

21
    def __set__(self, instance: Plotter, value: object) -> None:
1✔
22
        pass
1✔
23

24

25
class _FilePathGetter(object):
1✔
26

27
    def __init__(self, directory_property: str, filename_property: str) -> None:
1✔
28
        self.directory_property = directory_property
1✔
29
        self.filename_property = filename_property
1✔
30

31
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
32
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))
1✔
33

34

35
class Plotter(object):
1✔
36
    """
37
    Base class for all lightcurve plotting classes in redback.
38
    """
39

40
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
41
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
42
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
43
    band_colors = KwargsAccessorWithDefault("band_colors", None)
1✔
44
    color = KwargsAccessorWithDefault("color", "k")
1✔
45
    band_labels = KwargsAccessorWithDefault("band_labels", None)
1✔
46
    band_scaling = KwargsAccessorWithDefault("band_scaling", {})
1✔
47
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
48
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
49
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "o")
1✔
50
    model = KwargsAccessorWithDefault("model", None)
1✔
51
    ms = KwargsAccessorWithDefault("ms", 5)
1✔
52
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
53

54
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
55
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
56
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
57
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
58
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
59

60
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
61
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
62
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
63

64
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
65
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
66
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
67
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
68

69
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
70
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
71
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
72
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
73
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
74
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
75

76
    plot_others = KwargsAccessorWithDefault("plot_others", True)
1✔
77
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
78
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
79
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
80
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
81
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
82

83
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0)
1✔
84
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5)
1✔
85
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0)
1✔
86
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
87

88
    def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None:
1✔
89
        """
90
        :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted.
91
        :param kwargs: Additional kwargs the plotter uses. -------
92
        :keyword capsize: Same as matplotlib capsize.
93
        :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands.
94
        :keyword legend_location: Same as matplotlib legend location.
95
        :keyword legend_cols: Same as matplotlib legend columns.
96
        :keyword color: Color of the data points.
97
        :keyword band_colors: A dictionary with the colors of the bands.
98
        :keyword band_labels: List with the names of the bands.
99
        :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types.
100
        :keyword dpi: Same as matplotlib dpi.
101
        :keyword elinewidth: same as matplotlib elinewidth
102
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
103
        :keyword model: str or callable, the model to plot.
104
        :keyword ms: Same as matplotlib markersize.
105
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
106
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
107
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
108
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
109
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
110
        :keyword random_sample_color: Color of the random sample curves.
111
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
112
        :keyword linewidth: Same as matplotlib linewidth
113
        :keyword zorder: Same as matplotlib zorder
114
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
115
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
116
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
117
        :keyword annotation_size: `size` argument of of `ax.annotate`.
118
        :keyword fontsize_axes: Font size of the x and y labels.
119
        :keyword fontsize_legend: Font size of the legend.
120
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
121
                                  Used on `supxlabel` and `supylabel`.
122
        :keyword fontsize_ticks: Font size of the axis ticks.
123
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
124
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
125
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
126
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
127
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
128
                                   'credible_intervals': Plot a credible interval that is calculated based
129
                                   on the available parameter sets.
130
        :keyword reference_mjd_date: Date to use as reference point for the x axis.
131
                                    Default is the first date in the data.
132
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
133
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
134
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
135
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
136
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
137
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
138
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
139
        """
140
        self.transient = transient
1✔
141
        self.kwargs = kwargs or dict()
1✔
142
        self._posterior_sorted = False
1✔
143

144
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
145

146
    def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
147
        """
148
        :param axes: The axes used in the plotting procedure.
149
        :type axes: matplotlib.axes.Axes
150

151
        :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot.
152
        :rtype: np.ndarray
153
        """
154
        if isinstance(axes, np.ndarray):
1✔
155
            ax = axes[0]
1✔
156
        else:
157
            ax = axes
1✔
158

159
        if ax.get_yscale() == 'linear':
1✔
160
            times = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
161
        else:
162
            times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
163

164
        if self.transient.use_phase_model:
1✔
165
            times = times + self._reference_mjd_date
1✔
166
        return times
1✔
167

168
    @property
1✔
169
    def _xlim_low(self) -> float:
1✔
170
        default = self.xlim_low_multiplier * self.transient.x[0]
1✔
171
        if default == 0:
1✔
172
            default += 1e-3
×
173
        return self.kwargs.get("xlim_low", default)
1✔
174

175
    @property
1✔
176
    def _xlim_high(self) -> float:
1✔
177
        if self._x_err is None:
1✔
178
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
179
        else:
180
            default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1])
×
181
        return self.kwargs.get("xlim_high", default)
1✔
182

183
    @property
1✔
184
    def _ylim_low(self) -> float:
1✔
185
        default = self.ylim_low_multiplier * min(self.transient.y)
1✔
186
        return self.kwargs.get("ylim_low", default)
1✔
187

188
    @property
1✔
189
    def _ylim_high(self) -> float:
1✔
190
        default = self.ylim_high_multiplier * np.max(self.transient.y)
1✔
191
        return self.kwargs.get("ylim_high", default)
1✔
192

193
    @property
1✔
194
    def _x_err(self) -> Union[np.ndarray, None]:
1✔
195
        if self.transient.x_err is not None:
1✔
196
            return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]])
×
197
        else:
198
            return None
1✔
199

200
    @property
1✔
201
    def _y_err(self) -> np.ndarray:
1✔
202
        if self.transient.y_err.ndim > 1.:
1✔
203
            return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]])
×
204
        else:
205
            return np.array([np.abs(self.transient.y_err)])
1✔
206
    @property
1✔
207
    def _lightcurve_plot_outdir(self) -> str:
1✔
208
        return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__))
1✔
209

210
    @property
1✔
211
    def _data_plot_outdir(self) -> str:
1✔
212
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
213

214
    def _get_outdir(self, default: str) -> str:
1✔
215
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
216

217
    def get_filename(self, default: str) -> str:
1✔
218
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
219

220
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
221
        return self.kwargs.get(kwarg, default) or default
1✔
222

223
    @property
1✔
224
    def _model_kwargs(self) -> dict:
1✔
225
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
226

227
    @property
1✔
228
    def _posterior(self) -> pd.DataFrame:
1✔
229
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
230
        if not self._posterior_sorted and posterior is not None:
1✔
231
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
232
            self._posterior_sorted = True
1✔
233
        return posterior
1✔
234

235
    @property
1✔
236
    def _max_like_params(self) -> pd.core.series.Series:
1✔
237
        return self._posterior.iloc[-1]
1✔
238

239
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
240
        integers = np.arange(len(self._posterior))
1✔
241
        indices = np.random.choice(integers, size=self.random_models)
1✔
242
        return [self._posterior.iloc[idx] for idx in indices]
1✔
243

244
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
245
    _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve")
1✔
246
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
247
    _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data")
1✔
248
    _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve")
1✔
249

250
    _data_plot_filepath = _FilePathGetter(
1✔
251
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
252
    _lightcurve_plot_filepath = _FilePathGetter(
1✔
253
        directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename")
254
    _residual_plot_filepath = _FilePathGetter(
1✔
255
        directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename")
256
    _multiband_data_plot_filepath = _FilePathGetter(
1✔
257
        directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename")
258
    _multiband_lightcurve_plot_filepath = _FilePathGetter(
1✔
259
        directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename")
260

261
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
262
        plt.tight_layout()
1✔
263
        if save:
1✔
264
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
265
        if show:
1✔
266
            plt.show()
×
267

268
class SpecPlotter(object):
1✔
269
    """
270
    Base class for all lightcurve plotting classes in redback.
271
    """
272

273
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
274
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
275
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
276
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
277
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
278
    color = KwargsAccessorWithDefault("color", "k")
1✔
279
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
280
    model = KwargsAccessorWithDefault("model", None)
1✔
281
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
282
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
283

284
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
285
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
286
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
287
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
288
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
289

290
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
291
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
292
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
293
    yscale = KwargsAccessorWithDefault("yscale", "linear")
1✔
294

295
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
296
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
297
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
298
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
299

300
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
301
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
302
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
303
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
304
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
305
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
306

307
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
308
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
309
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
310
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
311
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
312

313
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
1✔
314
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
315
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
1✔
316
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
317

318
    def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
1✔
319
        """
320
        :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
321
        :param kwargs: Additional kwargs the plotter uses. -------
322
        :keyword capsize: Same as matplotlib capsize.
323
        :keyword elinewidth: same as matplotlib elinewidth
324
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
325
        :keyword ms: Same as matplotlib markersize.
326
        :keyword legend_location: Same as matplotlib legend location.
327
        :keyword legend_cols: Same as matplotlib legend columns.
328
        :keyword color: Color of the data points.
329
        :keyword dpi: Same as matplotlib dpi.
330
        :keyword model: str or callable, the model to plot.
331
        :keyword ms: Same as matplotlib markersize.
332
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
333
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
334
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
335
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
336
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
337
        :keyword random_sample_color: Color of the random sample curves.
338
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
339
        :keyword linewidth: Same as matplotlib linewidth
340
        :keyword zorder: Same as matplotlib zorder
341
        :keyword yscale: Same as matplotlib yscale, default is linear
342
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
343
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
344
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
345
        :keyword annotation_size: `size` argument of of `ax.annotate`.
346
        :keyword fontsize_axes: Font size of the x and y labels.
347
        :keyword fontsize_legend: Font size of the legend.
348
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
349
                                  Used on `supxlabel` and `supylabel`.
350
        :keyword fontsize_ticks: Font size of the axis ticks.
351
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
352
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
353
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
354
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
355
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
356
                                   'credible_intervals': Plot a credible interval that is calculated based
357
                                   on the available parameter sets.
358
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
359
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
360
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
361
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
362
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
363
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
364
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
365
        """
366
        self.transient = spectrum
1✔
367
        self.kwargs = kwargs or dict()
1✔
368
        self._posterior_sorted = False
1✔
369

370
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
371

372
    def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
373
        """
374
        :param axes: The axes used in the plotting procedure.
375
        :type axes: matplotlib.axes.Axes
376

377
        :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
378
        :rtype: np.ndarray
379
        """
380
        if isinstance(axes, np.ndarray):
1✔
381
            ax = axes[0]
×
382
        else:
383
            ax = axes
1✔
384

385
        if ax.get_yscale() == 'linear':
1✔
386
            angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
387
        else:
388
            angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
389

390
        return angstroms
1✔
391

392
    @property
1✔
393
    def _xlim_low(self) -> float:
1✔
394
        default = self.xlim_low_multiplier * self.transient.angstroms[0]
1✔
395
        if default == 0:
1✔
396
            default += 1e-3
×
397
        return self.kwargs.get("xlim_low", default)
1✔
398

399
    @property
1✔
400
    def _xlim_high(self) -> float:
1✔
401
        default = self.xlim_high_multiplier * self.transient.angstroms[-1]
1✔
402
        return self.kwargs.get("xlim_high", default)
1✔
403

404
    @property
1✔
405
    def _ylim_low(self) -> float:
1✔
406
        default = self.ylim_low_multiplier * min(self.transient.flux_density)
1✔
407
        return self.kwargs.get("ylim_low", default/1e-17)
1✔
408

409
    @property
1✔
410
    def _ylim_high(self) -> float:
1✔
411
        default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
1✔
412
        return self.kwargs.get("ylim_high", default/1e-17)
1✔
413

414
    @property
1✔
415
    def _y_err(self) -> np.ndarray:
1✔
416
        return np.array([np.abs(self.transient.flux_density_err)])
1✔
417

418
    @property
1✔
419
    def _data_plot_outdir(self) -> str:
1✔
420
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
421

422
    def _get_outdir(self, default: str) -> str:
1✔
423
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
424

425
    def get_filename(self, default: str) -> str:
1✔
426
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
427

428
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
429
        return self.kwargs.get(kwarg, default) or default
1✔
430

431
    @property
1✔
432
    def _model_kwargs(self) -> dict:
1✔
433
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
434

435
    @property
1✔
436
    def _posterior(self) -> pd.DataFrame:
1✔
437
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
438
        if not self._posterior_sorted and posterior is not None:
1✔
439
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
440
            self._posterior_sorted = True
1✔
441
        return posterior
1✔
442

443
    @property
1✔
444
    def _max_like_params(self) -> pd.core.series.Series:
1✔
445
        return self._posterior.iloc[-1]
1✔
446

447
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
448
        integers = np.arange(len(self._posterior))
1✔
449
        indices = np.random.choice(integers, size=self.random_models)
1✔
450
        return [self._posterior.iloc[idx] for idx in indices]
1✔
451

452
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
453
    _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
1✔
454
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
455

456
    _data_plot_filepath = _FilePathGetter(
1✔
457
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
458
    _spectrum_ppd_plot_filepath = _FilePathGetter(
1✔
459
        directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
460
    _residual_plot_filepath = _FilePathGetter(
1✔
461
        directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
462

463
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
464
        plt.tight_layout()
1✔
465
        if save:
1✔
466
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
467
        if show:
1✔
468
            plt.show()
×
469

470

471
class IntegratedFluxPlotter(Plotter):
1✔
472

473
    @property
1✔
474
    def _xlabel(self) -> str:
1✔
475
        return r"Time since burst [s]"
1✔
476

477
    @property
1✔
478
    def _ylabel(self) -> str:
1✔
479
        return self.transient.ylabel
1✔
480

481
    def plot_data(
1✔
482
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
483
        """Plots the Integrated flux data and returns Axes.
484

485
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
486
        :type axes: Union[matplotlib.axes.Axes, None], optional
487
        :param save: Whether to save the plot. (Default value = True)
488
        :type save: bool
489
        :param show: Whether to show the plot. (Default value = True)
490
        :type show: bool
491

492
        :return: The axes with the plot.
493
        :rtype: matplotlib.axes.Axes
494
        """
495
        ax = axes or plt.gca()
1✔
496

497
        ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err,
1✔
498
                    fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
499

500
        ax.set_xscale('log')
1✔
501
        ax.set_yscale('log')
1✔
502

503
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
504
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
505
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
506
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
507

508
        ax.annotate(
1✔
509
            self.transient.name, xy=self.xy, xycoords=self.xycoords,
510
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
511

512
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
513

514
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
515
        return ax
1✔
516

517
    def plot_lightcurve(
1✔
518
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
519
        """Plots the Integrated flux data and the lightcurve and returns Axes.
520

521
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
522
        :type axes: Union[matplotlib.axes.Axes, None], optional
523
        :param save: Whether to save the plot. (Default value = True)
524
        :type save: bool
525
        :param show: Whether to show the plot. (Default value = True)
526
        :type show: bool
527

528
        :return: The axes with the plot.
529
        :rtype: matplotlib.axes.Axes
530
        """
531
        
532
        axes = axes or plt.gca()
1✔
533

534
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
535
        times = self._get_times(axes)
1✔
536

537
        self._plot_lightcurves(axes, times)
1✔
538

539
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
540
        return axes
1✔
541

542
    def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None:
1✔
543
        if self.plot_max_likelihood:
×
544
            ys = self.model(times, **self._max_like_params, **self._model_kwargs)
×
545
            axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
546

547
        random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
×
548
                          for random_params in self._get_random_parameters()]
549
        if self.uncertainty_mode == "random_models":
×
550
            for ys in random_ys_list:
×
551
                axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
552
                          zorder=self.zorder)
553
        elif self.uncertainty_mode == "credible_intervals":
×
554
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
555
            axes.fill_between(
×
556
                times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
557

558
    def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None:
1✔
559
        ys = self.model(times, **params, **self._model_kwargs)
×
560
        axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
561
                  zorder=self.zorder)
562

563
    def plot_residuals(
1✔
564
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
565
        """Plots the residual of the Integrated flux data returns Axes.
566

567
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
568
        :param save: Whether to save the plot. (Default value = True)
569
        :param show: Whether to show the plot. (Default value = True)
570

571
        :return: The axes with the plot.
572
        :rtype: matplotlib.axes.Axes
573
        """
574
        if axes is None:
1✔
575
            fig, axes = plt.subplots(
1✔
576
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
577

578
        axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False)
1✔
579
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
580
        axes[0].set_xlabel("")
1✔
581
        ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs)
1✔
582
        axes[1].errorbar(
1✔
583
            self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err,
584
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
585
        axes[1].set_yscale("log")
1✔
586
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
587
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
588

589
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
590
        return axes
1✔
591

592

593
class LuminosityOpticalPlotter(IntegratedFluxPlotter):
1✔
594

595
    @property
1✔
596
    def _xlabel(self) -> str:
1✔
597
        return r"Time since explosion [days]"
1✔
598

599
    @property
1✔
600
    def _ylabel(self) -> str:
1✔
601
        return r"L$_{\rm bol}$ [$10^{50}$ erg s$^{-1}$]"
1✔
602

603
class LuminosityPlotter(IntegratedFluxPlotter):
1✔
604
    pass
1✔
605

606

607
class MagnitudePlotter(Plotter):
1✔
608

609
    xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
610
    xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
1✔
611
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
1✔
612
    ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.8)
1✔
613
    ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.2)
1✔
614
    ncols = KwargsAccessorWithDefault("ncols", 2)
1✔
615

616
    @property
1✔
617
    def _colors(self) -> str:
1✔
618
        return self.kwargs.get("colors", self.transient.get_colors(self._filters))
1✔
619

620
    @property
1✔
621
    def _xlabel(self) -> str:
1✔
622
        if self.transient.use_phase_model:
1✔
623
            default = f"Time since {self._reference_mjd_date} MJD [days]"
1✔
624
        else:
625
            default = self.transient.xlabel
1✔
626
        return self.kwargs.get("xlabel", default)
1✔
627

628
    @property
1✔
629
    def _ylabel(self) -> str:
1✔
630
        return self.kwargs.get("ylabel", self.transient.ylabel)
1✔
631

632
    @property
1✔
633
    def _get_bands_to_plot(self) -> list[str]:
1✔
634
        return self.kwargs.get("bands_to_plot", self.transient.active_bands)
1✔
635

636
    @property
1✔
637
    def _xlim_low(self) -> float:
1✔
638
        if self.transient.use_phase_model:
1✔
639
            default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier
1✔
640
        else:
641
            default = self.xlim_low_multiplier * self.transient.x[0]
1✔
642
        if default == 0:
1✔
643
            default += 1e-3
1✔
644
        return self.kwargs.get("xlim_low", default)
1✔
645

646
    @property
1✔
647
    def _xlim_high(self) -> float:
1✔
648
        if self.transient.use_phase_model:
1✔
649
            default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier
1✔
650
        else:
651
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
652
        return self.kwargs.get("xlim_high", default)
1✔
653

654
    @property
1✔
655
    def _ylim_low_magnitude(self) -> float:
1✔
656
        return self.ylim_low_magnitude_multiplier * min(self.transient.y)
1✔
657

658
    @property
1✔
659
    def _ylim_high_magnitude(self) -> float:
1✔
660
        return self.ylim_high_magnitude_multiplier * np.max(self.transient.y)
1✔
661

662
    def _get_ylim_low_with_indices(self, indices: list) -> float:
1✔
663
        return self.ylim_low_multiplier * min(self.transient.y[indices])
1✔
664

665
    def _get_ylim_high_with_indices(self, indices: list) -> float:
1✔
666
        return self.ylim_high_multiplier * np.max(self.transient.y[indices])
1✔
667

668
    def _get_x_err(self, indices: list) -> np.ndarray:
1✔
669
        return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err
1✔
670

671
    def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None:
1✔
672
        if self.transient.magnitude_data:
1✔
673
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
674
            ax.invert_yaxis()
1✔
675
        else:
676
            ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
677
            ax.set_yscale("log")
1✔
678

679
    def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None:
1✔
680
        if self.transient.magnitude_data:
1✔
681
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
682
            ax.invert_yaxis()
1✔
683
        else:
684
            ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
1✔
685
                        self._get_ylim_high_with_indices(indices=indices))
686
            ax.set_yscale("log")
1✔
687

688
    def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None:
1✔
689
        if self.transient.use_phase_model:
1✔
690
            axes.set_xscale("log")
1✔
691
        axes.set_xlim(self._xlim_low, self._xlim_high)
1✔
692

693
    @property
1✔
694
    def _nrows(self) -> int:
1✔
695
        default = int(np.ceil(len(self._filters) / 2))
1✔
696
        return self._get_kwarg_with_default("nrows", default=default)
1✔
697

698
    @property
1✔
699
    def _npanels(self) -> int:
1✔
700
        npanels = self._nrows * self.ncols
×
701
        if npanels < len(self._filters):
×
702
            raise ValueError(f"Insufficient number of panels. {npanels} panels were given "
×
703
                             f"but {len(self._filters)} panels are needed.")
704
        return npanels
×
705

706
    @property
1✔
707
    def _figsize(self) -> tuple:
1✔
708
        default = (4 + 4 * self.ncols, 2 + 2 * self._nrows)
1✔
709
        return self._get_kwarg_with_default("figsize", default=default)
1✔
710

711
    @property
1✔
712
    def _reference_mjd_date(self) -> int:
1✔
713
        if self.transient.use_phase_model:
1✔
714
            return self.kwargs.get("reference_mjd_date", int(self.transient.x[0]))
1✔
715
        return 0
1✔
716

717
    @property
1✔
718
    def band_label_generator(self):
1✔
719
        if self.band_labels is not None:
1✔
720
            return (bl for bl in self.band_labels)
×
721

722
    def plot_data(
1✔
723
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
724
        """Plots the Magnitude data and returns Axes.
725

726
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
727
        :type axes: Union[matplotlib.axes.Axes, None], optional
728
        :param save: Whether to save the plot. (Default value = True)
729
        :type save: bool
730
        :param show: Whether to show the plot. (Default value = True)
731
        :type show: bool
732

733
        :return: The axes with the plot.
734
        :rtype: matplotlib.axes.Axes
735
        """
736
        ax = axes or plt.gca()
1✔
737

738
        band_label_generator = self.band_label_generator
1✔
739

740
        for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
741
            if band in self._filters:
1✔
742
                color = self._colors[list(self._filters).index(band)]
1✔
743
                if band_label_generator is None:
1✔
744
                    if band in self.band_scaling:
1✔
NEW
745
                        label = band + ' ' + self.band_scaling.get("type") + ' ' + str(self.band_scaling.get(band))
×
746
                    else:
747
                        label = band   
1✔
748
                else:
749
                    label = next(band_label_generator)
×
750
            elif self.plot_others:
×
751
                color = "black"
×
752
                label = None
×
753
            else:
754
                continue
×
755
            if isinstance(label, float):
1✔
756
                label = f"{label:.2e}"
×
757
            if self.band_colors is not None:
1✔
758
                color = self.band_colors[band]
×
759
            if band in self.band_scaling:
1✔
760
                if self.band_scaling.get("type") == 'x':
×
761
                    ax.errorbar(
×
762
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band),
763
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band),
764
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
765
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
766
                elif self.band_scaling.get("type") == '+':
×
767
                    ax.errorbar(
×
768
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band),
769
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
770
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
771
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
772
            else:
773
                ax.errorbar(
1✔
774
                    self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices],
775
                    xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
776
                    fmt=self.errorbar_fmt, ms=self.ms, color=color,
777
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label)
778

779
        self._set_x_axis(axes=ax)
1✔
780
        self._set_y_axis_data(ax)
1✔
781

782
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
783
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
784

785
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
786
        ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
787

788
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
789
        return ax
1✔
790

791
    def plot_lightcurve(
1✔
792
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\
793
            -> matplotlib.axes.Axes:
794
        """Plots the Magnitude data and returns Axes.
795

796
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
797
        :type axes: Union[matplotlib.axes.Axes, None], optional
798
        :param save: Whether to save the plot. (Default value = True)
799
        :type save: bool
800
        :param show: Whether to show the plot. (Default value = True)
801
        :type show: bool
802

803
        :return: The axes with the plot.
804
        :rtype: matplotlib.axes.Axes
805
        """
806
        axes = axes or plt.gca()
1✔
807

808
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
809
        axes.set_yscale('log')
1✔
810

811
        times = self._get_times(axes)
1✔
812
        bands_to_plot = self._get_bands_to_plot
1✔
813

814
        color_max = self.max_likelihood_color
1✔
815
        color_sample = self.random_sample_color
1✔
816
        for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)):
1✔
817
            if self.set_same_color_per_subplot is True:
1✔
818
                if self.band_colors is not None:
1✔
819
                    color = self.band_colors[band]
×
820
                color_max = color
1✔
821
                color_sample = color
1✔
822
            sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band])
1✔
823
            self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))]
1✔
824
            if isinstance(band, str):
1✔
825
                frequency = redback.utils.bands_to_frequency([band])
1✔
826
            else:
827
                frequency = band
×
828
            self._model_kwargs['frequency'] = np.ones(len(times)) * frequency
1✔
829
            if self.plot_max_likelihood:
1✔
830
                ys = self.model(times, **self._max_like_params, **self._model_kwargs)
1✔
831
                if band in self.band_scaling:
1✔
832
                    if self.band_scaling.get("type") == 'x':
×
833
                        axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
834
                    elif self.band_scaling.get("type") == '+':
×
835
                        axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
836
                else:        
837
                    axes.plot(times - self._reference_mjd_date, ys, color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
1✔
838

839
            random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
840
                              for random_params in self._get_random_parameters()]
841
            if self.uncertainty_mode == "random_models":
1✔
842
                for ys in random_ys_list:
1✔
843
                    if band in self.band_scaling:
1✔
844
                        if self.band_scaling.get("type") == 'x':
×
845
                            axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
846
                        elif self.band_scaling.get("type") == '+':
×
847
                            axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
848
                    else:
849
                        axes.plot(times - self._reference_mjd_date, ys, color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
1✔
850
            elif self.uncertainty_mode == "credible_intervals":
×
851
                if band in self.band_scaling:
×
852
                    if self.band_scaling.get("type") == 'x':
×
853
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
×
854
                    elif self.band_scaling.get("type") == '+':
×
855
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
×
856
                else:
857
                    lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
×
858
                axes.fill_between(
×
859
                    times - self._reference_mjd_date, lower_bound, upper_bound,
860
                    alpha=self.uncertainty_band_alpha, color=color_sample)
861

862
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
863
        return axes
1✔
864

865
    def _check_valid_multiband_data_mode(self) -> bool:
1✔
866
        if self.transient.luminosity_data:
1✔
867
            redback.utils.logger.warning(
×
868
                f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.")
869
            return False
×
870
        return True
1✔
871

872
    def plot_multiband(
1✔
873
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True,
874
            show: bool = True) -> matplotlib.axes.Axes:
875
        """Plots the Magnitude multiband data and returns Axes.
876

877
        :param figure: Matplotlib figure to plot the data into.
878
        :type figure: matplotlib.figure.Figure
879
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
880
        :type axes: Union[matplotlib.axes.Axes, None], optional
881
        :param save: Whether to save the plot. (Default value = True)
882
        :type save: bool
883
        :param show: Whether to show the plot. (Default value = True)
884
        :type show: bool
885

886
        :return: The axes with the plot.
887
        :rtype: matplotlib.axes.Axes
888
        """
889
        if not self._check_valid_multiband_data_mode():
1✔
890
            return
×
891

892
        if figure is None or axes is None:
1✔
893
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
×
894
        axes = axes.ravel()
1✔
895

896
        band_label_generator = self.band_label_generator
1✔
897

898
        ii = 0
1✔
899
        for indices, band, freq in zip(
1✔
900
                self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies):
901
            if band not in self._filters:
1✔
902
                continue
×
903

904
            x_err = self._get_x_err(indices)
1✔
905
            color = self._colors[list(self._filters).index(band)]
1✔
906
            if self.band_colors is not None:
1✔
907
                color = self.band_colors[band]
×
908
            if band_label_generator is None:
1✔
909
                label = self._get_multiband_plot_label(band, freq)
1✔
910
            else:
911
                label = next(band_label_generator)
×
912

913
            axes[ii].errorbar(
1✔
914
                self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=x_err,
915
                yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color,
916
                elinewidth=self.elinewidth, capsize=self.capsize,
917
                label=label)
918

919
            self._set_x_axis(axes[ii])
1✔
920
            self._set_y_axis_multiband_data(axes[ii], indices)
1✔
921
            axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
922
            axes[ii].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
923
            ii += 1
1✔
924

925
        figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure)
1✔
926
        figure.supylabel(self._ylabel, fontsize=self.fontsize_figure)
1✔
927
        plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace)
1✔
928

929
        self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show)
1✔
930
        return axes
1✔
931

932
    @staticmethod
1✔
933
    def _get_multiband_plot_label(band: str, freq: float) -> str:
1✔
934
        if isinstance(band, str):
1✔
935
            if 1e10 < float(freq) < 1e16:
1✔
936
                label = band
1✔
937
            else:
938
                label = f"{freq:.2e}"
×
939
        else:
940
            label = f"{band:.2e}"
×
941
        return label
1✔
942

943
    @property
1✔
944
    def _filters(self) -> list[str]:
1✔
945
        filters = self.kwargs.get("filters", self.transient.active_bands)
1✔
946
        if 'bands_to_plot' in self.kwargs:
1✔
947
            filters = self.kwargs['bands_to_plot']
×
948
        if filters is None:
1✔
949
            return self.transient.active_bands
×
950
        elif str(filters) == 'default':
1✔
951
            return self.transient.default_filters
×
952
        return filters
1✔
953

954
    def plot_multiband_lightcurve(
1✔
955
        self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
956
        """Plots the Magnitude multiband lightcurve and returns Axes.
957

958
        :param figure: Matplotlib figure to plot the data into.
959
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
960
        :type axes: Union[matplotlib.axes.Axes, None], optional
961
        :param save: Whether to save the plot. (Default value = True)
962
        :type save: bool
963
        :param show: Whether to show the plot. (Default value = True)
964
        :type show: bool
965

966
        :return: The axes with the plot.
967
        :rtype: matplotlib.axes.Axes
968
        """
969
        if not self._check_valid_multiband_data_mode():
1✔
970
            return
×
971

972
        if figure is None or axes is None:
1✔
973
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
1✔
974

975
        axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False)
1✔
976
        times = self._get_times(axes)
1✔
977

978
        ii = 0
1✔
979
        color_max = self.max_likelihood_color
1✔
980
        color_sample = self.random_sample_color
1✔
981
        for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies):
1✔
982
            if band not in self._filters:
1✔
983
                continue
×
984
            new_model_kwargs = self._model_kwargs.copy()
1✔
985
            new_model_kwargs['frequency'] = freq
1✔
986
            new_model_kwargs['bands'] = band
1✔
987
            
988
            if self.set_same_color_per_subplot is True:
1✔
989
                color = self._colors[list(self._filters).index(band)]
1✔
990
                if self.band_colors is not None:
1✔
991
                    color = self.band_colors[band]
×
992
                color_max = color
1✔
993
                color_sample = color
1✔
994

995
            if self.plot_max_likelihood:
1✔
996
                ys = self.model(times, **self._max_like_params, **new_model_kwargs)
×
997
                axes[ii].plot(
×
998
                    times - self._reference_mjd_date, ys, color=color_max,
999
                    alpha=self.max_likelihood_alpha, lw=self.linewidth)
1000
            random_ys_list = [self.model(times, **random_params, **new_model_kwargs)
1✔
1001
                              for random_params in self._get_random_parameters()]
1002
            if self.uncertainty_mode == "random_models":
1✔
1003
                for random_ys in random_ys_list:
1✔
1004
                    axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
1✔
1005
                                  alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder)
1006
            elif self.uncertainty_mode == "credible_intervals":
×
1007
                lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
1008
                axes[ii].fill_between(
×
1009
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1010
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1011
            ii += 1
1✔
1012

1013
        self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show)
1✔
1014
        return axes
1✔
1015

1016

1017
class FluxDensityPlotter(MagnitudePlotter):
1✔
1018
    pass
1✔
1019

1020
class IntegratedFluxOpticalPlotter(MagnitudePlotter):
1✔
1021
    pass
1✔
1022

1023
class SpectrumPlotter(SpecPlotter):
1✔
1024
    @property
1✔
1025
    def _xlabel(self) -> str:
1✔
1026
        return self.transient.xlabel
1✔
1027

1028
    @property
1✔
1029
    def _ylabel(self) -> str:
1✔
1030
        return self.transient.ylabel
1✔
1031

1032
    def plot_data(
1✔
1033
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1034
        """Plots the spectrum data and returns Axes.
1035

1036
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1037
        :type axes: Union[matplotlib.axes.Axes, None], optional
1038
        :param save: Whether to save the plot. (Default value = True)
1039
        :type save: bool
1040
        :param show: Whether to show the plot. (Default value = True)
1041
        :type show: bool
1042

1043
        :return: The axes with the plot.
1044
        :rtype: matplotlib.axes.Axes
1045
        """
1046
        ax = axes or plt.gca()
1✔
1047

1048
        if self.transient.plot_with_time_label:
1✔
1049
            label = self.transient.time
1✔
1050
        else:
1051
            label = self.transient.name
×
1052
        ax.plot(self.transient.angstroms, self.transient.flux_density/1e-17, color=self.color,
1✔
1053
                lw=self.linewidth)
1054
        ax.set_xscale('linear')
1✔
1055
        ax.set_yscale(self.yscale)
1✔
1056

1057
        ax.set_xlim(self._xlim_low, self._xlim_high)
1✔
1058
        ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
1059
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
1060
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
1061

1062
        ax.annotate(
1✔
1063
            label, xy=self.xy, xycoords=self.xycoords,
1064
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1065

1066
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
1067

1068
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
1069
        return ax
1✔
1070

1071
    def plot_spectrum(
1✔
1072
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1073
        """Plots the spectrum data and the fit and returns Axes.
1074

1075
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1076
        :type axes: Union[matplotlib.axes.Axes, None], optional
1077
        :param save: Whether to save the plot. (Default value = True)
1078
        :type save: bool
1079
        :param show: Whether to show the plot. (Default value = True)
1080
        :type show: bool
1081

1082
        :return: The axes with the plot.
1083
        :rtype: matplotlib.axes.Axes
1084
        """
1085

1086
        axes = axes or plt.gca()
1✔
1087

1088
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
1089
        angstroms = self._get_angstroms(axes)
1✔
1090

1091
        self._plot_spectrums(axes, angstroms)
1✔
1092

1093
        self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
1✔
1094
        return axes
1✔
1095

1096
    def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1✔
1097
        if self.plot_max_likelihood:
1✔
1098
            ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1099
            axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
1✔
1100
                      lw=self.linewidth)
1101

1102
        random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
1✔
1103
                          for random_params in self._get_random_parameters()]
1104
        if self.uncertainty_mode == "random_models":
1✔
1105
            for ys in random_ys_list:
1✔
1106
                axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
1✔
1107
                          lw=self.linewidth, zorder=self.zorder)
1108
        elif self.uncertainty_mode == "credible_intervals":
×
1109
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
×
1110
                                                                                interval=self.credible_interval_level)
1111
            axes.fill_between(
×
1112
                angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1113

1114
    def plot_residuals(
1✔
1115
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1116
        """Plots the residual of the Integrated flux data returns Axes.
1117

1118
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1119
        :param save: Whether to save the plot. (Default value = True)
1120
        :param show: Whether to show the plot. (Default value = True)
1121

1122
        :return: The axes with the plot.
1123
        :rtype: matplotlib.axes.Axes
1124
        """
1125
        if axes is None:
1✔
1126
            fig, axes = plt.subplots(
1✔
1127
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1128

1129
        axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
1✔
1130
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
1✔
1131
        axes[0].set_xlabel("")
1✔
1132
        ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
1✔
1133
        axes[1].errorbar(
1✔
1134
            self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1135
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
1136
        axes[1].set_yscale('linear')
1✔
1137
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
1✔
1138
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
1139

1140
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
1✔
1141
        return axes
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc