• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 14430354752

13 Apr 2025 02:23PM UTC coverage: 86.635% (+6.0%) from 80.663%
14430354752

Pull #266

github

web-flow
Merge 8147dba2c into e087188ab
Pull Request #266: A big overhaul

1621 of 1828 new or added lines in 14 files covered. (88.68%)

4 existing lines in 2 files now uncovered.

12673 of 14628 relevant lines covered (86.64%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.73
/redback/transient/transient.py
1
from __future__ import annotations
1✔
2

3
from typing import Union
1✔
4

5
import matplotlib
1✔
6
import numpy as np
1✔
7
import pandas as pd
1✔
8

9
import redback
1✔
10
from redback.plotting import \
1✔
11
    LuminosityPlotter, FluxDensityPlotter, IntegratedFluxPlotter, MagnitudePlotter, \
12
    IntegratedFluxOpticalPlotter, SpectrumPlotter, LuminosityOpticalPlotter
13
from redback.model_library import all_models_dict
1✔
14
from collections import namedtuple
1✔
15

16
class Spectrum(object):
1✔
17
    def __init__(self, angstroms: np.ndarray, flux_density: np.ndarray, flux_density_err: np.ndarray,
1✔
18
                 time: str = None, name: str = '', **kwargs) -> None:
19
        """
20
        A class to store spectral data.
21

22
        :param angstroms: Wavelength in angstroms.
23
        :param flux_density: flux density in ergs/s/cm^2/angstrom.
24
        :param flux_density_err: flux density error in ergs/s/cm^2/angstrom.
25
        :param time: Time of the spectrum. Could be a phase or time since burst. Only used for plotting.
26
        :param name: Name of the spectrum.
27
        """
28

29
        self.angstroms = angstroms
1✔
30
        self.flux_density = flux_density
1✔
31
        self.flux_density_err = flux_density_err
1✔
32
        self.time = time
1✔
33
        self.name = name
1✔
34
        if self.time is None:
1✔
35
            self.plot_with_time_label = False
1✔
36
        else:
37
            self.plot_with_time_label = True
1✔
38
        self.directory_structure = redback.get_data.directory.spectrum_directory_structure(transient=name)
1✔
39
        self.data_mode = 'spectrum'
1✔
40

41
    @property
1✔
42
    def xlabel(self) -> str:
1✔
43
        """
44
        :return: xlabel used in plotting functions
45
        :rtype: str
46
        """
47
        return r'Wavelength [$\mathrm{\AA}$]'
1✔
48

49
    @property
1✔
50
    def ylabel(self) -> str:
1✔
51
        """
52
        :return: ylabel used in plotting functions
53
        :rtype: str
54
        """
55
        return r'Flux ($10^{-17}$ erg s$^{-1}$ cm$^{-2}$ $\mathrm{\AA}$)'
1✔
56

57
    def plot_data(self, axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None, save: bool = True,
1✔
58
            show: bool = True, color: str = 'k', **kwargs) -> matplotlib.axes.Axes:
59
        """Plots the Transient data and returns Axes.
60

61
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
62
        :param filename: Name of the file to be plotted in.
63
        :param outdir: The directory in which to save the file in.
64
        :param save: Whether to save the plot. (Default value = True)
65
        :param show: Whether to show the plot. (Default value = True)
66
        :param color: Color of the data.
67
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
68
        Available in the online documentation under at `redback.plotting.Plotter`.
69
        `print(Transient.plot_data.__doc__)` to see all options!
70
        :return: The axes with the plot.
71
        """
72

73
        plotter = SpectrumPlotter(spectrum=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
74
        return plotter.plot_data(axes=axes, save=save, show=show)
×
75

76
    def plot_spectrum(
1✔
77
            self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
78
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
79
            model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
80
        """
81
        :param model: The model used to plot the lightcurve.
82
        :param filename: The output filename. Otherwise, use default which starts with the name
83
                         attribute and ends with *lightcurve.png.
84
        :param axes: Axes to plot in if given.
85
        :param save:Whether to save the plot.
86
        :param show: Whether to show the plot.
87
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
88
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
89
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
90
        :param model_kwargs: Additional keyword arguments to be passed into the model.
91
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
92
        Available in the online documentation under at `redback.plotting.Plotter`.
93
        `print(Transient.plot_lightcurve.__doc__)` to see all options!
94
        :return: The axes.
95
        """
96
        plotter = SpectrumPlotter(
×
97
            spectrum=self, model=model, filename=filename, outdir=outdir,
98
            posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
99
        return plotter.plot_spectrum(axes=axes, save=save, show=show)
×
100

101
    def plot_residual(self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
1✔
102
                      save: bool = True, show: bool = True, posterior: pd.DataFrame = None,
103
                      model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
104
        """
105
        :param model: The model used to plot the lightcurve.
106
        :param filename: The output filename. Otherwise, use default which starts with the name
107
                         attribute and ends with *lightcurve.png.
108
        :param axes: Axes to plot in if given.
109
        :param save:Whether to save the plot.
110
        :param show: Whether to show the plot.
111
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
112
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
113
        :param model_kwargs: Additional keyword arguments to be passed into the model.
114
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
115
        Available in the online documentation under at `redback.plotting.Plotter`.
116
        `print(Transient.plot_residual.__doc__)` to see all options!
117
        :return: The axes.
118
        """
119
        plotter = SpectrumPlotter(
×
120
            spectrum=self, model=model, filename=filename, outdir=outdir,
121
            posterior=posterior, model_kwargs=model_kwargs, **kwargs)
122
        return plotter.plot_residuals(axes=axes, save=save, show=show)
×
123
    LuminosityPlotter, FluxDensityPlotter, IntegratedFluxPlotter, MagnitudePlotter, IntegratedFluxOpticalPlotter
1✔
124

125
class Transient(object):
1✔
126
    DATA_MODES = ['luminosity', 'flux', 'flux_density', 'magnitude', 'counts', 'ttes']
1✔
127
    _ATTRIBUTE_NAME_DICT = dict(luminosity="Lum50", flux="flux", flux_density="flux_density",
1✔
128
                                counts="counts", magnitude="magnitude")
129

130
    ylabel_dict = dict(luminosity=r'Luminosity [$10^{50}$ erg s$^{-1}$]',
1✔
131
                       magnitude=r'Magnitude',
132
                       flux=r'Flux [erg cm$^{-2}$ s$^{-1}$]',
133
                       flux_density=r'Flux density [mJy]',
134
                       counts=r'Counts')
135

136
    luminosity_data = redback.utils.DataModeSwitch('luminosity')
1✔
137
    flux_data = redback.utils.DataModeSwitch('flux')
1✔
138
    flux_density_data = redback.utils.DataModeSwitch('flux_density')
1✔
139
    magnitude_data = redback.utils.DataModeSwitch('magnitude')
1✔
140
    counts_data = redback.utils.DataModeSwitch('counts')
1✔
141
    tte_data = redback.utils.DataModeSwitch('ttes')
1✔
142

143
    def __init__(
1✔
144
            self, time: np.ndarray = None, time_err: np.ndarray = None, time_mjd: np.ndarray = None,
145
            time_mjd_err: np.ndarray = None, time_rest_frame: np.ndarray = None, time_rest_frame_err: np.ndarray = None,
146
            Lum50: np.ndarray = None, Lum50_err: np.ndarray = None, flux: np.ndarray = None,
147
            flux_err: np.ndarray = None, flux_density: np.ndarray = None, flux_density_err: np.ndarray = None,
148
            magnitude: np.ndarray = None, magnitude_err: np.ndarray = None, counts: np.ndarray = None,
149
            ttes: np.ndarray = None, bin_size: float = None, redshift: float = np.nan, data_mode: str = None,
150
            name: str = '', photon_index: float = np.nan, use_phase_model: bool = False,
151
            optical_data: bool = False, frequency: np.ndarray = None, system: np.ndarray = None, bands: np.ndarray = None,
152
            active_bands: Union[np.ndarray, str] = None, plotting_order: Union[np.ndarray, str] = None, **kwargs: None) -> None:
153
        """This is a general constructor for the Transient class. Note that you only need to give data corresponding to
154
        the data mode you are using. For luminosity data provide times in the rest frame, if using a phase model
155
        provide time in MJD, else use the default time (observer frame).
156

157
        :param time: Times in the observer frame.
158
        :type time: np.ndarray, optional
159
        :param time_err: Time errors in the observer frame.
160
        :type time_err: np.ndarray, optional
161
        :param time_mjd: Times in MJD. Used if using phase model.
162
        :type time_mjd: np.ndarray, optional
163
        :param time_mjd_err: Time errors in MJD. Used if using phase model.
164
        :type time_mjd_err: np.ndarray, optional
165
        :param time_rest_frame: Times in the rest frame. Used for luminosity data.
166
        :type time_rest_frame: np.ndarray, optional
167
        :param time_rest_frame_err: Time errors in the rest frame. Used for luminosity data.
168
        :type time_rest_frame_err: np.ndarray, optional
169
        :param Lum50: Luminosity values.
170
        :type Lum50: np.ndarray, optional
171
        :param Lum50_err: Luminosity error values.
172
        :type Lum50_err: np.ndarray, optional
173
        :param flux: Flux values.
174
        :type flux: np.ndarray, optional
175
        :param flux_err: Flux error values.
176
        :type flux_err: np.ndarray, optional
177
        :param flux_density: Flux density values.
178
        :type flux_density: np.ndarray, optional
179
        :param flux_density_err: Flux density error values.
180
        :type flux_density_err: np.ndarray, optional
181
        :param magnitude: Magnitude values for photometry data.
182
        :type magnitude: np.ndarray, optional
183
        :param magnitude_err: Magnitude error values for photometry data.
184
        :type magnitude_err: np.ndarray, optional
185
        :param counts: Counts for prompt data.
186
        :type counts: np.ndarray, optional
187
        :param ttes: Time-tagged events data for unbinned prompt data.
188
        :type ttes: np.ndarray, optional
189
        :param bin_size: Bin size for binning time-tagged event data.
190
        :type bin_size: float, optional
191
        :param redshift: Redshift value.
192
        :type redshift: float, optional
193
        :param data_mode: Data mode. Must be one from `Transient.DATA_MODES`.
194
        :type data_mode: str, optional
195
        :param name: Name of the transient.
196
        :type name: str, optional
197
        :param photon_index: Photon index value.
198
        :type photon_index: float, optional
199
        :param use_phase_model: Whether we are using a phase model.
200
        :type use_phase_model: bool, optional
201
        :param optical_data: Whether we are fitting optical data, useful for plotting.
202
        :type optical_data: bool, optional
203
        :param frequency: Array of band frequencies in photometry data.
204
        :type frequency: np.ndarray, optional
205
        :param system: System values.
206
        :type system: np.ndarray, optional
207
        :param bands: Band values.
208
        :type bands: np.ndarray, optional
209
        :param active_bands: List or array of active bands to be used in the analysis.
210
                             Use all available bands if 'all' is given.
211
        :type active_bands: Union[list, np.ndarray], optional
212
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
213
        :type plotting_order: Union[np.ndarray, str], optional
214
        :param kwargs: Additional callables:
215
                       bands_to_frequency: Conversion function to convert a list of bands to frequencies.
216
                                           Use redback.utils.bands_to_frequency if not given.
217
                       bin_ttes: Binning function for time-tagged event data.
218
                                 Use redback.utils.bands_to_frequency if not given.
219
        :type kwargs: None, optional
220
        """
221
        self.bin_size = bin_size
1✔
222
        self.bin_ttes = kwargs.get("bin_ttes", redback.utils.bin_ttes)
1✔
223
        self.bands_to_frequency = kwargs.get("bands_to_frequency", redback.utils.bands_to_frequency)
1✔
224

225
        if data_mode == 'ttes':
1✔
226
            time, counts = self.bin_ttes(ttes, self.bin_size)
1✔
227

228
        self.time = time
1✔
229
        self.time_err = time_err
1✔
230
        self.time_mjd = time_mjd
1✔
231
        self.time_mjd_err = time_mjd_err
1✔
232
        self.time_rest_frame = time_rest_frame
1✔
233
        self.time_rest_frame_err = time_rest_frame_err
1✔
234

235
        self.Lum50 = Lum50
1✔
236
        self.Lum50_err = Lum50_err
1✔
237
        self.flux = flux
1✔
238
        self.flux_err = flux_err
1✔
239
        self.flux_density = flux_density
1✔
240
        self.flux_density_err = flux_density_err
1✔
241
        self.magnitude = magnitude
1✔
242
        self.magnitude_err = magnitude_err
1✔
243
        self.counts = counts
1✔
244
        self.counts_err = np.sqrt(counts) if counts is not None else None
1✔
245
        self.ttes = ttes
1✔
246

247
        self._frequency = None
1✔
248
        self._bands = None
1✔
249
        self.set_bands_and_frequency(bands=bands, frequency=frequency)
1✔
250
        self.system = system
1✔
251
        self.data_mode = data_mode
1✔
252
        self.active_bands = active_bands
1✔
253
        self.sncosmo_bands = redback.utils.sncosmo_bandname_from_band(self.bands)
1✔
254
        self.redshift = redshift
1✔
255
        self.name = name
1✔
256
        self.use_phase_model = use_phase_model
1✔
257
        self.optical_data = optical_data
1✔
258
        self.plotting_order = plotting_order
1✔
259

260
        self.meta_data = None
1✔
261
        self.photon_index = photon_index
1✔
262
        self.directory_structure = redback.get_data.directory.DirectoryStructure(
1✔
263
            directory_path=".", raw_file_path=".", processed_file_path=".")
264

265
    @staticmethod
1✔
266
    def load_data_generic(processed_file_path, data_mode="magnitude"):
1✔
267
        """Loads data from specified directory and file, and returns it as a tuple.
268

269
        :param processed_file_path: Path to the processed file to load
270
        :type processed_file_path: str
271
        :param data_mode: Name of the data mode.
272
                          Must be from ['magnitude', 'flux_density', 'all']. Default is magnitude.
273
        :type data_mode: str, optional
274

275
        :return: Six elements when querying magnitude or flux_density data, Eight for 'all'.
276
        :rtype: tuple
277
        """
278
        DATA_MODES = ['luminosity', 'flux', 'flux_density', 'magnitude', 'counts', 'ttes', 'all']
1✔
279
        df = pd.read_csv(processed_file_path)
1✔
280
        time_days = np.array(df["time (days)"])
1✔
281
        time_mjd = np.array(df["time"])
1✔
282
        magnitude = np.array(df["magnitude"])
1✔
283
        magnitude_err = np.array(df["e_magnitude"])
1✔
284
        bands = np.array(df["band"])
1✔
285
        flux_density = np.array(df["flux_density(mjy)"])
1✔
286
        flux_density_err = np.array(df["flux_density_error"])
1✔
287
        if data_mode not in DATA_MODES:
1✔
288
            raise ValueError(f"Data mode {data_mode} not in {DATA_MODES}")
1✔
289
        if data_mode == "magnitude":
1✔
290
            return time_days, time_mjd, magnitude, magnitude_err, bands
1✔
291
        elif data_mode == "flux_density":
1✔
292
            return time_days, time_mjd, flux_density, flux_density_err, bands
1✔
293
        elif data_mode == "all":
1✔
294
            return time_days, time_mjd, flux_density, flux_density_err, magnitude, magnitude_err, bands
1✔
295

296
    @classmethod
1✔
297
    def from_lasair_data(
1✔
298
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
299
            use_phase_model: bool = False, plotting_order: Union[np.ndarray, str] = None) -> Transient:
300
        """Constructor method to built object from LASAIR data.
301

302
        :param name: Name of the transient.
303
        :type name: str
304
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
305
        :type data_mode: str, optional
306
        :param active_bands: Sets active bands based on array given.
307
                             If argument is 'all', all unique bands in `self.bands` will be used.
308
        :type active_bands: Union[np.ndarray, str]
309
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
310
        :type plotting_order: Union[np.ndarray, str], optional
311
        :param use_phase_model: Whether to use a phase model.
312
        :type use_phase_model: bool, optional
313

314
        :return: A class instance.
315
        :rtype: OpticalTransient
316
        """
317
        if cls.__name__ == "TDE":
1✔
318
            transient_type = "tidal_disruption_event"
×
319
        else:
320
            transient_type = cls.__name__.lower()
1✔
321
        directory_structure = redback.get_data.directory.lasair_directory_structure(
1✔
322
            transient=name, transient_type=transient_type)
323
        df = pd.read_csv(directory_structure.processed_file_path)
1✔
324
        time_days = np.array(df["time (days)"])
1✔
325
        time_mjd = np.array(df["time"])
1✔
326
        magnitude = np.array(df["magnitude"])
1✔
327
        magnitude_err = np.array(df["e_magnitude"])
1✔
328
        bands = np.array(df["band"])
1✔
329
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
330
        flux_err = np.array(df["flux_error"])
1✔
331
        flux_density = np.array(df["flux_density(mjy)"])
1✔
332
        flux_density_err = np.array(df["flux_density_error"])
1✔
333
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
334
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
335
                   magnitude_err=magnitude_err, flux=flux, flux_err=flux_err, bands=bands, active_bands=active_bands,
336
                   use_phase_model=use_phase_model, optical_data=True, plotting_order=plotting_order)
337

338
    @classmethod
1✔
339
    def from_simulated_optical_data(
1✔
340
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
341
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False) -> Transient:
342
        """Constructor method to built object from SimulatedOpticalTransient.
343

344
        :param name: Name of the transient.
345
        :type name: str
346
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
347
        :type data_mode: str, optional
348
        :param active_bands: Sets active bands based on array given.
349
                             If argument is 'all', all unique bands in `self.bands` will be used.
350
        :type active_bands: Union[np.ndarray, str]
351
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
352
        :type plotting_order: Union[np.ndarray, str], optional
353
        :param use_phase_model: Whether to use a phase model.
354
        :type use_phase_model: bool, optional
355

356
        :return: A class instance.
357
        :rtype: OpticalTransient
358
        """
359
        path = "simulated/" + name + ".csv"
1✔
360
        df = pd.read_csv(path)
1✔
361
        df = df[df.detected != 0]
1✔
362
        time_days = np.array(df["time (days)"])
1✔
363
        time_mjd = np.array(df["time"])
1✔
364
        magnitude = np.array(df["magnitude"])
1✔
365
        magnitude_err = np.array(df["e_magnitude"])
1✔
366
        bands = np.array(df["band"])
1✔
367
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
368
        flux_err = np.array(df["flux_error"])
1✔
369
        flux_density = np.array(df["flux_density(mjy)"])
1✔
370
        flux_density_err = np.array(df["flux_density_error"])
1✔
371
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
372
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
373
                   magnitude_err=magnitude_err, flux=flux, flux_err=flux_err, bands=bands, active_bands=active_bands,
374
                   use_phase_model=use_phase_model, optical_data=True, plotting_order=plotting_order)
375

376
    @property
1✔
377
    def _time_attribute_name(self) -> str:
1✔
378
        if self.luminosity_data:
1✔
379
            return "time_rest_frame"
1✔
380
        elif self.use_phase_model:
1✔
381
            return "time_mjd"
1✔
382
        return "time"
1✔
383

384
    @property
1✔
385
    def _time_err_attribute_name(self) -> str:
1✔
386
        return self._time_attribute_name + "_err"
1✔
387

388
    @property
1✔
389
    def _y_attribute_name(self) -> str:
1✔
390
        return self._ATTRIBUTE_NAME_DICT[self.data_mode]
1✔
391

392
    @property
1✔
393
    def _y_err_attribute_name(self) -> str:
1✔
394
        return self._ATTRIBUTE_NAME_DICT[self.data_mode] + "_err"
1✔
395

396
    @property
1✔
397
    def x(self) -> np.ndarray:
1✔
398
        """
399
        :return: The time values given the active data mode.
400
        :rtype: np.ndarray
401
        """
402
        return getattr(self, self._time_attribute_name)
1✔
403

404
    @x.setter
1✔
405
    def x(self, x: np.ndarray) -> None:
1✔
406
        """Sets the time values for the active data mode.
407
        :param x: The desired time values.
408
        :type x: np.ndarray
409
        """
410
        setattr(self, self._time_attribute_name, x)
1✔
411

412
    @property
1✔
413
    def x_err(self) -> np.ndarray:
1✔
414
        """
415
        :return: The time error values given the active data mode.
416
        :rtype: np.ndarray
417
        """
418
        return getattr(self, self._time_err_attribute_name)
1✔
419

420
    @x_err.setter
1✔
421
    def x_err(self, x_err: np.ndarray) -> None:
1✔
422
        """Sets the time error values for the active data mode.
423
        :param x_err: The desired time error values.
424
        :type x_err: np.ndarray
425
        """
426
        setattr(self, self._time_err_attribute_name, x_err)
1✔
427

428
    @property
1✔
429
    def y(self) -> np.ndarray:
1✔
430
        """
431
        :return: The y values given the active data mode.
432
        :rtype: np.ndarray
433
        """
434

435
        return getattr(self, self._y_attribute_name)
1✔
436

437
    @y.setter
1✔
438
    def y(self, y: np.ndarray) -> None:
1✔
439
        """Sets the y values for the active data mode.
440
        :param y: The desired y values.
441
        :type y: np.ndarray
442
        """
443
        setattr(self, self._y_attribute_name, y)
1✔
444

445
    @property
1✔
446
    def y_err(self) -> np.ndarray:
1✔
447
        """
448
        :return: The y error values given the active data mode.
449
        :rtype: np.ndarray
450
        """
451
        return getattr(self, self._y_err_attribute_name)
1✔
452

453
    @y_err.setter
1✔
454
    def y_err(self, y_err: np.ndarray) -> None:
1✔
455
        """Sets the y error values for the active data mode.
456
        :param y_err: The desired y error values.
457
        :type y_err: np.ndarray
458
        """
459
        setattr(self, self._y_err_attribute_name, y_err)
1✔
460

461
    @property
1✔
462
    def data_mode(self) -> str:
1✔
463
        """
464
        :return: The currently active data mode (one in `Transient.DATA_MODES`).
465
        :rtype: str
466
        """
467
        return self._data_mode
1✔
468

469
    @data_mode.setter
1✔
470
    def data_mode(self, data_mode: str) -> None:
1✔
471
        """
472
        :param data_mode: One of the data modes in `Transient.DATA_MODES`.
473
        :type data_mode: str
474
        """
475
        if data_mode in self.DATA_MODES or data_mode is None:
1✔
476
            self._data_mode = data_mode
1✔
477
        else:
478
            raise ValueError("Unknown data mode.")
1✔
479

480
    @property
1✔
481
    def xlabel(self) -> str:
1✔
482
        """
483
        :return: xlabel used in plotting functions
484
        :rtype: str
485
        """
486
        if self.use_phase_model:
1✔
487
            return r"Time [MJD]"
1✔
488
        else:
489
            return r"Time since explosion [days]"
1✔
490

491
    @property
1✔
492
    def ylabel(self) -> str:
1✔
493
        """
494
        :return: ylabel used in plotting functions
495
        :rtype: str
496
        """
497
        try:
1✔
498
            return self.ylabel_dict[self.data_mode]
1✔
499
        except KeyError:
1✔
500
            raise ValueError("No data mode specified")
1✔
501

502
    def set_bands_and_frequency(
1✔
503
            self, bands: Union[None, list, np.ndarray], frequency: Union[None, list, np.ndarray]):
504
        """Sets bands and frequencies at the same time to keep the logic consistent. If both are given use those values.
505
        If only frequencies are given, use them also as band names.
506
        If only bands are given, try to convert them to frequencies.
507

508
        :param bands: The bands, e.g. ['g', 'i'].
509
        :type bands: Union[None, list, np.ndarray]
510
        :param frequency: The frequencies associated with the bands i.e., the effective frequency.
511
        :type frequency: Union[None, list, np.ndarray]
512
        """
513
        if (bands is None and frequency is None) or (bands is not None and frequency is not None):
1✔
514
            self._bands = bands
1✔
515
            self._frequency = frequency
1✔
516
        elif bands is None and frequency is not None:
1✔
517
            self._frequency = frequency
1✔
518
            self._bands = self.frequency
1✔
519
        elif bands is not None and frequency is None:
1✔
520
            self._bands = bands
1✔
521
            self._frequency = self.bands_to_frequency(self.bands)
1✔
522

523
    @property
1✔
524
    def frequency(self) -> np.ndarray:
1✔
525
        """
526
        :return: Used band frequencies
527
        :rtype: np.ndarray
528
        """
529
        return self._frequency
1✔
530

531
    @frequency.setter
1✔
532
    def frequency(self, frequency: np.ndarray) -> None:
1✔
533
        """
534
        :param frequency: Set band frequencies if an array is given. Otherwise, convert bands to frequencies.
535
        :type frequency: np.ndarray
536
        """
537
        self.set_bands_and_frequency(bands=self.bands, frequency=frequency)
1✔
538

539
    @property
1✔
540
    def bands(self) -> Union[list, None, np.ndarray]:
1✔
541
        return self._bands
1✔
542

543
    @bands.setter
1✔
544
    def bands(self, bands: Union[list, None, np.ndarray]):
1✔
545
        self.set_bands_and_frequency(bands=bands, frequency=self.frequency)
×
546

547
    @property
1✔
548
    def filtered_frequencies(self) -> np.array:
1✔
549
        """
550
        :return: The frequencies only associated with the active bands.
551
        :rtype: np.ndarray
552
        """
553
        return self.frequency[self.filtered_indices]
1✔
554

555
    @property
1✔
556
    def filtered_sncosmo_bands(self) -> np.array:
1✔
557
        """
558
        :return: The sncosmo bands only associated with the active bands.
559
        :rtype: np.ndarray
560
        """
561
        return self.sncosmo_bands[self.filtered_indices]
×
562

563
    @property
1✔
564
    def filtered_bands(self) -> np.array:
1✔
565
        """
566
        :return: The band names only associated with the active bands.
567
        :rtype: np.ndarray
568
        """
569
        return self.bands[self.filtered_indices]
×
570

571
    @property
1✔
572
    def active_bands(self) -> list:
1✔
573
        """
574
        :return: List of active bands used.
575
        :rtype list:
576
        """
577
        return self._active_bands
1✔
578

579
    @active_bands.setter
1✔
580
    def active_bands(self, active_bands: Union[list, str, None]) -> None:
1✔
581
        """
582
        :param active_bands: Sets active bands based on list given.
583
                             If argument is 'all', all unique bands in `self.bands` will be used.
584
        :type active_bands: Union[list, str]
585
        """
586
        if str(active_bands) == 'all':
1✔
587
            self._active_bands = list(np.unique(self.bands))
1✔
588
        else:
589
            self._active_bands = active_bands
1✔
590

591
    @property
1✔
592
    def filtered_indices(self) -> Union[list, None]:
1✔
593
        """
594
        :return: The list indices in `bands` associated with the active bands.
595
        :rtype: Union[list, None]
596
        """
597
        if self.bands is None:
1✔
598
            return list(np.arange(len(self.x)))
×
599
        return [b in self.active_bands for b in self.bands]
1✔
600

601
    def get_filtered_data(self) -> tuple:
1✔
602
        """Used to filter flux density, photometry or integrated flux data, so we only use data that is using the active bands.
603
        :return: A tuple with the filtered data. Format is (x, x_err, y, y_err)
604
        :rtype: tuple
605
        """
606
        if any([self.flux_data, self.magnitude_data, self.flux_density_data]):
1✔
607
            filtered_x = self.x[self.filtered_indices]
1✔
608
            try:
1✔
609
                filtered_x_err = self.x_err[self.filtered_indices]
1✔
610
            except (IndexError, TypeError):
1✔
611
                filtered_x_err = None
1✔
612
            filtered_y = self.y[self.filtered_indices]
1✔
613
            filtered_y_err = self.y_err[self.filtered_indices]
1✔
614
            return filtered_x, filtered_x_err, filtered_y, filtered_y_err
1✔
615
        else:
616
            raise ValueError(f"Transient needs to be in flux density, magnitude or flux data mode, "
1✔
617
                             f"but is in {self.data_mode} instead.")
618

619
    @property
1✔
620
    def unique_bands(self) -> np.ndarray:
1✔
621
        """
622
        :return: All bands that we get from the data, eliminating all duplicates.
623
        :rtype: np.ndarray
624
        """
625
        if self.plotting_order is not None:
1✔
626
            return self.plotting_order
×
627
        else:
628
            return np.unique(self.bands)
1✔
629

630
    @property
1✔
631
    def unique_frequencies(self) -> np.ndarray:
1✔
632
        """
633
        :return: All frequencies that we get from the data, eliminating all duplicates.
634
        :rtype: np.ndarray
635
        """
636
        try:
1✔
637
            if isinstance(self.unique_bands[0], (float, int)):
1✔
638
                return self.unique_bands
×
639
        except (TypeError, IndexError):
×
640
            pass
×
641
        return self.bands_to_frequency(self.unique_bands)
1✔
642

643
    @property
1✔
644
    def list_of_band_indices(self) -> list:
1✔
645
        """
646
        :return: Indices that map between bands in the data and the unique bands we obtain.
647
        :rtype: list
648
        """
649
        return [np.where(self.bands == np.array(b))[0] for b in self.unique_bands]
1✔
650

651
    @property
1✔
652
    def default_filters(self) -> list:
1✔
653
        """
654
        :return: Default list of filters to use.
655
        :rtype: list
656
        """
657
        return ["g", "r", "i", "z", "y", "J", "H", "K"]
1✔
658

659
    @staticmethod
1✔
660
    def get_colors(filters: Union[np.ndarray, list]) -> matplotlib.colors.Colormap:
1✔
661
        """
662
        :param filters: Array of list of filters to use in the plot.
663
        :type filters: Union[np.ndarray, list]
664
        :return: Colormap with one color for each filter.
665
        :rtype: matplotlib.colors.Colormap
666
        """
667
        return matplotlib.cm.rainbow(np.linspace(0, 1, len(filters)))
1✔
668

669
    def plot_data(self, axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None, save: bool = True,
1✔
670
            show: bool = True, plot_others: bool = True, color: str = 'k', **kwargs) -> matplotlib.axes.Axes:
671
        """Plots the Transient data and returns Axes.
672

673
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
674
        :param filename: Name of the file to be plotted in.
675
        :param outdir: The directory in which to save the file in.
676
        :param save: Whether to save the plot. (Default value = True)
677
        :param show: Whether to show the plot. (Default value = True)
678
        :param plot_others: Whether to plot inactive bands. (Default value = True)
679
        :param color: Color of the data.
680
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
681
        Available in the online documentation under at `redback.plotting.Plotter`.
682
        `print(Transient.plot_data.__doc__)` to see all options!
683
        :return: The axes with the plot.
684
        """
685

686
        if self.flux_data:
×
687
            if self.optical_data:
×
688
                plotter = IntegratedFluxOpticalPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
689
                                       plot_others=plot_others, **kwargs)
690
            else:
691
                plotter = IntegratedFluxPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
692
        elif self.luminosity_data:
×
NEW
693
            if self.optical_data:
×
NEW
694
                plotter = LuminosityOpticalPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
695
                                                   **kwargs)
696
            else:
NEW
697
                plotter = LuminosityPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
698
        elif self.flux_density_data:
×
699
            plotter = FluxDensityPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
700
                                         plot_others=plot_others, **kwargs)
701
        elif self.magnitude_data:
×
702
            plotter = MagnitudePlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
703
                                       plot_others=plot_others, **kwargs)
704
        else:
705
            return axes
×
706
        return plotter.plot_data(axes=axes, save=save, show=show)
×
707

708
    def plot_multiband(
1✔
709
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, filename: str = None,
710
            outdir: str = None, ncols: int = 2, save: bool = True, show: bool = True,
711
            nrows: int = None, figsize: tuple = None, filters: list = None, **kwargs: None) \
712
            -> matplotlib.axes.Axes:
713
        """
714
        :param figure: Figure can be given if defaults are not satisfying.
715
        :param axes: Axes can be given if defaults are not satisfying.
716
        :param filename: Name of the file to be plotted in.
717
        :param outdir: The directory in which to save the file in.
718
        :param save: Whether to save the plot. (Default value = True)
719
        :param show: Whether to show the plot. (Default value = True)
720
        :param ncols: Number of columns to use on the plot. Default is 2.
721
        :param nrows: Number of rows to use on the plot. If None are given this will
722
                      be inferred from ncols and the number of filters.
723
        :param figsize: Size of the figure. A default based on ncols and nrows will be used if None is given.
724
        :param filters: Which bands to plot. Will use default filters if None is given.
725
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
726
        Available in the online documentation under at `redback.plotting.Plotter`.
727
        `print(Transient.plot_multiband.__doc__)` to see all options!
728
        :return: The axes.
729
        """
730
        if self.data_mode not in ['flux_density', 'magnitude', 'flux']:
×
731
            raise ValueError(
×
732
                f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
733
        if self.magnitude_data:
×
734
            plotter = MagnitudePlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows,
×
735
                                       ncols=ncols, figsize=figsize, **kwargs)
736
        elif self.flux_density_data:
×
737
            plotter = FluxDensityPlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows,
×
738
                                         ncols=ncols, figsize=figsize, **kwargs)
739
        elif self.flux_data:
×
740
            plotter = IntegratedFluxOpticalPlotter(transient=self, filters=filters, filename=filename, outdir=outdir,
×
741
                                                   nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
742
        else:
743
            return
×
744
        return plotter.plot_multiband(figure=figure, axes=axes, save=save, show=show)
×
745

746
    def plot_lightcurve(
1✔
747
            self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
748
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
749
            model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
750
        """
751
        :param model: The model used to plot the lightcurve.
752
        :param filename: The output filename. Otherwise, use default which starts with the name
753
                         attribute and ends with *lightcurve.png.
754
        :param axes: Axes to plot in if given.
755
        :param save:Whether to save the plot.
756
        :param show: Whether to show the plot.
757
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
758
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
759
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
760
        :param model_kwargs: Additional keyword arguments to be passed into the model.
761
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
762
        Available in the online documentation under at `redback.plotting.Plotter`.
763
        `print(Transient.plot_lightcurve.__doc__)` to see all options!
764
        :return: The axes.
765
        """
766
        if self.flux_data:
1✔
767
            if self.optical_data:
1✔
768
                plotter = IntegratedFluxOpticalPlotter(
1✔
769
                    transient=self, model=model, filename=filename, outdir=outdir,
770
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
771
            else:
772
                plotter = IntegratedFluxPlotter(
×
773
                    transient=self, model=model, filename=filename, outdir=outdir,
774
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
775
        elif self.luminosity_data:
1✔
NEW
776
            if self.optical_data:
×
NEW
777
                plotter = LuminosityOpticalPlotter(transient=self, model=model, filename=filename, outdir=outdir,
×
778
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
779
            else:
NEW
780
                plotter = LuminosityPlotter(
×
781
                    transient=self, model=model, filename=filename, outdir=outdir,
782
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
783
        elif self.flux_density_data:
1✔
784
            plotter = FluxDensityPlotter(
1✔
785
                transient=self, model=model, filename=filename, outdir=outdir,
786
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
787
        elif self.magnitude_data:
1✔
788
            plotter = MagnitudePlotter(
1✔
789
                transient=self, model=model, filename=filename, outdir=outdir,
790
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
791
        else:
792
            return axes
×
793
        return plotter.plot_lightcurve(axes=axes, save=save, show=show)
1✔
794

795
    def plot_residual(self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
1✔
796
                      save: bool = True, show: bool = True, posterior: pd.DataFrame = None,
797
                      model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
798
        """
799
        :param model: The model used to plot the lightcurve.
800
        :param filename: The output filename. Otherwise, use default which starts with the name
801
                         attribute and ends with *lightcurve.png.
802
        :param axes: Axes to plot in if given.
803
        :param save:Whether to save the plot.
804
        :param show: Whether to show the plot.
805
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
806
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
807
        :param model_kwargs: Additional keyword arguments to be passed into the model.
808
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
809
        Available in the online documentation under at `redback.plotting.Plotter`.
810
        `print(Transient.plot_residual.__doc__)` to see all options!
811
        :return: The axes.
812
        """
813
        if self.flux_data:
×
814
            plotter = IntegratedFluxPlotter(
×
815
                transient=self, model=model, filename=filename, outdir=outdir,
816
                posterior=posterior, model_kwargs=model_kwargs, **kwargs)
817
        elif self.luminosity_data:
×
NEW
818
            if self.optical_data:
×
NEW
819
                plotter = LuminosityOpticalPlotter(
×
820
                    transient=self, model=model, filename=filename, outdir=outdir,
821
                    posterior=posterior, model_kwargs=model_kwargs, **kwargs)
822
            else:
NEW
823
                plotter = LuminosityPlotter(
×
824
                    transient=self, model=model, filename=filename, outdir=outdir,
825
                    posterior=posterior, model_kwargs=model_kwargs, **kwargs)
826
        else:
827
            raise ValueError("Residual plotting not implemented for this data mode")
×
828
        return plotter.plot_residuals(axes=axes, save=save, show=show)
×
829

830

831
    def fit_gp(self, mean_model, kernel, prior=None, use_frequency=True):
1✔
832
        """
833
        Fit a GP to the data using george and scipy minimization.
834

835
        :param mean_model: Mean model to use in the GP fit. Can be a string to refer to a redback model, a callable, or None
836
        :param kernel: George GP to use. User must ensure this is set up correctly.
837
        :param prior: Prior to use when fitting with a mean model.
838
        :param use_frequency: Whether to use the effective frequency in a 2D GP fit. Cannot be used with most mean models.
839
        :return: Named tuple with George GP object and additional useful data.
840
        """
841
        try:
1✔
842
            import george
1✔
843
            import george.kernels as kernels
1✔
844
        except ImportError:
×
845
            redback.utils.logger.warning("George must be installed to use GP fitting.")
×
846
        import scipy.optimize as op
1✔
847
        from bilby.core.likelihood import function_to_george_mean_model
1✔
848

849
        output = namedtuple("gp_out", ["gp", "scaled_y", "y_scaler", 'use_frequency', 'mean_model'])
1✔
850
        output.use_frequency = use_frequency
1✔
851
        output.mean_model = mean_model
1✔
852

853
        if self.data_mode == 'luminosity':
1✔
854
            x = self.time_rest_frame
1✔
855
            y = self.y
1✔
856
            try:
1✔
857
                y_err = np.max(self.y_err, axis=0)
1✔
858
            except IndexError:
×
859
                y_err = self.y_err
×
860
        else:
861
            x, x_err, y, y_err = self.get_filtered_data()
×
862
        redback.utils.logger.info("Rescaling data for GP fitting.")
1✔
863
        gp_y_err = y_err / np.max(y)
1✔
864
        gp_y = y / np.max(y)
1✔
865
        output.scaled_y = gp_y
1✔
866
        output.y_scaler = np.max(y)
1✔
867

868
        def nll(p):
1✔
869
            gp.set_parameter_vector(p)
×
870
            ll = gp.log_likelihood(gp_y, quiet=True)
×
871
            return -ll if np.isfinite(ll) else 1e25
×
872

873
        def grad_nll(p):
1✔
874
            gp.set_parameter_vector(p)
×
875
            return -gp.grad_log_likelihood(gp_y, quiet=True)
×
876

877
        if use_frequency:
1✔
878
            redback.utils.logger.info("Using frequencies and time in the GP fit.")
1✔
879
            redback.utils.logger.info("Kernel used: " + str(kernel))
1✔
880
            redback.utils.logger.info("Ensure that the kernel is set up correctly for 2D GP.")
1✔
881
            redback.utils.logger.info("You will be returned a single GP object with frequency as a parameter")
1✔
882
            freqs = self.filtered_frequencies
1✔
883
            X = np.column_stack((freqs, x))
1✔
884
        else:
885
            redback.utils.logger.info("Using time in GP fit.")
1✔
886
            redback.utils.logger.info("Kernel used: " + str(kernel))
1✔
887
            redback.utils.logger.info("Ensure that the kernel is set up correctly for 1D GP.")
1✔
888
            redback.utils.logger.info("You will be returned a GP object unique to a band/frequency"
1✔
889
                                      " in the data if working with multiband data")
890
            X = x
1✔
891

892
        if mean_model is None:
1✔
893
            redback.utils.logger.info("Mean model not given, fitting GP with no mean model.")
1✔
894
            gp = george.GP(kernel)
1✔
895
            gp.compute(X, gp_y_err + 1e-8)
1✔
896
            p0 = gp.get_parameter_vector()
1✔
897
            results = op.minimize(nll, p0, jac=grad_nll)
1✔
898
            gp.set_parameter_vector(results.x)
1✔
899
            redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y)}")
1✔
900
            redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
1✔
901
            output.gp = gp
1✔
902
        else:
903
            if isinstance(mean_model, str):
1✔
904
                mean_model_func = all_models_dict[mean_model]
×
905
                redback.utils.logger.info("Using inbuilt redback function {} as a mean model.".format(mean_model))
×
906
                if prior is None:
×
907
                    redback.utils.logger.warning("No prior given for mean model. Using default prior.")
×
908
                    prior = redback.priors.get_priors(mean_model)
×
909
            else:
910
                mean_model_func = mean_model
1✔
911
                redback.utils.logger.info("Using user-defined python function as a mean model.")
1✔
912

913
            if prior is None:
1✔
914
                redback.utils.logger.warning("Prior must be specified for GP fit with a mean model")
1✔
915
                raise ValueError("No prior specified")
1✔
916

917
            if self.data_mode in ['flux_density', 'magnitude', 'flux']:
1✔
918
                redback.utils.logger.info("Setting up GP version of mean model.")
×
919
                gp_dict = {}
×
920
                scaled_y_dict = {}
×
921
                for ii in range(len(self.unique_bands)):
×
922
                    scaled_y_dict[self.unique_bands[ii]] = gp_y[self.list_of_band_indices[ii]]
×
923
                    redback.utils.logger.info("Fitting for band {}".format(self.unique_bands[ii]))
×
924
                    gp_x = X[self.list_of_band_indices[ii]]
×
925

926
                    def nll(p):
×
927
                        gp.set_parameter_vector(p)
×
928
                        ll = gp.log_likelihood(gp_y[self.list_of_band_indices[ii]], quiet=True)
×
929
                        return -ll if np.isfinite(ll) else 1e25
×
930

931
                    mean_model_class = function_to_george_mean_model(mean_model_func)
×
932
                    mm = mean_model_class(**prior.sample())
×
933
                    gp = george.GP(kernel, mean=mm, fit_mean=True)
×
NEW
934
                    gp.compute(gp_x, gp_y_err[self.list_of_band_indices[ii]] + 1e-8)
×
935
                    p0 = gp.get_parameter_vector()
×
936
                    results = op.minimize(nll, p0)
×
937
                    gp.set_parameter_vector(results.x)
×
938
                    redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y[self.list_of_band_indices[ii]])}")
×
939
                    redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
×
940
                    gp_dict[self.unique_bands[ii]] = gp
×
941
                    del gp
×
942
                output.gp = gp_dict
×
943
                output.scaled_y = scaled_y_dict
×
944
            else:
945
                mean_model_class = function_to_george_mean_model(mean_model_func)
1✔
946
                mm = mean_model_class(**prior.sample())
1✔
947
                gp = george.GP(kernel, mean=mm, fit_mean=True)
1✔
948
                gp.compute(X, gp_y_err + 1e-8)
1✔
949
                p0 = gp.get_parameter_vector()
1✔
950
                results = op.minimize(nll, p0)
1✔
951
                gp.set_parameter_vector(results.x)
1✔
952
                redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y)}")
1✔
953
                redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
1✔
954
                output.gp = gp
1✔
955
        return output
1✔
956

957
    def plot_multiband_lightcurve(
1✔
958
            self, model: callable, filename: str = None, outdir: str = None,
959
            figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None,
960
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
961
            model_kwargs: dict = None, **kwargs: object) -> matplotlib.axes.Axes:
962
        """
963
        :param model: The model used to plot the lightcurve.
964
        :param filename: The output filename. Otherwise, use default which starts with the name
965
                         attribute and ends with *lightcurve.png.
966
        :param figure: Figure can be given if defaults are not satisfying.
967
        :param axes: Axes to plot in if given.
968
        :param save:Whether to save the plot.
969
        :param show: Whether to show the plot.
970
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
971
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
972
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
973
        :param model_kwargs: Additional keyword arguments to be passed into the model.
974
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
975
        Available in the online documentation under at `redback.plotting.Plotter`.
976
        `print(Transient.plot_multiband_lightcurve.__doc__)` to see all options!
977

978
        :return: The axes.
979
        """
980
        if self.data_mode not in ['flux_density', 'magnitude', 'flux']:
1✔
981
            raise ValueError(
×
982
                f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
983
        if self.magnitude_data:
1✔
984
            plotter = MagnitudePlotter(
1✔
985
                transient=self, model=model, filename=filename, outdir=outdir,
986
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
987
        elif self.flux_data:
1✔
988
            plotter = IntegratedFluxOpticalPlotter(transient=self, model=model, filename=filename, outdir=outdir,
1✔
989
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
990
        elif self.flux_density_data:
1✔
991
            plotter = FluxDensityPlotter(
1✔
992
                transient=self, model=model, filename=filename, outdir=outdir,
993
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
994
        else:
995
            return
×
996
        return plotter.plot_multiband_lightcurve(figure=figure, axes=axes, save=save, show=show)
1✔
997

998
    _formatted_kwargs_options = redback.plotting.Plotter.keyword_docstring
1✔
999
    plot_data.__doc__ = plot_data.__doc__.replace(
1✔
1000
        "`print(Transient.plot_data.__doc__)` to see all options!", _formatted_kwargs_options)
1001
    plot_multiband.__doc__ = plot_multiband.__doc__.replace(
1✔
1002
        "`print(Transient.plot_multiband.__doc__)` to see all options!", _formatted_kwargs_options)
1003
    plot_lightcurve.__doc__ = plot_lightcurve.__doc__.replace(
1✔
1004
        "`print(Transient.plot_lightcurve.__doc__)` to see all options!", _formatted_kwargs_options)
1005
    plot_multiband_lightcurve.__doc__ = plot_multiband_lightcurve.__doc__.replace(
1✔
1006
        "`print(Transient.plot_multiband_lightcurve.__doc__)` to see all options!", _formatted_kwargs_options)
1007
    plot_residual.__doc__ = plot_residual.__doc__.replace(
1✔
1008
        "`print(Transient.plot_residual.__doc__)` to see all options!", _formatted_kwargs_options)
1009

1010

1011
class OpticalTransient(Transient):
1✔
1012
    DATA_MODES = ['flux', 'flux_density', 'magnitude', 'luminosity']
1✔
1013

1014
    @staticmethod
1✔
1015
    def load_data(processed_file_path, data_mode="magnitude"):
1✔
1016
        """Loads data from specified directory and file, and returns it as a tuple.
1017

1018
        :param processed_file_path: Path to the processed file to load
1019
        :type processed_file_path: str
1020
        :param data_mode: Name of the data mode.
1021
                          Must be from ['magnitude', 'flux_density', 'all']. Default is magnitude.
1022
        :type data_mode: str, optional
1023

1024
        :return: Six elements when querying magnitude or flux_density data, Eight for 'all'
1025
        :rtype: tuple
1026
        """
1027
        df = pd.read_csv(processed_file_path)
1✔
1028
        time_days = np.array(df["time (days)"])
1✔
1029
        time_mjd = np.array(df["time"])
1✔
1030
        magnitude = np.array(df["magnitude"])
1✔
1031
        magnitude_err = np.array(df["e_magnitude"])
1✔
1032
        bands = np.array(df["band"])
1✔
1033
        system = np.array(df["system"])
1✔
1034
        flux_density = np.array(df["flux_density(mjy)"])
1✔
1035
        flux_density_err = np.array(df["flux_density_error"])
1✔
1036
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
1037
        flux_err = np.array(df['flux_error'])
1✔
1038
        if data_mode == "magnitude":
1✔
1039
            return time_days, time_mjd, magnitude, magnitude_err, bands, system
1✔
1040
        elif data_mode == "flux_density":
1✔
1041
            return time_days, time_mjd, flux_density, flux_density_err, bands, system
1✔
1042
        elif data_mode == "flux":
1✔
1043
            return time_days, time_mjd, flux, flux_err, bands, system
×
1044
        elif data_mode == "all":
1✔
1045
            return time_days, time_mjd, flux_density, flux_density_err, \
1✔
1046
                   magnitude, magnitude_err, flux, flux_err, bands, system
1047

1048
    def __init__(
1✔
1049
            self, name: str, data_mode: str = 'magnitude', time: np.ndarray = None, time_err: np.ndarray = None,
1050
            time_mjd: np.ndarray = None, time_mjd_err: np.ndarray = None, time_rest_frame: np.ndarray = None,
1051
            time_rest_frame_err: np.ndarray = None, Lum50: np.ndarray = None, Lum50_err: np.ndarray = None,
1052
            flux: np.ndarray = None, flux_err: np.ndarray = None, flux_density: np.ndarray = None,
1053
            flux_density_err: np.ndarray = None, magnitude: np.ndarray = None, magnitude_err: np.ndarray = None,
1054
            redshift: float = np.nan, photon_index: float = np.nan, frequency: np.ndarray = None,
1055
            bands: np.ndarray = None, system: np.ndarray = None, active_bands: Union[np.ndarray, str] = 'all',
1056
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False,
1057
            optical_data:bool = True, **kwargs: None) -> None:
1058
        """This is a general constructor for the Transient class. Note that you only need to give data corresponding to
1059
        the data mode you are using. For luminosity data provide times in the rest frame, if using a phase model
1060
        provide time in MJD, else use the default time (observer frame).
1061

1062
        :param name: Name of the transient.
1063
        :type name: str
1064
        :param data_mode: Data mode. Must be one from `OpticalTransient.DATA_MODES`.
1065
        :type data_mode: str, optional
1066
        :param time: Times in the observer frame.
1067
        :type time: np.ndarray, optional
1068
        :param time_err: Time errors in the observer frame.
1069
        :type time_err: np.ndarray, optional
1070
        :param time_mjd: Times in MJD. Used if using phase model.
1071
        :type time_mjd: np.ndarray, optional
1072
        :param time_mjd_err: Time errors in MJD. Used if using phase model.
1073
        :type time_mjd_err: np.ndarray, optional
1074
        :param time_rest_frame: Times in the rest frame. Used for luminosity data.
1075
        :type time_rest_frame: np.ndarray, optional
1076
        :param time_rest_frame_err: Time errors in the rest frame. Used for luminosity data.
1077
        :type time_rest_frame_err: np.ndarray, optional
1078
        :param Lum50: Luminosity values.
1079
        :type Lum50: np.ndarray, optional
1080
        :param Lum50_err: Luminosity error values.
1081
        :type Lum50_err: np.ndarray, optional
1082
        :param flux: Flux values.
1083
        :type flux: np.ndarray, optional
1084
        :param flux_err: Flux error values.
1085
        :type flux_err: np.ndarray, optional
1086
        :param flux_density: Flux density values.
1087
        :type flux_density: np.ndarray, optional
1088
        :param flux_density_err: Flux density error values.
1089
        :type flux_density_err: np.ndarray, optional
1090
        :param magnitude: Magnitude values for photometry data.
1091
        :type magnitude: np.ndarray, optional
1092
        :param magnitude_err: Magnitude error values for photometry data.
1093
        :type magnitude_err: np.ndarray, optional
1094
        :param redshift: Redshift value.
1095
        :type redshift: float, optional
1096
        :param photon_index: Photon index value.
1097
        :type photon_index: float, optional
1098
        :param frequency: Array of band frequencies in photometry data.
1099
        :type frequency: np.ndarray, optional
1100
        :param bands: Band values.
1101
        :type bands: np.ndarray, optional
1102
        :param system: System values.
1103
        :type system: np.ndarray, optional
1104
        :param active_bands: List or array of active bands to be used in the analysis.
1105
                             Use all available bands if 'all' is given.
1106
        :type active_bands: Union[list, np.ndarray], optional
1107
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
1108
        :type plotting_order: Union[np.ndarray, str], optional
1109
        :param use_phase_model: Whether we are using a phase model.
1110
        :type use_phase_model: bool, optional
1111
        :param optical_data: Whether we are fitting optical data, useful for plotting.
1112
        :type optical_data: bool, optional
1113
        :param kwargs:
1114
            Additional callables:
1115
            bands_to_frequency: Conversion function to convert a list of bands to frequencies. Use
1116
                                  redback.utils.bands_to_frequency if not given.
1117
        :type kwargs: dict, optional
1118
        """
1119
        super().__init__(time=time, time_err=time_err, time_rest_frame=time_rest_frame, time_mjd=time_mjd,
1✔
1120
                         time_mjd_err=time_mjd_err, frequency=frequency,
1121
                         time_rest_frame_err=time_rest_frame_err, Lum50=Lum50, Lum50_err=Lum50_err,
1122
                         flux=flux, flux_err=flux_err, redshift=redshift, photon_index=photon_index,
1123
                         flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
1124
                         magnitude_err=magnitude_err, data_mode=data_mode, name=name,
1125
                         use_phase_model=use_phase_model, optical_data=optical_data,
1126
                         system=system, bands=bands, plotting_order=plotting_order,
1127
                         active_bands=active_bands, **kwargs)
1128
        self.directory_structure = redback.get_data.directory.DirectoryStructure(
1✔
1129
            directory_path=".", raw_file_path=".", processed_file_path=".")
1130

1131
    @classmethod
1✔
1132
    def from_open_access_catalogue(
1✔
1133
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
1134
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False) -> OpticalTransient:
1135
        """Constructor method to built object from Open Access Catalogue
1136

1137
        :param name: Name of the transient.
1138
        :type name: str
1139
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
1140
        :type data_mode: str, optional
1141
        :param active_bands:
1142
            Sets active bands based on array given.
1143
            If argument is 'all', all unique bands in `self.bands` will be used.
1144
        :type active_bands: Union[np.ndarray, str]
1145
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
1146
        :type plotting_order: Union[np.ndarray, str], optional
1147
        :param use_phase_model: Whether to use a phase model.
1148
        :type use_phase_model: bool, optional
1149

1150
        :return: A class instance
1151
        :rtype: OpticalTransient
1152
        """
1153
        if cls.__name__ == "TDE":
1✔
1154
            transient_type = "tidal_disruption_event"
×
1155
        else:
1156
            transient_type = cls.__name__.lower()
1✔
1157
        directory_structure = redback.get_data.directory.open_access_directory_structure(
1✔
1158
            transient=name, transient_type=transient_type)
1159
        time_days, time_mjd, flux_density, flux_density_err, magnitude, magnitude_err, flux, flux_err, bands, system = \
1✔
1160
            cls.load_data(processed_file_path=directory_structure.processed_file_path, data_mode="all")
1161
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
1162
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
1163
                   magnitude_err=magnitude_err, bands=bands, system=system, active_bands=active_bands,
1164
                   use_phase_model=use_phase_model, optical_data=True, flux=flux, flux_err=flux_err,
1165
                   plotting_order=plotting_order)
1166

1167
    @property
1✔
1168
    def event_table(self) -> str:
1✔
1169
        """
1170
        :return: Path to the metadata table.
1171
        :rtype: str
1172
        """
1173
        return f"{self.directory_structure.directory_path}/{self.name}_metadata.csv"
1✔
1174

1175
    def _set_data(self) -> None:
1✔
1176
        """Sets the metadata from the event table."""
1177
        try:
1✔
1178
            meta_data = pd.read_csv(self.event_table, on_bad_lines='skip', delimiter=',', dtype='str')
1✔
1179
        except FileNotFoundError as e:
1✔
1180
            redback.utils.logger.warning(e)
1✔
1181
            redback.utils.logger.warning("Setting metadata to None. This is not an error, but a warning that no metadata could be found online.")
1✔
1182
            meta_data = None
1✔
1183
        self.meta_data = meta_data
1✔
1184

1185
    @property
1✔
1186
    def transient_dir(self) -> str:
1✔
1187
        """
1188
        :return: The transient directory given the name of the transient.
1189
        :rtype: str
1190
        """
1191
        return self._get_transient_dir()
1✔
1192

1193
    def _get_transient_dir(self) -> str:
1✔
1194
        """
1195

1196
        :return: The transient directory path
1197
        :rtype: str
1198
        """
1199
        transient_dir, _, _ = redback.get_data.directory.open_access_directory_structure(
1✔
1200
            transient=self.name, transient_type=self.__class__.__name__.lower())
1201
        return transient_dir
1✔
1202

1203
    def estimate_bb_params(self, distance: float = 1e27, bin_width: float = 1.0, min_filters: int = 3, **kwargs):
1✔
1204
        """
1205
        Estimate the blackbody temperature and photospheric radius as functions of time by fitting
1206
        a blackbody SED to the multi‑band photometry.
1207

1208
        The method groups the photometric data into time bins (epochs) of width bin_width (in the
1209
        same units as self.x, typically days). For each epoch with at least min_filters measurements
1210
        (from distinct filters), it fits a blackbody model to the data. When working with photometry
1211
        provided in an effective flux density format (data_mode == "flux_density") the effective–wavelength
1212
        approximation is used. When the data_mode is "flux" (or "magnitude") users have the option
1213
        (via use_eff_wavelength=True) to instead use the effective wavelength approximation by converting AB
1214
        magnitudes to flux density (using redback.utils.calc_flux_density_from_ABmag). If this flag is not
1215
        provided (or is False) then the full bandpass integration is applied.
1216

1217
        Parameters
1218
        ----------
1219
        distance : float, optional
1220
            Distance to the transient in centimeters. Default is 1e27 cm.
1221
        bin_width : float, optional
1222
            Width of the time bins (in days) used to group the photometric data. Default is 1.0.
1223
        min_filters : int, optional
1224
            Minimum number of measurements (from distinct filters) required in a bin to perform the fit.
1225
            Default is 3.
1226
        kwargs : Additional keyword arguments
1227
            maxfev : int, optional, default is 1000
1228
            T_init : float, optional, default is 1e4, used as the initial guess for the fit.
1229
            R_init : float, optional, default is 1e15, used as the initial guess for the fit.
1230
            use_eff_wavelength : bool, optional, default is False.
1231
                If True, then even for photometry provided as magnitudes (or bandpass fluxes),
1232
                the effective wavelength approximation is used. In that case the AB magnitudes are
1233
                converted to flux densities via redback.utils.calc_flux_density_from_ABmag.
1234
                If False, full bandpass integration is used.
1235

1236
        Returns
1237
        -------
1238
        df_bb : pandas.DataFrame or None
1239
            A DataFrame containing columns:
1240
              - epoch_times : binned epoch times,
1241
              - temperature : best-fit blackbody temperatures (Kelvin),
1242
              - radius : best-fit photospheric radii (cm),
1243
              - temp_err : 1σ uncertainties on the temperatures,
1244
              - radius_err : 1σ uncertainties on the radii.
1245
            Returns None if insufficient data are available.
1246
        """
1247
        from scipy.optimize import curve_fit
1✔
1248
        import astropy.units as uu
1✔
1249
        import numpy as np
1✔
1250
        import pandas as pd
1✔
1251

1252
        # Get the filtered photometry.
1253
        # Assumes self.get_filtered_data() returns (time, time_err, y, y_err)
1254
        time_data, _, flux_data, flux_err_data = self.get_filtered_data()
1✔
1255

1256
        redback.utils.logger.info("Estimating blackbody parameters for {}.".format(self.name))
1✔
1257
        redback.utils.logger.info("Using data mode = {}".format(self.data_mode))
1✔
1258

1259
        # Determine whether we are in bandpass mode.
1260
        use_bandpass = False
1✔
1261
        if hasattr(self, "data_mode") and self.data_mode in ['flux', 'magnitude']:
1✔
NEW
1262
            use_bandpass = True
×
1263
            # Assume self.filtered_sncosmo_bands contains the (string) band names.
NEW
1264
            band_data = self.filtered_sncosmo_bands
×
1265
        else:
1266
            # Otherwise the flux data and frequencies are assumed to be given.
1267
            redback.utils.logger.info("Using effective wavelength approximation for {}".format(self.data_mode))
1✔
1268
            freq_data = self.filtered_frequencies
1✔
1269

1270
        # Option: force effective wavelength approximation even if data_mode is bandpass.
1271
        force_eff = kwargs.get('use_eff_wavelength', False)
1✔
1272
        if use_bandpass and force_eff:
1✔
NEW
1273
            redback.utils.logger.warning("Using effective wavelength approximation for {}".format(self.data_mode))
×
1274

NEW
1275
            if self.data_mode == 'magnitude':
×
1276
                # Convert the AB magnitudes to flux density using the redback function.
NEW
1277
                from redback.utils import abmag_to_flux_density_and_error_inmjy
×
NEW
1278
                flux_data, flux_err_data = abmag_to_flux_density_and_error_inmjy(flux_data, flux_err_data)
×
NEW
1279
                freq_data = redback.utils.bands_to_frequency(band_data)
×
1280
            else:
1281
                # Convert the bandpass fluxes to flux density using the redback function.
NEW
1282
                from redback.utils import bandpass_flux_to_flux_density, bands_to_effective_width
×
NEW
1283
                redback.utils.logger.warning("Ensure filters.csv has the correct bandpass effective widths for your filter.")
×
NEW
1284
                effective_widths = bands_to_effective_width(band_data)
×
NEW
1285
                freq_data = redback.utils.bands_to_frequency(band_data)
×
NEW
1286
                flux_data, flux_err_data = bandpass_flux_to_flux_density(flux_data, flux_err_data, effective_widths)
×
1287
            # Use the effective frequency approach.
NEW
1288
            use_bandpass = False
×
1289

1290
        # Get initial guesses.
1291
        T_init = kwargs.get('T_init', 1e4)
1✔
1292
        R_init = kwargs.get('R_init', 1e15)
1✔
1293
        maxfev = kwargs.get('maxfev', 1000)
1✔
1294

1295
        # Sort photometric data by time.
1296
        sort_idx = np.argsort(time_data)
1✔
1297
        time_data = time_data[sort_idx]
1✔
1298
        flux_data = flux_data[sort_idx]
1✔
1299
        flux_err_data = flux_err_data[sort_idx]
1✔
1300
        if use_bandpass:
1✔
NEW
1301
            band_data = np.array(band_data)[sort_idx]
×
1302
        else:
1303
            freq_data = np.array(freq_data)[sort_idx]
1✔
1304

1305
        # Retrieve redshift.
1306
        redshift = np.nan_to_num(self.redshift)
1✔
1307
        if redshift <= 0.:
1✔
NEW
1308
            raise ValueError("Redshift must be provided to perform K-correction.")
×
1309

1310
        # For effective frequency mode, K-correct frequencies.
1311
        if not use_bandpass:
1✔
1312
            freq_data, _ = redback.utils.calc_kcorrected_properties(frequency=freq_data,
1✔
1313
                                                                    redshift=redshift, time=0.)
1314

1315
        # Define the model functions.
1316
        if not use_bandpass:
1✔
1317
            # --- Effective-wavelength model ---
1318
            def bb_model(freq, logT, logR):
1✔
1319
                T = 10 ** logT
1✔
1320
                R = 10 ** logR
1✔
1321
                # Compute the model flux density in erg/s/cm^2/Hz.
1322
                model_flux_cgs = redback.sed.blackbody_to_flux_density(T, R, distance, freq)
1✔
1323
                # Convert to mJy. (1 Jy = 1e-23 erg/s/cm^2/Hz; 1 mJy = 1e-3 Jy = 1e-26 erg/s/cm^2/Hz)
1324
                model_flux_mjy = (model_flux_cgs / (1e-26 * uu.erg / uu.s / uu.cm**2 / uu.Hz)).value
1✔
1325
                return model_flux_mjy
1✔
1326

1327
            model_func = bb_model
1✔
1328
        else:
1329
            # --- Full bandpass integration model ---
1330
            # In this branch we do NOT want to pass strings to curve_fit.
1331
            # Instead, we will dummy-encode the independent variable as indices.
1332
            # We also capture the band names in a closure variable.
NEW
1333
            def bb_model_bandpass_from_index(x, logT, logR):
×
1334
                # Ensure x is a numpy array and convert indices to integers.
NEW
1335
                i_idx = np.round(x).astype(int)
×
1336
                # Retrieve all corresponding band names in one step.
NEW
1337
                bands = np.array(epoch_bands)[i_idx]
×
1338
                # Call bb_model_bandpass with the entire array of bands.
NEW
1339
                return bb_model_bandpass(bands, logT, logR, redshift, distance, output_format=self.data_mode)
×
1340

NEW
1341
            def bb_model_bandpass(band, logT, logR, redshift, distance, output_format='magnitude'):
×
NEW
1342
                from redback.utils import calc_kcorrected_properties, lambda_to_nu, bandpass_magnitude_to_flux
×
1343
                # Create a wavelength grid (in Å) from 100 to 80,000 Å.
NEW
1344
                lambda_obs = np.geomspace(100, 80000, 300)
×
1345
                # Convert to frequency (Hz) and apply K-correction.
NEW
1346
                frequency, _ = calc_kcorrected_properties(frequency=lambda_to_nu(lambda_obs),
×
1347
                                                          redshift=redshift, time=0.)
NEW
1348
                T = 10 ** logT
×
NEW
1349
                R = 10 ** logR
×
1350
                # Compute the model SED (flux density in erg/s/cm^2/Hz).
NEW
1351
                model_flux = redback.sed.blackbody_to_flux_density(T, R, distance, frequency)
×
1352
                # Convert the SED to per-Å units.
NEW
1353
                _spectra = model_flux.to(uu.erg / uu.cm ** 2 / uu.s / uu.Angstrom,
×
1354
                                         equivalencies=uu.spectral_density(wav=lambda_obs * uu.Angstrom))
NEW
1355
                spectra = np.zeros((5, 300))
×
NEW
1356
                spectra[:, :] = _spectra.value
×
1357
                # Create a source object from the spectrum.
NEW
1358
                source = redback.sed.RedbackTimeSeriesSource(phase=np.array([0, 1, 2, 3, 4]),
×
1359
                                                             wave=lambda_obs, flux=spectra)
NEW
1360
                if output_format == 'flux':
×
1361
                    # Convert bandpass magnitude to flux.
NEW
1362
                    mag = source.bandmag(phase=0, band=band, magsys='ab')
×
NEW
1363
                    return bandpass_magnitude_to_flux(magnitude=mag, bands=band)
×
NEW
1364
                elif output_format == 'magnitude':
×
NEW
1365
                    mag = source.bandmag(phase=0, band=band, magsys='ab')
×
NEW
1366
                    return mag
×
1367
                else:
NEW
1368
                    raise ValueError("Unknown output_format in bb_model_bandpass.")
×
1369

1370
            # Our wrapper for curve_fit uses dummy x-values.
NEW
1371
            model_func = bb_model_bandpass_from_index
×
1372

1373
        # Initialize lists to store fit results.
1374
        epoch_times = []
1✔
1375
        temperatures = []
1✔
1376
        radii = []
1✔
1377
        temp_errs = []
1✔
1378
        radius_errs = []
1✔
1379

1380
        t_min = np.min(time_data)
1✔
1381
        t_max = np.max(time_data)
1✔
1382
        bins = np.arange(t_min, t_max + bin_width, bin_width)
1✔
1383
        redback.utils.logger.info("Number of bins: {}".format(len(bins)))
1✔
1384

1385
        # Ensure at least one bin has enough points.
1386
        bins_with_enough = [i for i in range(len(bins) - 1)
1✔
1387
                            if np.sum((time_data >= bins[i]) & (time_data < bins[i + 1])) >= min_filters]
1388
        if len(bins_with_enough) == 0:
1✔
NEW
1389
            redback.utils.logger.warning("No time bins have at least {} measurements. Fitting cannot proceed.".format(min_filters))
×
NEW
1390
            redback.utils.logger.warning("Try generating more data through GPs, increasing bin widths, or using fewer filters.")
×
NEW
1391
            return None
×
1392

1393
        # Loop over bins (epochs): for each with enough data perform the fit.
1394
        for i in range(len(bins) - 1):
1✔
1395
            mask = (time_data >= bins[i]) & (time_data < bins[i + 1])
1✔
1396
            if np.sum(mask) < min_filters:
1✔
NEW
1397
                continue
×
1398
            t_epoch = np.mean(time_data[mask])
1✔
1399
            try:
1✔
1400
                if not use_bandpass:
1✔
1401
                    # Use effective frequency array (numeric).
1402
                    xdata = freq_data[mask]
1✔
1403
                else:
1404
                    # For full bandpass integration mode, we dummy encode xdata.
1405
                    # We ignore the value and simply use indices [0, 1, 2, ...].
NEW
1406
                    epoch_bands = list(band_data[mask])  # capture the list of bands for this epoch
×
NEW
1407
                    xdata = np.arange(len(epoch_bands))
×
1408
                popt, pcov = curve_fit(
1✔
1409
                    model_func,
1410
                    xdata,
1411
                    flux_data[mask],
1412
                    sigma=flux_err_data[mask],
1413
                    p0=[np.log10(T_init), np.log10(R_init)],
1414
                    absolute_sigma=True,
1415
                    maxfev=maxfev
1416
                )
NEW
1417
            except Exception as e:
×
NEW
1418
                redback.utils.logger.warning(f"Fit failed for epoch {i}: {e}")
×
NEW
1419
                redback.utils.logger.warning(f"Skipping epoch {i} with time {t_epoch:.2f} days.")
×
NEW
1420
                continue
×
1421

1422
            logT_fit, logR_fit = popt
1✔
1423
            T_fit = 10 ** logT_fit
1✔
1424
            R_fit = 10 ** logR_fit
1✔
1425
            perr = np.sqrt(np.diag(pcov))
1✔
1426
            T_err = np.log(10) * T_fit * perr[0]
1✔
1427
            R_err = np.log(10) * R_fit * perr[1]
1✔
1428

1429
            epoch_times.append(t_epoch)
1✔
1430
            temperatures.append(T_fit)
1✔
1431
            radii.append(R_fit)
1✔
1432
            temp_errs.append(T_err)
1✔
1433
            radius_errs.append(R_err)
1✔
1434

1435
        if len(epoch_times) == 0:
1✔
NEW
1436
            redback.utils.logger.warning("No epochs with sufficient data yielded a successful fit.")
×
NEW
1437
            return None
×
1438

1439
        df_bb = pd.DataFrame({
1✔
1440
            'epoch_times': epoch_times,
1441
            'temperature': temperatures,
1442
            'radius': radii,
1443
            'temp_err': temp_errs,
1444
            'radius_err': radius_errs
1445
        })
1446

1447
        redback.utils.logger.info('Masking epochs with likely wrong extractions')
1✔
1448
        df_bb = df_bb[df_bb['temp_err'] / df_bb['temperature'] < 1]
1✔
1449
        df_bb = df_bb[df_bb['radius_err'] / df_bb['radius'] < 1]
1✔
1450
        return df_bb
1✔
1451

1452

1453
    def estimate_bolometric_luminosity(self, distance: float = 1e27, bin_width: float = 1.0,
1✔
1454
                                          min_filters: int = 3, **kwargs):
1455
        """
1456
        Estimate the bolometric luminosity as a function of time by fitting the blackbody SED
1457
        to the multi‑band photometry and then integrating that spectrum. For each epoch the bolometric
1458
        luminosity is computed using the Stefan–Boltzmann law evaluated at the source:
1459

1460
            L_bol = 4 π R² σ_SB T⁴
1461

1462
        Uncertainties in T and R are propagated assuming
1463

1464
            (ΔL_bol / L_bol)² = (2 ΔR / R)² + (4 ΔT / T)².
1465

1466
        Optionally, two corrections can be applied:
1467

1468
        1. A boost–factor to “restore” missing blue flux. If a cutoff wavelength is provided via
1469
           the keyword 'lambda_cut' (in angstroms), it is converted to centimeters and a boost factor is
1470
           calculated as:
1471

1472
               Boost = (F_tot / F_red)
1473

1474
           where F_tot = σ_SB T⁴ and F_red is computed by numerically integrating π * B_λ(T)
1475
           from the cutoff wavelength (in cm) to infinity. The final (boosted) luminosity becomes:
1476

1477
               L_boosted = Boost × (4π R² σ_SB T⁴).
1478

1479
        2. An extinction correction. If the bolometric extinction (A_ext, in magnitudes) is supplied via
1480
           the keyword 'A_ext', the luminosity will be reduced by a factor of 10^(–0.4·A_ext) to account
1481
           for dust extinction. (A_ext defaults to 0.)
1482

1483
        Parameters
1484
        ----------
1485
        distance : float, optional
1486
            Distance to the transient in centimeters. (Default is 1e27 cm.)
1487
        bin_width : float, optional
1488
            Width of the time bins (in days) used for grouping photometry. (Default is 1.0.)
1489
        min_filters : int, optional
1490
            Minimum number of independent filters required in a bin to perform a fit. (Default is 3.)
1491
        kwargs : dict, optional
1492
            Additional keyword arguments to pass to `estimate_bb_params` (e.g., maxfev, T_init, R_init,
1493
            use_eff_wavelength, etc.). Additionally:
1494
        - 'lambda_cut': If provided (in angstroms), the bolometric luminosity will be “boosted”
1495
          to account for missing blue flux.
1496
        - 'A_ext': Bolometric extinction in magnitudes. The observed luminosity is increased by a factor
1497
          10^(+0.4·A_ext). (Default is 0.)
1498

1499
        Returns
1500
        -------
1501
        df_bol : pandas.DataFrame or None
1502
            A DataFrame containing columns:
1503
              - epoch_times: Mean time of the bin (days).
1504
              - temperature: Fitted blackbody temperature (K).
1505
              - radius: Fitted photospheric radius (cm).
1506
              - lum_bol: Derived bolometric luminosity (1e50 erg/s) computed as 4π R² σ_SB T⁴
1507
                         (boosted and extinction-corrected if requested).
1508
              - lum_bol_bb: Derived bolometric blackbody luminosity (1e50 erg/s) computed as 4π R² σ_SB T⁴,
1509
                            before applying either the boost or extinction correction.
1510
              - lum_bol_err: 1σ uncertainty on L_bol (1e50 erg/s) from error propagation.
1511
              - time_rest_frame: Epoch time divided by (1+redshift), i.e., the rest-frame time in days.
1512
            Returns None if no valid blackbody fits were obtained.
1513
        """
1514
        from redback.sed import boosted_bolometric_luminosity
1✔
1515

1516
        # Retrieve optional lambda_cut (in angstroms) for the boost correction.
1517
        lambda_cut_angstrom = kwargs.pop('lambda_cut', None)
1✔
1518
        if lambda_cut_angstrom is not None:
1✔
1519
            redback.utils.logger.info("Including effects of missing flux due to line blanketing.")
1✔
1520
            redback.utils.logger.info(
1✔
1521
                "Using lambda_cut = {} Å for bolometric luminosity boost.".format(lambda_cut_angstrom))
1522
            # Convert lambda_cut from angstroms to centimeters (1 Å = 1e-8 cm)
1523
            lambda_cut = lambda_cut_angstrom * 1e-8
1✔
1524
        else:
1525
            redback.utils.logger.info("No lambda_cut provided; no correction applied. Assuming a pure blackbody SED.")
1✔
1526
            lambda_cut = None
1✔
1527

1528
        # Retrieve optional extinction in magnitudes.
1529
        A_ext = kwargs.pop('A_ext', 0.0)
1✔
1530
        if A_ext != 0.0:
1✔
1531
            redback.utils.logger.info("Applying extinction correction with A_ext = {} mag.".format(A_ext))
1✔
1532
        extinction_factor = 10 ** (0.4 * A_ext)
1✔
1533

1534
        # Retrieve blackbody parameters via your existing method.
1535
        df_bb = self.estimate_bb_params(distance=distance, bin_width=bin_width, min_filters=min_filters, **kwargs)
1✔
1536
        if df_bb is None or len(df_bb) == 0:
1✔
NEW
1537
            redback.utils.logger.warning("No valid blackbody fits were obtained; cannot estimate bolometric luminosity.")
×
NEW
1538
            return None
×
1539

1540
        # Compute L_bol (or L_boosted) for each epoch and propagate uncertainties.
1541
        L_bol = []
1✔
1542
        L_bol_err = []
1✔
1543
        L_bol_bb = []
1✔
1544
        L_bol_bb_err = []
1✔
1545
        for index, row in df_bb.iterrows():
1✔
1546
            temp = row['temperature']
1✔
1547
            radius = row['radius']
1✔
1548
            T_err = row['temp_err']
1✔
1549
            R_err = row['radius_err']
1✔
1550

1551
            # Use boosted luminosity if lambda_cut is provided.
1552
            if lambda_cut is not None:
1✔
1553
                lum, lum_bb = boosted_bolometric_luminosity(temp, radius, lambda_cut)
1✔
1554
            else:
1555
                lum = 4 * np.pi * (radius ** 2) * redback.constants.sigma_sb * (temp ** 4)
1✔
1556
                lum_bb = lum
1✔
1557

1558
            # Apply extinction correction to both luminosities.
1559
            lum *= extinction_factor
1✔
1560
            lum_bb *= extinction_factor
1✔
1561

1562
            # Propagate uncertainties using:
1563
            # (ΔL/L)² = (2 ΔR / R)² + (4 ΔT / T)².
1564
            rel_err = np.sqrt((2 * R_err / radius) ** 2 + (4 * T_err / temp) ** 2)
1✔
1565
            L_err = lum * rel_err
1✔
1566
            L_err_bb = lum_bb * rel_err
1✔
1567

1568
            L_bol.append(lum)
1✔
1569
            L_bol_bb.append(lum_bb)
1✔
1570
            L_bol_err.append(L_err)
1✔
1571
            L_bol_bb_err.append(L_err_bb)
1✔
1572

1573
        df_bol = df_bb.copy()
1✔
1574
        df_bol['lum_bol'] = np.array(L_bol) / 1e50
1✔
1575
        df_bol['lum_bol_err'] = np.array(L_bol_err) / 1e50
1✔
1576
        df_bol['lum_bol_bb'] = np.array(L_bol_bb) / 1e50
1✔
1577
        df_bol['lum_bol_bb_err'] = np.array(L_bol_bb_err) / 1e50
1✔
1578
        df_bol['time_rest_frame'] = df_bol['epoch_times'] / (1 + self.redshift)
1✔
1579

1580
        redback.utils.logger.info('Masking bolometric estimates with likely wrong extractions')
1✔
1581
        df_bol = df_bol[df_bol['lum_bol_err'] / df_bol['lum_bol'] < 1]
1✔
1582
        redback.utils.logger.info(
1✔
1583
            "Estimated bolometric luminosity using blackbody integration (with boost and extinction corrections if specified).")
1584
        return df_bol
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc