• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 18131194739

30 Sep 2025 01:16PM UTC coverage: 87.629% (-0.04%) from 87.665%
18131194739

Pull #293

github

web-flow
Merge 05c94ad74 into 8141796a4
Pull Request #293: Add a class method to load from LightCurveLynx output

47 of 49 new or added lines in 2 files covered. (95.92%)

15 existing lines in 2 files now uncovered.

14429 of 16466 relevant lines covered (87.63%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

78.15
/redback/transient/transient.py
1
from __future__ import annotations
1✔
2

3
from typing import Union
1✔
4

5
import matplotlib
1✔
6
import numpy as np
1✔
7
import pandas as pd
1✔
8

9
import redback
1✔
10
from redback.plotting import \
1✔
11
    LuminosityPlotter, FluxDensityPlotter, IntegratedFluxPlotter, MagnitudePlotter, \
12
    IntegratedFluxOpticalPlotter, SpectrumPlotter, LuminosityOpticalPlotter
13
from redback.model_library import all_models_dict
1✔
14
from collections import namedtuple
1✔
15

16
class Spectrum(object):
1✔
17
    def __init__(self, angstroms: np.ndarray, flux_density: np.ndarray, flux_density_err: np.ndarray,
1✔
18
                 time: str = None, name: str = '', **kwargs) -> None:
19
        """
20
        A class to store spectral data.
21

22
        :param angstroms: Wavelength in angstroms.
23
        :param flux_density: flux density in ergs/s/cm^2/angstrom.
24
        :param flux_density_err: flux density error in ergs/s/cm^2/angstrom.
25
        :param time: Time of the spectrum. Could be a phase or time since burst. Only used for plotting.
26
        :param name: Name of the spectrum.
27
        """
28

29
        self.angstroms = angstroms
1✔
30
        self.flux_density = flux_density
1✔
31
        self.flux_density_err = flux_density_err
1✔
32
        self.time = time
1✔
33
        self.name = name
1✔
34
        if self.time is None:
1✔
35
            self.plot_with_time_label = False
1✔
36
        else:
37
            self.plot_with_time_label = True
1✔
38
        self.directory_structure = redback.get_data.directory.spectrum_directory_structure(transient=name)
1✔
39
        self.data_mode = 'spectrum'
1✔
40

41
    @property
1✔
42
    def xlabel(self) -> str:
1✔
43
        """
44
        :return: xlabel used in plotting functions
45
        :rtype: str
46
        """
47
        return r'Wavelength [$\mathrm{\AA}$]'
1✔
48

49
    @property
1✔
50
    def ylabel(self) -> str:
1✔
51
        """
52
        :return: ylabel used in plotting functions
53
        :rtype: str
54
        """
55
        return r'Flux ($10^{-17}$ erg s$^{-1}$ cm$^{-2}$ $\mathrm{\AA}$)'
1✔
56

57
    def plot_data(self, axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None, save: bool = True,
1✔
58
            show: bool = True, color: str = 'k', **kwargs) -> matplotlib.axes.Axes:
59
        """Plots the Transient data and returns Axes.
60

61
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
62
        :param filename: Name of the file to be plotted in.
63
        :param outdir: The directory in which to save the file in.
64
        :param save: Whether to save the plot. (Default value = True)
65
        :param show: Whether to show the plot. (Default value = True)
66
        :param color: Color of the data.
67
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
68
        Available in the online documentation under at `redback.plotting.Plotter`.
69
        `print(Transient.plot_data.__doc__)` to see all options!
70
        :return: The axes with the plot.
71
        """
72

73
        plotter = SpectrumPlotter(spectrum=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
74
        return plotter.plot_data(axes=axes, save=save, show=show)
×
75

76
    def plot_spectrum(
1✔
77
            self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
78
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
79
            model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
80
        """
81
        :param model: The model used to plot the lightcurve.
82
        :param filename: The output filename. Otherwise, use default which starts with the name
83
                         attribute and ends with *lightcurve.png.
84
        :param axes: Axes to plot in if given.
85
        :param save:Whether to save the plot.
86
        :param show: Whether to show the plot.
87
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
88
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
89
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
90
        :param model_kwargs: Additional keyword arguments to be passed into the model.
91
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
92
        Available in the online documentation under at `redback.plotting.Plotter`.
93
        `print(Transient.plot_lightcurve.__doc__)` to see all options!
94
        :return: The axes.
95
        """
96
        plotter = SpectrumPlotter(
×
97
            spectrum=self, model=model, filename=filename, outdir=outdir,
98
            posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
99
        return plotter.plot_spectrum(axes=axes, save=save, show=show)
×
100

101
    def plot_residual(self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
1✔
102
                      save: bool = True, show: bool = True, posterior: pd.DataFrame = None,
103
                      model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
104
        """
105
        :param model: The model used to plot the lightcurve.
106
        :param filename: The output filename. Otherwise, use default which starts with the name
107
                         attribute and ends with *lightcurve.png.
108
        :param axes: Axes to plot in if given.
109
        :param save:Whether to save the plot.
110
        :param show: Whether to show the plot.
111
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
112
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
113
        :param model_kwargs: Additional keyword arguments to be passed into the model.
114
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
115
        Available in the online documentation under at `redback.plotting.Plotter`.
116
        `print(Transient.plot_residual.__doc__)` to see all options!
117
        :return: The axes.
118
        """
119
        plotter = SpectrumPlotter(
×
120
            spectrum=self, model=model, filename=filename, outdir=outdir,
121
            posterior=posterior, model_kwargs=model_kwargs, **kwargs)
122
        return plotter.plot_residuals(axes=axes, save=save, show=show)
×
123
    LuminosityPlotter, FluxDensityPlotter, IntegratedFluxPlotter, MagnitudePlotter, IntegratedFluxOpticalPlotter
1✔
124

125
class Transient(object):
1✔
126
    DATA_MODES = ['luminosity', 'flux', 'flux_density', 'magnitude', 'counts', 'ttes']
1✔
127
    _ATTRIBUTE_NAME_DICT = dict(luminosity="Lum50", flux="flux", flux_density="flux_density",
1✔
128
                                counts="counts", magnitude="magnitude")
129

130
    ylabel_dict = dict(luminosity=r'Luminosity [$10^{50}$ erg s$^{-1}$]',
1✔
131
                       magnitude=r'Magnitude',
132
                       flux=r'Flux [erg cm$^{-2}$ s$^{-1}$]',
133
                       flux_density=r'Flux density [mJy]',
134
                       counts=r'Counts')
135

136
    luminosity_data = redback.utils.DataModeSwitch('luminosity')
1✔
137
    flux_data = redback.utils.DataModeSwitch('flux')
1✔
138
    flux_density_data = redback.utils.DataModeSwitch('flux_density')
1✔
139
    magnitude_data = redback.utils.DataModeSwitch('magnitude')
1✔
140
    counts_data = redback.utils.DataModeSwitch('counts')
1✔
141
    tte_data = redback.utils.DataModeSwitch('ttes')
1✔
142

143
    def __init__(
1✔
144
            self, time: np.ndarray = None, time_err: np.ndarray = None, time_mjd: np.ndarray = None,
145
            time_mjd_err: np.ndarray = None, time_rest_frame: np.ndarray = None, time_rest_frame_err: np.ndarray = None,
146
            Lum50: np.ndarray = None, Lum50_err: np.ndarray = None, flux: np.ndarray = None,
147
            flux_err: np.ndarray = None, flux_density: np.ndarray = None, flux_density_err: np.ndarray = None,
148
            magnitude: np.ndarray = None, magnitude_err: np.ndarray = None, counts: np.ndarray = None,
149
            ttes: np.ndarray = None, bin_size: float = None, redshift: float = np.nan, data_mode: str = None,
150
            name: str = '', photon_index: float = np.nan, use_phase_model: bool = False,
151
            optical_data: bool = False, frequency: np.ndarray = None, system: np.ndarray = None, bands: np.ndarray = None,
152
            active_bands: Union[np.ndarray, str] = None, plotting_order: Union[np.ndarray, str] = None, **kwargs: None) -> None:
153
        """This is a general constructor for the Transient class. Note that you only need to give data corresponding to
154
        the data mode you are using. For luminosity data provide times in the rest frame, if using a phase model
155
        provide time in MJD, else use the default time (observer frame).
156

157
        :param time: Times in the observer frame.
158
        :type time: np.ndarray, optional
159
        :param time_err: Time errors in the observer frame.
160
        :type time_err: np.ndarray, optional
161
        :param time_mjd: Times in MJD. Used if using phase model.
162
        :type time_mjd: np.ndarray, optional
163
        :param time_mjd_err: Time errors in MJD. Used if using phase model.
164
        :type time_mjd_err: np.ndarray, optional
165
        :param time_rest_frame: Times in the rest frame. Used for luminosity data.
166
        :type time_rest_frame: np.ndarray, optional
167
        :param time_rest_frame_err: Time errors in the rest frame. Used for luminosity data.
168
        :type time_rest_frame_err: np.ndarray, optional
169
        :param Lum50: Luminosity values.
170
        :type Lum50: np.ndarray, optional
171
        :param Lum50_err: Luminosity error values.
172
        :type Lum50_err: np.ndarray, optional
173
        :param flux: Flux values.
174
        :type flux: np.ndarray, optional
175
        :param flux_err: Flux error values.
176
        :type flux_err: np.ndarray, optional
177
        :param flux_density: Flux density values.
178
        :type flux_density: np.ndarray, optional
179
        :param flux_density_err: Flux density error values.
180
        :type flux_density_err: np.ndarray, optional
181
        :param magnitude: Magnitude values for photometry data.
182
        :type magnitude: np.ndarray, optional
183
        :param magnitude_err: Magnitude error values for photometry data.
184
        :type magnitude_err: np.ndarray, optional
185
        :param counts: Counts for prompt data.
186
        :type counts: np.ndarray, optional
187
        :param ttes: Time-tagged events data for unbinned prompt data.
188
        :type ttes: np.ndarray, optional
189
        :param bin_size: Bin size for binning time-tagged event data.
190
        :type bin_size: float, optional
191
        :param redshift: Redshift value.
192
        :type redshift: float, optional
193
        :param data_mode: Data mode. Must be one from `Transient.DATA_MODES`.
194
        :type data_mode: str, optional
195
        :param name: Name of the transient.
196
        :type name: str, optional
197
        :param photon_index: Photon index value.
198
        :type photon_index: float, optional
199
        :param use_phase_model: Whether we are using a phase model.
200
        :type use_phase_model: bool, optional
201
        :param optical_data: Whether we are fitting optical data, useful for plotting.
202
        :type optical_data: bool, optional
203
        :param frequency: Array of band frequencies in photometry data.
204
        :type frequency: np.ndarray, optional
205
        :param system: System values.
206
        :type system: np.ndarray, optional
207
        :param bands: Band values.
208
        :type bands: np.ndarray, optional
209
        :param active_bands: List or array of active bands to be used in the analysis.
210
                             Use all available bands if 'all' is given.
211
        :type active_bands: Union[list, np.ndarray], optional
212
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
213
        :type plotting_order: Union[np.ndarray, str], optional
214
        :param kwargs: Additional callables:
215
                       bands_to_frequency: Conversion function to convert a list of bands to frequencies.
216
                                           Use redback.utils.bands_to_frequency if not given.
217
                       bin_ttes: Binning function for time-tagged event data.
218
                                 Use redback.utils.bands_to_frequency if not given.
219
        :type kwargs: None, optional
220
        """
221
        self.bin_size = bin_size
1✔
222
        self.bin_ttes = kwargs.get("bin_ttes", redback.utils.bin_ttes)
1✔
223
        self.bands_to_frequency = kwargs.get("bands_to_frequency", redback.utils.bands_to_frequency)
1✔
224

225
        if data_mode == 'ttes':
1✔
226
            time, counts = self.bin_ttes(ttes, self.bin_size)
1✔
227

228
        self.time = time
1✔
229
        self.time_err = time_err
1✔
230
        self.time_mjd = time_mjd
1✔
231
        self.time_mjd_err = time_mjd_err
1✔
232
        self.time_rest_frame = time_rest_frame
1✔
233
        self.time_rest_frame_err = time_rest_frame_err
1✔
234

235
        self.Lum50 = Lum50
1✔
236
        self.Lum50_err = Lum50_err
1✔
237
        self.flux = flux
1✔
238
        self.flux_err = flux_err
1✔
239
        self.flux_density = flux_density
1✔
240
        self.flux_density_err = flux_density_err
1✔
241
        self.magnitude = magnitude
1✔
242
        self.magnitude_err = magnitude_err
1✔
243
        self.counts = counts
1✔
244
        self.counts_err = np.sqrt(counts) if counts is not None else None
1✔
245
        self.ttes = ttes
1✔
246

247
        self._frequency = None
1✔
248
        self._bands = None
1✔
249
        self.set_bands_and_frequency(bands=bands, frequency=frequency)
1✔
250
        self.system = system
1✔
251
        self.data_mode = data_mode
1✔
252
        self.active_bands = active_bands
1✔
253
        self.sncosmo_bands = redback.utils.sncosmo_bandname_from_band(self.bands)
1✔
254
        self.redshift = redshift
1✔
255
        self.name = name
1✔
256
        self.use_phase_model = use_phase_model
1✔
257
        self.optical_data = optical_data
1✔
258
        self.plotting_order = plotting_order
1✔
259

260
        self.meta_data = None
1✔
261
        self.photon_index = photon_index
1✔
262
        self.directory_structure = redback.get_data.directory.DirectoryStructure(
1✔
263
            directory_path=".", raw_file_path=".", processed_file_path=".")
264

265
    @staticmethod
1✔
266
    def load_data_generic(processed_file_path, data_mode="magnitude"):
1✔
267
        """Loads data from specified directory and file, and returns it as a tuple.
268

269
        :param processed_file_path: Path to the processed file to load
270
        :type processed_file_path: str
271
        :param data_mode: Name of the data mode.
272
                          Must be from ['magnitude', 'flux_density', 'all']. Default is magnitude.
273
        :type data_mode: str, optional
274

275
        :return: Six elements when querying magnitude or flux_density data, Eight for 'all'.
276
        :rtype: tuple
277
        """
278
        DATA_MODES = ['luminosity', 'flux', 'flux_density', 'magnitude', 'counts', 'ttes', 'all']
1✔
279
        df = pd.read_csv(processed_file_path)
1✔
280
        time_days = np.array(df["time (days)"])
1✔
281
        time_mjd = np.array(df["time"])
1✔
282
        magnitude = np.array(df["magnitude"])
1✔
283
        magnitude_err = np.array(df["e_magnitude"])
1✔
284
        bands = np.array(df["band"])
1✔
285
        flux_density = np.array(df["flux_density(mjy)"])
1✔
286
        flux_density_err = np.array(df["flux_density_error"])
1✔
287
        if data_mode not in DATA_MODES:
1✔
288
            raise ValueError(f"Data mode {data_mode} not in {DATA_MODES}")
1✔
289
        if data_mode == "magnitude":
1✔
290
            return time_days, time_mjd, magnitude, magnitude_err, bands
1✔
291
        elif data_mode == "flux_density":
1✔
292
            return time_days, time_mjd, flux_density, flux_density_err, bands
1✔
293
        elif data_mode == "all":
1✔
294
            return time_days, time_mjd, flux_density, flux_density_err, magnitude, magnitude_err, bands
1✔
295

296
    @classmethod
1✔
297
    def from_lasair_data(
1✔
298
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
299
            use_phase_model: bool = False, plotting_order: Union[np.ndarray, str] = None) -> Transient:
300
        """Constructor method to built object from LASAIR data.
301

302
        :param name: Name of the transient.
303
        :type name: str
304
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
305
        :type data_mode: str, optional
306
        :param active_bands: Sets active bands based on array given.
307
                             If argument is 'all', all unique bands in `self.bands` will be used.
308
        :type active_bands: Union[np.ndarray, str]
309
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
310
        :type plotting_order: Union[np.ndarray, str], optional
311
        :param use_phase_model: Whether to use a phase model.
312
        :type use_phase_model: bool, optional
313

314
        :return: A class instance.
315
        :rtype: OpticalTransient
316
        """
317
        if cls.__name__ == "TDE":
1✔
318
            transient_type = "tidal_disruption_event"
×
319
        else:
320
            transient_type = cls.__name__.lower()
1✔
321
        directory_structure = redback.get_data.directory.lasair_directory_structure(
1✔
322
            transient=name, transient_type=transient_type)
323
        df = pd.read_csv(directory_structure.processed_file_path)
1✔
324
        time_days = np.array(df["time (days)"])
1✔
325
        time_mjd = np.array(df["time"])
1✔
326
        magnitude = np.array(df["magnitude"])
1✔
327
        magnitude_err = np.array(df["e_magnitude"])
1✔
328
        bands = np.array(df["band"])
1✔
329
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
330
        flux_err = np.array(df["flux_error"])
1✔
331
        flux_density = np.array(df["flux_density(mjy)"])
1✔
332
        flux_density_err = np.array(df["flux_density_error"])
1✔
333
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
334
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
335
                   magnitude_err=magnitude_err, flux=flux, flux_err=flux_err, bands=bands, active_bands=active_bands,
336
                   use_phase_model=use_phase_model, optical_data=True, plotting_order=plotting_order)
337

338
    @classmethod
1✔
339
    def from_simulated_optical_data(
1✔
340
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
341
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False) -> Transient:
342
        """Constructor method to built object from SimulatedOpticalTransient.
343

344
        :param name: Name of the transient.
345
        :type name: str
346
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
347
        :type data_mode: str, optional
348
        :param active_bands: Sets active bands based on array given.
349
                             If argument is 'all', all unique bands in `self.bands` will be used.
350
        :type active_bands: Union[np.ndarray, str]
351
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
352
        :type plotting_order: Union[np.ndarray, str], optional
353
        :param use_phase_model: Whether to use a phase model.
354
        :type use_phase_model: bool, optional
355

356
        :return: A class instance.
357
        :rtype: OpticalTransient
358
        """
359
        path = "simulated/" + name + ".csv"
1✔
360
        df = pd.read_csv(path)
1✔
361
        df = df[df.detected != 0]
1✔
362
        time_days = np.array(df["time (days)"])
1✔
363
        time_mjd = np.array(df["time"])
1✔
364
        magnitude = np.array(df["magnitude"])
1✔
365
        magnitude_err = np.array(df["e_magnitude"])
1✔
366
        bands = np.array(df["band"])
1✔
367
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
368
        flux_err = np.array(df["flux_error"])
1✔
369
        flux_density = np.array(df["flux_density(mjy)"])
1✔
370
        flux_density_err = np.array(df["flux_density_error"])
1✔
371
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
372
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
373
                   magnitude_err=magnitude_err, flux=flux, flux_err=flux_err, bands=bands, active_bands=active_bands,
374
                   use_phase_model=use_phase_model, optical_data=True, plotting_order=plotting_order)
375

376

377
    @classmethod
1✔
378
    def from_lightcurvelynx(
1✔
379
            cls, name: str, data: pd.DataFrame = None, data_mode: str = "magnitude",
380
            active_bands: Union[np.ndarray, str] = 'all', plotting_order: Union[np.ndarray, str] = None,
381
            use_phase_model: bool = False) -> Transient:
382
        """Constructor method to built object from a LightCurveLynx simulated light curve.
383
        https://github.com/lincc-frameworks/lightcurvelynx
384

385
        :param name: Name of the transient.
386
        :type name: str
387
        :param data: DataFrame containing the light curve data. If None, it will try to load from "simulated/{name}.csv".
388
        :type data: pd.DataFrame, optional
389
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
390
        :type data_mode: str, optional
391
        :param active_bands: Sets active bands based on array given.
392
                             If argument is 'all', all unique bands in `self.bands` will be used.
393
        :type active_bands: Union[np.ndarray, str]
394
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
395
        :type plotting_order: Union[np.ndarray, str], optional
396
        :param use_phase_model: Whether to use a phase model.
397
        :type use_phase_model: bool, optional
398

399
        :return: A class instance.
400
        :rtype: OpticalTransient
401
        """
402
        if data is None:
1✔
NEW
403
            path = "simulated/" + name + ".csv"
×
NEW
404
            data = pd.read_csv(path)
×
405

406
        # Filter out the non-detections.
407
        if "detected" in data.columns:
1✔
408
            data = data[data.detected != 0]
1✔
409

410
        # Process the time and bands data.
411
        time_mjd = data["mjd"].to_numpy()
1✔
412
        time_days = data["time_rel"].to_numpy() if "time_rel" in data.columns else None
1✔
413
        bands = data["filter"].to_numpy()
1✔
414

415
        # Process the magnitude data. Checking that we have the values if the data mode is magnitude.
416
        if "mag" in data.columns:
1✔
417
            magnitude = data["mag"].to_numpy()
1✔
418
            magnitude_err = data["magerr"].to_numpy()
1✔
419
        elif data_mode == "magnitude":
1✔
420
            raise ValueError("Magnitude data mode selected but no magnitude data found in the DataFrame.")
1✔
421
        else:
422
            magnitude = None
1✔
423
            magnitude_err = None
1✔
424

425
        # Handle the flux density data, converting from nanojanskys to milijanskys.
426
        flux_density = np.array(data["flux"]) / 1e6  # Convert from nJy to mJy
1✔
427
        flux_density_err = np.array(data["fluxerr"]) / 1e6  # Convert from nJy to mJy
1✔
428

429
        # We do not have the flux information.
430
        flux = None
1✔
431
        flux_err = None
1✔
432

433
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
434
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
435
                   magnitude_err=magnitude_err, flux=flux, flux_err=flux_err, bands=bands, active_bands=active_bands,
436
                   use_phase_model=use_phase_model, optical_data=True, plotting_order=plotting_order)
437

438

439
    @property
1✔
440
    def _time_attribute_name(self) -> str:
1✔
441
        if self.luminosity_data:
1✔
442
            return "time_rest_frame"
1✔
443
        elif self.use_phase_model:
1✔
444
            return "time_mjd"
1✔
445
        return "time"
1✔
446

447
    @property
1✔
448
    def _time_err_attribute_name(self) -> str:
1✔
449
        return self._time_attribute_name + "_err"
1✔
450

451
    @property
1✔
452
    def _y_attribute_name(self) -> str:
1✔
453
        return self._ATTRIBUTE_NAME_DICT[self.data_mode]
1✔
454

455
    @property
1✔
456
    def _y_err_attribute_name(self) -> str:
1✔
457
        return self._ATTRIBUTE_NAME_DICT[self.data_mode] + "_err"
1✔
458

459
    @property
1✔
460
    def x(self) -> np.ndarray:
1✔
461
        """
462
        :return: The time values given the active data mode.
463
        :rtype: np.ndarray
464
        """
465
        return getattr(self, self._time_attribute_name)
1✔
466

467
    @x.setter
1✔
468
    def x(self, x: np.ndarray) -> None:
1✔
469
        """Sets the time values for the active data mode.
470
        :param x: The desired time values.
471
        :type x: np.ndarray
472
        """
473
        setattr(self, self._time_attribute_name, x)
1✔
474

475
    @property
1✔
476
    def x_err(self) -> np.ndarray:
1✔
477
        """
478
        :return: The time error values given the active data mode.
479
        :rtype: np.ndarray
480
        """
481
        return getattr(self, self._time_err_attribute_name)
1✔
482

483
    @x_err.setter
1✔
484
    def x_err(self, x_err: np.ndarray) -> None:
1✔
485
        """Sets the time error values for the active data mode.
486
        :param x_err: The desired time error values.
487
        :type x_err: np.ndarray
488
        """
489
        setattr(self, self._time_err_attribute_name, x_err)
1✔
490

491
    @property
1✔
492
    def y(self) -> np.ndarray:
1✔
493
        """
494
        :return: The y values given the active data mode.
495
        :rtype: np.ndarray
496
        """
497

498
        return getattr(self, self._y_attribute_name)
1✔
499

500
    @y.setter
1✔
501
    def y(self, y: np.ndarray) -> None:
1✔
502
        """Sets the y values for the active data mode.
503
        :param y: The desired y values.
504
        :type y: np.ndarray
505
        """
506
        setattr(self, self._y_attribute_name, y)
1✔
507

508
    @property
1✔
509
    def y_err(self) -> np.ndarray:
1✔
510
        """
511
        :return: The y error values given the active data mode.
512
        :rtype: np.ndarray
513
        """
514
        return getattr(self, self._y_err_attribute_name)
1✔
515

516
    @y_err.setter
1✔
517
    def y_err(self, y_err: np.ndarray) -> None:
1✔
518
        """Sets the y error values for the active data mode.
519
        :param y_err: The desired y error values.
520
        :type y_err: np.ndarray
521
        """
522
        setattr(self, self._y_err_attribute_name, y_err)
1✔
523

524
    @property
1✔
525
    def data_mode(self) -> str:
1✔
526
        """
527
        :return: The currently active data mode (one in `Transient.DATA_MODES`).
528
        :rtype: str
529
        """
530
        return self._data_mode
1✔
531

532
    @data_mode.setter
1✔
533
    def data_mode(self, data_mode: str) -> None:
1✔
534
        """
535
        :param data_mode: One of the data modes in `Transient.DATA_MODES`.
536
        :type data_mode: str
537
        """
538
        if data_mode in self.DATA_MODES or data_mode is None:
1✔
539
            self._data_mode = data_mode
1✔
540
        else:
541
            raise ValueError("Unknown data mode.")
1✔
542

543
    @property
1✔
544
    def xlabel(self) -> str:
1✔
545
        """
546
        :return: xlabel used in plotting functions
547
        :rtype: str
548
        """
549
        if self.use_phase_model:
1✔
550
            return r"Time [MJD]"
1✔
551
        else:
552
            return r"Time since explosion [days]"
1✔
553

554
    @property
1✔
555
    def ylabel(self) -> str:
1✔
556
        """
557
        :return: ylabel used in plotting functions
558
        :rtype: str
559
        """
560
        try:
1✔
561
            return self.ylabel_dict[self.data_mode]
1✔
562
        except KeyError:
1✔
563
            raise ValueError("No data mode specified")
1✔
564

565
    def set_bands_and_frequency(
1✔
566
            self, bands: Union[None, list, np.ndarray], frequency: Union[None, list, np.ndarray]):
567
        """Sets bands and frequencies at the same time to keep the logic consistent. If both are given use those values.
568
        If only frequencies are given, use them also as band names.
569
        If only bands are given, try to convert them to frequencies.
570

571
        :param bands: The bands, e.g. ['g', 'i'].
572
        :type bands: Union[None, list, np.ndarray]
573
        :param frequency: The frequencies associated with the bands i.e., the effective frequency.
574
        :type frequency: Union[None, list, np.ndarray]
575
        """
576
        if (bands is None and frequency is None) or (bands is not None and frequency is not None):
1✔
577
            self._bands = bands
1✔
578
            self._frequency = frequency
1✔
579
        elif bands is None and frequency is not None:
1✔
580
            self._frequency = frequency
1✔
581
            self._bands = self.frequency
1✔
582
        elif bands is not None and frequency is None:
1✔
583
            self._bands = bands
1✔
584
            self._frequency = self.bands_to_frequency(self.bands)
1✔
585

586
    @property
1✔
587
    def frequency(self) -> np.ndarray:
1✔
588
        """
589
        :return: Used band frequencies
590
        :rtype: np.ndarray
591
        """
592
        return self._frequency
1✔
593

594
    @frequency.setter
1✔
595
    def frequency(self, frequency: np.ndarray) -> None:
1✔
596
        """
597
        :param frequency: Set band frequencies if an array is given. Otherwise, convert bands to frequencies.
598
        :type frequency: np.ndarray
599
        """
600
        self.set_bands_and_frequency(bands=self.bands, frequency=frequency)
1✔
601

602
    @property
1✔
603
    def bands(self) -> Union[list, None, np.ndarray]:
1✔
604
        return self._bands
1✔
605

606
    @bands.setter
1✔
607
    def bands(self, bands: Union[list, None, np.ndarray]):
1✔
608
        self.set_bands_and_frequency(bands=bands, frequency=self.frequency)
×
609

610
    @property
1✔
611
    def filtered_frequencies(self) -> np.array:
1✔
612
        """
613
        :return: The frequencies only associated with the active bands.
614
        :rtype: np.ndarray
615
        """
616
        return self.frequency[self.filtered_indices]
1✔
617

618
    @property
1✔
619
    def filtered_sncosmo_bands(self) -> np.array:
1✔
620
        """
621
        :return: The sncosmo bands only associated with the active bands.
622
        :rtype: np.ndarray
623
        """
624
        return self.sncosmo_bands[self.filtered_indices]
×
625

626
    @property
1✔
627
    def filtered_bands(self) -> np.array:
1✔
628
        """
629
        :return: The band names only associated with the active bands.
630
        :rtype: np.ndarray
631
        """
632
        return self.bands[self.filtered_indices]
×
633

634
    @property
1✔
635
    def active_bands(self) -> list:
1✔
636
        """
637
        :return: List of active bands used.
638
        :rtype list:
639
        """
640
        return self._active_bands
1✔
641

642
    @active_bands.setter
1✔
643
    def active_bands(self, active_bands: Union[list, str, None]) -> None:
1✔
644
        """
645
        :param active_bands: Sets active bands based on list given.
646
                             If argument is 'all', all unique bands in `self.bands` will be used.
647
        :type active_bands: Union[list, str]
648
        """
649
        if str(active_bands) == 'all':
1✔
650
            self._active_bands = list(np.unique(self.bands))
1✔
651
        else:
652
            self._active_bands = active_bands
1✔
653

654
    @property
1✔
655
    def filtered_indices(self) -> Union[list, None]:
1✔
656
        """
657
        :return: The list indices in `bands` associated with the active bands.
658
        :rtype: Union[list, None]
659
        """
660
        if self.bands is None:
1✔
661
            return list(np.arange(len(self.x)))
×
662
        return [b in self.active_bands for b in self.bands]
1✔
663

664
    def get_filtered_data(self) -> tuple:
1✔
665
        """Used to filter flux density, photometry or integrated flux data, so we only use data that is using the active bands.
666
        :return: A tuple with the filtered data. Format is (x, x_err, y, y_err)
667
        :rtype: tuple
668
        """
669
        if any([self.flux_data, self.magnitude_data, self.flux_density_data]):
1✔
670
            filtered_x = self.x[self.filtered_indices]
1✔
671
            try:
1✔
672
                filtered_x_err = self.x_err[self.filtered_indices]
1✔
673
            except (IndexError, TypeError):
1✔
674
                filtered_x_err = None
1✔
675
            filtered_y = self.y[self.filtered_indices]
1✔
676
            filtered_y_err = self.y_err[self.filtered_indices]
1✔
677
            return filtered_x, filtered_x_err, filtered_y, filtered_y_err
1✔
678
        else:
679
            raise ValueError(f"Transient needs to be in flux density, magnitude or flux data mode, "
1✔
680
                             f"but is in {self.data_mode} instead.")
681

682
    @property
1✔
683
    def unique_bands(self) -> np.ndarray:
1✔
684
        """
685
        :return: All bands that we get from the data, eliminating all duplicates.
686
        :rtype: np.ndarray
687
        """
688
        if self.plotting_order is not None:
1✔
689
            return self.plotting_order
×
690
        else:
691
            return np.unique(self.bands)
1✔
692

693
    @property
1✔
694
    def unique_frequencies(self) -> np.ndarray:
1✔
695
        """
696
        :return: All frequencies that we get from the data, eliminating all duplicates.
697
        :rtype: np.ndarray
698
        """
699
        try:
1✔
700
            if isinstance(self.unique_bands[0], (float, int)):
1✔
701
                return self.unique_bands
×
702
        except (TypeError, IndexError):
×
703
            pass
×
704
        return self.bands_to_frequency(self.unique_bands)
1✔
705

706
    @property
1✔
707
    def list_of_band_indices(self) -> list:
1✔
708
        """
709
        :return: Indices that map between bands in the data and the unique bands we obtain.
710
        :rtype: list
711
        """
712
        return [np.where(self.bands == np.array(b))[0] for b in self.unique_bands]
1✔
713

714
    @property
1✔
715
    def default_filters(self) -> list:
1✔
716
        """
717
        :return: Default list of filters to use.
718
        :rtype: list
719
        """
720
        return ["g", "r", "i", "z", "y", "J", "H", "K"]
1✔
721

722
    @staticmethod
1✔
723
    def get_colors(filters: Union[np.ndarray, list]) -> matplotlib.colors.Colormap:
1✔
724
        """
725
        :param filters: Array of list of filters to use in the plot.
726
        :type filters: Union[np.ndarray, list]
727
        :return: Colormap with one color for each filter.
728
        :rtype: matplotlib.colors.Colormap
729
        """
730
        return matplotlib.cm.rainbow(np.linspace(0, 1, len(filters)))
1✔
731

732
    def plot_data(self, axes: matplotlib.axes.Axes = None, filename: str = None, outdir: str = None, save: bool = True,
1✔
733
            show: bool = True, plot_others: bool = True, color: str = 'k', **kwargs) -> matplotlib.axes.Axes:
734
        """Plots the Transient data and returns Axes.
735

736
        :param axes: Matplotlib axes to plot the lightcurve into. Useful for user specific modifications to the plot.
737
        :param filename: Name of the file to be plotted in.
738
        :param outdir: The directory in which to save the file in.
739
        :param save: Whether to save the plot. (Default value = True)
740
        :param show: Whether to show the plot. (Default value = True)
741
        :param plot_others: Whether to plot inactive bands. (Default value = True)
742
        :param color: Color of the data.
743
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
744
        Available in the online documentation under at `redback.plotting.Plotter`.
745
        `print(Transient.plot_data.__doc__)` to see all options!
746
        :return: The axes with the plot.
747
        """
748

749
        if self.flux_data:
×
750
            if self.optical_data:
×
751
                plotter = IntegratedFluxOpticalPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
752
                                       plot_others=plot_others, **kwargs)
753
            else:
754
                plotter = IntegratedFluxPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
755
        elif self.luminosity_data:
×
756
            if self.optical_data:
×
757
                plotter = LuminosityOpticalPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
758
                                                   **kwargs)
759
            else:
760
                plotter = LuminosityPlotter(transient=self, color=color, filename=filename, outdir=outdir, **kwargs)
×
761
        elif self.flux_density_data:
×
762
            plotter = FluxDensityPlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
763
                                         plot_others=plot_others, **kwargs)
764
        elif self.magnitude_data:
×
765
            plotter = MagnitudePlotter(transient=self, color=color, filename=filename, outdir=outdir,
×
766
                                       plot_others=plot_others, **kwargs)
767
        else:
768
            return axes
×
769
        return plotter.plot_data(axes=axes, save=save, show=show)
×
770

771
    def plot_multiband(
1✔
772
            self, figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None, filename: str = None,
773
            outdir: str = None, ncols: int = 2, save: bool = True, show: bool = True,
774
            nrows: int = None, figsize: tuple = None, filters: list = None, **kwargs: None) \
775
            -> matplotlib.axes.Axes:
776
        """
777
        :param figure: Figure can be given if defaults are not satisfying.
778
        :param axes: Axes can be given if defaults are not satisfying.
779
        :param filename: Name of the file to be plotted in.
780
        :param outdir: The directory in which to save the file in.
781
        :param save: Whether to save the plot. (Default value = True)
782
        :param show: Whether to show the plot. (Default value = True)
783
        :param ncols: Number of columns to use on the plot. Default is 2.
784
        :param nrows: Number of rows to use on the plot. If None are given this will
785
                      be inferred from ncols and the number of filters.
786
        :param figsize: Size of the figure. A default based on ncols and nrows will be used if None is given.
787
        :param filters: Which bands to plot. Will use default filters if None is given.
788
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
789
        Available in the online documentation under at `redback.plotting.Plotter`.
790
        `print(Transient.plot_multiband.__doc__)` to see all options!
791
        :return: The axes.
792
        """
793
        if self.data_mode not in ['flux_density', 'magnitude', 'flux']:
×
794
            raise ValueError(
×
795
                f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
796
        if self.magnitude_data:
×
797
            plotter = MagnitudePlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows,
×
798
                                       ncols=ncols, figsize=figsize, **kwargs)
799
        elif self.flux_density_data:
×
800
            plotter = FluxDensityPlotter(transient=self, filters=filters, filename=filename, outdir=outdir, nrows=nrows,
×
801
                                         ncols=ncols, figsize=figsize, **kwargs)
802
        elif self.flux_data:
×
803
            plotter = IntegratedFluxOpticalPlotter(transient=self, filters=filters, filename=filename, outdir=outdir,
×
804
                                                   nrows=nrows, ncols=ncols, figsize=figsize, **kwargs)
805
        else:
806
            return
×
807
        return plotter.plot_multiband(figure=figure, axes=axes, save=save, show=show)
×
808

809
    def plot_lightcurve(
1✔
810
            self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
811
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
812
            model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
813
        """
814
        :param model: The model used to plot the lightcurve.
815
        :param filename: The output filename. Otherwise, use default which starts with the name
816
                         attribute and ends with *lightcurve.png.
817
        :param axes: Axes to plot in if given.
818
        :param save:Whether to save the plot.
819
        :param show: Whether to show the plot.
820
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
821
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
822
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
823
        :param model_kwargs: Additional keyword arguments to be passed into the model.
824
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
825
        Available in the online documentation under at `redback.plotting.Plotter`.
826
        `print(Transient.plot_lightcurve.__doc__)` to see all options!
827
        :return: The axes.
828
        """
829
        if self.flux_data:
1✔
830
            if self.optical_data:
1✔
831
                plotter = IntegratedFluxOpticalPlotter(
1✔
832
                    transient=self, model=model, filename=filename, outdir=outdir,
833
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
834
            else:
835
                plotter = IntegratedFluxPlotter(
×
836
                    transient=self, model=model, filename=filename, outdir=outdir,
837
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
838
        elif self.luminosity_data:
1✔
839
            if self.optical_data:
×
840
                plotter = LuminosityOpticalPlotter(transient=self, model=model, filename=filename, outdir=outdir,
×
841
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
842
            else:
843
                plotter = LuminosityPlotter(
×
844
                    transient=self, model=model, filename=filename, outdir=outdir,
845
                    posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
846
        elif self.flux_density_data:
1✔
847
            plotter = FluxDensityPlotter(
1✔
848
                transient=self, model=model, filename=filename, outdir=outdir,
849
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
850
        elif self.magnitude_data:
1✔
851
            plotter = MagnitudePlotter(
1✔
852
                transient=self, model=model, filename=filename, outdir=outdir,
853
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
854
        else:
855
            return axes
×
856
        return plotter.plot_lightcurve(axes=axes, save=save, show=show)
1✔
857

858
    def plot_residual(self, model: callable, filename: str = None, outdir: str = None, axes: matplotlib.axes.Axes = None,
1✔
859
                      save: bool = True, show: bool = True, posterior: pd.DataFrame = None,
860
                      model_kwargs: dict = None, **kwargs: None) -> matplotlib.axes.Axes:
861
        """
862
        :param model: The model used to plot the lightcurve.
863
        :param filename: The output filename. Otherwise, use default which starts with the name
864
                         attribute and ends with *lightcurve.png.
865
        :param axes: Axes to plot in if given.
866
        :param save:Whether to save the plot.
867
        :param show: Whether to show the plot.
868
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
869
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
870
        :param model_kwargs: Additional keyword arguments to be passed into the model.
871
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
872
        Available in the online documentation under at `redback.plotting.Plotter`.
873
        `print(Transient.plot_residual.__doc__)` to see all options!
874
        :return: The axes.
875
        """
876
        if self.flux_data:
×
877
            plotter = IntegratedFluxPlotter(
×
878
                transient=self, model=model, filename=filename, outdir=outdir,
879
                posterior=posterior, model_kwargs=model_kwargs, **kwargs)
880
        elif self.luminosity_data:
×
881
            if self.optical_data:
×
882
                plotter = LuminosityOpticalPlotter(
×
883
                    transient=self, model=model, filename=filename, outdir=outdir,
884
                    posterior=posterior, model_kwargs=model_kwargs, **kwargs)
885
            else:
886
                plotter = LuminosityPlotter(
×
887
                    transient=self, model=model, filename=filename, outdir=outdir,
888
                    posterior=posterior, model_kwargs=model_kwargs, **kwargs)
889
        else:
890
            raise ValueError("Residual plotting not implemented for this data mode")
×
891
        return plotter.plot_residuals(axes=axes, save=save, show=show)
×
892

893

894
    def fit_gp(self, mean_model, kernel, prior=None, use_frequency=True):
1✔
895
        """
896
        Fit a GP to the data using george and scipy minimization.
897

898
        :param mean_model: Mean model to use in the GP fit. Can be a string to refer to a redback model, a callable, or None
899
        :param kernel: George GP to use. User must ensure this is set up correctly.
900
        :param prior: Prior to use when fitting with a mean model.
901
        :param use_frequency: Whether to use the effective frequency in a 2D GP fit. Cannot be used with most mean models.
902
        :return: Named tuple with George GP object and additional useful data.
903
        """
904
        try:
1✔
905
            import george
1✔
906
            import george.kernels as kernels
1✔
907
        except ImportError:
×
908
            redback.utils.logger.warning("George must be installed to use GP fitting.")
×
909
        import scipy.optimize as op
1✔
910
        from bilby.core.likelihood import function_to_george_mean_model
1✔
911

912
        output = namedtuple("gp_out", ["gp", "scaled_y", "y_scaler", 'use_frequency', 'mean_model'])
1✔
913
        output.use_frequency = use_frequency
1✔
914
        output.mean_model = mean_model
1✔
915

916
        if self.data_mode == 'luminosity':
1✔
917
            x = self.time_rest_frame
1✔
918
            y = self.y
1✔
919
            try:
1✔
920
                y_err = np.max(self.y_err, axis=0)
1✔
921
            except IndexError:
×
922
                y_err = self.y_err
×
923
        else:
924
            x, x_err, y, y_err = self.get_filtered_data()
×
925
        redback.utils.logger.info("Rescaling data for GP fitting.")
1✔
926
        gp_y_err = y_err / np.max(y)
1✔
927
        gp_y = y / np.max(y)
1✔
928
        output.scaled_y = gp_y
1✔
929
        output.y_scaler = np.max(y)
1✔
930

931
        def nll(p):
1✔
932
            gp.set_parameter_vector(p)
×
933
            ll = gp.log_likelihood(gp_y, quiet=True)
×
934
            return -ll if np.isfinite(ll) else 1e25
×
935

936
        def grad_nll(p):
1✔
937
            gp.set_parameter_vector(p)
×
938
            return -gp.grad_log_likelihood(gp_y, quiet=True)
×
939

940
        if use_frequency:
1✔
941
            redback.utils.logger.info("Using frequencies and time in the GP fit.")
1✔
942
            redback.utils.logger.info("Kernel used: " + str(kernel))
1✔
943
            redback.utils.logger.info("Ensure that the kernel is set up correctly for 2D GP.")
1✔
944
            redback.utils.logger.info("You will be returned a single GP object with frequency as a parameter")
1✔
945
            freqs = self.filtered_frequencies
1✔
946
            X = np.column_stack((freqs, x))
1✔
947
        else:
948
            redback.utils.logger.info("Using time in GP fit.")
1✔
949
            redback.utils.logger.info("Kernel used: " + str(kernel))
1✔
950
            redback.utils.logger.info("Ensure that the kernel is set up correctly for 1D GP.")
1✔
951
            redback.utils.logger.info("You will be returned a GP object unique to a band/frequency"
1✔
952
                                      " in the data if working with multiband data")
953
            X = x
1✔
954

955
        if mean_model is None:
1✔
956
            redback.utils.logger.info("Mean model not given, fitting GP with no mean model.")
1✔
957
            gp = george.GP(kernel)
1✔
958
            gp.compute(X, gp_y_err + 1e-8)
1✔
959
            p0 = gp.get_parameter_vector()
1✔
960
            results = op.minimize(nll, p0, jac=grad_nll)
1✔
961
            gp.set_parameter_vector(results.x)
1✔
962
            redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y)}")
1✔
963
            redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
1✔
964
            output.gp = gp
1✔
965
        else:
966
            if isinstance(mean_model, str):
1✔
967
                mean_model_func = all_models_dict[mean_model]
×
968
                redback.utils.logger.info("Using inbuilt redback function {} as a mean model.".format(mean_model))
×
969
                if prior is None:
×
970
                    redback.utils.logger.warning("No prior given for mean model. Using default prior.")
×
971
                    prior = redback.priors.get_priors(mean_model)
×
972
            else:
973
                mean_model_func = mean_model
1✔
974
                redback.utils.logger.info("Using user-defined python function as a mean model.")
1✔
975

976
            if prior is None:
1✔
977
                redback.utils.logger.warning("Prior must be specified for GP fit with a mean model")
1✔
978
                raise ValueError("No prior specified")
1✔
979

980
            if self.data_mode in ['flux_density', 'magnitude', 'flux']:
1✔
981
                redback.utils.logger.info("Setting up GP version of mean model.")
×
982
                gp_dict = {}
×
983
                scaled_y_dict = {}
×
984
                for ii in range(len(self.unique_bands)):
×
985
                    scaled_y_dict[self.unique_bands[ii]] = gp_y[self.list_of_band_indices[ii]]
×
986
                    redback.utils.logger.info("Fitting for band {}".format(self.unique_bands[ii]))
×
987
                    gp_x = X[self.list_of_band_indices[ii]]
×
988

989
                    def nll(p):
×
990
                        gp.set_parameter_vector(p)
×
991
                        ll = gp.log_likelihood(gp_y[self.list_of_band_indices[ii]], quiet=True)
×
992
                        return -ll if np.isfinite(ll) else 1e25
×
993

994
                    mean_model_class = function_to_george_mean_model(mean_model_func)
×
995
                    mm = mean_model_class(**prior.sample())
×
996
                    gp = george.GP(kernel, mean=mm, fit_mean=True)
×
997
                    gp.compute(gp_x, gp_y_err[self.list_of_band_indices[ii]] + 1e-8)
×
998
                    p0 = gp.get_parameter_vector()
×
999
                    results = op.minimize(nll, p0)
×
1000
                    gp.set_parameter_vector(results.x)
×
1001
                    redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y[self.list_of_band_indices[ii]])}")
×
1002
                    redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
×
1003
                    gp_dict[self.unique_bands[ii]] = gp
×
1004
                    del gp
×
1005
                output.gp = gp_dict
×
1006
                output.scaled_y = scaled_y_dict
×
1007
            else:
1008
                mean_model_class = function_to_george_mean_model(mean_model_func)
1✔
1009
                mm = mean_model_class(**prior.sample())
1✔
1010
                gp = george.GP(kernel, mean=mm, fit_mean=True)
1✔
1011
                gp.compute(X, gp_y_err + 1e-8)
1✔
1012
                p0 = gp.get_parameter_vector()
1✔
1013
                results = op.minimize(nll, p0)
1✔
1014
                gp.set_parameter_vector(results.x)
1✔
1015
                redback.utils.logger.info(f"GP final loglikelihood: {gp.log_likelihood(gp_y)}")
1✔
1016
                redback.utils.logger.info(f"GP final parameters: {gp.get_parameter_dict()}")
1✔
1017
                output.gp = gp
1✔
1018
        return output
1✔
1019

1020
    def plot_multiband_lightcurve(
1✔
1021
            self, model: callable, filename: str = None, outdir: str = None,
1022
            figure: matplotlib.figure.Figure = None, axes: matplotlib.axes.Axes = None,
1023
            save: bool = True, show: bool = True, random_models: int = 100, posterior: pd.DataFrame = None,
1024
            model_kwargs: dict = None, **kwargs: object) -> matplotlib.axes.Axes:
1025
        """
1026
        :param model: The model used to plot the lightcurve.
1027
        :param filename: The output filename. Otherwise, use default which starts with the name
1028
                         attribute and ends with *lightcurve.png.
1029
        :param figure: Figure can be given if defaults are not satisfying.
1030
        :param axes: Axes to plot in if given.
1031
        :param save:Whether to save the plot.
1032
        :param show: Whether to show the plot.
1033
        :param random_models: Number of random posterior samples plotted faintly. (Default value = 100)
1034
        :param posterior: Posterior distribution to which to draw samples from. Is optional but must be given.
1035
        :param outdir: Out directory in which to save the plot. Default is the current working directory.
1036
        :param model_kwargs: Additional keyword arguments to be passed into the model.
1037
        :param kwargs: Additional keyword arguments to pass in the Plotter methods.
1038
        Available in the online documentation under at `redback.plotting.Plotter`.
1039
        `print(Transient.plot_multiband_lightcurve.__doc__)` to see all options!
1040

1041
        :return: The axes.
1042
        """
1043
        if self.data_mode not in ['flux_density', 'magnitude', 'flux']:
1✔
1044
            raise ValueError(
×
1045
                f'You cannot plot multiband data with {self.data_mode} data mode . Why are you doing this?')
1046
        if self.magnitude_data:
1✔
1047
            plotter = MagnitudePlotter(
1✔
1048
                transient=self, model=model, filename=filename, outdir=outdir,
1049
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
1050
        elif self.flux_data:
1✔
1051
            plotter = IntegratedFluxOpticalPlotter(transient=self, model=model, filename=filename, outdir=outdir,
1✔
1052
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
1053
        elif self.flux_density_data:
1✔
1054
            plotter = FluxDensityPlotter(
1✔
1055
                transient=self, model=model, filename=filename, outdir=outdir,
1056
                posterior=posterior, model_kwargs=model_kwargs, random_models=random_models, **kwargs)
1057
        else:
1058
            return
×
1059
        return plotter.plot_multiband_lightcurve(figure=figure, axes=axes, save=save, show=show)
1✔
1060

1061
    _formatted_kwargs_options = redback.plotting.Plotter.keyword_docstring
1✔
1062
    plot_data.__doc__ = plot_data.__doc__.replace(
1✔
1063
        "`print(Transient.plot_data.__doc__)` to see all options!", _formatted_kwargs_options)
1064
    plot_multiband.__doc__ = plot_multiband.__doc__.replace(
1✔
1065
        "`print(Transient.plot_multiband.__doc__)` to see all options!", _formatted_kwargs_options)
1066
    plot_lightcurve.__doc__ = plot_lightcurve.__doc__.replace(
1✔
1067
        "`print(Transient.plot_lightcurve.__doc__)` to see all options!", _formatted_kwargs_options)
1068
    plot_multiband_lightcurve.__doc__ = plot_multiband_lightcurve.__doc__.replace(
1✔
1069
        "`print(Transient.plot_multiband_lightcurve.__doc__)` to see all options!", _formatted_kwargs_options)
1070
    plot_residual.__doc__ = plot_residual.__doc__.replace(
1✔
1071
        "`print(Transient.plot_residual.__doc__)` to see all options!", _formatted_kwargs_options)
1072

1073

1074
class OpticalTransient(Transient):
1✔
1075
    DATA_MODES = ['flux', 'flux_density', 'magnitude', 'luminosity']
1✔
1076

1077
    @staticmethod
1✔
1078
    def load_data(processed_file_path, data_mode="magnitude"):
1✔
1079
        """Loads data from specified directory and file, and returns it as a tuple.
1080

1081
        :param processed_file_path: Path to the processed file to load
1082
        :type processed_file_path: str
1083
        :param data_mode: Name of the data mode.
1084
                          Must be from ['magnitude', 'flux_density', 'all']. Default is magnitude.
1085
        :type data_mode: str, optional
1086

1087
        :return: Six elements when querying magnitude or flux_density data, Eight for 'all'
1088
        :rtype: tuple
1089
        """
1090
        df = pd.read_csv(processed_file_path)
1✔
1091
        time_days = np.array(df["time (days)"])
1✔
1092
        time_mjd = np.array(df["time"])
1✔
1093
        magnitude = np.array(df["magnitude"])
1✔
1094
        magnitude_err = np.array(df["e_magnitude"])
1✔
1095
        bands = np.array(df["band"])
1✔
1096
        system = np.array(df["system"])
1✔
1097
        flux_density = np.array(df["flux_density(mjy)"])
1✔
1098
        flux_density_err = np.array(df["flux_density_error"])
1✔
1099
        flux = np.array(df["flux(erg/cm2/s)"])
1✔
1100
        flux_err = np.array(df['flux_error'])
1✔
1101
        if data_mode == "magnitude":
1✔
1102
            return time_days, time_mjd, magnitude, magnitude_err, bands, system
1✔
1103
        elif data_mode == "flux_density":
1✔
1104
            return time_days, time_mjd, flux_density, flux_density_err, bands, system
1✔
1105
        elif data_mode == "flux":
1✔
1106
            return time_days, time_mjd, flux, flux_err, bands, system
×
1107
        elif data_mode == "all":
1✔
1108
            return time_days, time_mjd, flux_density, flux_density_err, \
1✔
1109
                   magnitude, magnitude_err, flux, flux_err, bands, system
1110

1111
    def __init__(
1✔
1112
            self, name: str, data_mode: str = 'magnitude', time: np.ndarray = None, time_err: np.ndarray = None,
1113
            time_mjd: np.ndarray = None, time_mjd_err: np.ndarray = None, time_rest_frame: np.ndarray = None,
1114
            time_rest_frame_err: np.ndarray = None, Lum50: np.ndarray = None, Lum50_err: np.ndarray = None,
1115
            flux: np.ndarray = None, flux_err: np.ndarray = None, flux_density: np.ndarray = None,
1116
            flux_density_err: np.ndarray = None, magnitude: np.ndarray = None, magnitude_err: np.ndarray = None,
1117
            redshift: float = np.nan, photon_index: float = np.nan, frequency: np.ndarray = None,
1118
            bands: np.ndarray = None, system: np.ndarray = None, active_bands: Union[np.ndarray, str] = 'all',
1119
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False,
1120
            optical_data:bool = True, **kwargs: None) -> None:
1121
        """This is a general constructor for the Transient class. Note that you only need to give data corresponding to
1122
        the data mode you are using. For luminosity data provide times in the rest frame, if using a phase model
1123
        provide time in MJD, else use the default time (observer frame).
1124

1125
        :param name: Name of the transient.
1126
        :type name: str
1127
        :param data_mode: Data mode. Must be one from `OpticalTransient.DATA_MODES`.
1128
        :type data_mode: str, optional
1129
        :param time: Times in the observer frame.
1130
        :type time: np.ndarray, optional
1131
        :param time_err: Time errors in the observer frame.
1132
        :type time_err: np.ndarray, optional
1133
        :param time_mjd: Times in MJD. Used if using phase model.
1134
        :type time_mjd: np.ndarray, optional
1135
        :param time_mjd_err: Time errors in MJD. Used if using phase model.
1136
        :type time_mjd_err: np.ndarray, optional
1137
        :param time_rest_frame: Times in the rest frame. Used for luminosity data.
1138
        :type time_rest_frame: np.ndarray, optional
1139
        :param time_rest_frame_err: Time errors in the rest frame. Used for luminosity data.
1140
        :type time_rest_frame_err: np.ndarray, optional
1141
        :param Lum50: Luminosity values.
1142
        :type Lum50: np.ndarray, optional
1143
        :param Lum50_err: Luminosity error values.
1144
        :type Lum50_err: np.ndarray, optional
1145
        :param flux: Flux values.
1146
        :type flux: np.ndarray, optional
1147
        :param flux_err: Flux error values.
1148
        :type flux_err: np.ndarray, optional
1149
        :param flux_density: Flux density values.
1150
        :type flux_density: np.ndarray, optional
1151
        :param flux_density_err: Flux density error values.
1152
        :type flux_density_err: np.ndarray, optional
1153
        :param magnitude: Magnitude values for photometry data.
1154
        :type magnitude: np.ndarray, optional
1155
        :param magnitude_err: Magnitude error values for photometry data.
1156
        :type magnitude_err: np.ndarray, optional
1157
        :param redshift: Redshift value.
1158
        :type redshift: float, optional
1159
        :param photon_index: Photon index value.
1160
        :type photon_index: float, optional
1161
        :param frequency: Array of band frequencies in photometry data.
1162
        :type frequency: np.ndarray, optional
1163
        :param bands: Band values.
1164
        :type bands: np.ndarray, optional
1165
        :param system: System values.
1166
        :type system: np.ndarray, optional
1167
        :param active_bands: List or array of active bands to be used in the analysis.
1168
                             Use all available bands if 'all' is given.
1169
        :type active_bands: Union[list, np.ndarray], optional
1170
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
1171
        :type plotting_order: Union[np.ndarray, str], optional
1172
        :param use_phase_model: Whether we are using a phase model.
1173
        :type use_phase_model: bool, optional
1174
        :param optical_data: Whether we are fitting optical data, useful for plotting.
1175
        :type optical_data: bool, optional
1176
        :param kwargs:
1177
            Additional callables:
1178
            bands_to_frequency: Conversion function to convert a list of bands to frequencies. Use
1179
                                  redback.utils.bands_to_frequency if not given.
1180
        :type kwargs: dict, optional
1181
        """
1182
        super().__init__(time=time, time_err=time_err, time_rest_frame=time_rest_frame, time_mjd=time_mjd,
1✔
1183
                         time_mjd_err=time_mjd_err, frequency=frequency,
1184
                         time_rest_frame_err=time_rest_frame_err, Lum50=Lum50, Lum50_err=Lum50_err,
1185
                         flux=flux, flux_err=flux_err, redshift=redshift, photon_index=photon_index,
1186
                         flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
1187
                         magnitude_err=magnitude_err, data_mode=data_mode, name=name,
1188
                         use_phase_model=use_phase_model, optical_data=optical_data,
1189
                         system=system, bands=bands, plotting_order=plotting_order,
1190
                         active_bands=active_bands, **kwargs)
1191
        self.directory_structure = redback.get_data.directory.DirectoryStructure(
1✔
1192
            directory_path=".", raw_file_path=".", processed_file_path=".")
1193

1194
    @classmethod
1✔
1195
    def from_open_access_catalogue(
1✔
1196
            cls, name: str, data_mode: str = "magnitude", active_bands: Union[np.ndarray, str] = 'all',
1197
            plotting_order: Union[np.ndarray, str] = None, use_phase_model: bool = False) -> OpticalTransient:
1198
        """Constructor method to built object from Open Access Catalogue
1199

1200
        :param name: Name of the transient.
1201
        :type name: str
1202
        :param data_mode: Data mode used. Must be from `OpticalTransient.DATA_MODES`. Default is magnitude.
1203
        :type data_mode: str, optional
1204
        :param active_bands:
1205
            Sets active bands based on array given.
1206
            If argument is 'all', all unique bands in `self.bands` will be used.
1207
        :type active_bands: Union[np.ndarray, str]
1208
        :param plotting_order: Order in which to plot the bands/and how unique bands are stored.
1209
        :type plotting_order: Union[np.ndarray, str], optional
1210
        :param use_phase_model: Whether to use a phase model.
1211
        :type use_phase_model: bool, optional
1212

1213
        :return: A class instance
1214
        :rtype: OpticalTransient
1215
        """
1216
        if cls.__name__ == "TDE":
1✔
1217
            transient_type = "tidal_disruption_event"
×
1218
        else:
1219
            transient_type = cls.__name__.lower()
1✔
1220
        directory_structure = redback.get_data.directory.open_access_directory_structure(
1✔
1221
            transient=name, transient_type=transient_type)
1222
        time_days, time_mjd, flux_density, flux_density_err, magnitude, magnitude_err, flux, flux_err, bands, system = \
1✔
1223
            cls.load_data(processed_file_path=directory_structure.processed_file_path, data_mode="all")
1224
        return cls(name=name, data_mode=data_mode, time=time_days, time_err=None, time_mjd=time_mjd,
1✔
1225
                   flux_density=flux_density, flux_density_err=flux_density_err, magnitude=magnitude,
1226
                   magnitude_err=magnitude_err, bands=bands, system=system, active_bands=active_bands,
1227
                   use_phase_model=use_phase_model, optical_data=True, flux=flux, flux_err=flux_err,
1228
                   plotting_order=plotting_order)
1229

1230
    @property
1✔
1231
    def event_table(self) -> str:
1✔
1232
        """
1233
        :return: Path to the metadata table.
1234
        :rtype: str
1235
        """
1236
        return f"{self.directory_structure.directory_path}/{self.name}_metadata.csv"
1✔
1237

1238
    def _set_data(self) -> None:
1✔
1239
        """Sets the metadata from the event table."""
1240
        try:
1✔
1241
            meta_data = pd.read_csv(self.event_table, on_bad_lines='skip', delimiter=',', dtype='str')
1✔
1242
        except FileNotFoundError as e:
1✔
1243
            redback.utils.logger.warning(e)
1✔
1244
            redback.utils.logger.warning("Setting metadata to None. This is not an error, but a warning that no metadata could be found online.")
1✔
1245
            meta_data = None
1✔
1246
        self.meta_data = meta_data
1✔
1247

1248
    @property
1✔
1249
    def transient_dir(self) -> str:
1✔
1250
        """
1251
        :return: The transient directory given the name of the transient.
1252
        :rtype: str
1253
        """
1254
        return self._get_transient_dir()
1✔
1255

1256
    def _get_transient_dir(self) -> str:
1✔
1257
        """
1258

1259
        :return: The transient directory path
1260
        :rtype: str
1261
        """
1262
        transient_dir, _, _ = redback.get_data.directory.open_access_directory_structure(
1✔
1263
            transient=self.name, transient_type=self.__class__.__name__.lower())
1264
        return transient_dir
1✔
1265

1266
    def estimate_bb_params(self, distance: float = 1e27, bin_width: float = 1.0, min_filters: int = 3, **kwargs):
1✔
1267
        """
1268
        Estimate the blackbody temperature and photospheric radius as functions of time by fitting
1269
        a blackbody SED to the multi‑band photometry.
1270

1271
        The method groups the photometric data into time bins (epochs) of width bin_width (in the
1272
        same units as self.x, typically days). For each epoch with at least min_filters measurements
1273
        (from distinct filters), it fits a blackbody model to the data. When working with photometry
1274
        provided in an effective flux density format (data_mode == "flux_density") the effective–wavelength
1275
        approximation is used. When the data_mode is "flux" (or "magnitude") users have the option
1276
        (via use_eff_wavelength=True) to instead use the effective wavelength approximation by converting AB
1277
        magnitudes to flux density (using redback.utils.calc_flux_density_from_ABmag). If this flag is not
1278
        provided (or is False) then the full bandpass integration is applied.
1279

1280
        Parameters
1281
        ----------
1282
        distance : float, optional
1283
            Distance to the transient in centimeters. Default is 1e27 cm.
1284
        bin_width : float, optional
1285
            Width of the time bins (in days) used to group the photometric data. Default is 1.0.
1286
        min_filters : int, optional
1287
            Minimum number of measurements (from distinct filters) required in a bin to perform the fit.
1288
            Default is 3.
1289
        kwargs : Additional keyword arguments
1290
            maxfev : int, optional, default is 1000
1291
            T_init : float, optional, default is 1e4, used as the initial guess for the fit.
1292
            R_init : float, optional, default is 1e15, used as the initial guess for the fit.
1293
            use_eff_wavelength : bool, optional, default is False.
1294
                If True, then even for photometry provided as magnitudes (or bandpass fluxes),
1295
                the effective wavelength approximation is used. In that case the AB magnitudes are
1296
                converted to flux densities via redback.utils.calc_flux_density_from_ABmag.
1297
                If False, full bandpass integration is used.
1298

1299
        Returns
1300
        -------
1301
        df_bb : pandas.DataFrame or None
1302
            A DataFrame containing columns:
1303
              - epoch_times : binned epoch times,
1304
              - temperature : best-fit blackbody temperatures (Kelvin),
1305
              - radius : best-fit photospheric radii (cm),
1306
              - temp_err : 1σ uncertainties on the temperatures,
1307
              - radius_err : 1σ uncertainties on the radii.
1308
            Returns None if insufficient data are available.
1309
        """
1310
        from scipy.optimize import curve_fit
1✔
1311
        import astropy.units as uu
1✔
1312
        import numpy as np
1✔
1313
        import pandas as pd
1✔
1314

1315
        # Get the filtered photometry.
1316
        # Assumes self.get_filtered_data() returns (time, time_err, y, y_err)
1317
        time_data, _, flux_data, flux_err_data = self.get_filtered_data()
1✔
1318

1319
        redback.utils.logger.info("Estimating blackbody parameters for {}.".format(self.name))
1✔
1320
        redback.utils.logger.info("Using data mode = {}".format(self.data_mode))
1✔
1321

1322
        # Determine whether we are in bandpass mode.
1323
        use_bandpass = False
1✔
1324
        if hasattr(self, "data_mode") and self.data_mode in ['flux', 'magnitude']:
1✔
1325
            use_bandpass = True
×
1326
            # Assume self.filtered_sncosmo_bands contains the (string) band names.
1327
            band_data = self.filtered_sncosmo_bands
×
1328
        else:
1329
            # Otherwise the flux data and frequencies are assumed to be given.
1330
            redback.utils.logger.info("Using effective wavelength approximation for {}".format(self.data_mode))
1✔
1331
            freq_data = self.filtered_frequencies
1✔
1332

1333
        # Option: force effective wavelength approximation even if data_mode is bandpass.
1334
        force_eff = kwargs.get('use_eff_wavelength', False)
1✔
1335
        if use_bandpass and force_eff:
1✔
1336
            redback.utils.logger.warning("Using effective wavelength approximation for {}".format(self.data_mode))
×
1337

1338
            if self.data_mode == 'magnitude':
×
1339
                # Convert the AB magnitudes to flux density using the redback function.
1340
                from redback.utils import abmag_to_flux_density_and_error_inmjy
×
1341
                flux_data, flux_err_data = abmag_to_flux_density_and_error_inmjy(flux_data, flux_err_data)
×
1342
                freq_data = redback.utils.bands_to_frequency(band_data)
×
1343
            else:
1344
                # Convert the bandpass fluxes to flux density using the redback function.
1345
                from redback.utils import bandpass_flux_to_flux_density, bands_to_effective_width
×
1346
                redback.utils.logger.warning("Ensure filters.csv has the correct bandpass effective widths for your filter.")
×
1347
                effective_widths = bands_to_effective_width(band_data)
×
1348
                freq_data = redback.utils.bands_to_frequency(band_data)
×
1349
                flux_data, flux_err_data = bandpass_flux_to_flux_density(flux_data, flux_err_data, effective_widths)
×
1350
            # Use the effective frequency approach.
1351
            use_bandpass = False
×
1352

1353
        # Get initial guesses.
1354
        T_init = kwargs.get('T_init', 1e4)
1✔
1355
        R_init = kwargs.get('R_init', 1e15)
1✔
1356
        maxfev = kwargs.get('maxfev', 1000)
1✔
1357

1358
        # Sort photometric data by time.
1359
        sort_idx = np.argsort(time_data)
1✔
1360
        time_data = time_data[sort_idx]
1✔
1361
        flux_data = flux_data[sort_idx]
1✔
1362
        flux_err_data = flux_err_data[sort_idx]
1✔
1363
        if use_bandpass:
1✔
1364
            band_data = np.array(band_data)[sort_idx]
×
1365
        else:
1366
            freq_data = np.array(freq_data)[sort_idx]
1✔
1367

1368
        # Retrieve redshift.
1369
        redshift = np.nan_to_num(self.redshift)
1✔
1370
        if redshift <= 0.:
1✔
1371
            raise ValueError("Redshift must be provided to perform K-correction.")
×
1372

1373
        # For effective frequency mode, K-correct frequencies.
1374
        if not use_bandpass:
1✔
1375
            freq_data, _ = redback.utils.calc_kcorrected_properties(frequency=freq_data,
1✔
1376
                                                                    redshift=redshift, time=0.)
1377

1378
        # Define the model functions.
1379
        if not use_bandpass:
1✔
1380
            # --- Effective-wavelength model ---
1381
            def bb_model(freq, logT, logR):
1✔
1382
                T = 10 ** logT
1✔
1383
                R = 10 ** logR
1✔
1384
                # Compute the model flux density in erg/s/cm^2/Hz.
1385
                model_flux_cgs = redback.sed.blackbody_to_flux_density(T, R, distance, freq)
1✔
1386
                # Convert to mJy. (1 Jy = 1e-23 erg/s/cm^2/Hz; 1 mJy = 1e-3 Jy = 1e-26 erg/s/cm^2/Hz)
1387
                model_flux_mjy = (model_flux_cgs / (1e-26 * uu.erg / uu.s / uu.cm**2 / uu.Hz)).value
1✔
1388
                return model_flux_mjy
1✔
1389

1390
            model_func = bb_model
1✔
1391
        else:
1392
            # --- Full bandpass integration model ---
1393
            # In this branch we do NOT want to pass strings to curve_fit.
1394
            # Instead, we will dummy-encode the independent variable as indices.
1395
            # We also capture the band names in a closure variable.
1396
            def bb_model_bandpass_from_index(x, logT, logR):
×
1397
                # Ensure x is a numpy array and convert indices to integers.
1398
                i_idx = np.round(x).astype(int)
×
1399
                # Retrieve all corresponding band names in one step.
1400
                bands = np.array(epoch_bands)[i_idx]
×
1401
                # Call bb_model_bandpass with the entire array of bands.
1402
                return bb_model_bandpass(bands, logT, logR, redshift, distance, output_format=self.data_mode)
×
1403

1404
            def bb_model_bandpass(band, logT, logR, redshift, distance, output_format='magnitude'):
×
1405
                from redback.utils import calc_kcorrected_properties, lambda_to_nu, bandpass_magnitude_to_flux
×
1406
                # Create a wavelength grid (in Å) from 100 to 80,000 Å.
1407
                lambda_obs = np.geomspace(100, 80000, 300)
×
1408
                # Convert to frequency (Hz) and apply K-correction.
1409
                frequency, _ = calc_kcorrected_properties(frequency=lambda_to_nu(lambda_obs),
×
1410
                                                          redshift=redshift, time=0.)
1411
                T = 10 ** logT
×
1412
                R = 10 ** logR
×
1413
                # Compute the model SED (flux density in erg/s/cm^2/Hz).
1414
                model_flux = redback.sed.blackbody_to_flux_density(T, R, distance, frequency)
×
1415
                # Convert the SED to per-Å units.
1416
                _spectra = model_flux.to(uu.erg / uu.cm ** 2 / uu.s / uu.Angstrom,
×
1417
                                         equivalencies=uu.spectral_density(wav=lambda_obs * uu.Angstrom))
1418
                spectra = np.zeros((5, 300))
×
1419
                spectra[:, :] = _spectra.value
×
1420
                # Create a source object from the spectrum.
1421
                source = redback.sed.RedbackTimeSeriesSource(phase=np.array([0, 1, 2, 3, 4]),
×
1422
                                                             wave=lambda_obs, flux=spectra)
1423
                if output_format == 'flux':
×
1424
                    # Convert bandpass magnitude to flux.
1425
                    mag = source.bandmag(phase=0, band=band, magsys='ab')
×
1426
                    return bandpass_magnitude_to_flux(magnitude=mag, bands=band)
×
1427
                elif output_format == 'magnitude':
×
1428
                    mag = source.bandmag(phase=0, band=band, magsys='ab')
×
1429
                    return mag
×
1430
                else:
1431
                    raise ValueError("Unknown output_format in bb_model_bandpass.")
×
1432

1433
            # Our wrapper for curve_fit uses dummy x-values.
1434
            model_func = bb_model_bandpass_from_index
×
1435

1436
        # Initialize lists to store fit results.
1437
        epoch_times = []
1✔
1438
        temperatures = []
1✔
1439
        radii = []
1✔
1440
        temp_errs = []
1✔
1441
        radius_errs = []
1✔
1442

1443
        t_min = np.min(time_data)
1✔
1444
        t_max = np.max(time_data)
1✔
1445
        bins = np.arange(t_min, t_max + bin_width, bin_width)
1✔
1446
        redback.utils.logger.info("Number of bins: {}".format(len(bins)))
1✔
1447

1448
        # Ensure at least one bin has enough points.
1449
        bins_with_enough = [i for i in range(len(bins) - 1)
1✔
1450
                            if np.sum((time_data >= bins[i]) & (time_data < bins[i + 1])) >= min_filters]
1451
        if len(bins_with_enough) == 0:
1✔
1452
            redback.utils.logger.warning("No time bins have at least {} measurements. Fitting cannot proceed.".format(min_filters))
×
1453
            redback.utils.logger.warning("Try generating more data through GPs, increasing bin widths, or using fewer filters.")
×
1454
            return None
×
1455

1456
        # Loop over bins (epochs): for each with enough data perform the fit.
1457
        for i in range(len(bins) - 1):
1✔
1458
            mask = (time_data >= bins[i]) & (time_data < bins[i + 1])
1✔
1459
            if np.sum(mask) < min_filters:
1✔
1460
                continue
×
1461
            t_epoch = np.mean(time_data[mask])
1✔
1462
            try:
1✔
1463
                if not use_bandpass:
1✔
1464
                    # Use effective frequency array (numeric).
1465
                    xdata = freq_data[mask]
1✔
1466
                else:
1467
                    # For full bandpass integration mode, we dummy encode xdata.
1468
                    # We ignore the value and simply use indices [0, 1, 2, ...].
1469
                    epoch_bands = list(band_data[mask])  # capture the list of bands for this epoch
×
1470
                    xdata = np.arange(len(epoch_bands))
×
1471
                popt, pcov = curve_fit(
1✔
1472
                    model_func,
1473
                    xdata,
1474
                    flux_data[mask],
1475
                    sigma=flux_err_data[mask],
1476
                    p0=[np.log10(T_init), np.log10(R_init)],
1477
                    absolute_sigma=True,
1478
                    maxfev=maxfev
1479
                )
1480
            except Exception as e:
×
1481
                redback.utils.logger.warning(f"Fit failed for epoch {i}: {e}")
×
1482
                redback.utils.logger.warning(f"Skipping epoch {i} with time {t_epoch:.2f} days.")
×
1483
                continue
×
1484

1485
            logT_fit, logR_fit = popt
1✔
1486
            T_fit = 10 ** logT_fit
1✔
1487
            R_fit = 10 ** logR_fit
1✔
1488
            perr = np.sqrt(np.diag(pcov))
1✔
1489
            T_err = np.log(10) * T_fit * perr[0]
1✔
1490
            R_err = np.log(10) * R_fit * perr[1]
1✔
1491

1492
            epoch_times.append(t_epoch)
1✔
1493
            temperatures.append(T_fit)
1✔
1494
            radii.append(R_fit)
1✔
1495
            temp_errs.append(T_err)
1✔
1496
            radius_errs.append(R_err)
1✔
1497

1498
        if len(epoch_times) == 0:
1✔
1499
            redback.utils.logger.warning("No epochs with sufficient data yielded a successful fit.")
×
1500
            return None
×
1501

1502
        df_bb = pd.DataFrame({
1✔
1503
            'epoch_times': epoch_times,
1504
            'temperature': temperatures,
1505
            'radius': radii,
1506
            'temp_err': temp_errs,
1507
            'radius_err': radius_errs
1508
        })
1509

1510
        redback.utils.logger.info('Masking epochs with likely wrong extractions')
1✔
1511
        df_bb = df_bb[df_bb['temp_err'] / df_bb['temperature'] < 1]
1✔
1512
        df_bb = df_bb[df_bb['radius_err'] / df_bb['radius'] < 1]
1✔
1513
        return df_bb
1✔
1514

1515

1516
    def estimate_bolometric_luminosity(self, distance: float = 1e27, bin_width: float = 1.0,
1✔
1517
                                          min_filters: int = 3, **kwargs):
1518
        """
1519
        Estimate the bolometric luminosity as a function of time by fitting the blackbody SED
1520
        to the multi‑band photometry and then integrating that spectrum. For each epoch the bolometric
1521
        luminosity is computed using the Stefan–Boltzmann law evaluated at the source:
1522

1523
            L_bol = 4 π R² σ_SB T⁴
1524

1525
        Uncertainties in T and R are propagated assuming
1526

1527
            (ΔL_bol / L_bol)² = (2 ΔR / R)² + (4 ΔT / T)².
1528

1529
        Optionally, two corrections can be applied:
1530

1531
        1. A boost–factor to “restore” missing blue flux. If a cutoff wavelength is provided via
1532
           the keyword 'lambda_cut' (in angstroms), it is converted to centimeters and a boost factor is
1533
           calculated as:
1534

1535
               Boost = (F_tot / F_red)
1536

1537
           where F_tot = σ_SB T⁴ and F_red is computed by numerically integrating π * B_λ(T)
1538
           from the cutoff wavelength (in cm) to infinity. The final (boosted) luminosity becomes:
1539

1540
               L_boosted = Boost × (4π R² σ_SB T⁴).
1541

1542
        2. An extinction correction. If the bolometric extinction (A_ext, in magnitudes) is supplied via
1543
           the keyword 'A_ext', the luminosity will be reduced by a factor of 10^(–0.4·A_ext) to account
1544
           for dust extinction. (A_ext defaults to 0.)
1545

1546
        Parameters
1547
        ----------
1548
        distance : float, optional
1549
            Distance to the transient in centimeters. (Default is 1e27 cm.)
1550
        bin_width : float, optional
1551
            Width of the time bins (in days) used for grouping photometry. (Default is 1.0.)
1552
        min_filters : int, optional
1553
            Minimum number of independent filters required in a bin to perform a fit. (Default is 3.)
1554
        kwargs : dict, optional
1555
            Additional keyword arguments to pass to `estimate_bb_params` (e.g., maxfev, T_init, R_init,
1556
            use_eff_wavelength, etc.). Additionally:
1557
        - 'lambda_cut': If provided (in angstroms), the bolometric luminosity will be “boosted”
1558
          to account for missing blue flux.
1559
        - 'A_ext': Bolometric extinction in magnitudes. The observed luminosity is increased by a factor
1560
          10^(+0.4·A_ext). (Default is 0.)
1561

1562
        Returns
1563
        -------
1564
        df_bol : pandas.DataFrame or None
1565
            A DataFrame containing columns:
1566
              - epoch_times: Mean time of the bin (days).
1567
              - temperature: Fitted blackbody temperature (K).
1568
              - radius: Fitted photospheric radius (cm).
1569
              - lum_bol: Derived bolometric luminosity (1e50 erg/s) computed as 4π R² σ_SB T⁴
1570
                         (boosted and extinction-corrected if requested).
1571
              - lum_bol_bb: Derived bolometric blackbody luminosity (1e50 erg/s) computed as 4π R² σ_SB T⁴,
1572
                            before applying either the boost or extinction correction.
1573
              - lum_bol_err: 1σ uncertainty on L_bol (1e50 erg/s) from error propagation.
1574
              - time_rest_frame: Epoch time divided by (1+redshift), i.e., the rest-frame time in days.
1575
            Returns None if no valid blackbody fits were obtained.
1576
        """
1577
        from redback.sed import boosted_bolometric_luminosity
1✔
1578

1579
        # Retrieve optional lambda_cut (in angstroms) for the boost correction.
1580
        lambda_cut_angstrom = kwargs.pop('lambda_cut', None)
1✔
1581
        if lambda_cut_angstrom is not None:
1✔
1582
            redback.utils.logger.info("Including effects of missing flux due to line blanketing.")
1✔
1583
            redback.utils.logger.info(
1✔
1584
                "Using lambda_cut = {} Å for bolometric luminosity boost.".format(lambda_cut_angstrom))
1585
            # Convert lambda_cut from angstroms to centimeters (1 Å = 1e-8 cm)
1586
            lambda_cut = lambda_cut_angstrom * 1e-8
1✔
1587
        else:
1588
            redback.utils.logger.info("No lambda_cut provided; no correction applied. Assuming a pure blackbody SED.")
1✔
1589
            lambda_cut = None
1✔
1590

1591
        # Retrieve optional extinction in magnitudes.
1592
        A_ext = kwargs.pop('A_ext', 0.0)
1✔
1593
        if A_ext != 0.0:
1✔
1594
            redback.utils.logger.info("Applying extinction correction with A_ext = {} mag.".format(A_ext))
1✔
1595
        extinction_factor = 10 ** (0.4 * A_ext)
1✔
1596

1597
        # Retrieve blackbody parameters via your existing method.
1598
        df_bb = self.estimate_bb_params(distance=distance, bin_width=bin_width, min_filters=min_filters, **kwargs)
1✔
1599
        if df_bb is None or len(df_bb) == 0:
1✔
1600
            redback.utils.logger.warning("No valid blackbody fits were obtained; cannot estimate bolometric luminosity.")
×
1601
            return None
×
1602

1603
        # Compute L_bol (or L_boosted) for each epoch and propagate uncertainties.
1604
        L_bol = []
1✔
1605
        L_bol_err = []
1✔
1606
        L_bol_bb = []
1✔
1607
        L_bol_bb_err = []
1✔
1608
        for index, row in df_bb.iterrows():
1✔
1609
            temp = row['temperature']
1✔
1610
            radius = row['radius']
1✔
1611
            T_err = row['temp_err']
1✔
1612
            R_err = row['radius_err']
1✔
1613

1614
            # Use boosted luminosity if lambda_cut is provided.
1615
            if lambda_cut is not None:
1✔
1616
                lum, lum_bb = boosted_bolometric_luminosity(temp, radius, lambda_cut)
1✔
1617
            else:
1618
                lum = 4 * np.pi * (radius ** 2) * redback.constants.sigma_sb * (temp ** 4)
1✔
1619
                lum_bb = lum
1✔
1620

1621
            # Apply extinction correction to both luminosities.
1622
            lum *= extinction_factor
1✔
1623
            lum_bb *= extinction_factor
1✔
1624

1625
            # Propagate uncertainties using:
1626
            # (ΔL/L)² = (2 ΔR / R)² + (4 ΔT / T)².
1627
            rel_err = np.sqrt((2 * R_err / radius) ** 2 + (4 * T_err / temp) ** 2)
1✔
1628
            L_err = lum * rel_err
1✔
1629
            L_err_bb = lum_bb * rel_err
1✔
1630

1631
            L_bol.append(lum)
1✔
1632
            L_bol_bb.append(lum_bb)
1✔
1633
            L_bol_err.append(L_err)
1✔
1634
            L_bol_bb_err.append(L_err_bb)
1✔
1635

1636
        df_bol = df_bb.copy()
1✔
1637
        df_bol['lum_bol'] = np.array(L_bol) / 1e50
1✔
1638
        df_bol['lum_bol_err'] = np.array(L_bol_err) / 1e50
1✔
1639
        df_bol['lum_bol_bb'] = np.array(L_bol_bb) / 1e50
1✔
1640
        df_bol['lum_bol_bb_err'] = np.array(L_bol_bb_err) / 1e50
1✔
1641
        df_bol['time_rest_frame'] = df_bol['epoch_times'] / (1 + self.redshift)
1✔
1642

1643
        redback.utils.logger.info('Masking bolometric estimates with likely wrong extractions')
1✔
1644
        df_bol = df_bol[df_bol['lum_bol_err'] / df_bol['lum_bol'] < 1]
1✔
1645
        redback.utils.logger.info(
1✔
1646
            "Estimated bolometric luminosity using blackbody integration (with boost and extinction corrections if specified).")
1647
        return df_bol
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc