• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 14430354752

13 Apr 2025 02:23PM UTC coverage: 86.635% (+6.0%) from 80.663%
14430354752

Pull #266

github

web-flow
Merge 8147dba2c into e087188ab
Pull Request #266: A big overhaul

1621 of 1828 new or added lines in 14 files covered. (88.68%)

4 existing lines in 2 files now uncovered.

12673 of 14628 relevant lines covered (86.64%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.29
/test/analysis_test.py
1
import os
1✔
2
import unittest
1✔
3
from unittest.mock import MagicMock, patch
1✔
4
from collections import namedtuple
1✔
5
import numpy as np
1✔
6
import matplotlib.pyplot as plt
1✔
7
import bilby
1✔
8
from os import listdir
1✔
9
from os.path import dirname
1✔
10
import pandas as pd
1✔
11
from pathlib import Path
1✔
12
from shutil import rmtree
1✔
13
import redback
1✔
14
from redback.analysis import (plot_evolution_parameters, plot_spectrum, plot_gp_lightcurves,
1✔
15
                              fit_temperature_and_radius_gp, generate_new_transient_data_from_gp)
16

17
_dirname = dirname(__file__)
1✔
18

19
class TestPlotModels(unittest.TestCase):
1✔
20
    outdir = "testing_plotting"
1✔
21

22
    @classmethod
1✔
23
    def setUpClass(cls) -> None:
1✔
24
        Path(cls.outdir).mkdir(exist_ok=True, parents=True)
1✔
25

26
    @classmethod
1✔
27
    def tearDownClass(cls) -> None:
1✔
28
        rmtree(cls.outdir)
1✔
29

30
    def setUp(self) -> None:
1✔
31
        self.path_to_files = f"{_dirname}/../redback/priors/"
1✔
32
        self.prior_files = listdir(self.path_to_files)
1✔
33

34
    def tearDown(self) -> None:
1✔
35
        pass
1✔
36

37
    def get_prior(self, file):
1✔
38
        prior_dict = bilby.prior.PriorDict()
1✔
39
        prior_dict.from_file(f"{self.path_to_files}{file}")
1✔
40
        return prior_dict
1✔
41

42
    def get_posterior(self, file):
1✔
43
        return pd.DataFrame.from_dict(self.get_prior(file=file).sample(2))
1✔
44

45
    def test_plotting(self):
1✔
46
        kwargs = dict(frequency=2e14)
1✔
47
        times = np.array([1, 2, 3])
1✔
48
        yobs = np.array([1e-3, 1e-3, 1e-3])
1✔
49
        yerr = np.ones_like(yobs) * 1e-4
1✔
50
        bands = np.array(['r', 'r', 'r'])
1✔
51
        valid_models = ['arnett', 'one_component_kilonova_model', 'slsn',
1✔
52
                        'tde_analytical', 'basic_mergernova']
53
        for f in self.prior_files:
1✔
54
            model_name = f.replace(".prior", "")
1✔
55
            if model_name in valid_models:
1✔
56
                print(f)
1✔
57
                posterior = self.get_posterior(file=f)
1✔
58
                transient = redback.supernova.Supernova(time=times, flux_density=yobs,
1✔
59
                                                        flux_density_err=yerr, bands=bands,
60
                                                        name='test',data_mode='flux_density',
61
                                                        use_phase_model=False)
62

63
                kwargs['output_format'] = 'flux_density'
1✔
64
                redback.analysis.plot_lightcurve(transient=transient, parameters=posterior,
1✔
65
                                                 model=model_name, model_kwargs=kwargs)
66
                redback.analysis.plot_multiband_lightcurve(transient=transient, parameters=posterior,
1✔
67
                                                 model=model_name, model_kwargs=kwargs)
68

69
class TestPlotDifferentBands(unittest.TestCase):
1✔
70
    outdir = "testing_plotting"
1✔
71

72
    @classmethod
1✔
73
    def setUpClass(cls) -> None:
1✔
74
        Path(cls.outdir).mkdir(exist_ok=True, parents=True)
1✔
75

76
    @classmethod
1✔
77
    def tearDownClass(cls) -> None:
1✔
78
        rmtree(cls.outdir)
1✔
79

80
    def setUp(self) -> None:
1✔
81
        self.path_to_files = f"{_dirname}/../redback/priors/"
1✔
82
        self.prior_files = listdir(self.path_to_files)
1✔
83

84
    def tearDown(self) -> None:
1✔
85
        pass
1✔
86

87
    def get_prior(self, file):
1✔
88
        prior_dict = bilby.prior.PriorDict()
1✔
89
        prior_dict.from_file(f"{self.path_to_files}{file}")
1✔
90
        return prior_dict
1✔
91

92
    def get_posterior(self, file):
1✔
93
        return pd.DataFrame.from_dict(self.get_prior(file=file).sample(2))
1✔
94

95
    def test_plotting(self):
1✔
96
        kwargs = dict(frequency=2e14)
1✔
97
        times = np.array([1, 2, 3])
1✔
98
        yobs = np.array([1e-3, 1e-3, 1e-3])
1✔
99
        yerr = np.ones_like(yobs) * 1e-4
1✔
100
        bands = np.array(['sdssr', 'sdssz', 'sdssu'])
1✔
101
        valid_models = ['arnett', 'one_component_kilonova_model', 'slsn',
1✔
102
                        'tde_analytical', 'basic_mergernova']
103
        for f in self.prior_files:
1✔
104
            model_name = f.replace(".prior", "")
1✔
105
            if model_name in valid_models:
1✔
106
                print(f)
1✔
107
                posterior = self.get_posterior(file=f)
1✔
108
                transient = redback.supernova.Supernova(time=times, magnitude=yobs,
1✔
109
                                                        magnitude_err=yerr, bands=bands,
110
                                                        name='test',data_mode='magnitude',
111
                                                        use_phase_model=False)
112

113
                kwargs['output_format'] = 'magnitude'
1✔
114
                redback.analysis.plot_lightcurve(transient=transient, parameters=posterior,
1✔
115
                                                 model=model_name, model_kwargs=kwargs)
116
                redback.analysis.plot_multiband_lightcurve(transient=transient, parameters=posterior,
1✔
117
                                                 model=model_name, model_kwargs=kwargs)
118

119
class TestMagnitudePlot(unittest.TestCase):
1✔
120
    outdir = "testing_plotting"
1✔
121

122
    @classmethod
1✔
123
    def setUpClass(cls) -> None:
1✔
124
        Path(cls.outdir).mkdir(exist_ok=True, parents=True)
1✔
125

126
    @classmethod
1✔
127
    def tearDownClass(cls) -> None:
1✔
128
        rmtree(cls.outdir)
1✔
129

130
    def setUp(self) -> None:
1✔
131
        self.path_to_files = f"{_dirname}/../redback/priors/"
1✔
132
        self.prior_files = listdir(self.path_to_files)
1✔
133

134
    def tearDown(self) -> None:
1✔
135
        pass
1✔
136

137
    def get_prior(self, file):
1✔
138
        prior_dict = bilby.prior.PriorDict()
1✔
139
        prior_dict.from_file(f"{self.path_to_files}{file}")
1✔
140
        return prior_dict
1✔
141

142
    def get_posterior(self, file):
1✔
143
        return pd.DataFrame.from_dict(self.get_prior(file=file).sample(2))
1✔
144

145
    def test_plotting(self):
1✔
146
        kwargs = dict(frequency=2e14)
1✔
147
        times = np.array([1, 2, 3])
1✔
148
        yobs = np.array([1e-3, 1e-3, 1e-3])
1✔
149
        yerr = np.ones_like(yobs) * 1e-4
1✔
150
        bands = np.array(['sdssr', 'sdssr', 'sdssr'])
1✔
151
        valid_models = ['arnett', 'one_component_kilonova_model', 'slsn',
1✔
152
                        'tde_analytical', 'basic_mergernova']
153
        for f in self.prior_files:
1✔
154
            model_name = f.replace(".prior", "")
1✔
155
            if model_name in valid_models:
1✔
156
                print(f)
1✔
157
                posterior = self.get_posterior(file=f)
1✔
158
                transient = redback.supernova.Supernova(time=times, magnitude=yobs,
1✔
159
                                                        magnitude_err=yerr, bands=bands,
160
                                                        name='test',data_mode='magnitude',
161
                                                        use_phase_model=False)
162

163
                kwargs['output_format'] = 'magnitude'
1✔
164
                redback.analysis.plot_lightcurve(transient=transient, parameters=posterior,
1✔
165
                                                 model=model_name, model_kwargs=kwargs)
166
                redback.analysis.plot_multiband_lightcurve(transient=transient, parameters=posterior,
1✔
167
                                                 model=model_name, model_kwargs=kwargs)
168

169
class TestFluxPlot(unittest.TestCase):
1✔
170
    outdir = "testing_plotting"
1✔
171

172
    @classmethod
1✔
173
    def setUpClass(cls) -> None:
1✔
174
        Path(cls.outdir).mkdir(exist_ok=True, parents=True)
1✔
175

176
    @classmethod
1✔
177
    def tearDownClass(cls) -> None:
1✔
178
        rmtree(cls.outdir)
1✔
179

180
    def setUp(self) -> None:
1✔
181
        self.path_to_files = f"{_dirname}/../redback/priors/"
1✔
182
        self.prior_files = listdir(self.path_to_files)
1✔
183

184
    def tearDown(self) -> None:
1✔
185
        pass
1✔
186

187
    def get_prior(self, file):
1✔
188
        prior_dict = bilby.prior.PriorDict()
1✔
189
        prior_dict.from_file(f"{self.path_to_files}{file}")
1✔
190
        return prior_dict
1✔
191

192
    def get_posterior(self, file):
1✔
193
        return pd.DataFrame.from_dict(self.get_prior(file=file).sample(2))
1✔
194

195
    def test_plotting(self):
1✔
196
        kwargs = dict(frequency=2e14)
1✔
197
        times = np.array([1, 2, 3])
1✔
198
        yobs = np.array([1e-3, 1e-3, 1e-3])
1✔
199
        yerr = np.ones_like(yobs) * 1e-4
1✔
200
        bands = np.array(['sdssr', 'sdssr', 'sdssr'])
1✔
201
        valid_models = ['arnett', 'one_component_kilonova_model', 'slsn',
1✔
202
                        'tde_analytical', 'basic_mergernova']
203
        for f in self.prior_files:
1✔
204
            model_name = f.replace(".prior", "")
1✔
205
            if model_name in valid_models:
1✔
206
                print(f)
1✔
207
                posterior = self.get_posterior(file=f)
1✔
208
                transient = redback.supernova.Supernova(time=times, flux=yobs,
1✔
209
                                                        flux_err=yerr, bands=bands,
210
                                                        name='test',data_mode='flux',
211
                                                        use_phase_model=False)
212

213
                kwargs['output_format'] = 'flux'
1✔
214
                redback.analysis.plot_lightcurve(transient=transient, parameters=posterior,
1✔
215
                                                 model=model_name, model_kwargs=kwargs)
216
                redback.analysis.plot_multiband_lightcurve(transient=transient, parameters=posterior,
1✔
217
                                                 model=model_name, model_kwargs=kwargs)
218

219
class TestPlotPhaseModels(unittest.TestCase):
1✔
220
    outdir = "testing_plotting"
1✔
221

222
    @classmethod
1✔
223
    def setUpClass(cls) -> None:
1✔
224
        Path(cls.outdir).mkdir(exist_ok=True, parents=True)
1✔
225

226
    @classmethod
1✔
227
    def tearDownClass(cls) -> None:
1✔
228
        rmtree(cls.outdir)
1✔
229

230
    def setUp(self) -> None:
1✔
231
        self.path_to_files = f"{_dirname}/../redback/priors/"
1✔
232
        self.prior_files = listdir(self.path_to_files)
1✔
233

234
    def tearDown(self) -> None:
1✔
235
        pass
1✔
236

237
    def get_prior(self, file):
1✔
238
        prior_dict = bilby.prior.PriorDict()
1✔
239
        prior_dict.from_file(f"{self.path_to_files}{file}")
1✔
240
        return prior_dict
1✔
241

242
    def get_posterior(self, file):
1✔
243
        return pd.DataFrame.from_dict(self.get_prior(file=file).sample(2))
1✔
244

245
    def test_plotting(self):
1✔
246
        kwargs = dict(frequency=2e14)
1✔
247
        times = np.array([1, 2, 3]) + 55855
1✔
248
        yobs = np.array([1e-3, 1e-3, 1e-3])
1✔
249
        yerr = np.ones_like(yobs) * 1e-4
1✔
250
        bands = np.array(['r', 'r', 'r'])
1✔
251
        valid_models = ['arnett', 'one_component_kilonova_model', 'slsn',
1✔
252
                        'tde_analytical', 'basic_mergernova']
253
        for f in self.prior_files:
1✔
254
            model_name = f.replace(".prior", "")
1✔
255
            if model_name in valid_models:
1✔
256
                print(f)
1✔
257
                posterior = self.get_posterior(file=f)
1✔
258
                transient = redback.supernova.Supernova(time_mjd=times, flux_density=yobs,
1✔
259
                                                        flux_density_err=yerr, bands=bands,
260
                                                        name='test',data_mode='flux_density',
261
                                                        use_phase_model=True)
262
                model = 't0_base_model'
1✔
263
                kwargs['t0'] = 55855
1✔
264
                kwargs['base_model'] = model_name
1✔
265
                kwargs['output_format'] = 'flux_density'
1✔
266
                redback.analysis.plot_lightcurve(transient=transient, parameters=posterior,
1✔
267
                                                 model=model, model_kwargs=kwargs)
268
                redback.analysis.plot_multiband_lightcurve(transient=transient, parameters=posterior,
1✔
269
                                                 model=model, model_kwargs=kwargs)
270

271
# Dummy “evolving magnetar” model – used by plot_evolution_parameters:
272
DummyEvolvingMagnetarOutput = namedtuple("DummyEvolvingMagnetarOutput", ["nn", "mu", "alpha"])
1✔
273
def dummy_evolving_magnetar_only(time, **kwargs):
1✔
274
    # Return dummy arrays (one value per time point)
275
    return DummyEvolvingMagnetarOutput(
1✔
276
        nn = np.ones_like(time)*3.0,
277
        mu = np.ones_like(time)*1e30,
278
        alpha = np.ones_like(time)*2.5
279
    )
280

281
# Dummy spectrum model – used by plot_spectrum:
282
DummySpectrumOutput = namedtuple("DummySpectrumOutput", ["lambdas", "time", "spectra"])
1✔
283
def dummy_spectrum_model(time_to_plot, **kwargs):
1✔
284
    # Create an array of wavelengths (in Angstroms)
285
    lambdas = np.linspace(4000, 7000, 100)
1✔
286
    # Assume a dummy time array in seconds; for simplicity, create one with the same number of elements as time_to_plot
287
    time_array = np.linspace(0, (len(time_to_plot)-1)*86400, len(time_to_plot))
1✔
288
    # Dummy spectra: for each time, create a spectrum (here a simple linear ramp)
289
    spectra = np.tile(np.linspace(1, 100, 100), (len(time_to_plot), 1))
1✔
290
    return DummySpectrumOutput(lambdas=lambdas, time=time_array, spectra=spectra)
1✔
291

292
# Dummy “find_nearest” helper (used by plot_spectrum)
293
def dummy_find_nearest(arr, target):
1✔
NEW
294
    idx = (np.abs(arr - target)).argmin()
×
NEW
295
    return arr[idx], idx
×
296

297
# Dummy GP classes used as stand‑ins by the GP‐plotting functions.
298
class DummyGP:
1✔
299
    def predict(self, scaled_y, X_new, return_var=True, return_cov=True):
1✔
300
        n = X_new.shape[0] if isinstance(X_new, np.ndarray) and X_new.ndim == 2 else len(X_new)
1✔
301
        prediction = np.ones(n) * 2.0
1✔
302
        cov = np.eye(n) * 0.25
1✔
303
        return prediction, cov
1✔
304

305
class Dummy1DGP:
1✔
306
    def predict(self, scaled_y, t_new, return_var=True, return_cov=True):
1✔
307
        n = len(t_new)
1✔
308
        prediction = np.ones(n) * 2.0
1✔
309
        cov = np.eye(n) * 0.25
1✔
310
        return prediction, cov
1✔
311

312
# Dummy GP output container; for the 2D (with frequency) branch
313
class DummyGPOutput:
1✔
314
    def __init__(self, use_frequency=True, unique_bands=None):
1✔
315
        """
316
        Initialize a dummy Gaussian Process output for testing.
317

318
        :param use_frequency: If True, use frequency-based GP; otherwise, use band-specific GP.
319
        :param unique_bands: List of unique bands for which GP output is created.
320
        """
321
        self.use_frequency = use_frequency
1✔
322
        self.y_scaler = 1.0
1✔
323
        if use_frequency:
1✔
324
            # Single GP for frequency-based mode
325
            self.gp = DummyGP()
1✔
326
            self.scaled_y = np.ones(10)  # Dummy scaled data
1✔
327
        else:
328
            # Dictionary of GPs for each band (band-specific mode)
329
            if unique_bands is None:
1✔
NEW
330
                unique_bands = []  # Avoid issues with missing bands
×
331
            self.gp = {band: Dummy1DGP() for band in unique_bands}
1✔
332
            self.scaled_y = {band: np.ones(10) for band in unique_bands}
1✔
333

334

335
# Dummy transient for GP plotting.
336
class DummyTransientForGP:
1✔
337
    def __init__(self, x, unique_bands, data_mode):
1✔
338
        self.x = np.array(x)
1✔
339
        self.use_phase_model = False
1✔
340
        self.data_mode = data_mode
1✔
341
        self.unique_bands = unique_bands  # Expect a list of band names
1✔
342
    @property
1✔
343
    def unique_frequencies(self):
1✔
344
        # Return a dummy array of frequencies associated with unique_bands.
NEW
345
        return np.array([8.43500e+14 for band in self.unique_bands])
×
346

347
# Dummy transient for generating new GP‐data. (A minimal dummy OpticalTransient)
348
class DummyOpticalTransient:
1✔
349
    def __init__(self, name, data_mode, unique_frequencies=None, unique_bands=None, redshift=0.1):
1✔
350
        self.name = name
1✔
351
        self.data_mode = data_mode
1✔
352
        self.redshift = redshift
1✔
353
        self.unique_frequencies = unique_frequencies if unique_frequencies is not None else np.array([8.43500e+14])
1✔
354
        self.unique_bands = unique_bands if unique_bands is not None else ["dummy"]
1✔
355

356
# Dummy GP output container for generate_new_transient_data_from_gp.
357
class DummyGPOutputForGenerate:
1✔
358
    def __init__(self, use_frequency=True):
1✔
359
        self.use_frequency = use_frequency
1✔
360
        self.y_scaler = 1.0
1✔
361
        self.scaled_y = np.ones(10) if use_frequency else {"dummy": np.ones(10)}
1✔
362
        self.gp = DummyGP() if use_frequency else {"dummy": Dummy1DGP()}
1✔
363

364
# Dummy bands_to_frequency function for use in plot_gp_lightcurves.
365
def dummy_bands_to_frequency(band_list):
1✔
366
    # For testing, simply return a constant frequency.
NEW
367
    return 8.43500e+14
×
368

369
# === Test classes below ===
370

371
class TestPlotEvolutionParameters(unittest.TestCase):
1✔
372
    def setUp(self):
1✔
373
        # Create a dummy result object with the required metadata and posterior DataFrame.
374
        self.dummy_metadata = {"time": np.array([1, 10, 100])}
1✔
375
        # Create a posterior DataFrame with a few dummy rows (the content is not critical)
376
        df = pd.DataFrame({"param1": [0.1, 0.2], "param2": [1, 2]})
1✔
377
        self.dummy_result = type("DummyResult", (), {"metadata": self.dummy_metadata, "posterior": df})
1✔
378
        # Patch the evolving magnetar model in redback.model_library.all_models_dict.
379
        self.orig_evolving = redback.model_library.all_models_dict.get("evolving_magnetar_only")
1✔
380
        redback.model_library.all_models_dict["evolving_magnetar_only"] = dummy_evolving_magnetar_only
1✔
381

382
    def tearDown(self):
1✔
383
        if self.orig_evolving is not None:
1✔
384
            redback.model_library.all_models_dict["evolving_magnetar_only"] = self.orig_evolving
1✔
385

386
    def test_plot_evolution_parameters_returns_fig_and_axes(self):
1✔
387
        # Call with a small number of random models
388
        fig, ax = plot_evolution_parameters(self.dummy_result, random_models=3)
1✔
389
        self.assertIsInstance(fig, plt.Figure)
1✔
390
        self.assertEqual(len(ax), 3)
1✔
391
        # Check that each axis has a ylabel (e.g. 'braking index', etc.)
392
        for a in ax:
1✔
393
            self.assertNotEqual(a.get_ylabel(), "")
1✔
394
        plt.close(fig)
1✔
395

396
class TestPlotSpectrum(unittest.TestCase):
1✔
397
    def setUp(self):
1✔
398
        # Patch the spectrum model to our dummy.
399
        self.orig_spec_model = redback.model_library.all_models_dict.get("dummy_spectrum_model")
1✔
400
        redback.model_library.all_models_dict["dummy_spectrum_model"] = dummy_spectrum_model
1✔
401
        # Patch the find_nearest function that is used inside plot_spectrum.
402
        self.find_nearest_patcher = patch('redback.utils.find_nearest', dummy_find_nearest)
1✔
403
        self.find_nearest_patcher.start()
1✔
404
        # In case day_to_s is expected (for converting time), define it if necessary.
405
        self.day_to_s = 86400
1✔
406

407
    def tearDown(self):
1✔
408
        if self.orig_spec_model is not None:
1✔
NEW
409
            redback.model_library.all_models_dict["dummy_spectrum_model"] = self.orig_spec_model
×
410
        self.find_nearest_patcher.stop()
1✔
411

412
    def test_plot_spectrum_returns_axes(self):
1✔
413
        parameters = {"some_parameter": 1}
1✔
414
        time_to_plot = np.array([1, 2])  # in days
1✔
415
        fig, tmp_ax = plt.subplots()
1✔
416
        ax = plot_spectrum("dummy_spectrum_model", parameters, time_to_plot, axes=tmp_ax)
1✔
417
        self.assertIsNotNone(ax)
1✔
418
        self.assertIn("Wavelength", ax.get_xlabel())
1✔
419
        plt.close(fig)
1✔
420

421
class TestPlotGPLightcurves(unittest.TestCase):
1✔
422
    @patch('redback.utils.bands_to_frequency', side_effect=dummy_bands_to_frequency)
1✔
423
    def test_plot_gp_lightcurves_with_frequency(self, mock_btf):
1✔
424
        dummy_trans = DummyTransientForGP(x=np.linspace(0, 10, 50), unique_bands=["g", "r"], data_mode="flux_density")
1✔
425
        dummy_gp_output = DummyGPOutput(use_frequency=True)
1✔
426
        fig, ax = plt.subplots()
1✔
427
        ax_out = plot_gp_lightcurves(dummy_trans, dummy_gp_output, axes=ax)
1✔
428
        self.assertIsNotNone(ax_out)
1✔
429
        self.assertGreater(len(ax_out.get_lines()), 0)
1✔
430
        plt.close(fig)
1✔
431

432
    @patch('redback.utils.bands_to_frequency', side_effect=dummy_bands_to_frequency)
1✔
433
    def test_plot_gp_lightcurves_without_frequency(self, mock_btf):
1✔
434
        dummy_trans = DummyTransientForGP(x=np.linspace(0, 10, 50), unique_bands=["g", "r"], data_mode="flux_density")
1✔
435
        dummy_gp_output = DummyGPOutput(use_frequency=False, unique_bands=["g", "r"])
1✔
436
        fig, ax = plt.subplots()
1✔
437
        ax_out = plot_gp_lightcurves(dummy_trans, dummy_gp_output, axes=ax)
1✔
438
        self.assertIsNotNone(ax_out)
1✔
439
        self.assertGreater(len(ax_out.get_lines()), 0)
1✔
440
        plt.close(fig)
1✔
441

442
class TestFitTemperatureAndRadiusGP(unittest.TestCase):
1✔
443
    def setUp(self):
1✔
444
        # Build a simple DataFrame with the required columns.
445
        self.data = pd.DataFrame({
1✔
446
            "epoch_times": np.linspace(1, 100, 20),
447
            "temperature": np.linspace(10000, 5000, 20),
448
            "radius": np.linspace(1e14, 5e14, 20),
449
            "temp_err": np.full(20, 500),
450
            "radius_err": np.full(20, 1e13)
451
        })
452
        # Use a simple george exponential-squared kernel.
453
        from george.kernels import ExpSquaredKernel
1✔
454
        self.kernelT = ExpSquaredKernel(metric=1.0)
1✔
455
        self.kernelR = ExpSquaredKernel(metric=1.0)
1✔
456

457
    def test_fit_temperature_and_radius_gp_without_plot(self):
1✔
458
        gp_T, gp_R = fit_temperature_and_radius_gp(self.data, self.kernelT, self.kernelR, plot=False)
1✔
459
        # Import the GP type from george (may be george.GP)
460
        from george.gp import GP
1✔
461
        self.assertIsInstance(gp_T, GP)
1✔
462
        self.assertIsInstance(gp_R, GP)
1✔
463

464
    def test_fit_temperature_and_radius_gp_with_plot(self):
1✔
465
        output = fit_temperature_and_radius_gp(self.data, self.kernelT, self.kernelR, plot=True, fit_in_log=True)
1✔
466
        self.assertEqual(len(output), 4)
1✔
467
        gp_T, gp_R, fig, axes = output
1✔
468
        self.assertIsInstance(fig, plt.Figure)
1✔
469
        self.assertEqual(len(axes), 2)
1✔
470
        plt.close(fig)
1✔
471

472
class TestGenerateNewTransientDataFromGP(unittest.TestCase):
1✔
473
    def setUp(self):
1✔
474
        # Create a dummy gp_output for the “use_frequency” branch.
475
        self.gp_out = DummyGPOutputForGenerate(use_frequency=True)
1✔
476
        # Create a dummy new time array.
477
        self.t_new = np.linspace(0, 100, 10)
1✔
478
        # Create a dummy transient (OpticalTransient) with minimal required attributes.
479
        self.transient = DummyOpticalTransient(name="TestTransient", data_mode="flux_density", redshift=0.1)
1✔
480

481
    def test_generate_new_transient_data_from_gp_flux_density(self):
1✔
482
        new_transient = generate_new_transient_data_from_gp(self.gp_out, self.t_new, self.transient)
1✔
483
        # Check that the returned transient has a name ending with '_gp'
484
        self.assertTrue(new_transient.name.endswith("_gp"))
1✔
485
        # For flux_density mode, these attributes should be present.
486
        self.assertTrue(hasattr(new_transient, "flux_density"))
1✔
487
        self.assertTrue(hasattr(new_transient, "flux_density_err"))
1✔
488

489
    def test_generate_new_transient_data_from_gp_flux(self):
1✔
490
        self.transient.data_mode = "flux"
1✔
491
        new_transient = generate_new_transient_data_from_gp(self.gp_out, self.t_new, self.transient)
1✔
492
        self.assertTrue(hasattr(new_transient, "flux"))
1✔
493
        self.assertTrue(hasattr(new_transient, "flux_err"))
1✔
494

495
    def test_generate_new_transient_data_from_gp_magnitude(self):
1✔
496
        self.transient.data_mode = "magnitude"
1✔
497
        new_transient = generate_new_transient_data_from_gp(self.gp_out, self.t_new, self.transient)
1✔
498
        self.assertTrue(hasattr(new_transient, "magnitude"))
1✔
499
        self.assertTrue(hasattr(new_transient, "magnitude_err"))
1✔
500

501
    def test_generate_new_transient_data_from_gp_luminosity(self):
1✔
502
        self.transient.data_mode = "luminosity"
1✔
503
        new_transient = generate_new_transient_data_from_gp(self.gp_out, self.t_new, self.transient)
1✔
504
        # In luminosity mode, check for attributes such as 'Lum50' and 'Lum50_err'
505
        self.assertTrue(hasattr(new_transient, "Lum50"))
1✔
506
        self.assertTrue(hasattr(new_transient, "Lum50_err"))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc