• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

California-Planet-Search / radvel / 17447393926

03 Sep 2025 10:11PM UTC coverage: 87.165%. First build
17447393926

Pull #395

github

bjfultn
Fix Python 3.11 compatibility: replace old Python 2 raise syntax and bare except clauses
Pull Request #395: testing new ci

170 of 241 new or added lines in 8 files covered. (70.54%)

3735 of 4285 relevant lines covered (87.16%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.12
/radvel/tests/test_api.py
1
import sys
1✔
2
import copy
1✔
3
import warnings
1✔
4
import time
1✔
5
import types
1✔
6
# pytest is now the test runner, but we don't need to import it in test files
7

8
import radvel
1✔
9
import radvel.driver
1✔
10
from radvel.nested_sampling import BACKENDS
1✔
11
import numpy as np
1✔
12
import scipy
1✔
13
import radvel.prior
1✔
14

15
warnings.simplefilter('ignore')
1✔
16

17
class _args(types.SimpleNamespace):
1✔
18
    outputdir = '/tmp/'
1✔
19
    decorr = False
1✔
20
    name_in_title = False
1✔
21
    gp = False
1✔
22
    simple = False
1✔
23

24
    nwalkers = 50
1✔
25
    nsteps = 100
1✔
26
    ensembles = 8
1✔
27
    maxGR = 1.10
1✔
28
    burnGR = 1.30
1✔
29
    burnAfactor = 25
1✔
30
    minAfactor = 50
1✔
31
    maxArchange = .07
1✔
32
    minTz = 1000
1✔
33
    minsteps = 100
1✔
34
    minpercent = 5
1✔
35
    thin = 1
1✔
36
    serial = False
1✔
37
    save = True
1✔
38
    savename = 'rawchains.h5'
1✔
39
    proceed = False
1✔
40
    proceedname = None
1✔
41
    headless=True
1✔
42
    sampler = 'auto'
1✔
43
    run_kwargs = None
1✔
44
    sampler_kwargs = None
1✔
45
    overwrite = False
1✔
46
    sampler = 'ultranest'
1✔
47

48

49
def _standard_run(setupfn, arguments, do_ns=True, do_mcmc=True):
1✔
50
    """
51
    Run through all of the standard steps
52
    """
53

54
    args = arguments
1✔
55
    args.setupfn = setupfn
1✔
56

57
    radvel.driver.fit(args)
1✔
58

59
    if do_mcmc:
1✔
60
        radvel.driver.mcmc(args)
1✔
61
    if do_ns:
1✔
62
        radvel.driver.nested_sampling(args)
1✔
63
    if not (do_mcmc or do_ns):
1✔
64
        raise ValueError('One of do_mcmc or do_ns must be true to run this test.')
×
65

66
    # For ns step, sampler gives the library
67
    # For subsequent steps, sampler should be mcmc, ns or auto
68
    args.sampler = 'auto'
1✔
69
    radvel.driver.derive(args)
1✔
70

71
    args.type = ['trend', 'jit', 'e', 'nplanets', 'gp']
1✔
72
    args.verbose = True
1✔
73
    radvel.driver.ic_compare(args)
1✔
74

75
    args.type = ['params', 'priors', 'rv', 'ic_compare', 'derived', 'crit']
1✔
76
    radvel.driver.tables(args)
1✔
77

78
    args.type = ['rv', 'corner', 'derived']
1✔
79
    if do_mcmc:
1✔
80
        args.type += ['auto', 'trend']
1✔
81
    args.plotkw = {'highlight_last': True, 'show_rms': True}
1✔
82
    radvel.driver.plots(args)
1✔
83

84
    args.comptype = 'ic_compare'
1✔
85
    args.latex_compiler = 'pdflatex'
1✔
86
    radvel.driver.report(args)
1✔
87

88

89
def test_k2(setupfn='example_planets/epic203771098.py'):
1✔
90
    """
91
    Run through K2-24 example
92
    """
93
    args = _args()
1✔
94
    # Use temporary directory for isolation
95
    import tempfile
1✔
96
    import os
1✔
97
    temp_dir = tempfile.mkdtemp()
1✔
98
    args.outputdir = temp_dir
1✔
99
    args.setupfn = setupfn
1✔
100
    try:
1✔
101
        _standard_run(setupfn, args)
1✔
102
    finally:
103
        import shutil
1✔
104
        shutil.rmtree(temp_dir)
1✔
105

106
def test_mcmc_proceed(setupfn='example_planets/epic203771098.py'):
1✔
107
    """
108
    Run through K2-24 example and try to resume
109
    """
110
    args = _args()
1✔
111
    # Use temporary directory for isolation
112
    import tempfile
1✔
113
    temp_dir = tempfile.mkdtemp()
1✔
114
    args.outputdir = temp_dir
1✔
115
    args.setupfn = setupfn
1✔
116
    try:
1✔
117
        # We always re-sample: ensure that standard run with MCMC only works
118
        _standard_run(setupfn, args, do_ns=False)
1✔
119

120
        # set the proceed flag and continue
121
        args.proceed = True
1✔
122
        radvel.driver.mcmc(args)
1✔
123

124
        args.ensembles = 1
1✔
125
        # Use nose-style assertion instead of pytest
126
        try:
1✔
127
            radvel.driver.mcmc(args)
1✔
NEW
128
            assert False, "Expected ValueError"
×
129
        except ValueError:
1✔
130
            pass
1✔
131

132
        args.serial = True
1✔
133
        args.proceed = False
1✔
134
        radvel.driver.mcmc(args)
1✔
135
    finally:
136
        import shutil
1✔
137
        shutil.rmtree(temp_dir)
1✔
138

139

140
def test_ns_proceed(setupfn='example_planets/epic203771098.py'):
1✔
141
    """
142
    Run through K2-24 example and try to resume
143
    """
144
    args = _args()
1✔
145
    args.sampler = 'ultranest'  # Use default sampler
1✔
146
    # Use temporary directory for isolation
147
    import tempfile
1✔
148
    temp_dir = tempfile.mkdtemp()
1✔
149
    args.outputdir = temp_dir
1✔
150
    args.setupfn = setupfn
1✔
151
    try:
1✔
152
        # We always re-sample: ensure that standard run with NS only works
153
        _standard_run(setupfn, args, do_mcmc=False)
1✔
154

155
        args.sampler = 'ultranest'  # Need to set sampler again
1✔
156

157
        # Test that overwrites=False works
158
        try:
1✔
159
            radvel.driver.nested_sampling(args)
1✔
NEW
160
            assert False, "Expected FileExistsError"
×
161
        except FileExistsError:
1✔
162
            pass
1✔
163

164
        # Test that resume is not accepted for sampler/run kwargs
165
        args.overwrite = True
1✔
166
        args.sampler_kwargs = "resume=True"
1✔
167
        try:
1✔
168
            radvel.driver.nested_sampling(args)
1✔
NEW
169
            assert False, "Expected ValueError"
×
170
        except ValueError:
1✔
171
            pass
1✔
172
        args.overwrite = False
1✔
173
        args.sampler_kwargs = None
1✔
174

175
        # Test that resume is not too long (that it actually resumes)
176
        args.proceed = True
1✔
177
        start = time.time()
1✔
178
        radvel.driver.nested_sampling(args)
1✔
179
        end = time.time()
1✔
180
        time_minutes = (start - end) / 60
1✔
181
        assert time_minutes < 1.0
1✔
182
    finally:
183
        import shutil
1✔
184
        shutil.rmtree(temp_dir)
1✔
185

186
def test_hd(setupfn='example_planets/HD164922.py'):
1✔
187
    """
188
    Check multi-instrument fit
189
    """
190
    args = _args()
1✔
191
    args.setupfn = setupfn
1✔
192

193
    radvel.driver.fit(args)
1✔
194
    radvel.driver.mcmc(args)
1✔
195

196
    args.type = ['rv']
1✔
197
    args.plotkw = {}
1✔
198
    radvel.driver.plots(args)
1✔
199

200

201
def test_k2131(setupfn='example_planets/k2-131.py'):
1✔
202
    """
203
    Check GP fit
204
    """
205
    args = _args()
1✔
206
    args.setupfn = setupfn
1✔
207

208
    # Add defensive checks for input data
209
    import tempfile
1✔
210
    import os
1✔
211
    temp_dir = tempfile.mkdtemp()
1✔
212
    args.outputdir = temp_dir
1✔
213
    
214
    try:
1✔
215
        # Load and validate input data before fitting
216
        import pandas as pd
1✔
217
        data = pd.read_csv(os.path.join(radvel.DATADIR,'k2-131.txt'), sep=' ')
1✔
218
        t = np.array(data['time'])
1✔
219
        vel = np.array(data['mnvel'])
1✔
220
        errvel = np.array(data['errvel'])
1✔
221
        
222
        # Debug: Print data summary for CI debugging
223
        print(f"DEBUG: Data loaded from {os.path.join(radvel.DATADIR,'k2-131.txt')}")
1✔
224
        print(f"DEBUG: Data shape: {data.shape}")
1✔
225
        print(f"DEBUG: Time range: {t.min():.6f} to {t.max():.6f}")
1✔
226
        print(f"DEBUG: Velocity range: {vel.min():.2f} to {vel.max():.2f}")
1✔
227
        print(f"DEBUG: Error range: {errvel.min():.2f} to {errvel.max():.2f}")
1✔
228
        print(f"DEBUG: Any NaNs in time: {np.any(np.isnan(t))}")
1✔
229
        print(f"DEBUG: Any infs in time: {np.any(np.isinf(t))}")
1✔
230
        print(f"DEBUG: Any NaNs in velocity: {np.any(np.isnan(vel))}")
1✔
231
        print(f"DEBUG: Any infs in velocity: {np.any(np.isinf(vel))}")
1✔
232
        print(f"DEBUG: Any NaNs in error: {np.any(np.isnan(errvel))}")
1✔
233
        print(f"DEBUG: Any infs in error: {np.any(np.isinf(errvel))}")
1✔
234
        print(f"DEBUG: Any non-positive errors: {np.any(errvel <= 0)}")
1✔
235
        
236
        # Check for problematic values in input data
237
        if np.any(np.isnan(t)):
1✔
NEW
238
            print(f"DEBUG: Found NaNs in time at indices: {np.where(np.isnan(t))[0]}")
×
NEW
239
            print(f"DEBUG: Time values with NaNs: {t[np.isnan(t)]}")
×
NEW
240
            assert False, "Input time array contains NaNs"
×
241
        if np.any(np.isinf(t)):
1✔
NEW
242
            print(f"DEBUG: Found infs in time at indices: {np.where(np.isinf(t))[0]}")
×
NEW
243
            print(f"DEBUG: Time values with infs: {t[np.isinf(t)]}")
×
NEW
244
            assert False, "Input time array contains infs"
×
245
        if np.any(np.isnan(vel)):
1✔
NEW
246
            print(f"DEBUG: Found NaNs in velocity at indices: {np.where(np.isnan(vel))[0]}")
×
NEW
247
            print(f"DEBUG: Velocity values with NaNs: {vel[np.isnan(vel)]}")
×
NEW
248
            assert False, "Input velocity array contains NaNs"
×
249
        if np.any(np.isinf(vel)):
1✔
NEW
250
            print(f"DEBUG: Found infs in velocity at indices: {np.where(np.isinf(vel))[0]}")
×
NEW
251
            print(f"DEBUG: Velocity values with infs: {vel[np.isinf(vel)]}")
×
NEW
252
            assert False, "Input velocity array contains infs"
×
253
        if np.any(np.isnan(errvel)):
1✔
NEW
254
            print(f"DEBUG: Found NaNs in error at indices: {np.where(np.isnan(errvel))[0]}")
×
NEW
255
            print(f"DEBUG: Error values with NaNs: {errvel[np.isnan(errvel)]}")
×
NEW
256
            assert False, "Input error array contains NaNs"
×
257
        if np.any(np.isinf(errvel)):
1✔
NEW
258
            print(f"DEBUG: Found infs in error at indices: {np.where(np.isinf(errvel))[0]}")
×
NEW
259
            print(f"DEBUG: Error values with infs: {errvel[np.isinf(errvel)]}")
×
NEW
260
            assert False, "Input error array contains infs"
×
261
        if np.any(errvel <= 0):
1✔
NEW
262
            print(f"DEBUG: Found non-positive errors at indices: {np.where(errvel <= 0)[0]}")
×
NEW
263
            print(f"DEBUG: Non-positive error values: {errvel[errvel <= 0]}")
×
NEW
264
            assert False, "Input error array contains non-positive values"
×
265

266
        radvel.driver.fit(args)
1✔
267
        # Check if any arrays in args contain infs/NaNs after fit
268
        for attr_name in dir(args):
1✔
269
            attr = getattr(args, attr_name)
1✔
270
            if isinstance(attr, np.ndarray):
1✔
NEW
271
                if np.any(np.isnan(attr)):
×
NEW
272
                    raise ValueError(f"Array {attr_name} contains NaNs after fit")
×
NEW
273
                if np.any(np.isinf(attr)):
×
NEW
274
                    raise ValueError(f"Array {attr_name} contains infs after fit")
×
275

276
        args.type = ['gp']
1✔
277
        args.verbose = True
1✔
278
        radvel.driver.ic_compare(args)
1✔
279
        # Check if any arrays in args contain infs/NaNs after ic_compare
280
        for attr_name in dir(args):
1✔
281
            attr = getattr(args, attr_name)
1✔
282
            if isinstance(attr, np.ndarray):
1✔
NEW
283
                if np.any(np.isnan(attr)):
×
NEW
284
                    raise ValueError(f"Array {attr_name} contains NaNs after ic_compare")
×
NEW
285
                if np.any(np.isinf(attr)):
×
NEW
286
                    raise ValueError(f"Array {attr_name} contains infs after ic_compare")
×
287

288
        args.type = ['rv']
1✔
289
        args.gp = True
1✔
290
        args.plotkw = {}
1✔
291
        radvel.driver.plots(args)
1✔
292
        # Check if any arrays in args contain infs/NaNs after plots
293
        for attr_name in dir(args):
1✔
294
            attr = getattr(args, attr_name)
1✔
295
            if isinstance(attr, np.ndarray):
1✔
NEW
296
                if np.any(np.isnan(attr)):
×
NEW
297
                    raise ValueError(f"Array {attr_name} contains NaNs after plots")
×
NEW
298
                if np.any(np.isinf(attr)):
×
NEW
299
                    raise ValueError(f"Array {attr_name} contains infs after plots")
×
300
                    
301
    finally:
302
        import shutil
1✔
303
        shutil.rmtree(temp_dir)
1✔
304

305

306
def test_celerite(setupfn='example_planets/k2-131_celerite.py'):
1✔
307
    """
308
    Check celerite GP fit
309
    """
310
    # Skip if celerite is not available
311
    if not radvel.gp._has_celerite:
1✔
NEW
312
        import pytest
×
NEW
313
        pytest.skip("celerite not available")
×
314
    
315
    args = _args()
1✔
316
    args.setupfn = setupfn
1✔
317

318
    radvel.driver.fit(args)
1✔
319

320
    args.type = ['rv']
1✔
321
    args.gp = True
1✔
322
    args.plotkw = {'plot_likelihoods_separately':True}
1✔
323
    radvel.driver.plots(args)
1✔
324

325

326
def test_basis():
1✔
327
    """
328
    Test basis conversions
329
    """
330

331
    basis_list = radvel.basis.BASIS_NAMES
1✔
332
    default_basis = 'per tc e w k'
1✔
333

334
    anybasis_params = radvel.Parameters(1, basis=default_basis)
1✔
335

336
    anybasis_params['per1'] = radvel.Parameter(value=20.885258)
1✔
337
    anybasis_params['tc1'] = radvel.Parameter(value=2072.79438)
1✔
338
    anybasis_params['e1'] = radvel.Parameter(value=0.01)
1✔
339
    anybasis_params['w1'] = radvel.Parameter(value=1.6)
1✔
340
    anybasis_params['k1'] = radvel.Parameter(value=10.0)
1✔
341

342
    anybasis_params['dvdt'] = radvel.Parameter(value=0.0)
1✔
343
    anybasis_params['curv'] = radvel.Parameter(value=0.0)
1✔
344

345
    anybasis_params['gamma_j'] = radvel.Parameter(1.0)
1✔
346
    anybasis_params['jit_j'] = radvel.Parameter(value=2.6)
1✔
347

348
    for new_basis in basis_list:
1✔
349
        iparams = radvel.basis._copy_params(anybasis_params)
1✔
350
        ivector = radvel.Vector(iparams)
1✔
351
        if new_basis != default_basis:
1✔
352
            new_vector = iparams.basis.v_to_any_basis(ivector, new_basis)
1✔
353
            new_params = iparams.basis.to_any_basis(iparams, new_basis)
1✔
354
            tmpv = new_vector.copy()
1✔
355
            tmp = radvel.basis._copy_params(new_params)
1✔
356

357
            old_vector = tmp.basis.v_to_any_basis(tmpv, default_basis)
1✔
358
            old_params = tmp.basis.to_any_basis(tmp, default_basis)
1✔
359

360
            for par in iparams:
1✔
361
                before = iparams[par].value
1✔
362
                after = old_params[par].value
1✔
363
                assert (before - after) <= 1e-5,\
1✔
364
                    "Parameters do not match after basis conversion: \
365
{}, {} != {}".format(par, before, after)
366

367
            for i in range(ivector.vector.shape[0]):
1✔
368
                before = ivector.vector[i][0]
1✔
369
                after = old_vector[i][0]
1✔
370
                assert (before - after) <= 1e-5, \
1✔
371
                    "Vectors do not match after basis conversion: \
372
{} row, {} != {}".format(i, before, after)
373

374

375

376
def test_kernels():
1✔
377
    """
378
    Test basic functionality of all standard GP kernels
379
    """
380
    # Skip if celerite is not available (needed for Celerite kernel)
381
    if not radvel.gp._has_celerite:
1✔
NEW
382
        import pytest
×
NEW
383
        pytest.skip("celerite not available")
×
384

385
    kernel_list = radvel.gp.KERNELS
1✔
386

387
    for kernel in kernel_list:
1✔
388
        hnames = kernel_list[kernel] # gets list of hyperparameter name strings
1✔
389
        hyperparams = {k: radvel.Parameter(value=1.) for k in hnames}
1✔
390
        kernel_call = getattr(radvel.gp, kernel + "Kernel")
1✔
391
        test_kernel = kernel_call(hyperparams)
1✔
392

393
        x = np.array([1.,2.,3.])
1✔
394
        test_kernel.compute_distances(x,x)
1✔
395
        test_kernel.compute_covmatrix(x.T)
1✔
396

397
        print("Testing {}".format(kernel_call(hyperparams)))
1✔
398

399
        sys.stdout.write("Testing error catching with dummy hyperparameters... \n")
1✔
400

401
        fakeparams1 = {}
1✔
402
        fakeparams1['dummy'] = radvel.Parameter(value=1.0)
1✔
403
        try:
1✔
404
            kernel_call(fakeparams1)
1✔
NEW
405
            raise RuntimeError('Test #1 failed for {}'.format(kernel))
×
406
        except AssertionError:
1✔
407
            sys.stdout.write("passed #1\n")
1✔
408

409
        fakeparams2 = copy.deepcopy(hyperparams)
1✔
410
        fakeparams2[hnames[0]] = 1.
1✔
411
        try:
1✔
412
            kernel_call(fakeparams2)
1✔
NEW
413
            raise RuntimeError('Test #2 failed for {}'.format(kernel))
×
414
        except AttributeError:
1✔
415
            sys.stdout.write("passed #2\n")
1✔
416

417

418
def params_and_vector_for_priors():
1✔
419
    params = radvel.Parameters(1, 'per tc secosw sesinw logk')
1✔
420
    params['per1'] = radvel.Parameter(10.0)
1✔
421
    params['tc1'] = radvel.Parameter(0.0)
1✔
422
    params['secosw1'] = radvel.Parameter(0.0)
1✔
423
    params['sesinw1'] = radvel.Parameter(0.0)
1✔
424
    params['logk1'] = radvel.Parameter(1.5)
1✔
425

426
    vector = radvel.Vector(params)
1✔
427

428
    return params, vector
1✔
429

430
def test_priors():
1✔
431
    """
432
    Test basic functionality of all Priors
433
    """
434
    # Get params and vector from the setup function
435
    params, vector = params_and_vector_for_priors()
1✔
436

437
    testTex = r'Delta Function Prior on $\sqrt{e}\cos{\omega}_{b}$'
1✔
438

439
    def userdef_prior_func(inp_list):
1✔
440
        if inp_list[0] >= 0. and inp_list[0] < 1.:
1✔
441
            return 0.
1✔
442
        else:
443
            return -np.inf
×
444

445
    prior_tests = {
1✔
446
        radvel.prior.EccentricityPrior(1):                  1/.99,
447
        radvel.prior.EccentricityPrior([1]):                1/.99,
448
        radvel.prior.PositiveKPrior(1):                     1.0,
449
        radvel.prior.Gaussian('per1', 9.9, 0.1):            scipy.stats.norm(9.9,0.1).pdf(10.),
450
        radvel.prior.HardBounds('per1', 1.0, 9.0):          0.,
451
        radvel.prior.HardBounds('per1', 1.0, 11.0):         1./10.,
452
        radvel.prior.Jeffreys('per1', 0.1, 100.0):          (1./10.)/np.log(100./0.1),
453
        radvel.prior.ModifiedJeffreys('per1', 0.1, 100.0, 0.):  (1./10.)/np.log(100./0.1),
454
        radvel.prior.ModifiedJeffreys('per1', 2., 100.0, 1.):  (1./9.)/np.log(99.),
455
        radvel.prior.SecondaryEclipsePrior(1, 5.0, 10.0):    1./np.sqrt(2.*np.pi),
456
        radvel.prior.NumericalPrior(
457
            ['sesinw1'],
458
            np.random.randn(1,5000000)
459
        ):                                                  scipy.stats.norm(0, 1).pdf(0.),
460
        radvel.prior.UserDefinedPrior(
461
            ['secosw1'], userdef_prior_func, testTex
462
        ):                                                  1.0,
463
        radvel.prior.InformativeBaselinePrior(
464
            'per1', 5.0, duration=1.0
465
        ):                                                  6./10.
466

467
    }
468

469
    for prior, val in prior_tests.items():
1✔
470
        print(prior.__repr__())
1✔
471
        print(prior.__str__())
1✔
472
        tolerance = .01
1✔
473
        print(abs(np.exp(prior(params, vector))))
1✔
474
        print(val)
1✔
475
        assert abs(np.exp(prior(params, vector)) - val) < tolerance, \
1✔
476
            "Prior output does not match expectation"
477

478

479
prior_scipy_list = [
1✔
480
    (radvel.prior.Gaussian("per1", 9.9, 0.1), scipy.stats.norm(9.9, 0.1)),
481
    (radvel.prior.HardBounds("per1", 1.0, 9.0), scipy.stats.uniform(1.0, 9.0 - 1.0)),
482
    (radvel.prior.Jeffreys("per1", 0.1, 100.0), scipy.stats.loguniform(0.1, 100.0)),
483
    (radvel.prior.NumericalPrior(["sesinw1"], np.random.randn(1, 5000000)), scipy.stats.norm(0, 1),),
484
    (
485
        radvel.prior.UserDefinedPrior(
486
            ["per1"],
487
            lambda x: scipy.stats.lognorm.pdf(x, 1e-1, 1e1),
488
            "lognorm",
489
            transform_func=lambda x: scipy.stats.lognorm.ppf(x, 1e-1, 1e1),
490
        ),
491
        scipy.stats.lognorm(1e-1, 1e1),
492
    ),
493
]
494
# Repeat for numerical prior to make sure staistically robust
495
prior_scipy_list += 10 * [
1✔
496
    (radvel.prior.ModifiedJeffreys("per1", 0.0, 100.0, -0.1), scipy.stats.loguniform(0.0 + 0.1, 100.0 + 0.1, loc=-0.1),)
497
]
498
def test_prior_transforms():
1✔
499
    """
500
    Test prior transforms for a subset of priors
501
    """
502
    # Test a few key priors instead of all parameterized ones
503
    prior, scipy_dist = prior_scipy_list[0]  # Test first prior
1✔
504

505
    rng = np.random.default_rng(3245)
1✔
506
    u = rng.uniform(size=100)
1✔
507

508

509
    expected_val = scipy_dist.ppf(u)
1✔
510
    np.testing.assert_allclose(
1✔
511
        prior.transform(u),
512
        expected_val,
513
        # Higher tolerance for numerical prior: interoplation of a histogram, otherwise use default rtol=1e-7
514
        atol=1.5e-2 if isinstance(prior, radvel.prior.NumericalPrior) else 0.0,
515
        err_msg=f"Prior transform failed for {prior}")
516

517
def test_userdefined_no_transform():
1✔
518
    rng = np.random.default_rng(3245)
1✔
519
    u = rng.uniform(size=100)
1✔
520

521
    try:
1✔
522
        radvel.prior.UserDefinedPrior(
1✔
523
            ["per1"],
524
            lambda x: scipy.stats.lognorm.pdf(x, 1e-1, 1e1),
525
            "lognorm",
526
        ).transform(u)
NEW
527
        assert False, "Expected TypeError"
×
528
    except TypeError:
1✔
529
        pass
1✔
530

531
def test_priors_no_transform():
1✔
532
    rng = np.random.default_rng(3245)
1✔
533
    u = rng.uniform(size=100)
1✔
534

535
    # Create a prior that doesn't have a transform method
536
    prior = radvel.prior.UserDefinedPrior(
1✔
537
        ["per1"],
538
        lambda x: 1.0,  # Simple function
539
        "test"
540
    )
541
    
542
    try:
1✔
543
        prior.transform(u)
1✔
NEW
544
        assert False, "Expected TypeError"
×
545
    except TypeError:
1✔
546
        pass
1✔
547

548

549
def likelihood_for_pt():
1✔
550
    params, _ = params_and_vector_for_priors()
1✔
551
    t = np.linspace(0, 10, num=100)
1✔
552
    vel = np.ones_like(t)
1✔
553
    errvel = np.ones_like(t) * 0.1
1✔
554
    mod = radvel.RVModel(params)
1✔
555
    mod.params['dvdt'] = radvel.Parameter(value=-0.02)
1✔
556
    mod.params['curv'] = radvel.Parameter(value=0.01)
1✔
557
    like = radvel.likelihood.RVLikelihood(mod, t, vel, errvel)
1✔
558
    like.params['gamma'] = radvel.Parameter(value=0.1, vary=False)
1✔
559
    like.params['jit'] = radvel.Parameter(value=1.0)
1✔
560
    like.params['secosw1'].vary = False
1✔
561
    like.params['sesinw1'].vary = False
1✔
562
    like.params['per1'].vary = False
1✔
563
    like.params['tc1'].vary = False
1✔
564
    return like
1✔
565

566

567
def test_prior_transform_all_params():
1✔
568

569
    # This should work
570
    post = radvel.posterior.Posterior(likelihood_for_pt())
1✔
571
    post.priors += [radvel.prior.Gaussian( 'dvdt', 0, 1.0)]
1✔
572
    post.priors += [radvel.prior.HardBounds( 'curv', 0.0, 1.0)]
1✔
573
    post.priors += [radvel.prior.ModifiedJeffreys( 'jit', 0, 10.0, -0.1)]
1✔
574
    post.priors += [radvel.prior.Gaussian( 'logk1', np.log(5), 5)]
1✔
575

576
    post.check_proper_priors()
1✔
577

578
    post = radvel.posterior.Posterior(likelihood_for_pt())
1✔
579
    post.priors += [radvel.prior.Gaussian( 'dvdt', 0, 1.0)]
1✔
580
    post.priors += [radvel.prior.HardBounds( 'curv', 0.0, 1.0)]
1✔
581
    post.priors += [radvel.prior.ModifiedJeffreys( 'jit', 0, 10.0, -0.1)]
1✔
582
    try:
1✔
583
        post.check_proper_priors()
1✔
NEW
584
        assert False, "Expected ValueError"
×
585
    except ValueError:
1✔
586
        pass
1✔
587

588
    post = radvel.posterior.Posterior(likelihood_for_pt())
1✔
589
    post.priors += [radvel.prior.Gaussian( 'dvdt', 0, 1.0)]
1✔
590
    post.priors += [radvel.prior.HardBounds( 'curv', 0.0, 1.0)]
1✔
591
    post.priors += [radvel.prior.ModifiedJeffreys( 'jit', 0, 10.0, -0.1)]
1✔
592
    post.priors += [radvel.prior.Gaussian( 'logk1', np.log(5), 5)]
1✔
593
    post.priors += [radvel.prior.Gaussian( 'logk1', 8, 5)]
1✔
594
    try:
1✔
595
        post.check_proper_priors()
1✔
NEW
596
        assert False, "Expected ValueError"
×
597
    except ValueError:
1✔
598
        pass
1✔
599

600

601

602
def test_prior_transform_order():
1✔
603

604
    post = radvel.posterior.Posterior(likelihood_for_pt())
1✔
605
    post.priors += [radvel.prior.Gaussian( 'dvdt', 0, 1.0)]
1✔
606
    post.priors += [radvel.prior.HardBounds( 'curv', 0.0, 1.0)]
1✔
607
    post.priors += [radvel.prior.ModifiedJeffreys( 'jit', 0, 10.0, -0.1)]
1✔
608
    post.priors += [radvel.prior.Gaussian( 'logk1', np.log(5), 5)]
1✔
609

610
    rng = np.random.default_rng(3245)
1✔
611
    u = rng.uniform(size=(len(post.vary_params), 100))
1✔
612
    p = post.prior_transform(u)
1✔
613

614
    prior_param_names = [prior.param for prior in post.priors]
1✔
615

616
    assert prior_param_names != post.name_vary_params(), "Parameters and priors should have different order for this test"
1✔
617

618
    for prior in post.priors:
1✔
619
        param_name = prior.param
1✔
620
        param_ind = post.name_vary_params().index(param_name)
1✔
621
        np.testing.assert_allclose(prior.transform(u[param_ind]), p[param_ind], err_msg="Prior transform failed for {}".format(prior))
1✔
622

623
def test_kepler():
1✔
624
    """
625
    Profile and test C-based Kepler solver
626
    """
627
    radvel.kepler.profile()
1✔
628

629

630
def test_model_comp(setupfn='example_planets/HD164922.py'):
1✔
631
    """
632
    Test some additional model_comp lines
633
    """
634

635
    args = _args()
1✔
636
    args.setupfn = setupfn
1✔
637
    radvel.driver.fit(args)
1✔
638

639
    # also check some additional lines of model_comp
640
    args.verbose = True
1✔
641
    args.type = ['trend']
1✔
642
    radvel.driver.ic_compare(args)
1✔
643

644
    args.simple = True
1✔
645
    args.type = ['e']
1✔
646
    radvel.driver.ic_compare(args)
1✔
647

648
    args.simple = False
1✔
649
    args.type = ['something_else']
1✔
650
    try:
1✔
651
        radvel.driver.ic_compare(args)
1✔
NEW
652
        raise RuntimeError("Unexpected result from model_comp.")
×
653
    except AssertionError:  # expected result
1✔
654
        return
1✔
655

656
if __name__ == '__main__':
1✔
657
    #test_k2()
658
    #test_hd()
659
    #test_model_comp()
660
    test_k2131()
×
661
    #test_celerite()
662
    # test_basis()
663
    #test_kernels()
664
    #test_kepler()
665
    #test_priors()
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc