• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 13488561206

24 Feb 2025 12:42AM UTC coverage: 81.072% (-1.0%) from 82.065%
13488561206

Pull #256

github

web-flow
Merge 46c5928af into a8e311392
Pull Request #256: Spectral analysis features

128 of 263 new or added lines in 11 files covered. (48.67%)

37 existing lines in 3 files now uncovered.

9890 of 12199 relevant lines covered (81.07%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.61
/redback/plotting.py
1
from __future__ import annotations
1✔
2

3
from os.path import join
1✔
4
from typing import Any, Union
1✔
5

6
import matplotlib
1✔
7
import matplotlib.pyplot as plt
1✔
8
import numpy as np
1✔
9
import pandas as pd
1✔
10

11
import redback
1✔
12
from redback.utils import KwargsAccessorWithDefault
1✔
13

14
class _FilenameGetter(object):
1✔
15
    def __init__(self, suffix: str) -> None:
1✔
16
        self.suffix = suffix
1✔
17

18
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
19
        return instance.get_filename(default=f"{instance.transient.name}_{self.suffix}.png")
1✔
20

21
    def __set__(self, instance: Plotter, value: object) -> None:
1✔
22
        pass
×
23

24

25
class _FilePathGetter(object):
1✔
26

27
    def __init__(self, directory_property: str, filename_property: str) -> None:
1✔
28
        self.directory_property = directory_property
1✔
29
        self.filename_property = filename_property
1✔
30

31
    def __get__(self, instance: Plotter, owner: object) -> str:
1✔
32
        return join(getattr(instance, self.directory_property), getattr(instance, self.filename_property))
1✔
33

34

35
class Plotter(object):
1✔
36
    """
37
    Base class for all lightcurve plotting classes in redback.
38
    """
39

40
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
41
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
42
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
43
    band_colors = KwargsAccessorWithDefault("band_colors", None)
1✔
44
    color = KwargsAccessorWithDefault("color", "k")
1✔
45
    band_labels = KwargsAccessorWithDefault("band_labels", None)
1✔
46
    band_scaling = KwargsAccessorWithDefault("band_scaling", {})
1✔
47
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
48
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
49
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
50
    model = KwargsAccessorWithDefault("model", None)
1✔
51
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
52
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
53

54
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
55
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
56
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
57
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
58
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
59

60
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
61
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
62
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
63

64
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
65
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
66
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
67
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
68

69
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
70
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
71
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
72
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
73
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
74
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
75

76
    plot_others = KwargsAccessorWithDefault("plot_others", True)
1✔
77
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
78
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
79
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
80
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
81
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
82

83
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 2.0)
1✔
84
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.5)
1✔
85
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 2.0)
1✔
86
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
87

88
    def __init__(self, transient: Union[redback.transient.Transient, None], **kwargs) -> None:
1✔
89
        """
90
        :param transient: An instance of `redback.transient.Transient`. Contains the data to be plotted.
91
        :param kwargs: Additional kwargs the plotter uses. -------
92
        :keyword capsize: Same as matplotlib capsize.
93
        :keyword bands_to_plot: List of bands to plot in plot lightcurve and multiband lightcurve. Default is active bands.
94
        :keyword legend_location: Same as matplotlib legend location.
95
        :keyword legend_cols: Same as matplotlib legend columns.
96
        :keyword color: Color of the data points.
97
        :keyword band_colors: A dictionary with the colors of the bands.
98
        :keyword band_labels: List with the names of the bands.
99
        :keyword band_scaling: Dict with the scaling for each band. First entry should be {type: '+' or 'x'} for different types.
100
        :keyword dpi: Same as matplotlib dpi.
101
        :keyword elinewidth: same as matplotlib elinewidth
102
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
103
        :keyword model: str or callable, the model to plot.
104
        :keyword ms: Same as matplotlib markersize.
105
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
106
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
107
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
108
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
109
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
110
        :keyword random_sample_color: Color of the random sample curves.
111
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
112
        :keyword linewidth: Same as matplotlib linewidth
113
        :keyword zorder: Same as matplotlib zorder
114
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
115
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
116
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
117
        :keyword annotation_size: `size` argument of of `ax.annotate`.
118
        :keyword fontsize_axes: Font size of the x and y labels.
119
        :keyword fontsize_legend: Font size of the legend.
120
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
121
                                  Used on `supxlabel` and `supylabel`.
122
        :keyword fontsize_ticks: Font size of the axis ticks.
123
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
124
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
125
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
126
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
127
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
128
                                   'credible_intervals': Plot a credible interval that is calculated based
129
                                   on the available parameter sets.
130
        :keyword reference_mjd_date: Date to use as reference point for the x axis.
131
                                    Default is the first date in the data.
132
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
133
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
134
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
135
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
136
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
137
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
138
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
139
        """
140
        self.transient = transient
1✔
141
        self.kwargs = kwargs or dict()
1✔
142
        self._posterior_sorted = False
1✔
143

144
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
145

146
    def _get_times(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
147
        """
148
        :param axes: The axes used in the plotting procedure.
149
        :type axes: matplotlib.axes.Axes
150

151
        :return: Linearly or logarithmically scaled time values depending on the y scale used in the plot.
152
        :rtype: np.ndarray
153
        """
154
        if isinstance(axes, np.ndarray):
1✔
155
            ax = axes[0]
1✔
156
        else:
157
            ax = axes
1✔
158

159
        if ax.get_yscale() == 'linear':
1✔
160
            times = np.linspace(self._xlim_low, self._xlim_high, 200)
1✔
161
        else:
162
            times = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
1✔
163

164
        if self.transient.use_phase_model:
1✔
165
            times = times + self._reference_mjd_date
1✔
166
        return times
1✔
167

168
    @property
1✔
169
    def _xlim_low(self) -> float:
1✔
170
        default = self.xlim_low_multiplier * self.transient.x[0]
×
171
        if default == 0:
×
172
            default += 1e-3
×
173
        return self.kwargs.get("xlim_low", default)
×
174

175
    @property
1✔
176
    def _xlim_high(self) -> float:
1✔
177
        if self._x_err is None:
×
178
            default = self.xlim_high_multiplier * self.transient.x[-1]
×
179
        else:
180
            default = self.xlim_high_multiplier * (self.transient.x[-1] + self._x_err[1][-1])
×
181
        return self.kwargs.get("xlim_high", default)
×
182

183
    @property
1✔
184
    def _ylim_low(self) -> float:
1✔
185
        default = self.ylim_low_multiplier * min(self.transient.y)
1✔
186
        return self.kwargs.get("ylim_low", default)
1✔
187

188
    @property
1✔
189
    def _ylim_high(self) -> float:
1✔
190
        default = self.ylim_high_multiplier * np.max(self.transient.y)
1✔
191
        return self.kwargs.get("ylim_high", default)
1✔
192

193
    @property
1✔
194
    def _x_err(self) -> Union[np.ndarray, None]:
1✔
195
        if self.transient.x_err is not None:
×
196
            return np.array([np.abs(self.transient.x_err[1, :]), self.transient.x_err[0, :]])
×
197
        else:
198
            return None
×
199

200
    @property
1✔
201
    def _y_err(self) -> np.ndarray:
1✔
202
        if self.transient.y_err.ndim > 1.:
×
203
            return np.array([np.abs(self.transient.y_err[1, :]), self.transient.y_err[0, :]])
×
204
        else:
205
            return np.array([np.abs(self.transient.y_err)])
×
206
    @property
1✔
207
    def _lightcurve_plot_outdir(self) -> str:
1✔
208
        return self._get_outdir(join(self.transient.directory_structure.directory_path, self.model.__name__))
1✔
209

210
    @property
1✔
211
    def _data_plot_outdir(self) -> str:
1✔
212
        return self._get_outdir(self.transient.directory_structure.directory_path)
1✔
213

214
    def _get_outdir(self, default: str) -> str:
1✔
215
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
1✔
216

217
    def get_filename(self, default: str) -> str:
1✔
218
        return self._get_kwarg_with_default(kwarg="filename", default=default)
1✔
219

220
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
221
        return self.kwargs.get(kwarg, default) or default
1✔
222

223
    @property
1✔
224
    def _model_kwargs(self) -> dict:
1✔
225
        return self._get_kwarg_with_default("model_kwargs", dict())
1✔
226

227
    @property
1✔
228
    def _posterior(self) -> pd.DataFrame:
1✔
229
        posterior = self.kwargs.get("posterior", pd.DataFrame())
1✔
230
        if not self._posterior_sorted and posterior is not None:
1✔
231
            posterior.sort_values(by='log_likelihood', inplace=True)
1✔
232
            self._posterior_sorted = True
1✔
233
        return posterior
1✔
234

235
    @property
1✔
236
    def _max_like_params(self) -> pd.core.series.Series:
1✔
237
        return self._posterior.iloc[-1]
×
238

239
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
240
        integers = np.arange(len(self._posterior))
1✔
241
        indices = np.random.choice(integers, size=self.random_models)
1✔
242
        return [self._posterior.iloc[idx] for idx in indices]
1✔
243

244
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
245
    _lightcurve_plot_filename = _FilenameGetter(suffix="lightcurve")
1✔
246
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
247
    _multiband_data_plot_filename = _FilenameGetter(suffix="multiband_data")
1✔
248
    _multiband_lightcurve_plot_filename = _FilenameGetter(suffix="multiband_lightcurve")
1✔
249

250
    _data_plot_filepath = _FilePathGetter(
1✔
251
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
252
    _lightcurve_plot_filepath = _FilePathGetter(
1✔
253
        directory_property="_lightcurve_plot_outdir", filename_property="_lightcurve_plot_filename")
254
    _residual_plot_filepath = _FilePathGetter(
1✔
255
        directory_property="_lightcurve_plot_outdir", filename_property="_residual_plot_filename")
256
    _multiband_data_plot_filepath = _FilePathGetter(
1✔
257
        directory_property="_data_plot_outdir", filename_property="_multiband_data_plot_filename")
258
    _multiband_lightcurve_plot_filepath = _FilePathGetter(
1✔
259
        directory_property="_lightcurve_plot_outdir", filename_property="_multiband_lightcurve_plot_filename")
260

261
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
262
        plt.tight_layout()
1✔
263
        if save:
1✔
264
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
265
        if show:
1✔
266
            plt.show()
×
267

268
class SpecPlotter(object):
1✔
269
    """
270
    Base class for all lightcurve plotting classes in redback.
271
    """
272

273
    capsize = KwargsAccessorWithDefault("capsize", 0.)
1✔
274
    elinewidth = KwargsAccessorWithDefault("elinewidth", 2)
1✔
275
    errorbar_fmt = KwargsAccessorWithDefault("errorbar_fmt", "x")
1✔
276
    legend_location = KwargsAccessorWithDefault("legend_location", "best")
1✔
277
    legend_cols = KwargsAccessorWithDefault("legend_cols", 2)
1✔
278
    color = KwargsAccessorWithDefault("color", "k")
1✔
279
    dpi = KwargsAccessorWithDefault("dpi", 300)
1✔
280
    model = KwargsAccessorWithDefault("model", None)
1✔
281
    ms = KwargsAccessorWithDefault("ms", 1)
1✔
282
    axis_tick_params_pad = KwargsAccessorWithDefault("axis_tick_params_pad", 10)
1✔
283

284
    max_likelihood_alpha = KwargsAccessorWithDefault("max_likelihood_alpha", 0.65)
1✔
285
    random_sample_alpha = KwargsAccessorWithDefault("random_sample_alpha", 0.05)
1✔
286
    uncertainty_band_alpha = KwargsAccessorWithDefault("uncertainty_band_alpha", 0.4)
1✔
287
    max_likelihood_color = KwargsAccessorWithDefault("max_likelihood_color", "blue")
1✔
288
    random_sample_color = KwargsAccessorWithDefault("random_sample_color", "red")
1✔
289

290
    bbox_inches = KwargsAccessorWithDefault("bbox_inches", "tight")
1✔
291
    linewidth = KwargsAccessorWithDefault("linewidth", 2)
1✔
292
    zorder = KwargsAccessorWithDefault("zorder", -1)
1✔
293
    yscale = KwargsAccessorWithDefault("yscale", "linear")
1✔
294

295
    xy = KwargsAccessorWithDefault("xy", (0.95, 0.9))
1✔
296
    xycoords = KwargsAccessorWithDefault("xycoords", "axes fraction")
1✔
297
    horizontalalignment = KwargsAccessorWithDefault("horizontalalignment", "right")
1✔
298
    annotation_size = KwargsAccessorWithDefault("annotation_size", 20)
1✔
299

300
    fontsize_axes = KwargsAccessorWithDefault("fontsize_axes", 18)
1✔
301
    fontsize_figure = KwargsAccessorWithDefault("fontsize_figure", 30)
1✔
302
    fontsize_legend = KwargsAccessorWithDefault("fontsize_legend", 18)
1✔
303
    fontsize_ticks = KwargsAccessorWithDefault("fontsize_ticks", 16)
1✔
304
    hspace = KwargsAccessorWithDefault("hspace", 0.04)
1✔
305
    wspace = KwargsAccessorWithDefault("wspace", 0.15)
1✔
306

307
    random_models = KwargsAccessorWithDefault("random_models", 100)
1✔
308
    uncertainty_mode = KwargsAccessorWithDefault("uncertainty_mode", "random_models")
1✔
309
    credible_interval_level = KwargsAccessorWithDefault("credible_interval_level", 0.9)
1✔
310
    plot_max_likelihood = KwargsAccessorWithDefault("plot_max_likelihood", True)
1✔
311
    set_same_color_per_subplot = KwargsAccessorWithDefault("set_same_color_per_subplot", True)
1✔
312

313
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.05)
1✔
314
    xlim_low_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
315
    ylim_high_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.1)
1✔
316
    ylim_low_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.5)
1✔
317

318
    def __init__(self, spectrum: Union[redback.transient.Spectrum, None], **kwargs) -> None:
1✔
319
        """
320
        :param spectrum: An instance of `redback.transient.Spectrum`. Contains the data to be plotted.
321
        :param kwargs: Additional kwargs the plotter uses. -------
322
        :keyword capsize: Same as matplotlib capsize.
323
        :keyword elinewidth: same as matplotlib elinewidth
324
        :keyword errorbar_fmt: 'fmt' argument of `ax.errorbar`.
325
        :keyword ms: Same as matplotlib markersize.
326
        :keyword legend_location: Same as matplotlib legend location.
327
        :keyword legend_cols: Same as matplotlib legend columns.
328
        :keyword color: Color of the data points.
329
        :keyword dpi: Same as matplotlib dpi.
330
        :keyword model: str or callable, the model to plot.
331
        :keyword ms: Same as matplotlib markersize.
332
        :keyword axis_tick_params_pad: `pad` argument in calls to `ax.tick_params` when setting the axes.
333
        :keyword max_likelihood_alpha: `alpha` argument, i.e. transparency, when plotting the max likelihood curve.
334
        :keyword random_sample_alpha: `alpha` argument, i.e. transparency, when plotting random sample curves.
335
        :keyword uncertainty_band_alpha: `alpha` argument, i.e. transparency, when plotting a credible band.
336
        :keyword max_likelihood_color: Color of the maximum likelihood curve.
337
        :keyword random_sample_color: Color of the random sample curves.
338
        :keyword bbox_inches: Setting for saving plots. Default is 'tight'.
339
        :keyword linewidth: Same as matplotlib linewidth
340
        :keyword zorder: Same as matplotlib zorder
341
        :keyword yscale: Same as matplotlib yscale, default is linear
342
        :keyword xy: For `ax.annotate' x and y coordinates of the point to annotate.
343
        :keyword xycoords: The coordinate system `xy` is given in. Default is 'axes fraction'
344
        :keyword horizontalalignment: Horizontal alignment of the annotation. Default is 'right'
345
        :keyword annotation_size: `size` argument of of `ax.annotate`.
346
        :keyword fontsize_axes: Font size of the x and y labels.
347
        :keyword fontsize_legend: Font size of the legend.
348
        :keyword fontsize_figure: Font size of the figure. Relevant for multiband plots.
349
                                  Used on `supxlabel` and `supylabel`.
350
        :keyword fontsize_ticks: Font size of the axis ticks.
351
        :keyword hspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
352
        :keyword wspace: Argument for `subplots_adjust`, sets horizontal spacing between panels.
353
        :keyword plot_others: Whether to plot additional bands in the data plot, all in the same colors
354
        :keyword random_models: Number of random draws to use to calculate credible bands or to plot.
355
        :keyword uncertainty_mode: 'random_models': Plot random draws from the available parameter sets.
356
                                   'credible_intervals': Plot a credible interval that is calculated based
357
                                   on the available parameter sets.
358
        :keyword credible_interval_level: 0.9: Plot the 90% credible interval.
359
        :keyword plot_max_likelihood: Plots the draw corresponding to the maximum likelihood. Default is 'True'.
360
        :keyword set_same_color_per_subplot: Sets the lightcurve to be the same color as the data per subplot. Default is 'True'.
361
        :keyword xlim_high_multiplier: Adjust the maximum xlim based on available x values.
362
        :keyword xlim_low_multiplier: Adjust the minimum xlim based on available x values.
363
        :keyword ylim_high_multiplier: Adjust the maximum ylim based on available x values.
364
        :keyword ylim_low_multiplier: Adjust the minimum ylim based on available x values.
365
        """
NEW
366
        self.transient = spectrum
×
NEW
367
        self.kwargs = kwargs or dict()
×
NEW
368
        self._posterior_sorted = False
×
369

370
    keyword_docstring = __init__.__doc__.split("-------")[1]
1✔
371

372
    def _get_angstroms(self, axes: matplotlib.axes.Axes) -> np.ndarray:
1✔
373
        """
374
        :param axes: The axes used in the plotting procedure.
375
        :type axes: matplotlib.axes.Axes
376

377
        :return: Linearly or logarithmically scaled angtrom values depending on the y scale used in the plot.
378
        :rtype: np.ndarray
379
        """
NEW
380
        if isinstance(axes, np.ndarray):
×
NEW
381
            ax = axes[0]
×
382
        else:
NEW
383
            ax = axes
×
384

NEW
385
        if ax.get_yscale() == 'linear':
×
NEW
386
            angstroms = np.linspace(self._xlim_low, self._xlim_high, 200)
×
387
        else:
NEW
388
            angstroms = np.exp(np.linspace(np.log(self._xlim_low), np.log(self._xlim_high), 200))
×
389

NEW
390
        return angstroms
×
391

392
    @property
1✔
393
    def _xlim_low(self) -> float:
1✔
NEW
394
        default = self.xlim_low_multiplier * self.transient.angstroms[0]
×
NEW
395
        if default == 0:
×
NEW
396
            default += 1e-3
×
NEW
397
        return self.kwargs.get("xlim_low", default)
×
398

399
    @property
1✔
400
    def _xlim_high(self) -> float:
1✔
NEW
401
        default = self.xlim_high_multiplier * self.transient.angstroms[-1]
×
NEW
402
        return self.kwargs.get("xlim_high", default)
×
403

404
    @property
1✔
405
    def _ylim_low(self) -> float:
1✔
NEW
406
        default = self.ylim_low_multiplier * min(self.transient.flux_density)
×
NEW
407
        return self.kwargs.get("ylim_low", default/1e-17)
×
408

409
    @property
1✔
410
    def _ylim_high(self) -> float:
1✔
NEW
411
        default = self.ylim_high_multiplier * np.max(self.transient.flux_density)
×
NEW
412
        return self.kwargs.get("ylim_high", default/1e-17)
×
413

414
    @property
1✔
415
    def _y_err(self) -> np.ndarray:
1✔
NEW
416
            return np.array([np.abs(self.transient.y_err)])
×
417

418
    @property
1✔
419
    def _data_plot_outdir(self) -> str:
1✔
NEW
420
        return self._get_outdir(self.transient.directory_structure.directory_path)
×
421

422
    def _get_outdir(self, default: str) -> str:
1✔
NEW
423
        return self._get_kwarg_with_default(kwarg="outdir", default=default)
×
424

425
    def get_filename(self, default: str) -> str:
1✔
NEW
426
        return self._get_kwarg_with_default(kwarg="filename", default=default)
×
427

428
    def _get_kwarg_with_default(self, kwarg: str, default: Any) -> Any:
1✔
NEW
429
        return self.kwargs.get(kwarg, default) or default
×
430

431
    @property
1✔
432
    def _model_kwargs(self) -> dict:
1✔
NEW
433
        return self._get_kwarg_with_default("model_kwargs", dict())
×
434

435
    @property
1✔
436
    def _posterior(self) -> pd.DataFrame:
1✔
NEW
437
        posterior = self.kwargs.get("posterior", pd.DataFrame())
×
NEW
438
        if not self._posterior_sorted and posterior is not None:
×
NEW
439
            posterior.sort_values(by='log_likelihood', inplace=True)
×
NEW
440
            self._posterior_sorted = True
×
NEW
441
        return posterior
×
442

443
    @property
1✔
444
    def _max_like_params(self) -> pd.core.series.Series:
1✔
NEW
445
        return self._posterior.iloc[-1]
×
446

447
    def _get_random_parameters(self) -> list[pd.core.series.Series]:
1✔
NEW
448
        integers = np.arange(len(self._posterior))
×
NEW
449
        indices = np.random.choice(integers, size=self.random_models)
×
NEW
450
        return [self._posterior.iloc[idx] for idx in indices]
×
451

452
    _data_plot_filename = _FilenameGetter(suffix="data")
1✔
453
    _spectrum_ppd_plot_filename = _FilenameGetter(suffix="spectrum_ppd")
1✔
454
    _residual_plot_filename = _FilenameGetter(suffix="residual")
1✔
455

456
    _data_plot_filepath = _FilePathGetter(
1✔
457
        directory_property="_data_plot_outdir", filename_property="_data_plot_filename")
458
    _spectrum_ppd_plot_filepath = _FilePathGetter(
1✔
459
        directory_property="_data_plot_outdir", filename_property="_spectrum_ppd_plot_filename")
460
    _residual_plot_filepath = _FilePathGetter(
1✔
461
        directory_property="_data_plot_outdir", filename_property="_residual_plot_filename")
462

463
    def _save_and_show(self, filepath: str, save: bool, show: bool) -> None:
1✔
NEW
464
        plt.tight_layout()
×
NEW
465
        if save:
×
NEW
466
            plt.savefig(filepath, dpi=self.dpi, bbox_inches=self.bbox_inches, transparent=False, facecolor='white')
×
NEW
467
        if show:
×
NEW
468
            plt.show()
×
469

470

471
class IntegratedFluxPlotter(Plotter):
1✔
472

473
    @property
1✔
474
    def _xlabel(self) -> str:
1✔
475
        return r"Time since burst [s]"
×
476

477
    @property
1✔
478
    def _ylabel(self) -> str:
1✔
479
        return self.transient.ylabel
×
480

481
    def plot_data(
1✔
482
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
483
        """Plots the Integrated flux data and returns Axes.
484

485
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
486
        :type axes: Union[matplotlib.axes.Axes, None], optional
487
        :param save: Whether to save the plot. (Default value = True)
488
        :type save: bool
489
        :param show: Whether to show the plot. (Default value = True)
490
        :type show: bool
491

492
        :return: The axes with the plot.
493
        :rtype: matplotlib.axes.Axes
494
        """
495
        ax = axes or plt.gca()
×
496

497
        ax.errorbar(self.transient.x, self.transient.y, xerr=self._x_err, yerr=self._y_err,
×
498
                    fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
499

500
        ax.set_xscale('log')
×
501
        ax.set_yscale('log')
×
502

503
        ax.set_xlim(self._xlim_low, self._xlim_high)
×
504
        ax.set_ylim(self._ylim_low, self._ylim_high)
×
505
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
×
506
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
×
507

508
        ax.annotate(
×
509
            self.transient.name, xy=self.xy, xycoords=self.xycoords,
510
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
511

512
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
×
513

514
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
×
515
        return ax
×
516

517
    def plot_lightcurve(
1✔
518
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
519
        """Plots the Integrated flux data and the lightcurve and returns Axes.
520

521
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
522
        :type axes: Union[matplotlib.axes.Axes, None], optional
523
        :param save: Whether to save the plot. (Default value = True)
524
        :type save: bool
525
        :param show: Whether to show the plot. (Default value = True)
526
        :type show: bool
527

528
        :return: The axes with the plot.
529
        :rtype: matplotlib.axes.Axes
530
        """
531
        
532
        axes = axes or plt.gca()
×
533

534
        axes = self.plot_data(axes=axes, save=False, show=False)
×
535
        times = self._get_times(axes)
×
536

537
        self._plot_lightcurves(axes, times)
×
538

539
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
×
540
        return axes
×
541

542
    def _plot_lightcurves(self, axes: matplotlib.axes.Axes, times: np.ndarray) -> None:
1✔
543
        if self.plot_max_likelihood:
×
544
            ys = self.model(times, **self._max_like_params, **self._model_kwargs)
×
545
            axes.plot(times, ys, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
546

547
        random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
×
548
                          for random_params in self._get_random_parameters()]
549
        if self.uncertainty_mode == "random_models":
×
550
            for ys in random_ys_list:
×
551
                axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
552
                          zorder=self.zorder)
553
        elif self.uncertainty_mode == "credible_intervals":
×
554
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
555
            axes.fill_between(
×
556
                times, lower_bound, upper_bound, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
557

558
    def _plot_single_lightcurve(self, axes: matplotlib.axes.Axes, times: np.ndarray, params: dict) -> None:
1✔
559
        ys = self.model(times, **params, **self._model_kwargs)
×
560
        axes.plot(times, ys, color=self.random_sample_color, alpha=self.random_sample_alpha, lw=self.linewidth,
×
561
                  zorder=self.zorder)
562

563
    def plot_residuals(
1✔
564
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
565
        """Plots the residual of the Integrated flux data returns Axes.
566

567
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
568
        :param save: Whether to save the plot. (Default value = True)
569
        :param show: Whether to show the plot. (Default value = True)
570

571
        :return: The axes with the plot.
572
        :rtype: matplotlib.axes.Axes
573
        """
574
        if axes is None:
×
575
            fig, axes = plt.subplots(
×
576
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
577

578
        axes[0] = self.plot_lightcurve(axes=axes[0], save=False, show=False)
×
579
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
×
580
        axes[0].set_xlabel("")
×
581
        ys = self.model(self.transient.x, **self._max_like_params, **self._model_kwargs)
×
582
        axes[1].errorbar(
×
583
            self.transient.x, self.transient.y - ys, xerr=self._x_err, yerr=self._y_err,
584
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
585
        axes[1].set_yscale("log")
×
586
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
×
587
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
×
588

589
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
×
590
        return axes
×
591

592

593
class LuminosityPlotter(IntegratedFluxPlotter):
1✔
594
    pass
1✔
595

596

597
class MagnitudePlotter(Plotter):
1✔
598

599
    xlim_low_phase_model_multiplier = KwargsAccessorWithDefault("xlim_low_multiplier", 0.9)
1✔
600
    xlim_high_phase_model_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.1)
1✔
601
    xlim_high_multiplier = KwargsAccessorWithDefault("xlim_high_multiplier", 1.2)
1✔
602
    ylim_low_magnitude_multiplier = KwargsAccessorWithDefault("ylim_low_multiplier", 0.8)
1✔
603
    ylim_high_magnitude_multiplier = KwargsAccessorWithDefault("ylim_high_multiplier", 1.2)
1✔
604
    ncols = KwargsAccessorWithDefault("ncols", 2)
1✔
605

606
    @property
1✔
607
    def _colors(self) -> str:
1✔
608
        return self.kwargs.get("colors", self.transient.get_colors(self._filters))
1✔
609

610
    @property
1✔
611
    def _xlabel(self) -> str:
1✔
612
        if self.transient.use_phase_model:
1✔
613
            default = f"Time since {self._reference_mjd_date} MJD [days]"
1✔
614
        else:
615
            default = self.transient.xlabel
1✔
616
        return self.kwargs.get("xlabel", default)
1✔
617

618
    @property
1✔
619
    def _ylabel(self) -> str:
1✔
620
        return self.kwargs.get("ylabel", self.transient.ylabel)
1✔
621

622
    @property
1✔
623
    def _get_bands_to_plot(self) -> list[str]:
1✔
624
        return self.kwargs.get("bands_to_plot", self.transient.active_bands)
1✔
625

626
    @property
1✔
627
    def _xlim_low(self) -> float:
1✔
628
        if self.transient.use_phase_model:
1✔
629
            default = (self.transient.x[0] - self._reference_mjd_date) * self.xlim_low_phase_model_multiplier
1✔
630
        else:
631
            default = self.xlim_low_multiplier * self.transient.x[0]
1✔
632
        if default == 0:
1✔
633
            default += 1e-3
1✔
634
        return self.kwargs.get("xlim_low", default)
1✔
635

636
    @property
1✔
637
    def _xlim_high(self) -> float:
1✔
638
        if self.transient.use_phase_model:
1✔
639
            default = (self.transient.x[-1] - self._reference_mjd_date) * self.xlim_high_phase_model_multiplier
1✔
640
        else:
641
            default = self.xlim_high_multiplier * self.transient.x[-1]
1✔
642
        return self.kwargs.get("xlim_high", default)
1✔
643

644
    @property
1✔
645
    def _ylim_low_magnitude(self) -> float:
1✔
646
        return self.ylim_low_magnitude_multiplier * min(self.transient.y)
1✔
647

648
    @property
1✔
649
    def _ylim_high_magnitude(self) -> float:
1✔
650
        return self.ylim_high_magnitude_multiplier * np.max(self.transient.y)
1✔
651

652
    def _get_ylim_low_with_indices(self, indices: list) -> float:
1✔
653
        return self.ylim_low_multiplier * min(self.transient.y[indices])
1✔
654

655
    def _get_ylim_high_with_indices(self, indices: list) -> float:
1✔
656
        return self.ylim_high_multiplier * np.max(self.transient.y[indices])
1✔
657

658
    def _get_x_err(self, indices: list) -> np.ndarray:
1✔
659
        return self.transient.x_err[indices] if self.transient.x_err is not None else self.transient.x_err
1✔
660

661
    def _set_y_axis_data(self, ax: matplotlib.axes.Axes) -> None:
1✔
662
        if self.transient.magnitude_data:
1✔
663
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
664
            ax.invert_yaxis()
1✔
665
        else:
666
            ax.set_ylim(self._ylim_low, self._ylim_high)
1✔
667
            ax.set_yscale("log")
1✔
668

669
    def _set_y_axis_multiband_data(self, ax: matplotlib.axes.Axes, indices: list) -> None:
1✔
670
        if self.transient.magnitude_data:
1✔
671
            ax.set_ylim(self._ylim_low_magnitude, self._ylim_high_magnitude)
1✔
672
            ax.invert_yaxis()
1✔
673
        else:
674
            ax.set_ylim(self._get_ylim_low_with_indices(indices=indices),
1✔
675
                        self._get_ylim_high_with_indices(indices=indices))
676
            ax.set_yscale("log")
1✔
677

678
    def _set_x_axis(self, axes: matplotlib.axes.Axes) -> None:
1✔
679
        if self.transient.use_phase_model:
1✔
680
            axes.set_xscale("log")
1✔
681
        axes.set_xlim(self._xlim_low, self._xlim_high)
1✔
682

683
    @property
1✔
684
    def _nrows(self) -> int:
1✔
685
        default = int(np.ceil(len(self._filters) / 2))
1✔
686
        return self._get_kwarg_with_default("nrows", default=default)
1✔
687

688
    @property
1✔
689
    def _npanels(self) -> int:
1✔
690
        npanels = self._nrows * self.ncols
×
691
        if npanels < len(self._filters):
×
692
            raise ValueError(f"Insufficient number of panels. {npanels} panels were given "
×
693
                             f"but {len(self._filters)} panels are needed.")
694
        return npanels
×
695

696
    @property
1✔
697
    def _figsize(self) -> tuple:
1✔
698
        default = (4 + 4 * self.ncols, 2 + 2 * self._nrows)
1✔
699
        return self._get_kwarg_with_default("figsize", default=default)
1✔
700

701
    @property
1✔
702
    def _reference_mjd_date(self) -> int:
1✔
703
        if self.transient.use_phase_model:
1✔
704
            return self.kwargs.get("reference_mjd_date", int(self.transient.x[0]))
1✔
705
        return 0
1✔
706

707
    @property
1✔
708
    def band_label_generator(self):
1✔
709
        if self.band_labels is not None:
1✔
710
            return (bl for bl in self.band_labels)
×
711

712
    def plot_data(
1✔
713
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
714
        """Plots the Magnitude data and returns Axes.
715

716
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
717
        :type axes: Union[matplotlib.axes.Axes, None], optional
718
        :param save: Whether to save the plot. (Default value = True)
719
        :type save: bool
720
        :param show: Whether to show the plot. (Default value = True)
721
        :type show: bool
722

723
        :return: The axes with the plot.
724
        :rtype: matplotlib.axes.Axes
725
        """
726
        ax = axes or plt.gca()
1✔
727

728
        band_label_generator = self.band_label_generator
1✔
729

730
        for indices, band in zip(self.transient.list_of_band_indices, self.transient.unique_bands):
1✔
731
            if band in self._filters:
1✔
732
                color = self._colors[list(self._filters).index(band)]
1✔
733
                if band_label_generator is None:
1✔
734
                    if band in self.band_scaling:
1✔
735
                        label = str(self.band_scaling.get(band))  + ' ' + self.band_scaling.get("type") + ' ' + band 
×
736
                    else:
737
                        label = band   
1✔
738
                else:
739
                    label = next(band_label_generator)
×
740
            elif self.plot_others:
×
741
                color = "black"
×
742
                label = None
×
743
            else:
744
                continue
×
745
            if isinstance(label, float):
1✔
746
                label = f"{label:.2e}"
×
747
            if self.band_colors is not None:
1✔
748
                color = self.band_colors[band]
×
749
            if band in self.band_scaling:
1✔
750
                if self.band_scaling.get("type") == 'x':
×
751
                    ax.errorbar(
×
752
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] * self.band_scaling.get(band),
753
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices] * self.band_scaling.get(band),
754
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
755
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
756
                elif self.band_scaling.get("type") == '+':
×
757
                    ax.errorbar(
×
758
                        self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices] + self.band_scaling.get(band),
759
                        xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
760
                        fmt=self.errorbar_fmt, ms=self.ms, color=color,
761
                        elinewidth=self.elinewidth, capsize=self.capsize, label=label)
762
            else:
763
                ax.errorbar(
1✔
764
                    self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices],
765
                    xerr=self._get_x_err(indices), yerr=self.transient.y_err[indices],
766
                    fmt=self.errorbar_fmt, ms=self.ms, color=color,
767
                    elinewidth=self.elinewidth, capsize=self.capsize, label=label)
768

769
        self._set_x_axis(axes=ax)
1✔
770
        self._set_y_axis_data(ax)
1✔
771

772
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
1✔
773
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
1✔
774

775
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
776
        ax.legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
777

778
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
1✔
779
        return ax
1✔
780

781
    def plot_lightcurve(
1✔
782
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True)\
783
            -> matplotlib.axes.Axes:
784
        """Plots the Magnitude data and returns Axes.
785

786
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
787
        :type axes: Union[matplotlib.axes.Axes, None], optional
788
        :param save: Whether to save the plot. (Default value = True)
789
        :type save: bool
790
        :param show: Whether to show the plot. (Default value = True)
791
        :type show: bool
792

793
        :return: The axes with the plot.
794
        :rtype: matplotlib.axes.Axes
795
        """
796
        axes = axes or plt.gca()
1✔
797

798
        axes = self.plot_data(axes=axes, save=False, show=False)
1✔
799
        axes.set_yscale('log')
1✔
800

801
        times = self._get_times(axes)
1✔
802
        bands_to_plot = self._get_bands_to_plot
1✔
803

804
        color_max = self.max_likelihood_color
1✔
805
        color_sample = self.random_sample_color
1✔
806
        for band, color in zip(bands_to_plot, self.transient.get_colors(bands_to_plot)):
1✔
807
            if self.set_same_color_per_subplot is True:
1✔
808
                if self.band_colors is not None:
1✔
809
                    color = self.band_colors[band]
×
810
                color_max = color
1✔
811
                color_sample = color
1✔
812
            sn_cosmo_band = redback.utils.sncosmo_bandname_from_band([band])
1✔
813
            self._model_kwargs["bands"] = [sn_cosmo_band[0] for _ in range(len(times))]
1✔
814
            if isinstance(band, str):
1✔
815
                frequency = redback.utils.bands_to_frequency([band])
1✔
816
            else:
817
                frequency = band
×
818
            self._model_kwargs['frequency'] = np.ones(len(times)) * frequency
1✔
819
            if self.plot_max_likelihood:
1✔
820
                ys = self.model(times, **self._max_like_params, **self._model_kwargs)
×
821
                if band in self.band_scaling:
×
822
                    if self.band_scaling.get("type") == 'x':
×
823
                        axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
824
                    elif self.band_scaling.get("type") == '+':
×
825
                        axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
826
                else:        
827
                    axes.plot(times - self._reference_mjd_date, ys, color=color_max, alpha=self.max_likelihood_alpha, lw=self.linewidth)
×
828

829
            random_ys_list = [self.model(times, **random_params, **self._model_kwargs)
1✔
830
                              for random_params in self._get_random_parameters()]
831
            if self.uncertainty_mode == "random_models":
1✔
832
                for ys in random_ys_list:
1✔
833
                    if band in self.band_scaling:
1✔
834
                        if self.band_scaling.get("type") == 'x':
×
835
                            axes.plot(times - self._reference_mjd_date, ys * self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
836
                        elif self.band_scaling.get("type") == '+':
×
837
                            axes.plot(times - self._reference_mjd_date, ys + self.band_scaling.get(band), color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
×
838
                    else:
839
                        axes.plot(times - self._reference_mjd_date, ys, color=color_sample, alpha=self.random_sample_alpha, lw=self.linewidth, zorder=-1)
1✔
840
            elif self.uncertainty_mode == "credible_intervals":
×
841
                if band in self.band_scaling:
×
842
                    if self.band_scaling.get("type") == 'x':
×
843
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) * self.band_scaling.get(band), interval=self.credible_interval_level)
×
844
                    elif self.band_scaling.get("type") == '+':
×
845
                        lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list) + self.band_scaling.get(band), interval=self.credible_interval_level)
×
846
                else:
847
                    lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=np.array(random_ys_list), interval=self.credible_interval_level)
×
848
                axes.fill_between(
×
849
                    times - self._reference_mjd_date, lower_bound, upper_bound,
850
                    alpha=self.uncertainty_band_alpha, color=color_sample)
851

852
        self._save_and_show(filepath=self._lightcurve_plot_filepath, save=save, show=show)
1✔
853
        return axes
1✔
854

855
    def _check_valid_multiband_data_mode(self) -> bool:
1✔
856
        if self.transient.luminosity_data:
1✔
857
            redback.utils.logger.warning(
×
858
                f"Plotting multiband lightcurve/data not possible for {self.transient.data_mode}. Returning.")
859
            return False
×
860
        return True
1✔
861

862
    def plot_multiband(
1✔
863
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True,
864
            show: bool = True) -> matplotlib.axes.Axes:
865
        """Plots the Magnitude multiband data and returns Axes.
866

867
        :param figure: Matplotlib figure to plot the data into.
868
        :type figure: matplotlib.figure.Figure
869
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
870
        :type axes: Union[matplotlib.axes.Axes, None], optional
871
        :param save: Whether to save the plot. (Default value = True)
872
        :type save: bool
873
        :param show: Whether to show the plot. (Default value = True)
874
        :type show: bool
875

876
        :return: The axes with the plot.
877
        :rtype: matplotlib.axes.Axes
878
        """
879
        if not self._check_valid_multiband_data_mode():
1✔
880
            return
×
881

882
        if figure is None or axes is None:
1✔
883
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
×
884
        axes = axes.ravel()
1✔
885

886
        band_label_generator = self.band_label_generator
1✔
887

888
        ii = 0
1✔
889
        for indices, band, freq in zip(
1✔
890
                self.transient.list_of_band_indices, self.transient.unique_bands, self.transient.unique_frequencies):
891
            if band not in self._filters:
1✔
892
                continue
×
893

894
            x_err = self._get_x_err(indices)
1✔
895
            color = self._colors[list(self._filters).index(band)]
1✔
896
            if self.band_colors is not None:
1✔
897
                color = self.band_colors[band]
×
898
            if band_label_generator is None:
1✔
899
                label = self._get_multiband_plot_label(band, freq)
1✔
900
            else:
901
                label = next(band_label_generator)
×
902

903
            axes[ii].errorbar(
1✔
904
                self.transient.x[indices] - self._reference_mjd_date, self.transient.y[indices], xerr=x_err,
905
                yerr=self.transient.y_err[indices], fmt=self.errorbar_fmt, ms=self.ms, color=color,
906
                elinewidth=self.elinewidth, capsize=self.capsize,
907
                label=label)
908

909
            self._set_x_axis(axes[ii])
1✔
910
            self._set_y_axis_multiband_data(axes[ii], indices)
1✔
911
            axes[ii].legend(ncol=self.legend_cols, loc=self.legend_location, fontsize=self.fontsize_legend)
1✔
912
            axes[ii].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
1✔
913
            ii += 1
1✔
914

915
        figure.supxlabel(self._xlabel, fontsize=self.fontsize_figure)
1✔
916
        figure.supylabel(self._ylabel, fontsize=self.fontsize_figure)
1✔
917
        plt.subplots_adjust(wspace=self.wspace, hspace=self.hspace)
1✔
918

919
        self._save_and_show(filepath=self._multiband_data_plot_filepath, save=save, show=show)
1✔
920
        return axes
1✔
921

922
    @staticmethod
1✔
923
    def _get_multiband_plot_label(band: str, freq: float) -> str:
1✔
924
        if isinstance(band, str):
1✔
925
            if 1e10 < float(freq) < 1e16:
1✔
926
                label = band
1✔
927
            else:
928
                label = f"{freq:.2e}"
×
929
        else:
930
            label = f"{band:.2e}"
×
931
        return label
1✔
932

933
    @property
1✔
934
    def _filters(self) -> list[str]:
1✔
935
        filters = self.kwargs.get("filters", self.transient.active_bands)
1✔
936
        if 'bands_to_plot' in self.kwargs:
1✔
937
            filters = self.kwargs['bands_to_plot']
×
938
        if filters is None:
1✔
939
            return self.transient.active_bands
×
940
        elif str(filters) == 'default':
1✔
941
            return self.transient.default_filters
×
942
        return filters
1✔
943

944
    def plot_multiband_lightcurve(
1✔
945
        self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
946
        """Plots the Magnitude multiband lightcurve and returns Axes.
947

948
        :param figure: Matplotlib figure to plot the data into.
949
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
950
        :type axes: Union[matplotlib.axes.Axes, None], optional
951
        :param save: Whether to save the plot. (Default value = True)
952
        :type save: bool
953
        :param show: Whether to show the plot. (Default value = True)
954
        :type show: bool
955

956
        :return: The axes with the plot.
957
        :rtype: matplotlib.axes.Axes
958
        """
959
        if not self._check_valid_multiband_data_mode():
1✔
960
            return
×
961

962
        if figure is None or axes is None:
1✔
963
            figure, axes = plt.subplots(ncols=self.ncols, nrows=self._nrows, sharex='all', figsize=self._figsize)
1✔
964

965
        axes = self.plot_multiband(figure=figure, axes=axes, save=False, show=False)
1✔
966
        times = self._get_times(axes)
1✔
967

968
        ii = 0
1✔
969
        color_max = self.max_likelihood_color
1✔
970
        color_sample = self.random_sample_color
1✔
971
        for band, freq in zip(self.transient.unique_bands, self.transient.unique_frequencies):
1✔
972
            if band not in self._filters:
1✔
973
                continue
×
974
            new_model_kwargs = self._model_kwargs.copy()
1✔
975
            new_model_kwargs['frequency'] = freq
1✔
976
            new_model_kwargs['bands'] = band
1✔
977
            
978
            if self.set_same_color_per_subplot is True:
1✔
979
                color = self._colors[list(self._filters).index(band)]
1✔
980
                if self.band_colors is not None:
1✔
981
                    color = self.band_colors[band]
×
982
                color_max = color
1✔
983
                color_sample = color
1✔
984

985
            if self.plot_max_likelihood:
1✔
986
                ys = self.model(times, **self._max_like_params, **new_model_kwargs)
×
987
                axes[ii].plot(
×
988
                    times - self._reference_mjd_date, ys, color=color_max,
989
                    alpha=self.max_likelihood_alpha, lw=self.linewidth)
990
            random_ys_list = [self.model(times, **random_params, **new_model_kwargs)
1✔
991
                              for random_params in self._get_random_parameters()]
992
            if self.uncertainty_mode == "random_models":
1✔
993
                for random_ys in random_ys_list:
1✔
994
                    axes[ii].plot(times - self._reference_mjd_date, random_ys, color=color_sample,
1✔
995
                                  alpha=self.random_sample_alpha, lw=self.linewidth, zorder=self.zorder)
996
            elif self.uncertainty_mode == "credible_intervals":
×
997
                lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list, interval=self.credible_interval_level)
×
998
                axes[ii].fill_between(
×
999
                    times - self._reference_mjd_date, lower_bound, upper_bound,
1000
                    alpha=self.uncertainty_band_alpha, color=color_sample)
1001
            ii += 1
1✔
1002

1003
        self._save_and_show(filepath=self._multiband_lightcurve_plot_filepath, save=save, show=show)
1✔
1004
        return axes
1✔
1005

1006

1007
class FluxDensityPlotter(MagnitudePlotter):
1✔
1008
    pass
1✔
1009

1010
class IntegratedFluxOpticalPlotter(MagnitudePlotter):
1✔
1011
    pass
1✔
1012

1013
class SpectrumPlotter(SpecPlotter):
1✔
1014
    @property
1✔
1015
    def _xlabel(self) -> str:
1✔
NEW
1016
        return self.transient.xlabel
×
1017

1018
    @property
1✔
1019
    def _ylabel(self) -> str:
1✔
NEW
1020
        return self.transient.ylabel
×
1021

1022
    def plot_data(
1✔
1023
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1024
        """Plots the spectrum data and returns Axes.
1025

1026
        :param axes: Matplotlib axes to plot the data into. Useful for user specific modifications to the plot.
1027
        :type axes: Union[matplotlib.axes.Axes, None], optional
1028
        :param save: Whether to save the plot. (Default value = True)
1029
        :type save: bool
1030
        :param show: Whether to show the plot. (Default value = True)
1031
        :type show: bool
1032

1033
        :return: The axes with the plot.
1034
        :rtype: matplotlib.axes.Axes
1035
        """
NEW
1036
        ax = axes or plt.gca()
×
1037

NEW
1038
        if self.transient.plot_with_time_label:
×
NEW
1039
            label = self.transient.time
×
1040
        else:
NEW
1041
            label = self.transient.name
×
NEW
1042
        ax.plot(self.transient.angstroms, self.transient.flux_density/1e-17, color=self.color,
×
1043
                lw=self.linewidth)
NEW
1044
        ax.set_xscale('linear')
×
NEW
1045
        ax.set_yscale(self.yscale)
×
1046

NEW
1047
        ax.set_xlim(self._xlim_low, self._xlim_high)
×
NEW
1048
        ax.set_ylim(self._ylim_low, self._ylim_high)
×
NEW
1049
        ax.set_xlabel(self._xlabel, fontsize=self.fontsize_axes)
×
NEW
1050
        ax.set_ylabel(self._ylabel, fontsize=self.fontsize_axes)
×
1051

NEW
1052
        ax.annotate(
×
1053
            label, xy=self.xy, xycoords=self.xycoords,
1054
            horizontalalignment=self.horizontalalignment, size=self.annotation_size)
1055

NEW
1056
        ax.tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
×
1057

NEW
1058
        self._save_and_show(filepath=self._data_plot_filepath, save=save, show=show)
×
NEW
1059
        return ax
×
1060

1061
    def plot_spectrum(
1✔
1062
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1063
        """Plots the spectrum data and the fit and returns Axes.
1064

1065
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1066
        :type axes: Union[matplotlib.axes.Axes, None], optional
1067
        :param save: Whether to save the plot. (Default value = True)
1068
        :type save: bool
1069
        :param show: Whether to show the plot. (Default value = True)
1070
        :type show: bool
1071

1072
        :return: The axes with the plot.
1073
        :rtype: matplotlib.axes.Axes
1074
        """
1075

NEW
1076
        axes = axes or plt.gca()
×
1077

NEW
1078
        axes = self.plot_data(axes=axes, save=False, show=False)
×
NEW
1079
        angstroms = self._get_angstroms(axes)
×
1080

NEW
1081
        self._plot_spectrums(axes, angstroms)
×
1082

NEW
1083
        self._save_and_show(filepath=self._spectrum_ppd_plot_filepath, save=save, show=show)
×
NEW
1084
        return axes
×
1085

1086
    def _plot_spectrums(self, axes: matplotlib.axes.Axes, angstroms: np.ndarray) -> None:
1✔
NEW
1087
        if self.plot_max_likelihood:
×
NEW
1088
            ys = self.model(angstroms, **self._max_like_params, **self._model_kwargs)
×
NEW
1089
            axes.plot(angstroms, ys/1e-17, color=self.max_likelihood_color, alpha=self.max_likelihood_alpha,
×
1090
                      lw=self.linewidth)
1091

NEW
1092
        random_ys_list = [self.model(angstroms, **random_params, **self._model_kwargs)
×
1093
                          for random_params in self._get_random_parameters()]
NEW
1094
        if self.uncertainty_mode == "random_models":
×
NEW
1095
            for ys in random_ys_list:
×
NEW
1096
                axes.plot(angstroms, ys/1e-17, color=self.random_sample_color, alpha=self.random_sample_alpha,
×
1097
                          lw=self.linewidth, zorder=self.zorder)
NEW
1098
        elif self.uncertainty_mode == "credible_intervals":
×
NEW
1099
            lower_bound, upper_bound, _ = redback.utils.calc_credible_intervals(samples=random_ys_list,
×
1100
                                                                                interval=self.credible_interval_level)
NEW
1101
            axes.fill_between(
×
1102
                angstroms, lower_bound/1e-17, upper_bound/1e-17, alpha=self.uncertainty_band_alpha, color=self.max_likelihood_color)
1103

1104
    def plot_residuals(
1✔
1105
            self, axes: matplotlib.axes.Axes = None, save: bool = True, show: bool = True) -> matplotlib.axes.Axes:
1106
        """Plots the residual of the Integrated flux data returns Axes.
1107

1108
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
1109
        :param save: Whether to save the plot. (Default value = True)
1110
        :param show: Whether to show the plot. (Default value = True)
1111

1112
        :return: The axes with the plot.
1113
        :rtype: matplotlib.axes.Axes
1114
        """
NEW
1115
        if axes is None:
×
NEW
1116
            fig, axes = plt.subplots(
×
1117
                nrows=2, ncols=1, sharex=True, sharey=False, figsize=(10, 8), gridspec_kw=dict(height_ratios=[2, 1]))
1118

NEW
1119
        axes[0] = self.plot_spectrum(axes=axes[0], save=False, show=False)
×
NEW
1120
        axes[1].set_xlabel(axes[0].get_xlabel(), fontsize=self.fontsize_axes)
×
NEW
1121
        axes[0].set_xlabel("")
×
NEW
1122
        ys = self.model(self.transient.angstroms, **self._max_like_params, **self._model_kwargs)
×
NEW
1123
        axes[1].errorbar(
×
1124
            self.transient.angstroms, self.transient.flux_density - ys, yerr=self.transient.flux_density_err,
1125
            fmt=self.errorbar_fmt, c=self.color, ms=self.ms, elinewidth=self.elinewidth, capsize=self.capsize)
NEW
1126
        axes[1].set_yscale('linear')
×
NEW
1127
        axes[1].set_ylabel("Residual", fontsize=self.fontsize_axes)
×
NEW
1128
        axes[1].tick_params(axis='both', which='both', pad=self.axis_tick_params_pad, labelsize=self.fontsize_ticks)
×
1129

NEW
1130
        self._save_and_show(filepath=self._residual_plot_filepath, save=save, show=show)
×
NEW
1131
        return axes
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc