• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

nikhil-sarin / redback / 14430354752

13 Apr 2025 02:23PM UTC coverage: 86.635% (+6.0%) from 80.663%
14430354752

Pull #266

github

web-flow
Merge 8147dba2c into e087188ab
Pull Request #266: A big overhaul

1621 of 1828 new or added lines in 14 files covered. (88.68%)

4 existing lines in 2 files now uncovered.

12673 of 14628 relevant lines covered (86.64%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.09
/test/sampler_test.py
1
import unittest
1✔
2
import tempfile
1✔
3
import numpy as np
1✔
4
from types import SimpleNamespace
1✔
5
from unittest.mock import patch
1✔
6

7
import bilby
1✔
8

9
from redback.model_library import all_models_dict
1✔
10
from redback.result import RedbackResult
1✔
11
from redback.transient.afterglow import Afterglow
1✔
12
from redback.transient.prompt import PromptTimeSeries
1✔
13
from redback.transient.transient import OpticalTransient, Transient, Spectrum
1✔
14
from redback.sampler import fit_model
1✔
15

16

17
# --- Dummy Model and Result --- #
18
def dummy_model(x, **kwargs):
1✔
19
    """A dummy model function that returns an array of ones."""
NEW
20
    return np.ones_like(x)
×
21

22

23
# Allow the model lookup via the standard dictionary.
24
all_models_dict["dummy_model"] = dummy_model
1✔
25

26

27
class DummyResult(RedbackResult):
1✔
28
    """A minimal dummy result class mimic."""
29

30
    def __init__(self):
1✔
31
        self.data = "dummy_result"
1✔
32

33
    def plot_spectrum(self, model):
1✔
UNCOV
34
        pass
×
35

36
    def plot_lightcurve(self, model):
1✔
UNCOV
37
        pass
×
38

39

40
# --- Revised Dummy Transient Classes --- #
41
class DummySpectrum(Spectrum):
1✔
42
    def __init__(self, outdir):
1✔
43
        # Set required attributes.
44
        self.data_mode = "flux_density"
1✔
45
        self.directory_structure = SimpleNamespace(directory_path=outdir)
1✔
46
        self.name = "DummySpectrum"
1✔
47
        self.use_phase_model = False  # Required by base transient.
1✔
48
        self.angstroms = np.linspace(4000, 7000, 100)
1✔
49
        self.flux_density = np.ones(100) * 1e-16
1✔
50
        self.flux_density_err = np.ones(100) * 1e-18
1✔
51
        # Define _bands: for spectrum, use dummy filter name for each wavelength.
52
        self._bands = np.array(["dummy"] * len(self.angstroms))
1✔
53
        # (Spectrum may not require _active_bands)
54

55

56
class DummyAfterglow(Afterglow):
1✔
57
    def __init__(self, outdir):
1✔
58
        self.data_mode = "flux_density"
1✔
59
        self.directory_structure = SimpleNamespace(directory_path=outdir)
1✔
60
        self.name = "DummyAfterglow"
1✔
61
        self.use_phase_model = False  # Required by base transient.
1✔
62
        self.x = np.linspace(0, 10, 50)
1✔
63
        self.x_err = np.zeros((2, 50))
1✔
64
        self.y = np.ones(50) * 10.0
1✔
65
        self.y_err = np.ones(50)
1✔
66
        self.photon_index = 1.0
1✔
67
        # Define _bands first so that self.bands is available.
68
        self._bands = np.array(["dummy"] * len(self.x))
1✔
69
        # Also provide _active_bands so that filtered_indices works.
70
        self._active_bands = self._bands.copy()
1✔
71
        # Now setting frequency calls the setter which uses self.bands.
72
        self.frequency = np.ones(len(self.x))
1✔
73

74

75
class DummyPromptTimeSeries(PromptTimeSeries):
1✔
76
    def __init__(self, outdir):
1✔
77
        self.data_mode = "counts"  # Acceptable for prompt data.
1✔
78
        self.directory_structure = SimpleNamespace(directory_path=outdir)
1✔
79
        self.name = "DummyPrompt"
1✔
80
        self.use_phase_model = False  # Required.
1✔
81
        self.x = np.linspace(0, 10, 50)
1✔
82
        self.bin_size = 1.0
1✔
83
        self.y = np.ones(50) * 5.0
1✔
84
        self.y_err = np.ones(50)
1✔
85
        # Provide dummy _bands (if required downstream).
86
        self._bands = np.array(["dummy"] * len(self.x))
1✔
87

88

89
class DummyOpticalTransient(OpticalTransient):
1✔
90
    def __init__(self, outdir):
1✔
91
        self.data_mode = "flux_density"
1✔
92
        self.directory_structure = SimpleNamespace(directory_path=outdir)
1✔
93
        self.name = "DummyOptical"
1✔
94
        self.use_phase_model = False  # Prevent errors in base transient.
1✔
95
        self.x = np.linspace(0, 10, 50)
1✔
96
        self.x_err = np.zeros((2, 50))
1✔
97
        self.y = np.ones(50) * 15.0
1✔
98
        self.y_err = np.ones(50)
1✔
99
        # Set _bands so that the bands property works.
100
        self._bands = np.array(["dummy"] * len(self.x))
1✔
101
        # Also provide _active_bands so that filtered_indices works.
102
        self._active_bands = self._bands.copy()
1✔
103

104

105
class DummyTransient(Transient):
1✔
106
    def __init__(self, outdir):
1✔
107
        self.data_mode = "flux_density"
1✔
108
        self.directory_structure = SimpleNamespace(directory_path=outdir)
1✔
109
        self.name = "DummyTransient"
1✔
110
        self.use_phase_model = False  # Required by base transient.
1✔
111
        self.x = np.linspace(0, 10, 50)
1✔
112
        self.x_err = np.zeros((2, 50))
1✔
113
        self.y = np.ones(50) * 20.0
1✔
114
        self.y_err = np.ones(50)
1✔
115
        # Define _bands and _active_bands.
116
        self._bands = np.array(["dummy"] * len(self.x))
1✔
117
        self._active_bands = self._bands.copy()
1✔
118

119

120
# Dummy object that is not a recognized transient type.
121
class DummyNotTransient:
1✔
122
    data_mode = "flux_density"  # Provide a dummy attribute.
1✔
123

124

125
# --- Tests for the fit_model function --- #
126
class TestFitModel(unittest.TestCase):
1✔
127
    def setUp(self):
1✔
128
        # Create a temporary directory for outdir.
129
        self.temp_dir = tempfile.TemporaryDirectory()
1✔
130
        self.outdir = self.temp_dir.name
1✔
131
        # Default model_kwargs.
132
        self.model_kwargs = {"output_format": "flux_density"}
1✔
133
        self.sampler = "dynesty"
1✔
134
        self.nlive = 100
1✔
135
        self.walks = 50
1✔
136
        self.prior = bilby.prior.PriorDict()  # Empty PriorDict for testing.
1✔
137
        # Create a dummy RedbackResult to be returned by the sampler.
138
        self.dummy_result = DummyResult()
1✔
139

140
    def tearDown(self):
1✔
141
        self.temp_dir.cleanup()
1✔
142

143
    @patch("redback.result.read_in_result", side_effect=Exception("No result"))
1✔
144
    @patch("bilby.run_sampler", autospec=True)
1✔
145
    def test_fit_spectrum(self, mock_run_sampler, mock_read_result):
1✔
146
        trans = DummySpectrum(self.outdir)
1✔
147
        # For spectrum, add a frequency array to model_kwargs.
148
        model_kwargs = self.model_kwargs.copy()
1✔
149
        model_kwargs["frequency"] = np.linspace(1e14, 1e15, len(trans.angstroms))
1✔
150
        mock_run_sampler.return_value = self.dummy_result
1✔
151

152
        result = fit_model(
1✔
153
            transient=trans, model="dummy_model", outdir=self.outdir, label="TestSpectrum",
154
            sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
155
            model_kwargs=model_kwargs, plot=False
156
        )
157
        self.assertEqual(result, self.dummy_result)
1✔
158
        mock_run_sampler.assert_called_once()
1✔
159

160
    @patch("redback.result.read_in_result", side_effect=Exception("No result"))
1✔
161
    @patch("bilby.run_sampler", autospec=True)
1✔
162
    def test_fit_afterglow(self, mock_run_sampler, mock_read_result):
1✔
163
        trans = DummyAfterglow(self.outdir)
1✔
164
        # Supply a frequency key (if needed) for consistency.
165
        model_kwargs = self.model_kwargs.copy()
1✔
166
        model_kwargs["frequency"] = np.linspace(1e14, 1e15, len(trans.x))
1✔
167
        mock_run_sampler.return_value = self.dummy_result
1✔
168

169
        result = fit_model(
1✔
170
            transient=trans, model="dummy_model", outdir=self.outdir, label="TestAfterglow",
171
            sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
172
            model_kwargs=model_kwargs, plot=False
173
        )
174
        self.assertEqual(result, self.dummy_result)
1✔
175
        mock_run_sampler.assert_called_once()
1✔
176

177
    @patch("redback.result.read_in_result", side_effect=Exception("No result"))
1✔
178
    @patch("bilby.run_sampler", autospec=True)
1✔
179
    def test_fit_prompt(self, mock_run_sampler, mock_read_result):
1✔
180
        trans = DummyPromptTimeSeries(self.outdir)
1✔
181
        # For prompt objects, add a dummy frequency array to model_kwargs.
182
        model_kwargs = self.model_kwargs.copy()
1✔
183
        model_kwargs["frequency"] = np.linspace(1e14, 1e15, len(trans.x))
1✔
184
        mock_run_sampler.return_value = self.dummy_result
1✔
185

186
        result = fit_model(
1✔
187
            transient=trans, model="dummy_model", outdir=self.outdir, label="TestPrompt",
188
            sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
189
            model_kwargs=model_kwargs, plot=False
190
        )
191
        self.assertEqual(result, self.dummy_result)
1✔
192
        mock_run_sampler.assert_called_once()
1✔
193

194
    @patch("redback.result.read_in_result", side_effect=Exception("No result"))
1✔
195
    @patch("bilby.run_sampler", autospec=True)
1✔
196
    def test_fit_optical_transient(self, mock_run_sampler, mock_read_result):
1✔
197
        trans = DummyOpticalTransient(self.outdir)
1✔
198
        # For optical transients, supply a frequency key in model_kwargs.
199
        model_kwargs = self.model_kwargs.copy()
1✔
200
        model_kwargs["frequency"] = np.linspace(1e14, 1e15, len(trans.x))
1✔
201
        mock_run_sampler.return_value = self.dummy_result
1✔
202

203
        result = fit_model(
1✔
204
            transient=trans, model="dummy_model", outdir=self.outdir, label="TestOptical",
205
            sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
206
            model_kwargs=model_kwargs, plot=False
207
        )
208
        self.assertEqual(result, self.dummy_result)
1✔
209
        mock_run_sampler.assert_called_once()
1✔
210

211
    @patch("redback.result.read_in_result", side_effect=Exception("No result"))
1✔
212
    @patch("bilby.run_sampler", autospec=True)
1✔
213
    def test_fit_transient_base(self, mock_run_sampler, mock_read_result):
1✔
214
        trans = DummyTransient(self.outdir)
1✔
215
        # For base transient objects, supply a frequency key as well.
216
        model_kwargs = self.model_kwargs.copy()
1✔
217
        model_kwargs["frequency"] = np.linspace(1e14, 1e15, len(trans.x))
1✔
218
        mock_run_sampler.return_value = self.dummy_result
1✔
219

220
        result = fit_model(
1✔
221
            transient=trans, model="dummy_model", outdir=self.outdir, label="TestTransient",
222
            sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
223
            model_kwargs=model_kwargs, plot=False
224
        )
225
        self.assertEqual(result, self.dummy_result)
1✔
226
        mock_run_sampler.assert_called_once()
1✔
227

228
    def test_inconsistent_data_mode(self):
1✔
229
        # Test that if the transient's output_format does not match its data_mode, a ValueError is raised.
230
        trans = DummyTransient(self.outdir)
1✔
231
        trans.data_mode = "flux_density"
1✔
232
        inconsistent_kwargs = {"output_format": "magnitude"}
1✔
233
        with self.assertRaises(ValueError) as context:
1✔
234
            fit_model(
1✔
235
                transient=trans, model="dummy_model", outdir=self.outdir, label="TestInconsistency",
236
                sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
237
                model_kwargs=inconsistent_kwargs, plot=False
238
            )
239
        self.assertIn("inconsistent", str(context.exception))
1✔
240

241
    def test_unknown_transient_type(self):
1✔
242
        # Test that passing an object that is not a recognized transient type causes a ValueError.
243
        trans = DummyNotTransient()
1✔
244
        with self.assertRaises(ValueError) as context:
1✔
245
            fit_model(
1✔
246
                transient=trans, model="dummy_model", outdir=self.outdir, label="TestUnknown",
247
                sampler=self.sampler, nlive=self.nlive, prior=self.prior, walks=self.walks,
248
                model_kwargs=self.model_kwargs, plot=False
249
            )
250
        self.assertIn("not known", str(context.exception))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc