• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cylammarco / WDPhotTools / 20701577621

05 Jan 2026 12:32AM UTC coverage: 96.07% (-0.03%) from 96.096%
20701577621

Pull #49

github

cylammarco
avoid extrapolation in both ct and rbf interpolator for the atmosphere model
Pull Request #49: v0.0.13

196 of 208 new or added lines in 9 files covered. (94.23%)

1 existing line in 1 file now uncovered.

3300 of 3435 relevant lines covered (96.07%)

1.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.59
/src/WDPhotTools/atmosphere_model_reader.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3

4
"""Handling the formatting of different atmosphere models"""
5

6
import os
2✔
7

8
import numpy as np
2✔
9
from scipy.interpolate import CloughTocher2DInterpolator
2✔
10
from scipy.interpolate import RBFInterpolator
2✔
11

12

13
class AtmosphereModelReader(object):
2✔
14
    """Handling the formatting of different atmosphere models"""
15

16
    def __init__(self):
2✔
17
        super(AtmosphereModelReader, self).__init__()
2✔
18

19
        self.this_file = os.path.dirname(os.path.abspath(__file__))
2✔
20

21
        self.model_list = {
2✔
22
            "montreal_co_da_20": "Bedard et al. 2020 CO DA",
23
            "montreal_co_db_20": "Bedard et al. 2020 CO DB",
24
            "lpcode_he_da_07": "Panei et al. 2007 He DA",
25
            "lpcode_co_da_07": "Panei et al. 2007 CO DA",
26
            "lpcode_he_da_09": "Althaus et al. 2009 He DA",
27
            "lpcode_co_da_10_z001": "Renedo et al. 2010 CO DA Z=0.01",
28
            "lpcode_co_da_10_z0001": "Renedo et al. 2010 CO DA Z=0.001",
29
            "lpcode_co_da_15_z00003": "Althaus et al. 2015 DA Z=0.00003",
30
            "lpcode_co_da_15_z0001": "Althaus et al. 2015 DA Z=0.0001",
31
            "lpcode_co_da_15_z0005": "Althaus et al. 2015 DA Z=0.0005",
32
            "lpcode_co_db_17_z00005": "Althaus et al. 2017 DB Y=0.4",
33
            "lpcode_co_db_17_z0001": "Althaus et al. 2017 DB Y=0.4",
34
            "lpcode_co_db_17": "Camisassa et al. 2017 DB",
35
            "lpcode_one_da_07": "Althaus et al. 2007 ONe DA",
36
            "lpcode_one_da_19": "Camisassa et al. 2019 ONe DA",
37
            "lpcode_one_db_19": "Camisassa et al. 2019 ONe DB",
38
            "lpcode_da_22": "Althaus et al. 2013 He DA, Camisassa et al. 2016 CO DA,  Camisassa et al. 2019 ONe DA",
39
            "lpcode_db_22": "Camisassa et al. 2017 CO DB, " + "Camisassa et al. 2019 ONe DB",
40
        }
41

42
        # DA atmosphere
43
        filepath_da = os.path.join(
2✔
44
            os.path.dirname(os.path.abspath(__file__)),
45
            "wd_photometry/Table_DA_13012021.txt",
46
        )
47

48
        # DB atmosphere
49
        filepath_db = os.path.join(
2✔
50
            os.path.dirname(os.path.abspath(__file__)),
51
            "wd_photometry/Table_DB_13012021.txt",
52
        )
53

54
        # Prepare the array column dtype
55
        self.column_key = np.array(
2✔
56
            [
57
                "Teff",
58
                "logg",
59
                "mass",
60
                "Mbol",
61
                "BC",
62
                "U",
63
                "B",
64
                "V",
65
                "R",
66
                "I",
67
                "J",
68
                "H",
69
                "Ks",
70
                "Y_mko",
71
                "J_mko",
72
                "H_mko",
73
                "K_mko",
74
                "W1",
75
                "W2",
76
                "W3",
77
                "W4",
78
                "S36",
79
                "S45",
80
                "S58",
81
                "S80",
82
                "u_sdss",
83
                "g_sdss",
84
                "r_sdss",
85
                "i_sdss",
86
                "z_sdss",
87
                "g_ps1",
88
                "r_ps1",
89
                "i_ps1",
90
                "z_ps1",
91
                "y_ps1",
92
                "G2",
93
                "G2_BP",
94
                "G2_RP",
95
                "G3",
96
                "G3_BP",
97
                "G3_RP",
98
                "FUV",
99
                "NUV",
100
                "age",
101
            ]
102
        )
103
        self.column_key_formatted = np.array(
2✔
104
            [
105
                r"T$_{\mathrm{eff}}$",
106
                "log(g)",
107
                "Mass",
108
                r"M$_{\mathrm{bol}}$",
109
                "BC",
110
                r"$U$",
111
                r"$B$",
112
                r"$V$",
113
                r"$R$",
114
                r"$I$",
115
                r"$J$",
116
                r"$H$",
117
                r"$K_{\mathrm{s}}$",
118
                r"$Y_{\mathrm{MKO}}$",
119
                r"$J_{\mathrm{MKO}}$",
120
                r"$H_{\mathrm{MKO}}$",
121
                r"$K_{\mathrm{MKO}}$",
122
                r"$W_{1}$",
123
                r"$W_{2}$",
124
                r"$W_{3}$",
125
                r"$W_{4}$",
126
                r"$S_{36}$",
127
                r"$S_{45}$",
128
                r"$S_{58}$",
129
                r"$S_{80}$",
130
                r"u$_{\mathrm{SDSS}}$",
131
                r"$g_{\mathrm{SDSS}}$",
132
                r"$r_{\mathrm{SDSS}}$",
133
                r"$i_{\mathrm{SDSS}}$",
134
                r"$z_{\mathrm{SDSS}}$",
135
                r"$g_{\mathrm{PS1}}$",
136
                r"$r_{\mathrm{PS1}}$",
137
                r"$i_{\mathrm{PS1}}$",
138
                r"$z_{\mathrm{PS1}}$",
139
                r"$y_{\mathrm{PS1}}$",
140
                r"$G_{\mathrm{DR2}}$",
141
                r"$G_{\mathrm{BP, DR2}}$",
142
                r"$G_{\mathrm{RP, DR2}}$",
143
                r"$G_{\mathrm{DR3}}$",
144
                r"$G_{\mathrm{BP, DR3}}$",
145
                r"$G_{\mathrm{RP, DR3}}$",
146
                "FUV",
147
                "NUV",
148
                "Age",
149
            ]
150
        )
151
        self.column_key_unit = np.array(
2✔
152
            [
153
                "K",
154
                r"(cm/s$^2$)",
155
                r"M$_\odot$",
156
                "mag",
157
                "mag",
158
                "mag",
159
                "mag",
160
                "mag",
161
                "mag",
162
                "mag",
163
                "mag",
164
                "mag",
165
                "mag",
166
                "mag",
167
                "mag",
168
                "mag",
169
                "mag",
170
                "mag",
171
                "mag",
172
                "mag",
173
                "mag",
174
                "mag",
175
                "mag",
176
                "mag",
177
                "mag",
178
                "mag",
179
                "mag",
180
                "mag",
181
                "mag",
182
                "mag",
183
                "mag",
184
                "mag",
185
                "mag",
186
                "mag",
187
                "mag",
188
                "mag",
189
                "mag",
190
                "mag",
191
                "mag",
192
                "mag",
193
                "mag",
194
                "mag",
195
                "mag",
196
                "yr",
197
            ]
198
        )
199
        self.column_key_wavelength = np.array(
2✔
200
            [
201
                0.0,
202
                0.0,
203
                0.0,
204
                0.0,
205
                0.0,
206
                3585.0,
207
                4371.0,
208
                5478.0,
209
                6504.0,
210
                8020.0,
211
                12350.0,
212
                16460.0,
213
                21600.0,
214
                10310.0,
215
                12500.0,
216
                16360.0,
217
                22060.0,
218
                33682.0,
219
                46179.0,
220
                120717.0,
221
                221944.0,
222
                35378.0,
223
                44780.0,
224
                56962.0,
225
                77978.0,
226
                3557.0,
227
                4702.0,
228
                6175.0,
229
                7491.0,
230
                8946.0,
231
                4849.0,
232
                6201.0,
233
                7535.0,
234
                8674.0,
235
                9628.0,
236
                6229.0,
237
                5037.0,
238
                7752.0,
239
                6218.0,
240
                5110.0,
241
                7769.0,
242
                1535.0,
243
                2301.0,
244
                0.0,
245
            ]
246
        )
247

248
        self.column_names = {}
2✔
249
        self.column_units = {}
2✔
250
        self.column_wavelengths = {}
2✔
251
        for i, j, k, _l in zip(
2✔
252
            self.column_key,
253
            self.column_key_formatted,
254
            self.column_key_unit,
255
            self.column_key_wavelength,
256
        ):
257
            self.column_names[i] = j
2✔
258
            self.column_units[i] = k
2✔
259
            self.column_wavelengths[i] = _l
2✔
260

261
        self.column_type = np.array(([np.float64] * len(self.column_key)))
2✔
262
        self.dtype = list(zip(self.column_key, self.column_type))
2✔
263

264
        # Load the synthetic photometry file in a recarray
265
        self.model_da = np.loadtxt(filepath_da, skiprows=2, dtype=self.dtype)
2✔
266
        self.model_db = np.loadtxt(filepath_db, skiprows=2, dtype=self.dtype)
2✔
267

268
        self.model_da["age"][self.model_da["age"] <= 1.0] += 1.0
2✔
269
        self.model_db["age"][self.model_db["age"] <= 1.0] += 1.0
2✔
270

271
    def list_atmosphere_parameters(self):
2✔
272
        """
273
        Print the formatted list of parameters available from the atmophere
274
        models.
275

276
        """
277

278
        for i, j in zip(self.column_names.items(), self.column_units.items()):
2✔
279
            print(f"Parameter: {i[1]}, Column Name: {i[0]}, Unit: {j[1]}")
2✔
280

281
    def interp_am(
2✔
282
        self,
283
        dependent="G3",
284
        atmosphere="H",
285
        independent=["logg", "Mbol"],
286
        logg=8.0,
287
        interpolator="CT",
288
        kwargs_for_RBF={},
289
        kwargs_for_CT={},
290
    ):
291
        """
292
        This function interpolates the grid of synthetic photometry and a few
293
        other physical properties as a function of 2 independent variables,
294
        the Default choices are 'logg' and 'Mbol'.
295

296
        Parameters
297
        ----------
298
        dependent: str (Default: 'G3')
299
            The value to be interpolated over. Choose from:
300
            'Teff', 'logg', 'mass', 'Mbol', 'BC', 'U', 'B', 'V', 'R', 'I', 'J', 'H', 'Ks', 'Y_mko', 'J_mko', 'H_mko',
301
            'K_mko', 'W1', 'W2', 'W3', 'W4', 'S36', 'S45', 'S58', 'S80', 'u_sdss', 'g_sdss', 'r_sdss', 'i_sdss',
302
            'z_sdss', 'g_ps1', 'r_ps1', 'i_ps1', 'z_ps1', 'y_ps1', 'G2', 'G2_BP', 'G2_RP', 'G3', 'G3_BP', 'G3_RP',
303
            'FUV', 'NUV', 'age'.
304
        atmosphere: str (Default: 'H')
305
            The atmosphere type, 'H' or 'He'.
306
        independent: list (Default: ['logg', 'Mbol'])
307
            The parameters to be interpolated over for dependent.
308
        logg: float (Default: 8.0)
309
            Only used if independent is of length 1.
310
        interpolator: str (Default: 'RBF')
311
            Choose between 'RBF' and 'CT'.
312
        kwargs_for_RBF: dict (Default: {"neighbors": None, "smoothing": 0.0, "kernel": "thin_plate_spline",
313
            "epsilon": None, "degree": None,})
314
            Keyword argument for the interpolator. See `scipy.interpolate.RBFInterpolator`.
315
        kwargs_for_CT: dict (Default: {'fill_value': -np.inf, 'tol': 1e-10, 'maxiter': 100000})
316
            Keyword argument for the interpolator. See `scipy.interpolate.CloughTocher2DInterpolator`.
317

318
        Returns
319
        -------
320
            A callable function of CloughTocher2DInterpolator.
321

322
        """
323
        _kwargs_for_RBF = {
2✔
324
            "neighbors": None,
325
            "smoothing": 0.0,
326
            "kernel": "thin_plate_spline",
327
            "epsilon": None,
328
            "degree": None,
329
        }
330
        _kwargs_for_RBF.update(**kwargs_for_RBF)
2✔
331

332
        _kwargs_for_CT = {
2✔
333
            "fill_value": -np.inf,
334
            "tol": 1e-10,
335
            "maxiter": 100000,
336
            "rescale": True,
337
        }
338
        _kwargs_for_CT.update(**kwargs_for_CT)
2✔
339

340
        # DA atmosphere
341
        if atmosphere.lower() in ["h", "hydrogen", "da"]:
2✔
342
            model = self.model_da
2✔
343

344
        # DB atmosphere
345
        elif atmosphere.lower() in ["he", "helium", "db"]:
2✔
346
            model = self.model_db
2✔
347

348
        else:
349
            raise ValueError(
2✔
350
                'Please choose from "h", "hydrogen", "da", "he", "helium" or "db" as the atmophere type, you have '
351
                "provided {}.format(atmosphere.lower())"
352
            )
353

354
        independent = np.asarray(independent, dtype=object).reshape(-1)
2✔
355

356
        independent_list = self.column_key
2✔
357
        independent_list_lower_cases = np.char.lower(independent_list)
2✔
358

359
        # If only performing a 1D interpolation, the logg has to be assumed.
360
        if len(independent) == 1:
2✔
361
            if independent[0].lower() in independent_list_lower_cases:
2✔
362
                independent = np.array(("logg", independent[0]))
2✔
363

364
            else:
365
                raise ValueError(
×
366
                    "When ony interpolating in 1-dimension, the independent variable has to be one of: Teff, mass, "
367
                    "Mbol, or age."
368
                )
369

370
            _independent_arg_0 = np.where(independent[0].lower() == independent_list_lower_cases)[0][0]
2✔
371
            _independent_arg_1 = np.where(independent[1].lower() == independent_list_lower_cases)[0][0]
2✔
372

373
            independent = np.array([independent_list[_independent_arg_0], independent_list[_independent_arg_1]])
2✔
374

375
            arg_0 = model[independent[0]]
2✔
376
            arg_1 = model[independent[1]]
2✔
377

378
            arg_1_min = np.nanmin(arg_1)
2✔
379
            arg_1_max = np.nanmax(arg_1)
2✔
380

381
            if independent[1] in ["Teff", "age"]:
2✔
382
                arg_1 = np.log10(arg_1)
2✔
383

384
            if interpolator.lower() == "ct":
2✔
385
                # Interpolate with the scipy CloughTocher2DInterpolator
386
                _atmosphere_interpolator = CloughTocher2DInterpolator(
2✔
387
                    (arg_0, arg_1),
388
                    model[dependent],
389
                    **_kwargs_for_CT,
390
                )
391

392
                def atmosphere_interpolator(_x):
2✔
393
                    if independent[1] in ["Teff", "age"]:
2✔
394
                        _x = np.log10(_x)
2✔
395

396
                    return _atmosphere_interpolator(logg, _x)
2✔
397

398
            elif interpolator.lower() == "rbf":
2✔
399
                # Interpolate with the scipy RBFInterpolator
400
                _atmosphere_interpolator = RBFInterpolator(
2✔
401
                    np.stack((arg_0, arg_1), -1),
402
                    model[dependent],
403
                    **_kwargs_for_RBF,
404
                )
405

406
                def atmosphere_interpolator(_x):
2✔
407
                    _x_arr = np.asarray(_x).reshape(-1).astype(float)
2✔
408
                    length = _x_arr.size
2✔
409
                    _logg = np.full(length, logg, dtype=float)
2✔
410

411
                    _x_arr[_x_arr < arg_1_min] = arg_1_min
2✔
412
                    _x_arr[_x_arr > arg_1_max] = arg_1_max
2✔
413

414
                    if independent[1] in ["Teff", "age"]:
2✔
415
                        _x_arr = np.log10(_x_arr)
2✔
416

417
                    return _atmosphere_interpolator(np.column_stack((_logg, _x_arr)))
2✔
418

419
            else:
420
                raise ValueError("Interpolator should be CT or RBF, {interpolator} is given.")
2✔
421

422
        # If a 2D grid is to be interpolated, normally is the logg and another
423
        # parameter
424
        elif len(independent) == 2:
2✔
425
            _independent_arg_0 = np.where(independent[0].lower() == independent_list_lower_cases)[0][0]
2✔
426
            _independent_arg_1 = np.where(independent[1].lower() == independent_list_lower_cases)[0][0]
2✔
427

428
            independent = np.array([independent_list[_independent_arg_0], independent_list[_independent_arg_1]])
2✔
429

430
            arg_0 = model[independent[0]]
2✔
431
            arg_1 = model[independent[1]]
2✔
432

433
            arg_0_min = np.nanmin(arg_0)
2✔
434
            arg_0_max = np.nanmax(arg_0)
2✔
435
            arg_1_min = np.nanmin(arg_1)
2✔
436
            arg_1_max = np.nanmax(arg_1)
2✔
437

438
            if independent[0] in ["Teff", "age"]:
2✔
439
                arg_0 = np.log10(arg_0)
2✔
440

441
            if independent[1] in ["Teff", "age"]:
2✔
442
                arg_1 = np.log10(arg_1)
2✔
443

444
            if interpolator.lower() == "ct":
2✔
445
                # Interpolate with the scipy CloughTocher2DInterpolator
446
                _atmosphere_interpolator = CloughTocher2DInterpolator(
2✔
447
                    (arg_0, arg_1),
448
                    model[dependent],
449
                    **_kwargs_for_CT,
450
                )
451

452
                def atmosphere_interpolator(x0, x1=None):
2✔
453
                    # Support scalar/array inputs for both coordinates, with simple broadcasting
454
                    if x1 is None:
2✔
455
                        arr = np.asarray(x0).reshape(-1)
2✔
456
                        if arr.size >= 2:
2✔
457
                            x_0, x_1 = arr[0], arr[1]
2✔
NEW
458
                        elif arr.size == 1:
×
NEW
459
                            x_0, x_1 = arr[0], arr[0]
×
460
                        else:
NEW
461
                            x_0, x_1 = np.nan, np.nan
×
462
                    else:
463
                        x_0, x_1 = x0, x1
2✔
464

465
                    if isinstance(x_0, (float, int, np.integer)):
2✔
466
                        length0 = 1
2✔
467
                    else:
468
                        length0 = np.asarray(x_0).size
2✔
469

470
                    if isinstance(x_1, (float, int, np.integer)):
2✔
471
                        length1 = 1
2✔
472
                    else:
473
                        length1 = np.asarray(x_1).size
2✔
474

475
                    if length0 == length1:
2✔
476
                        pass
2✔
477
                    elif (length0 == 1) and (length1 > 1):
2✔
478
                        x_0 = [x_0] * length1
2✔
479
                        length0 = length1
2✔
NEW
480
                    elif (length0 > 1) and (length1 == 1):
×
NEW
481
                        x_1 = [x_1] * length0
×
NEW
482
                        length1 = length0
×
483
                    else:
NEW
484
                        raise ValueError(
×
485
                            "Either one variable is a float, int or of size 1, or two variables should have the same size."
486
                        )
487

488
                    _x_0 = np.asarray(x_0).reshape(-1).astype(float)
2✔
489
                    _x_1 = np.asarray(x_1).reshape(-1).astype(float)
2✔
490

491
                    # mark out-of-range inputs to avoid excessive extrapolation
492
                    mask_oob = (_x_0 < arg_0_min) | (_x_0 > arg_0_max) | (_x_1 < arg_1_min) | (_x_1 > arg_1_max)
2✔
493

494
                    if independent[0] in ["Teff", "age"]:
2✔
495
                        _x_0 = np.log10(_x_0)
2✔
496

497
                    if independent[1] in ["Teff", "age"]:
2✔
498
                        _x_1 = np.log10(_x_1)
2✔
499

500
                    out = _atmosphere_interpolator(_x_0, _x_1)
2✔
501
                    out = np.asarray(out).reshape(-1)
2✔
502
                    if np.any(mask_oob):
2✔
503
                        out[mask_oob] = np.nan
2✔
504
                    return out
2✔
505

506
            elif interpolator.lower() == "rbf":
2✔
507
                # Interpolate with the scipy RBFInterpolator
508
                _atmosphere_interpolator = RBFInterpolator(
2✔
509
                    np.stack((arg_0, arg_1), -1),
510
                    model[dependent],
511
                    **_kwargs_for_RBF,
512
                )
513

514
                def atmosphere_interpolator(*x):
2✔
515
                    # Accept (x0, x1) or single array-like; use first two values, duplicate if only one
516
                    if len(x) == 2:
2✔
517
                        x_0, x_1 = x
2✔
518
                    elif len(x) == 1:
2✔
519
                        arr = np.asarray(x[0]).reshape(-1)
2✔
520
                        if arr.size >= 2:
2✔
521
                            x_0, x_1 = arr[0], arr[1]
2✔
NEW
522
                        elif arr.size == 1:
×
NEW
523
                            x_0, x_1 = arr[0], arr[0]
×
524
                        else:
NEW
525
                            x_0, x_1 = np.nan, np.nan
×
526
                    else:
NEW
527
                        x_0, x_1 = np.nan, np.nan
×
528

529
                    if isinstance(x_0, (float, int, np.integer)):
2✔
530
                        length0 = 1
2✔
531
                    else:
532
                        length0 = np.asarray(x_0).size
2✔
533

534
                    if isinstance(x_1, (float, int, np.integer)):
2✔
535
                        length1 = 1
2✔
536
                    else:
537
                        length1 = np.asarray(x_1).size
2✔
538

539
                    if length0 == length1:
2✔
540
                        pass
2✔
541

542
                    elif (length0 == 1) & (length1 > 1):
2✔
543
                        x_0 = [x_0] * length1
2✔
544
                        length0 = length1
2✔
545

546
                    elif (length0 > 1) & (length1 == 1):
×
547
                        x_1 = [x_1] * length0
×
548
                        length1 = length0
×
549

550
                    else:
551
                        raise ValueError(
×
552
                            "Either one variable is a float, int or of size 1, or two variables should have the same "
553
                            "size."
554
                        )
555

556
                    _x_0 = np.asarray(x_0).reshape(-1).astype(float)
2✔
557
                    _x_1 = np.asarray(x_1).reshape(-1).astype(float)
2✔
558

559
                    # mark out-of-range inputs to avoid excessive extrapolation
560
                    mask_oob = (_x_0 < arg_0_min) | (_x_0 > arg_0_max) | (_x_1 < arg_1_min) | (_x_1 > arg_1_max)
2✔
561

562
                    if independent[0] in ["Teff", "age"]:
2✔
563
                        _x_0 = np.log10(_x_0)
2✔
564

565
                    if independent[1] in ["Teff", "age"]:
2✔
566
                        _x_1 = np.log10(_x_1)
2✔
567

568
                    out = _atmosphere_interpolator(np.column_stack((_x_0, _x_1)))
2✔
569
                    out = np.asarray(out).reshape(-1)
2✔
570
                    if np.any(mask_oob):
2✔
571
                        out[mask_oob] = np.nan
2✔
572
                    return out
2✔
573

574
            else:
575
                raise ValueError("This should never happen.")
2✔
576

577
        else:
578
            raise TypeError("Please provide ONE varaible name as a string or list, or TWO varaible names in a list.")
2✔
579

580
        return atmosphere_interpolator
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc