• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

LSSTDESC / Spectractor / 4890731301

pending completion
4890731301

Pull #125

github

GitHub
Merge 79e1b14f6 into 3549ae5c3
Pull Request #125: Fitparams

892 of 892 new or added lines in 16 files covered. (100.0%)

7127 of 7963 relevant lines covered (89.5%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.0
/spectractor/tools.py
1
import os
1✔
2
import shutil
1✔
3
from photutils.detection import IRAFStarFinder
1✔
4
from scipy.optimize import curve_fit
1✔
5
import numpy as np
1✔
6
from astropy.modeling import models, fitting
1✔
7
from astropy.stats import sigma_clip, sigma_clipped_stats
1✔
8
from astropy.io import fits
1✔
9
from astropy import wcs as WCS
1✔
10

11
import matplotlib.pyplot as plt
1✔
12
import matplotlib.colors
1✔
13
from matplotlib.ticker import MaxNLocator
1✔
14

15
import json
1✔
16
import warnings
1✔
17
from scipy.signal import fftconvolve, gaussian
1✔
18
from scipy.ndimage import maximum_filter, generate_binary_structure, binary_erosion
1✔
19
from scipy.interpolate import interp1d
1✔
20
from scipy.integrate import quad
1✔
21

22
from skimage.feature import hessian_matrix
1✔
23
from spectractor.config import set_logger
1✔
24
from spectractor import parameters
1✔
25
from math import floor
1✔
26

27
from numba import njit
1✔
28

29

30
_SCIKIT_IMAGE_NEW_HESSIAN = None
1✔
31

32

33
@njit(fastmath=True, cache=True)
1✔
34
def gauss(x, A, x0, sigma):
1✔
35
    """Evaluate the Gaussian function.
36

37
    Parameters
38
    ----------
39
    x: array_like
40
        Abscisse array to evaluate the function of size Nx.
41
    A: float
42
        Amplitude of the Gaussian function.
43
    x0: float
44
        Mean of the Gaussian function.
45
    sigma: float
46
        Standard deviation of the Gaussian function.
47

48
    Returns
49
    -------
50
    m: array_like
51
        The Gaussian function evaluated on the x array.
52

53
    Examples
54
    --------
55

56
    >>> x = np.arange(50)
57
    >>> y = gauss(x, 10, 25, 3)
58
    >>> print(y.shape)
59
    (50,)
60
    >>> y[25]
61
    10.0
62
    """
63
    return A * np.exp(-(x - x0) * (x - x0) / (2 * sigma * sigma))
×
64

65

66
@njit(fastmath=True, cache=True)
1✔
67
def gauss_jacobian(x, A, x0, sigma):
1✔
68
    """Compute the Jacobian matrix of the Gaussian function.
69

70
    Parameters
71
    ----------
72
    x: array_like
73
        Abscisse array to evaluate the function of size Nx.
74
    A: float
75
        Amplitude of the Gaussian function.
76
    x0: float
77
        Mean of the Gaussian function.
78
    sigma: float
79
        Standard deviation of the Gaussian function.
80

81
    Returns
82
    -------
83
    m: array_like
84
        The Jacobian matrix of size 3 x Nx.
85

86
    Examples
87
    --------
88

89
    >>> x = np.arange(50)
90
    >>> jac = gauss_jacobian(x, 10, 25, 3)
91
    >>> print(np.array(jac).T.shape)
92
    (50, 3)
93
    """
94
    dA = gauss(x, A, x0, sigma) / A
×
95
    dx0 = A * (x - x0) / (sigma * sigma) * dA
×
96
    dsigma = A * (x - x0) * (x - x0) / (sigma ** 3) * dA
×
97
    # return np.array([dA, dx0, dsigma]).T
98
    return dA, dx0, dsigma
×
99

100

101
@njit(fastmath=True, cache=True)
1✔
102
def line(x, a, b):
1✔
103
    return a * x + b
×
104

105

106
# noinspection PyTypeChecker
107
def fit_gauss(x, y, guess=[10, 1000, 1], bounds=(-np.inf, np.inf), sigma=None):
1✔
108
    """Fit a Gaussian profile to data, using curve_fit. The mean guess value of the Gaussian
109
    must not be far from the truth values. Boundaries helps a lot also.
110

111
    Parameters
112
    ----------
113
    x: np.array
114
        The x data values.
115
    y: np.array
116
        The y data values.
117
    guess: list, [amplitude, mean, sigma], optional
118
        List of first guessed values for the Gaussian fit (default: [10, 1000, 1]).
119
    bounds: list, optional
120
        List of boundaries for the parameters [[minima],[maxima]] (default: (-np.inf, np.inf)).
121
    sigma: np.array, optional
122
        The y data uncertainties.
123

124
    Returns
125
    -------
126
    popt: list
127
        Best fitting parameters of curve_fit.
128
    pcov: list
129
        Best fitting parameters covariance matrix from curve_fit.
130

131
    Examples
132
    --------
133

134
    >>> import numpy as np
135
    >>> import matplotlib.pyplot as plt
136
    >>> x = np.arange(600.,700.,2)
137
    >>> p = [10, 650, 10]
138
    >>> y = gauss(x, *p)
139
    >>> y_err = np.ones_like(y)
140
    >>> print(y[25])
141
    10.0
142
    >>> guess = (2,630,2)
143
    >>> popt, pcov = fit_gauss(x, y, guess=guess, bounds=((1,600,1),(100,700,100)), sigma=y_err)
144

145
    .. doctest::
146
        :hide:
147

148
        >>> assert np.all(np.isclose(p,popt))
149
    """
150
    def gauss_jacobian_wrapper(*params):
1✔
151
        return np.array(gauss_jacobian(*params)).T
1✔
152
    popt, pcov = curve_fit(gauss, x, y, p0=guess, bounds=bounds, tr_solver='exact', jac=gauss_jacobian_wrapper,
1✔
153
                           sigma=sigma, method='dogbox', verbose=0, xtol=1e-15, ftol=1e-15)
154
    return popt, pcov
1✔
155

156

157
def multigauss_and_line(x, *params):
1✔
158
    """Multiple Gaussian profile plus a straight line to data.
159
    The order of the parameters is line slope, line intercept,
160
    and then block of 3 parameters for the Gaussian profiles like amplitude, mean and standard
161
    deviation.
162

163
    Parameters
164
    ----------
165
    x: array
166
        The x data values.
167
    *params: list of float parameters as described above.
168

169
    Returns
170
    -------
171
    y: array
172
        The y profile values.
173

174
    Examples
175
    --------
176

177
    >>> x = np.arange(600.,800.,1)
178
    >>> y = multigauss_and_line(x, 1, 10, 20, 650, 3, 40, 750, 10)
179
    >>> print(y[0])
180
    610.0
181
    """
182
    out = line(x, params[0], params[1])
1✔
183
    for k in range((len(params) - 2) // 3):
1✔
184
        out += gauss(x, *params[2 + 3 * k:2 + 3 * k + 3])
1✔
185
    return out
1✔
186

187

188
# noinspection PyTypeChecker
189
def fit_multigauss_and_line(x, y, guess=[0, 1, 10, 1000, 1, 0], bounds=(-np.inf, np.inf)):
1✔
190
    """Fit a multiple Gaussian profile plus a straight line to data, using curve_fit.
191
    The mean guess value of the Gaussian must not be far from the truth values.
192
    Boundaries helps a lot also. The order of the parameters is line slope, line intercept,
193
    and then block of 3 parameters for the Gaussian profiles like amplitude, mean and standard
194
    deviation.
195

196
    Parameters
197
    ----------
198
    x: array
199
        The x data values.
200
    y: array
201
        The y data values.
202
    guess: list, [slope, intercept, amplitude, mean, sigma]
203
        List of first guessed values for the Gaussian fit (default: [0, 1, 10, 1000, 1]).
204
    bounds: 2D-list
205
        List of boundaries for the parameters [[minima],[maxima]] (default: (-np.inf, np.inf)).
206

207
    Returns
208
    -------
209
    popt: list
210
        Best fitting parameters of curve_fit.
211
    pcov: 2D-list
212
        Best fitting parameters covariance matrix from curve_fit.
213

214
    Examples
215
    --------
216

217
    >>> x = np.arange(600.,800.,1)
218
    >>> y = multigauss_and_line(x, 1, 10, 20, 650, 3, 40, 750, 10)
219
    >>> print(y[0])
220
    610.0
221
    >>> bounds = ((-np.inf,-np.inf,1,600,1,1,600,1),(np.inf,np.inf,100,800,100,100,800,100))
222
    >>> popt, pcov = fit_multigauss_and_line(x, y, guess=(0,1,3,630,3,3,770,3), bounds=bounds)
223
    >>> print(popt)
224
    [  1.  10.  20. 650.   3.  40. 750.  10.]
225
    """
226
    maxfev = 1000
1✔
227
    popt, pcov = curve_fit(multigauss_and_line, x, y, p0=guess, bounds=bounds, maxfev=maxfev, absolute_sigma=True)
1✔
228
    return popt, pcov
1✔
229

230

231
def rescale_x_to_legendre(x):
1✔
232
    middle = 0.5 * (np.max(x) + np.min(x))
1✔
233
    x_norm = x - middle
1✔
234
    if np.max(x_norm) != 0:
1✔
235
        return x_norm / np.max(x_norm)
1✔
236
    else:
237
        return x_norm
1✔
238

239

240
def rescale_x_from_legendre(x, Xmax, Xmin):
1✔
241
    X = 0.5 * x * (Xmax - Xmin) + 0.5 * (Xmax + Xmin)
×
242
    return X
×
243

244

245
# noinspection PyTypeChecker
246
def multigauss_and_bgd(x, *params):
1✔
247
    """Multiple Gaussian profile plus a polynomial background to data.
248
    Polynomial function is based on the orthogonal Legendre polynomial basis.
249
    The degree of the polynomial background is fixed by parameters.CALIB_BGD_NPARAMS.
250
    The order of the parameters is a first block CALIB_BGD_NPARAMS parameters (from low to high Legendre polynome degree,
251
    contrary to np.polyval), and then block of 3 parameters for the Gaussian profiles like amplitude, mean and standard
252
    deviation.
253

254
    Parameters
255
    ----------
256
    x: array
257
        The x data values.
258
    *params: list of float parameters as described above.
259

260
    Returns
261
    -------
262
    y: array
263
        The y profile values.
264

265
    Examples
266
    --------
267
    >>> parameters.CALIB_BGD_NPARAMS = 4
268
    >>> x = np.arange(600., 800., 1)
269
    >>> p = [20, 1, -1, -1, 20, 650, 3, 40, 750, 5]
270
    >>> y = multigauss_and_bgd(x, *p)
271
    >>> print(f'{y[0]:.2f}')
272
    19.00
273

274
    .. plot::
275

276
        from spectractor import parameters
277
        from spectractor.tools import multigauss_and_bgd
278
        import numpy as np
279
        parameters.CALIB_BGD_NPARAMS = 4
280
        x = np.arange(600., 800., 1)
281
        p = [20, 1, -1, -1, 20, 650, 3, 40, 750, 5]
282
        y = multigauss_and_bgd(x, *p)
283
        plt.plot(x,y,'r-')
284
        plt.show()
285

286
    """
287
    bgd_nparams = parameters.CALIB_BGD_NPARAMS
1✔
288
    # out = np.polyval(params[0:bgd_nparams], x)
289
    x_norm = rescale_x_to_legendre(x)
1✔
290
    out = np.polynomial.legendre.legval(x_norm, params[0:bgd_nparams])
1✔
291
    for k in range((len(params) - bgd_nparams) // 3):
1✔
292
        out += gauss(x, *params[bgd_nparams + 3 * k:bgd_nparams + 3 * k + 3])
1✔
293
    return out
1✔
294

295

296
# noinspection PyTypeChecker
297
def multigauss_and_bgd_jacobian(x, *params):
1✔
298
    """Jacobien of the multiple Gaussian profile plus a polynomial background to data.
299
    The degree of the polynomial background is fixed by parameters.CALIB_BGD_NPARAMS.
300
    The order of the parameters is a first block CALIB_BGD_NPARAMS parameters (from low to high Legendre polynome degree,
301
    contrary to np.polyval), and then block of 3 parameters for the Gaussian profiles like amplitude, mean and standard
302
    deviation. x values are renormalised on the [-1, 1] interval for the background.
303

304
    Parameters
305
    ----------
306
    x: array
307
        The x data values.
308
    *params: list of float parameters as described above.
309

310
    Returns
311
    -------
312
    y: array
313
        The jacobian values.
314

315
    Examples
316
    --------
317

318
    >>> import spectractor.parameters as parameters
319
    >>> parameters.CALIB_BGD_NPARAMS = 4
320
    >>> x = np.arange(600.,800.,1)
321
    >>> p = [20, 1, -1, -1, 20, 650, 3, 40, 750, 5]
322
    >>> y = multigauss_and_bgd_jacobian(x, *p)
323
    >>> assert(np.all(np.isclose(y.T[0],np.ones_like(x))))
324
    >>> print(y.shape)
325
    (200, 10)
326
    """
327
    bgd_nparams = parameters.CALIB_BGD_NPARAMS
1✔
328
    out = []
1✔
329
    x_norm = rescale_x_to_legendre(x)
1✔
330
    for k in range(bgd_nparams):
1✔
331
        # out.append(params[k]*(parameters.CALIB_BGD_ORDER-k)*x**(parameters.CALIB_BGD_ORDER-(k+1)))
332
        # out.append(x ** (bgd_nparams - 1 - k))
333
        c = np.zeros(bgd_nparams)
1✔
334
        c[k] = 1
1✔
335
        out.append(np.polynomial.legendre.legval(x_norm, c))
1✔
336
    for k in range((len(params) - bgd_nparams) // 3):
1✔
337
        jac = np.array(gauss_jacobian(x, *params[bgd_nparams + 3 * k:bgd_nparams + 3 * k + 3]))
1✔
338
        for j in jac:
1✔
339
            out.append(list(j))
1✔
340
    return np.array(out).T
1✔
341

342

343
# noinspection PyTypeChecker
344
def fit_multigauss_and_bgd(x, y, guess=[0, 1, 10, 1000, 1, 0], bounds=(-np.inf, np.inf), sigma=None):
1✔
345
    """Fit a multiple Gaussian profile plus a polynomial background to data, using iminuit.
346
    The mean guess value of the Gaussian must not be far from the truth values.
347
    Boundaries helps a lot also. The degree of the polynomial background is fixed by parameters.CALIB_BGD_NPARAMS.
348
    The order of the parameters is a first block CALIB_BGD_NPARAMS parameters (from high to low monomial terms,
349
    same as np.polyval), and then block of 3 parameters for the Gaussian profiles like amplitude, mean and standard
350
    deviation. x values are renormalised on the [-1, 1] interval for the background.
351

352
    Parameters
353
    ----------
354
    x: array
355
        The x data values.
356
    y: array
357
        The y data values.
358
    guess: list, [CALIB_BGD_ORDER+1 parameters, 3*number of Gaussian parameters]
359
        List of first guessed values for the Gaussian fit (default: [0, 1, 10, 1000, 1]).
360
    bounds: array
361
        List of boundaries for the parameters [[minima],[maxima]] (default: (-np.inf, np.inf)).
362
    sigma: array, optional
363
        The uncertainties on the y values (default: None).
364

365
    Returns
366
    -------
367
    popt: array
368
        Best fitting parameters of curve_fit.
369
    pcov: array
370
        Best fitting parameters covariance matrix from curve_fit.
371

372
    Examples
373
    --------
374

375
    >>> from spectractor.config import load_config
376
    >>> load_config("default.ini")
377
    >>> x = np.arange(600.,800.,1)
378
    >>> p = [20, 1, -1, -1, 20, 650, 3, 40, 750, 5]
379
    >>> y = multigauss_and_bgd(x, *p)
380
    >>> print(f'{y[0]:.2f}')
381
    19.00
382
    >>> err = 0.1 * np.sqrt(y)
383
    >>> guess = (15,0,0,0,10,640,2,20,750,7)
384
    >>> bounds = ((-np.inf,-np.inf,-np.inf,-np.inf,1,600,1,1,600,1),(np.inf,np.inf,np.inf,np.inf,100,800,100,100,800,100))
385
    >>> popt, pcov = fit_multigauss_and_bgd(x, y, guess=guess, bounds=bounds, sigma=err)
386
    >>> assert np.all(np.isclose(p,popt,rtol=1e-4))
387
    >>> fit = multigauss_and_bgd(x, *popt)
388

389
    .. plot::
390

391
        import matplotlib.pyplot as plt
392
        import numpy as np
393
        from spectractor.tools import multigauss_and_bgd, fit_multigauss_and_bgd
394
        x = np.arange(600.,800.,1)
395
        p = [20, 1, -1, -1, 20, 650, 3, 40, 750, 5]
396
        y = multigauss_and_bgd(x, *p)
397
        err = 0.1 * np.sqrt(y)
398
        guess = (15,0,0,0,10,640,2,20,750,7)
399
        bounds = ((-np.inf,-np.inf,-np.inf,-np.inf,1,600,1,1,600,1),(np.inf,np.inf,np.inf,np.inf,100,800,100,100,800,100))
400
        popt, pcov = fit_multigauss_and_bgd(x, y, guess=guess, bounds=bounds, sigma=err)
401
        fit = multigauss_and_bgd(x, *popt)
402
        fig = plt.figure()
403
        plt.errorbar(x,y,yerr=err,linestyle='None')
404
        plt.plot(x,fit,'r-')
405
        plt.plot(x,multigauss_and_bgd(x, *guess),'k--')
406
        plt.show()
407
    """
408
    maxfev = 10000
1✔
409
    popt, pcov = curve_fit(multigauss_and_bgd, x, y, p0=guess, bounds=bounds, maxfev=maxfev, sigma=sigma,
1✔
410
                           absolute_sigma=True, method='trf', xtol=1e-4, ftol=1e-4, verbose=0,
411
                           jac=multigauss_and_bgd_jacobian, x_scale='jac')
412
    # error = 0.1 * np.abs(guess) * np.ones_like(guess)
413
    # z = np.where(np.isclose(error,0.0,1e-6))
414
    # error[z] = 0.01
415
    # bounds = np.array(bounds)
416
    # if bounds.shape[0] == 2 and bounds.shape[1] > 2:
417
    #     bounds = bounds.T
418
    # guess = np.array(guess)
419
    #
420
    # def chisq_multigauss_and_bgd(params):
421
    #     if sigma is None:
422
    #         return np.nansum((multigauss_and_bgd(x, *params) - y)**2)
423
    #     else:
424
    #         return np.nansum(((multigauss_and_bgd(x, *params) - y)/sigma)**2)
425
    #
426
    # def chisq_multigauss_and_bgd_jac(params):
427
    #     diff = multigauss_and_bgd(x, *params) - y
428
    #     jac = multigauss_and_bgd_jacobian(x, *params)
429
    #     if sigma is None:
430
    #         return np.array([np.nansum(2 * jac[p] * diff) for p in range(len(params))])
431
    #     else:
432
    #         return np.array([np.nansum(2 * jac[p] * diff / (sigma*sigma)) for p in range(len(params))])
433
    #
434
    # fix = [False] * error.size
435
    # if fix_centroids:
436
    #     for k in range(parameters.CALIB_BGD_NPARAMS, len(fix), 3):
437
    #        fix[k+1] = True
438
    # # noinspection PyArgumentList
439
    # m = Minuit.from_array_func(fcn=chisq_multigauss_and_bgd, start=guess, error=error, errordef=1,
440
    #                            fix=fix, print_level=0, limit=bounds, grad=chisq_multigauss_and_bgd_jac)
441
    #
442
    # m.tol = 0.001
443
    # m.migrad()
444
    # try:
445
    #     pcov = m.np_covariance()
446
    # except:
447
    #     pcov = None
448
    # popt = m.np_values()
449
    return popt, pcov
1✔
450

451

452
# noinspection PyTupleAssignmentBalance
453
def fit_poly1d(x, y, order, w=None):
1✔
454
    """Fit a 1D polynomial function to data. Use np.polyfit.
455

456
    Parameters
457
    ----------
458
    x: array
459
        The x data values.
460
    y: array
461
        The y data values.
462
    order: int
463
        The degree of the polynomial function.
464
    w: array, optional
465
        Weights on the y data (default: None).
466

467
    Returns
468
    -------
469
    fit: array
470
        The best fitting parameter values.
471
    cov: 2D-array
472
        The covariance matrix
473
    model: array
474
        The best fitting profile
475

476
    Examples
477
    --------
478

479
    >>> x = np.arange(500., 1000., 1)
480
    >>> p = [3, 2, 1, 1]
481
    >>> y = np.polyval(p, x)
482
    >>> err = np.ones_like(y)
483
    >>> fit, cov, model = fit_poly1d(x, y, order=3)
484

485
    .. doctest::
486
        :hide:
487

488
        >>> assert np.all(np.isclose(p, fit, 1e-5))
489
        >>> assert np.all(np.isclose(model, y))
490
        >>> assert cov.shape == (4, 4)
491

492
    With uncertainties:
493

494
    >>> fit, cov2, model2 = fit_poly1d(x, y, order=3, w=err)
495

496
    .. doctest::
497
        :hide:
498

499
        >>> assert np.all(np.isclose(p, fit, 1e-5))
500

501
    >>> fit, cov3, model3 = fit_poly1d([0, 1], [1, 1], order=3, w=err)
502
    >>> print(fit)
503
    [0 0 0 0]
504
    """
505
    cov = np.array([])
1✔
506
    if len(x) > order:
1✔
507
        if w is None:
1✔
508
            fit, cov = np.polyfit(x, y, order, cov=True)
1✔
509
        else:
510
            fit, cov = np.polyfit(x, y, order, cov=True, w=w)
1✔
511
        model = np.polyval(fit, x)
1✔
512
    else:
513
        fit = np.array([0] * (order + 1))
1✔
514
        model = y
1✔
515
    return fit, cov, model
1✔
516

517

518
# noinspection PyTupleAssignmentBalance
519
def fit_poly1d_legendre(x, y, order, w=None):
1✔
520
    """Fit a 1D polynomial function to data using Legendre polynomial orthogonal basis.
521

522
    Parameters
523
    ----------
524
    x: array
525
        The x data values.
526
    y: array
527
        The y data values.
528
    order: int
529
        The degree of the polynomial function.
530
    w: array, optional
531
        Weights on the y data (default: None).
532

533
    Returns
534
    -------
535
    fit: array
536
        The best fitting parameter values.
537
    cov: 2D-array
538
        The covariance matrix
539
    model: array
540
        The best fitting profile
541

542
    Examples
543
    --------
544

545
    >>> x = np.arange(500., 1000., 1)
546
    >>> p = [-1e-6, -1e-4, 1, 1]
547
    >>> y = np.polyval(p, x)
548
    >>> err = np.ones_like(y)
549
    >>> fit, cov, model = fit_poly1d_legendre(x,y,order=3)
550
    >>> assert np.all(np.isclose(p,fit,3))
551
    >>> fit, cov2, model2 = fit_poly1d_legendre(x,y,order=3,w=err)
552
    >>> assert np.all(np.isclose(p,fit,3))
553
    >>> fit, cov3, model3 = fit_poly1d([0, 1], [1, 1], order=3, w=err)
554
    >>> print(fit)
555
    [0 0 0 0]
556

557
    .. plot::
558

559
        import matplotlib.pyplot as plt
560
        import numpy as np
561
        from spectractor.tools import fit_poly1d_legendre
562
        p = [-1e-6, -1e-4, 1, 1]
563
        x = np.arange(500., 1000., 1)
564
        y = np.polyval(p, x)
565
        err = np.ones_like(y)
566
        fit, cov2, model2 = fit_poly1d_legendre(x,y,order=3,w=err)
567
        plt.errorbar(x,y,yerr=err,fmt='ro')
568
        plt.plot(x,model2)
569
        plt.show()
570
    """
571
    cov = -1
1✔
572
    x_norm = rescale_x_to_legendre(x)
1✔
573
    if len(x) > order:
1✔
574
        fit, cov = np.polynomial.legendre.legfit(x_norm, y, deg=order, full=True, w=w)
1✔
575
        model = np.polynomial.legendre.legval(x_norm, fit)
1✔
576
    else:
577
        fit = np.array([0] * (order + 1))
×
578
        model = y
×
579
    return fit, cov, model
1✔
580

581

582
# noinspection PyTypeChecker,PyUnresolvedReferences
583
def fit_poly2d(x, y, z, order):
1✔
584
    """Fit a 2D polynomial function to data. Use astropy.modeling.
585

586
    Parameters
587
    ----------
588
    x: array
589
        The x data values.
590
    y: array
591
        The y data values.
592
    z: array
593
        The z data values.
594
    order: int
595
        The degree of the polynomial function.
596

597
    Returns
598
    -------
599
    model: Astropy model
600
        The best fitting astropy polynomial model
601

602
    Examples
603
    --------
604

605
    >>> x, y = np.mgrid[:50,:50]
606
    >>> z = x**2 + y**2 - 2*x*y
607
    >>> fit = fit_poly2d(x, y, z, order=2)
608

609
    .. doctest::
610
        :hide:
611

612
        >>> assert np.isclose(fit.c0_0.value, 0)
613
        >>> assert np.isclose(fit.c1_0.value, 0)
614
        >>> assert np.isclose(fit.c2_0.value, 1)
615
        >>> assert np.isclose(fit.c0_1.value, 0)
616
        >>> assert np.isclose(fit.c0_2.value, 1)
617
        >>> assert np.isclose(fit.c1_1.value, -2)
618
    """
619
    p_init = models.Polynomial2D(degree=order)
1✔
620
    fit_p = fitting.LevMarLSQFitter()
1✔
621
    with warnings.catch_warnings():
1✔
622
        # Ignore model linearity warning from the fitter
623
        warnings.simplefilter('ignore')
1✔
624
        p = fit_p(p_init, x, y, z)
1✔
625
        return p
1✔
626

627

628
def fit_poly1d_outlier_removal(x, y, order=2, sigma=3.0, niter=3):
1✔
629
    """Fit a 1D polynomial function to data. Use astropy.modeling.
630

631
    Parameters
632
    ----------
633
    x: array
634
        The x data values.
635
    y: array
636
        The y data values.
637
    order: int
638
        The degree of the polynomial function (default: 2).
639
    sigma: float
640
        Value of the sigma-clipping (default: 3.0).
641
    niter: int
642
        The number of iterations to converge (default: 3).
643

644
    Returns
645
    -------
646
    model: Astropy model
647
        The best fitting astropy model.
648
    outliers: array_like
649
        List of the outlier points.
650

651
    Examples
652
    --------
653

654
    >>> x = np.arange(500., 1000., 1)
655
    >>> p = [3,2,1,0]
656
    >>> y = np.polyval(p, x)
657
    >>> y[::10] = 0.
658
    >>> model, outliers = fit_poly1d_outlier_removal(x,y,order=3,sigma=3)
659
    >>> print('{:.2f}'.format(model.c0.value))
660
    0.00
661
    >>> print('{:.2f}'.format(model.c1.value))
662
    1.00
663
    >>> print('{:.2f}'.format(model.c2.value))
664
    2.00
665
    >>> print('{:.2f}'.format(model.c3.value))
666
    3.00
667

668
    """
669
    gg_init = models.Polynomial1D(order)
1✔
670
    gg_init.c1 = 0
1✔
671
    gg_init.c2 = 0
1✔
672
    fit = fitting.LinearLSQFitter()
1✔
673
    or_fit = fitting.FittingWithOutlierRemoval(fit, sigma_clip, niter=niter, sigma=sigma)
1✔
674
    # get fitted model and filtered data
675
    or_fitted_model, filtered_data = or_fit(gg_init, x, y)
1✔
676
    outliers = []  # not working
1✔
677
    """
1✔
678
    import matplotlib.pyplot as plt
679
    plt.figure(figsize=(8,5))
680
    plt.plot(x, y, 'gx', label="original data")
681
    plt.plot(x, gg_init(x), 'k.', label="guess")
682
    plt.plot(x, filtered_data, 'r+', label="filtered data")
683
    plt.plot(x, or_fitted_model(x), 'r--',
684
             label="model fitted w/ filtered data")
685
    plt.legend(loc=2, numpoints=1)
686
    if parameters.DISPLAY: plt.show()
687
    """
688
    return or_fitted_model, outliers
1✔
689

690

691
def fit_poly2d_outlier_removal(x, y, z, order=2, sigma=3.0, niter=30):
1✔
692
    """Fit a 2D polynomial function to data. Use astropy.modeling.
693

694
    Parameters
695
    ----------
696
    x: array
697
        The x data values.
698
    y: array
699
        The y data values.
700
    z: array
701
        The z data values.
702
    order: int
703
        The degree of the polynomial function (default: 2).
704
    sigma: float
705
        Value of the sigma-clipping (default: 3.0).
706
    niter: int
707
        The number of iterations to converge (default: 30).
708

709
    Returns
710
    -------
711
    model: Astropy model
712
        The best fitting astropy model.
713

714
    Examples
715
    --------
716

717
    >>> x, y = np.mgrid[:50,:50]
718
    >>> z = x**2 + y**2 - 2*x*y
719
    >>> z[::10,::10] = 0.
720
    >>> fit = fit_poly2d_outlier_removal(x,y,z,order=2,sigma=3)
721

722
    .. doctest::
723
        :hide:
724

725
        >>> assert np.isclose(fit.c0_0.value, 0)
726
        >>> assert np.isclose(fit.c1_0.value, 0)
727
        >>> assert np.isclose(fit.c2_0.value, 1)
728
        >>> assert np.isclose(fit.c0_1.value, 0)
729
        >>> assert np.isclose(fit.c0_2.value, 1)
730
        >>> assert np.isclose(fit.c1_1.value, -2)
731

732
    """
733
    my_logger = set_logger(__name__)
1✔
734
    gg_init = models.Polynomial2D(order)
1✔
735
    fit = fitting.LinearLSQFitter()
1✔
736
    or_fit = fitting.FittingWithOutlierRemoval(fit, sigma_clip, niter=niter, sigma=sigma)
1✔
737
    # get fitted model and filtered data
738
    or_fitted_model, filtered_data = or_fit(gg_init, x, y, z)
1✔
739
    my_logger.info(f'\n\t{or_fitted_model}')
1✔
740
    # my_logger.debug(f'\n\t{fit.fit_info}')
741
    return or_fitted_model
1✔
742

743

744
def tied_circular_gauss2d(g1):
1✔
745
    std = g1.x_stddev
1✔
746
    return std
1✔
747

748

749
def fit_gauss2d_outlier_removal(x, y, z, sigma=3.0, niter=3, guess=None, bounds=None, circular=False):
1✔
750
    """
751
    Fit an astropy Gaussian 2D model with parameters : amplitude, x_mean, y_mean, x_stddev, y_stddev, theta
752
    using outlier removal methods.
753

754
    Parameters
755
    ----------
756
    x: np.array
757
        2D array of the x coordinates from meshgrid.
758
    y: np.array
759
        2D array of the y coordinates from meshgrid.
760
    z: np.array
761
        the 2D array image.
762
    sigma: float
763
        value of sigma for the sigma rejection of outliers (default: 3)
764
    niter: int
765
        maximum number of iterations for the outlier detection (default: 3)
766
    guess: list, optional
767
        List containing a first guess for the PSF parameters (default: None).
768
    bounds: list, optional
769
        2D list containing bounds for the PSF parameters with format ((min,...), (max...)) (default: None)
770
    circular: bool, optional
771
        If True, force the Gaussian shape to be circular (default: False)
772

773
    Returns
774
    -------
775
    fitted_model: Fittable
776
        Astropy Gaussian2D model
777

778
    Examples
779
    --------
780

781
    >>> import numpy as np
782
    >>> import matplotlib.pyplot as plt
783
    >>> from astropy.modeling import models
784
    >>> X, Y = np.mgrid[:50,:50]
785
    >>> PSF = models.Gaussian2D()
786
    >>> p = (50, 25, 25, 5, 5, 0)
787
    >>> Z = PSF.evaluate(X, Y, *p)
788

789
    .. plot::
790

791
        import numpy as np
792
        import matplotlib.pyplot as plt
793
        from astropy.modeling import models
794
        X, Y = np.mgrid[:50,:50]
795
        PSF = models.Gaussian2D()
796
        p = (50, 25, 25, 5, 5, 0)
797
        Z = PSF.evaluate(X, Y, *p)
798
        plt.imshow(Z, origin='lower')
799
        plt.show()
800

801
    >>> guess = (45, 20, 20, 7, 7, 0)
802
    >>> bounds = ((1, 10, 10, 1, 1, -90), (100, 40, 40, 10, 10, 90))
803
    >>> fit = fit_gauss2d_outlier_removal(X, Y, Z, guess=guess, bounds=bounds, circular=True)
804
    >>> res = [getattr(fit, p).value for p in fit.param_names]
805
    >>> print(res)
806
    [50.0, 25.0, 25.0, 5.0, 5.0, 0.0]
807

808
    .. plot::
809

810
        import numpy as np
811
        import matplotlib.pyplot as plt
812
        from astropy.modeling import models
813
        from spectractor.tools import fit_gauss2d_outlier_removal
814
        X, Y = np.mgrid[:50,:50]
815
        PSF = models.Gaussian2D()
816
        p = (50, 25, 25, 5, 5, 0)
817
        Z = PSF.evaluate(X, Y, *p)
818
        guess = (45, 20, 20, 7, 7, 0)
819
        bounds = ((1, 10, 10, 1, 1, -90), (100, 40, 40, 10, 10, 90))
820
        fit = fit_gauss2d_outlier_removal(X, Y, Z, guess=guess, bounds=bounds, circular=True)
821
        plt.imshow(Z-fit(X, Y), origin='lower')
822
        plt.show()
823

824
    """
825
    my_logger = set_logger(__name__)
1✔
826
    gg_init = models.Gaussian2D()
1✔
827
    if guess is not None:
1✔
828
        for ip, p in enumerate(gg_init.param_names):
1✔
829
            getattr(gg_init, p).value = guess[ip]
1✔
830
    if bounds is not None:
1✔
831
        for ip, p in enumerate(gg_init.param_names):
1✔
832
            getattr(gg_init, p).min = bounds[0][ip]
1✔
833
            getattr(gg_init, p).max = bounds[1][ip]
1✔
834
    if circular:
1✔
835
        gg_init.y_stddev.tied = tied_circular_gauss2d
1✔
836
        gg_init.theta.fixed = True
1✔
837
    with warnings.catch_warnings():
1✔
838
        # Ignore model linearity warning from the fitter
839
        warnings.simplefilter('ignore')
1✔
840
        fit = fitting.LevMarLSQFitter()
1✔
841
        or_fit = fitting.FittingWithOutlierRemoval(fit, sigma_clip, niter=niter, sigma=sigma)
1✔
842
        # get fitted model and filtered data
843
        or_fitted_model, filtered_data = or_fit(gg_init, x, y, z)
1✔
844
        my_logger.info(f'\n\t{or_fitted_model}')
1✔
845
        # my_logger.debug(f'\n\t{fit.fit_info}')
846
        return or_fitted_model
1✔
847

848

849
def fit_moffat2d_outlier_removal(x, y, z, sigma=3.0, niter=3, guess=None, bounds=None):
1✔
850
    """
851
    Fit an astropy Moffat 2D model with parameters: amplitude, x_mean, y_mean, gamma, alpha
852
    using outlier removal methods.
853

854
    Parameters
855
    ----------
856
    x: np.array
857
        2D array of the x coordinates from meshgrid.
858
    y: np.array
859
        2D array of the y coordinates from meshgrid.
860
    z: np.array
861
        the 2D array image.
862
    sigma: float
863
        value of sigma for the sigma rejection of outliers (default: 3)
864
    niter: int
865
        maximum number of iterations for the outlier detection (default: 3)
866
    guess: list, optional
867
        List containing a first guess for the PSF parameters (default: None).
868
    bounds: list, optional
869
        2D list containing bounds for the PSF parameters with format ((min,...), (max...)) (default: None)
870

871
    Returns
872
    -------
873
    fitted_model: Fittable
874
        Astropy Moffat2D model
875

876
    Examples
877
    --------
878

879
    >>> import numpy as np
880
    >>> import matplotlib.pyplot as plt
881
    >>> from astropy.modeling import models
882
    >>> X, Y = np.mgrid[:100,:100]
883
    >>> PSF = models.Moffat2D()
884
    >>> p = (50, 50, 50, 5, 2)
885
    >>> Z = PSF.evaluate(X, Y, *p)
886

887
    .. plot::
888

889
        import numpy as np
890
        import matplotlib.pyplot as plt
891
        from astropy.modeling import models
892
        X, Y = np.mgrid[:100,:100]
893
        PSF = models.Moffat2D()
894
        p = (50, 50, 50, 5, 2)
895
        Z = PSF.evaluate(X, Y, *p)
896
        plt.imshow(Z, origin='loxer')
897
        plt.show()
898

899
    >>> guess = (45, 48, 52, 4, 2)
900
    >>> bounds = ((1, 10, 10, 1, 1), (100, 90, 90, 10, 10))
901
    >>> fit = fit_moffat2d_outlier_removal(X, Y, Z, guess=guess, bounds=bounds, niter=3)
902
    >>> res = [getattr(fit, p).value for p in fit.param_names]
903

904
    .. doctest::
905
        :hide:
906

907
        >>> assert(np.all(np.isclose(p, res, 1e-1)))
908

909
    .. plot::
910

911
        import numpy as np
912
        import matplotlib.pyplot as plt
913
        from astropy.modeling import models
914
        from spectractor.tools import fit_moffat2d_outlier_removal
915
        X, Y = np.mgrid[:100,:100]
916
        PSF = models.Moffat2D()
917
        p = (50, 50, 50, 5, 2)
918
        Z = PSF.evaluate(X, Y, *p)
919
        guess = (45, 48, 52, 4, 2)
920
        bounds = ((1, 10, 10, 1, 1), (100, 90, 90, 10, 10))
921
        fit = fit_moffat2d_outlier_removal(X, Y, Z, guess=guess, bounds=bounds, niter=3)
922
        plt.imshow(Z-fit(X, Y), origin='loxer')
923
        plt.show()
924
    """
925
    my_logger = set_logger(__name__)
1✔
926
    gg_init = models.Moffat2D()
1✔
927
    if guess is not None:
1✔
928
        for ip, p in enumerate(gg_init.param_names):
1✔
929
            getattr(gg_init, p).value = guess[ip]
1✔
930
    if bounds is not None:
1✔
931
        for ip, p in enumerate(gg_init.param_names):
1✔
932
            getattr(gg_init, p).min = bounds[0][ip]
1✔
933
            getattr(gg_init, p).max = bounds[1][ip]
1✔
934
    with warnings.catch_warnings():
1✔
935
        # Ignore model linearity warning from the fitter
936
        warnings.simplefilter('ignore')
1✔
937
        fit = fitting.LevMarLSQFitter()
1✔
938
        or_fit = fitting.FittingWithOutlierRemoval(fit, sigma_clip, niter=niter, sigma=sigma)
1✔
939
        # get fitted model and filtered data
940
        or_fitted_model, filtered_data = or_fit(gg_init, x, y, z)
1✔
941
        my_logger.info(f'\n\t{or_fitted_model}')
1✔
942
        # my_logger.debug(f'\n\t{fit.fit_info}')
943
        return or_fitted_model
1✔
944

945

946
def fit_moffat1d_outlier_removal(x, y, sigma=3.0, niter=3, guess=None, bounds=None):
1✔
947
    """
948
    Fit an astropy Moffat 1D model with parameters: amplitude, x_mean, gamma, alpha
949
    using outlier removal methods.
950

951
    Parameters
952
    ----------
953
    x: np.array
954
        1D array of the x coordinates from meshgrid.
955
    y: np.array
956
        the 1D array amplitudes.
957
    sigma: float
958
        value of sigma for the sigma rejection of outliers (default: 3)
959
    niter: int
960
        maximum number of iterations for the outlier detection (default: 3)
961
    guess: list, optional
962
        List containing a first guess for the PSF parameters (default: None).
963
    bounds: list, optional
964
        2D list containing bounds for the PSF parameters with format ((min,...), (max...)) (default: None)
965

966
    Returns
967
    -------
968
    fitted_model: Fittable
969
        Astropy Moffat1D model
970

971
    Examples
972
    --------
973

974
    >>> import numpy as np
975
    >>> import matplotlib.pyplot as plt
976
    >>> from astropy.modeling import models
977
    >>> X = np.arange(100)
978
    >>> PSF = models.Moffat1D()
979
    >>> p = (50, 50, 5, 2)
980
    >>> Y = PSF.evaluate(X, *p)
981

982
    .. plot::
983

984
        import numpy as np
985
        import matplotlib.pyplot as plt
986
        from astropy.modeling import models
987
        X = np.arange(100)
988
        PSF = models.Moffat1D()
989
        p = (50, 50, 5, 2)
990
        Y = PSF.evaluate(X, *p)
991
        plt.plot(X, Y)
992
        plt.show()
993

994
    >>> guess = (45, 48, 4, 2)
995
    >>> bounds = ((1, 10, 1, 1), (100, 90, 10, 10))
996
    >>> fit = fit_moffat1d_outlier_removal(X, Y, guess=guess, bounds=bounds, niter=3)
997
    >>> res = [getattr(fit, p).value for p in fit.param_names]
998

999
    .. doctest::
1000
        :hide:
1001

1002
        >>> assert(np.all(np.isclose(p, res, 1e-6)))
1003

1004
    .. plot::
1005

1006
        import numpy as np
1007
        import matplotlib.pyplot as plt
1008
        from astropy.modeling import models
1009
        from spectractor.tools import fit_moffat1d_outlier_removal
1010
        X = np.arange(100)
1011
        PSF = models.Moffat1D()
1012
        p = (50, 50, 5, 2)
1013
        Y = PSF.evaluate(X, *p)
1014
        guess = (45, 48, 4, 2)
1015
        bounds = ((1, 10, 1, 1), (100, 90, 10, 10))
1016
        fit = fit_moffat1d_outlier_removal(X, Y, guess=guess, bounds=bounds, niter=3)
1017
        plt.plot(X, Y-fit(X))
1018
        plt.show()
1019
    """
1020
    my_logger = set_logger(__name__)
1✔
1021
    gg_init = models.Moffat1D()
1✔
1022
    if guess is not None:
1✔
1023
        for ip, p in enumerate(gg_init.param_names):
1✔
1024
            getattr(gg_init, p).value = guess[ip]
1✔
1025
    if bounds is not None:
1✔
1026
        for ip, p in enumerate(gg_init.param_names):
1✔
1027
            getattr(gg_init, p).min = bounds[0][ip]
1✔
1028
            getattr(gg_init, p).max = bounds[1][ip]
1✔
1029
    with warnings.catch_warnings():
1✔
1030
        # Ignore model linearity warning from the fitter
1031
        warnings.simplefilter('ignore')
1✔
1032
        fit = fitting.LevMarLSQFitter()
1✔
1033
        or_fit = fitting.FittingWithOutlierRemoval(fit, sigma_clip, niter=niter, sigma=sigma)
1✔
1034
        # get fitted model and filtered data
1035
        or_fitted_model, filtered_data = or_fit(gg_init, x, y)
1✔
1036
        my_logger.debug(f'\n\t{or_fitted_model}')
1✔
1037
        # my_logger.debug(f'\n\t{fit.fit_info}')
1038
        return or_fitted_model
1✔
1039

1040

1041
def fit_moffat1d(x, y, guess=None, bounds=None):
1✔
1042
    """Fit an astropy Moffat 1D model with parameters :
1043
        amplitude, x_mean, gamma, alpha
1044

1045
    Parameters
1046
    ----------
1047
    x: np.array
1048
        1D array of the x coordinates from meshgrid.
1049
    y: np.array
1050
        the 1D array amplitudes.
1051
    guess: list, optional
1052
        List containing a first guess for the PSF parameters (default: None).
1053
    bounds: list, optional
1054
        2D list containing bounds for the PSF parameters with format ((min,...), (max...)) (default: None)
1055

1056
    Returns
1057
    -------
1058
    fitted_model: Fittable
1059
        Astropy Moffat1D model
1060

1061
    Examples
1062
    --------
1063

1064
    >>> import numpy as np
1065
    >>> import matplotlib.pyplot as plt
1066
    >>> from astropy.modeling import models
1067
    >>> X = np.arange(100)
1068
    >>> PSF = models.Moffat1D()
1069
    >>> p = (50, 50, 5, 2)
1070
    >>> Y = PSF.evaluate(X, *p)
1071

1072
    .. plot::
1073

1074
        import numpy as np
1075
        import matplotlib.pyplot as plt
1076
        from astropy.modeling import models
1077
        X = np.arange(100)
1078
        PSF = models.Moffat1D()
1079
        p = (50, 50, 5, 2)
1080
        Y = PSF.evaluate(X, *p)
1081
        plt.plot(X, Y)
1082
        plt.show()
1083

1084
    >>> guess = (45, 48, 4, 2)
1085
    >>> bounds = ((1, 10, 1, 1), (100, 90, 10, 10))
1086
    >>> fit = fit_moffat1d(X, Y, guess=guess, bounds=bounds)
1087
    >>> res = [getattr(fit, p).value for p in fit.param_names]
1088
    >>> assert(np.all(np.isclose(p, res, 1e-6)))
1089

1090
    .. plot::
1091

1092
        import numpy as np
1093
        import matplotlib.pyplot as plt
1094
        from astropy.modeling import models
1095
        from spectractor.tools import fit_moffat1d
1096
        X = np.arange(100)
1097
        PSF = models.Moffat1D()
1098
        p = (50, 50, 5, 2)
1099
        Y = PSF.evaluate(X, *p)
1100
        guess = (45, 48, 4, 2)
1101
        bounds = ((1, 10, 1, 1), (100, 90, 10, 10))
1102
        fit = fit_moffat1d(X, Y, guess=guess, bounds=bounds)
1103
        plt.plot(X, Y-fit(X))
1104
        plt.show()
1105
    """
1106
    my_logger = set_logger(__name__)
1✔
1107
    gg_init = models.Moffat1D()
1✔
1108
    if guess is not None:
1✔
1109
        for ip, p in enumerate(gg_init.param_names):
1✔
1110
            getattr(gg_init, p).value = guess[ip]
1✔
1111
    if bounds is not None:
1✔
1112
        for ip, p in enumerate(gg_init.param_names):
1✔
1113
            getattr(gg_init, p).min = bounds[0][ip]
1✔
1114
            getattr(gg_init, p).max = bounds[1][ip]
1✔
1115
    with warnings.catch_warnings():
1✔
1116
        # Ignore model linearity warning from the fitter
1117
        warnings.simplefilter('ignore')
1✔
1118
        fit = fitting.LevMarLSQFitter()
1✔
1119
        fitted_model = fit(gg_init, x, y)
1✔
1120
        my_logger.info(f'\n\t{fitted_model}')
1✔
1121
        # my_logger.debug(f'\n\t{fit.fit_info}')
1122
        return fitted_model
1✔
1123

1124

1125
class LevMarLSQFitterWithNan(fitting.LevMarLSQFitter):
1✔
1126

1127
    def objective_function(self, fps, *args):
1✔
1128
        """
1129
        Function to minimize.
1130

1131
        Parameters
1132
        ----------
1133
        fps : list
1134
            parameters returned by the fitter
1135
        args : list
1136
            [model, [weights], [input coordinates]]
1137
        """
1138

1139
        model = args[0]
×
1140
        weights = args[1]
×
1141
        fitting._fitter_to_model_params(model, fps)
×
1142
        meas = args[-1]
×
1143
        if weights is None:
×
1144
            a = np.ravel(model(*args[2: -1]) - meas)
×
1145
            a[np.isfinite(a)] = 0
×
1146
            return a
×
1147
        else:
1148
            a = np.ravel(weights * (model(*args[2: -1]) - meas))
×
1149
            a[~np.isfinite(a)] = 0
×
1150
            return a
×
1151

1152

1153
def compute_fwhm(x, y, minimum=0, center=None, full_output=False):
1✔
1154
    """
1155
    Compute the full width half maximum of y(x) curve,
1156
    using an interpolation of the data points and dichotomie method.
1157

1158
    Parameters
1159
    ----------
1160
    x: array_like
1161
        The abscisse array.
1162
    y: array_like
1163
        The function array.
1164
    minimum: float, optional
1165
        The minimum reference from which to compyte half the height (default: 0).
1166
    center: float, optional
1167
        The center of the curve. If None, the weighted averageof the y(x) distribution is computed (default: None).
1168
    full_output: bool, optional
1169
        If True, half maximum, the edges of the curve and the curve center are given in output (default: False).
1170

1171
    Returns
1172
    -------
1173
    FWHM: float
1174
        The full width half maximum of the curve.
1175
    half: float, optional
1176
        The half maximum value. Only if full_output=True.
1177
    center: float, optional
1178
        The y(x) center value. Only if full_output=True.
1179
    left_edge: float, optional
1180
        The left_edge value at half maximum. Only if full_output=True.
1181
    right_edge: float, optional
1182
        The right_edge value at half maximum. Only if full_output=True.
1183

1184
    Examples
1185
    --------
1186

1187
    Gaussian example
1188

1189
    >>> x = np.arange(0, 100, 1)
1190
    >>> stddev = 4
1191
    >>> middle = 40
1192
    >>> psf = gauss(x, 1, middle, stddev)
1193
    >>> fwhm, half, center, a, b = compute_fwhm(x, psf, full_output=True)
1194
    >>> print(f"{fwhm:.4f} {2.355*stddev:.4f} {center:.4f}")
1195
    9.4329 9.4200 40.0000
1196

1197
    .. doctest::
1198
        :hide:
1199

1200
        >>> assert np.isclose(fwhm, 2.355*stddev, atol=2e-1)
1201
        >>> assert np.isclose(center, middle, atol=1e-3)
1202

1203
    .. plot ::
1204

1205
        import matplotlib.pyplot as plt
1206
        import numpy as np
1207
        from spectractor.tools import gauss, compute_fwhm
1208
        x = np.arange(0, 100, 1)
1209
        stddev = 4
1210
        middle = 40
1211
        psf = gauss(x, 1, middle, stddev)
1212
        fwhm, half, center, a, b = compute_fwhm(x, psf, full_output=True)
1213
        plt.figure()
1214
        plt.plot(x, psf, label="function")
1215
        plt.axvline(center, color="gray", label="center")
1216
        plt.axvline(a, color="k", label="edges at half max")
1217
        plt.axvline(b, color="k", label="edges at half max")
1218
        plt.axhline(half, color="r", label="half max")
1219
        plt.legend()
1220
        plt.title(f"FWHM={fwhm:.3f}")
1221
        plt.xlabel("x")
1222
        plt.ylabel("y")
1223
        plt.show()
1224

1225
    Defocused PSF example
1226

1227
    >>> from spectractor.extractor.psf import MoffatGauss
1228
    >>> p = [2,40,40,4,2,-0.4,1,10]
1229
    >>> psf = MoffatGauss(p)
1230
    >>> fwhm, half, center, a, b = compute_fwhm(x, psf.evaluate(x), full_output=True)
1231

1232
    .. doctest::
1233
        :hide:
1234

1235
        >>> assert np.isclose(fwhm, 7.05, atol=1e-2)
1236
        >>> assert np.isclose(center, p[1], atol=1e-2)
1237

1238
    .. plot ::
1239

1240
        import matplotlib.pyplot as plt
1241
        import numpy as np
1242
        from spectractor.tools import gauss, compute_fwhm
1243
        from spectractor.extractor.psf import MoffatGauss
1244
        x = np.arange(0, 100, 1)
1245
        p = [2,40,40,4,2,-0.4,1,10]
1246
        psf = MoffatGauss(p)
1247
        fwhm, half, center, a, b = compute_fwhm(x, psf.evaluate(x), full_output=True)
1248
        plt.figure()
1249
        plt.plot(x, psf.evaluate(x, p), label="function")
1250
        plt.axvline(center, color="gray", label="center")
1251
        plt.axvline(a, color="k", label="edges at half max")
1252
        plt.axvline(b, color="k", label="edges at half max")
1253
        plt.axhline(half, color="r", label="half max")
1254
        plt.legend()
1255
        plt.title(f"FWHM={fwhm:.3f}")
1256
        plt.xlabel("x")
1257
        plt.ylabel("y")
1258
        plt.show()
1259
    """
1260
    if y.ndim > 1:
1✔
1261
        # TODO: implement fwhm for 2D curves
1262
        return -1
×
1263
    interp = interp1d(x, y, kind="linear", bounds_error=False, fill_value="extrapolate")
1✔
1264
    maximum = np.max(y) - minimum
1✔
1265
    imax = np.argmax(y)
1✔
1266
    a = x[imax + np.argmin(np.abs(y[imax:] - 0.9 * maximum))]
1✔
1267
    b = x[imax + np.argmin(np.abs(y[imax:] - 0.1 * maximum))]
1✔
1268

1269
    def eq(xx):
1✔
1270
        return interp(xx) - 0.5 * maximum
1✔
1271

1272
    res = dichotomie(eq, a, b, 1e-3)
1✔
1273
    if center is None:
1✔
1274
        center = np.average(x, weights=y)
1✔
1275
    fwhm = abs(2 * (res - center))
1✔
1276
    if not full_output:
1✔
1277
        return fwhm
1✔
1278
    else:
1279
        return fwhm, 0.5 * maximum, center, res, center - abs(res - center)
1✔
1280

1281

1282
def compute_integral(x, y, bounds=None):
1✔
1283
    """
1284
    Compute the integral of an y(x) curve. The curve is interpolated and extrapolated with cubic splines.
1285
    If not provided, bounds are set to the x array edges.
1286

1287
    Parameters
1288
    ----------
1289
    x: array_like
1290
        The abscisse array.
1291
    y: array_like
1292
        The function array.
1293
    bounds: array_like, optional
1294
        The bounds of the integral. If None, the edges of thex array are taken (default bounds=None).
1295

1296
    Returns
1297
    -------
1298
    result: float
1299
        The integral of the PSF model.
1300

1301
    Examples
1302
    --------
1303

1304
    Gaussian example
1305

1306
    .. doctest::
1307

1308
        >>> x = np.arange(0, 100, 0.5)
1309
        >>> stddev = 4
1310
        >>> middle = 40
1311
        >>> psf = gauss(x, 1/(stddev*np.sqrt(2*np.pi)), middle, stddev)
1312
        >>> integral = compute_integral(x, psf)
1313
        >>> print(f"{integral:.6f}")
1314
        1.000000
1315

1316
    Defocused PSF example
1317

1318
    .. doctest::
1319

1320
        >>> from spectractor.extractor.psf import MoffatGauss
1321
        >>> p = [2,30,30,4,2,-0.5,1,10]
1322
        >>> psf = MoffatGauss(p)
1323
        >>> integral = compute_integral(x, psf.evaluate(x))
1324
        >>> assert np.isclose(integral, p[0], atol=1e-2)
1325

1326
    """
1327
    if bounds is None:
1✔
1328
        bounds = (np.min(x), np.max(x))
1✔
1329
    interp = interp1d(x, y, kind="cubic", bounds_error=False, fill_value="extrapolate")
1✔
1330
    integral = quad(interp, bounds[0], bounds[1], limit=200)
1✔
1331
    return integral[0]
1✔
1332

1333

1334
def find_nearest(array, value):
1✔
1335
    """Find the nearest index and value in an array.
1336

1337
    Parameters
1338
    ----------
1339
    array: array
1340
        The array to inspect.
1341
    value: float
1342
        The value to look for.
1343

1344
    Returns
1345
    -------
1346
    index: int
1347
        The array index of the nearest value close to *value*
1348
    val: float
1349
        The value fo the array at index.
1350

1351
    Examples
1352
    --------
1353
    >>> x = np.arange(0.,10.)
1354
    >>> idx, val = find_nearest(x, 3.3)
1355
    >>> print(idx, val)
1356
    3 3.0
1357
    """
1358
    idx = (np.abs(array - value)).argmin()
1✔
1359
    return idx, array[idx]
1✔
1360

1361

1362
def ensure_dir(directory_name):
1✔
1363
    """Ensure that *directory_name* directory exists. If not, create it.
1364

1365
    Parameters
1366
    ----------
1367
    directory_name: str
1368
        The directory name.
1369

1370
    Examples
1371
    --------
1372
    >>> ensure_dir('tests')
1373
    >>> os.path.exists('tests')
1374
    True
1375
    >>> ensure_dir('tests/mytest')
1376
    >>> os.path.exists('tests/mytest')
1377
    True
1378
    >>> os.rmdir('./tests/mytest')
1379
    """
1380
    if not os.path.exists(directory_name):
1✔
1381
        os.makedirs(directory_name)
1✔
1382

1383

1384
def weighted_avg_and_std(values, weights):
1✔
1385
    """
1386
    Return the weighted average and standard deviation.
1387

1388
    values, weights -- Numpy ndarrays with the same shape.
1389

1390
    For example for the PSF
1391

1392
    x=pixel number
1393
    y=Intensity in pixel
1394

1395
    values-x
1396
    weights=y=f(x)
1397

1398
    """
1399
    average = np.average(values, weights=weights)
1✔
1400
    variance = np.average((values - average) ** 2, weights=weights)  # Fast and numerically precise
1✔
1401
    return average, np.sqrt(variance)
1✔
1402

1403

1404
def hessian_and_theta(data, margin_cut=1):
1✔
1405
    # Check for unannounced API change on hessian_matrix in scikit-image>=0.20
1406
    # See https://github.com/scikit-image/scikit-image/pull/6624
1407
    global _SCIKIT_IMAGE_NEW_HESSIAN
1408

1409
    if _SCIKIT_IMAGE_NEW_HESSIAN is None:
1✔
1410
        from importlib import metadata
1✔
1411
        import packaging
1✔
1412

1413
        vers = packaging.version.parse(metadata.version("scikit-image"))
1✔
1414
        if vers < packaging.version.parse("0.20.0"):
1✔
1415
            _SCIKIT_IMAGE_NEW_HESSIAN = False
×
1416
        else:
1417
            _SCIKIT_IMAGE_NEW_HESSIAN = True
1✔
1418

1419
    # compute hessian matrices on the image
1420
    order = "xy" if _SCIKIT_IMAGE_NEW_HESSIAN else "rc"
1✔
1421
    Hxx, Hxy, Hyy = hessian_matrix(data, sigma=3, order=order)
1✔
1422
    lambda_plus = 0.5 * ((Hxx + Hyy) + np.sqrt((Hxx - Hyy) ** 2 + 4 * Hxy * Hxy))
1✔
1423
    lambda_minus = 0.5 * ((Hxx + Hyy) - np.sqrt((Hxx - Hyy) ** 2 + 4 * Hxy * Hxy))
1✔
1424
    theta = 0.5 * np.arctan2(2 * Hxy, Hxx - Hyy) * 180 / np.pi
1✔
1425
    # remove the margins
1426
    lambda_minus = lambda_minus[margin_cut:-margin_cut, margin_cut:-margin_cut]
1✔
1427
    lambda_plus = lambda_plus[margin_cut:-margin_cut, margin_cut:-margin_cut]
1✔
1428
    theta = theta[margin_cut:-margin_cut, margin_cut:-margin_cut]
1✔
1429
    return lambda_plus, lambda_minus, theta
1✔
1430

1431

1432
def filter_stars_from_bgd(data, margin_cut=1):
1✔
1433
    lambda_plus, lambda_minus, theta = hessian_and_theta(np.copy(data), margin_cut=margin_cut)
×
1434
    # thresholds
1435
    lambda_threshold = np.median(lambda_minus) - 2 * np.std(lambda_minus)
×
1436
    mask = np.where(lambda_minus < lambda_threshold)
×
1437
    data[mask] = np.nan
×
1438
    return data
×
1439

1440

1441
def fftconvolve_gaussian(array, reso):
1✔
1442
    """Convolve an 1D or 2D array with a Gaussian profile of given standard deviation.
1443

1444
    Parameters
1445
    ----------
1446
    array: array
1447
        The array to convolve.
1448
    reso: float
1449
        The standard deviation of the Gaussian profile.
1450

1451
    Returns
1452
    -------
1453
    convolved: array
1454
        The convolved array, same size and shape as input.
1455

1456
    Examples
1457
    --------
1458
    >>> array = np.ones(100)
1459
    >>> output = fftconvolve_gaussian(array, 3)
1460
    >>> print(output[:3])
1461
    [0.5        0.63114657 0.74850168]
1462
    >>> array = np.ones((100, 100))
1463
    >>> output = fftconvolve_gaussian(array, 3)
1464
    >>> print(output[0][:3])
1465
    [0.5        0.63114657 0.74850168]
1466
    >>> array = np.ones((100, 100, 100))
1467
    >>> output = fftconvolve_gaussian(array, 3)
1468
    """
1469
    my_logger = set_logger(__name__)
1✔
1470
    if array.ndim == 2:
1✔
1471
        kernel = gaussian(array.shape[1], reso)
1✔
1472
        kernel /= np.sum(kernel)
1✔
1473
        for i in range(array.shape[0]):
1✔
1474
            array[i] = fftconvolve(array[i], kernel, mode='same')
1✔
1475
    elif array.ndim == 1:
1✔
1476
        kernel = gaussian(array.size, reso)
1✔
1477
        kernel /= np.sum(kernel)
1✔
1478
        array = fftconvolve(array, kernel, mode='same')
1✔
1479
    else:
1480
        my_logger.error(f'\n\tArray dimension must be 1 or 2. Here I have array.ndim={array.ndim}.')
1✔
1481
    return array
1✔
1482

1483

1484
def formatting_numbers(value, error_high, error_low, std=None, label=None):
1✔
1485
    """Format a physical value and its uncertainties. Round the uncertainties
1486
    to the first significant digit, and do the same for the physical value.
1487

1488
    Parameters
1489
    ----------
1490
    value: float
1491
        The physical value.
1492
    error_high: float
1493
        Upper uncertainty.
1494
    error_low: float
1495
        Lower uncertainty
1496
    std: float, optional
1497
        The RMS of the physical parameter (default: None).
1498
    label: str, optional
1499
        The name of the physical parameter to output (default: None).
1500

1501
    Returns
1502
    -------
1503
    text: tuple
1504
        The formatted output strings inside a tuple.
1505

1506
    Examples
1507
    --------
1508
    >>> formatting_numbers(3., 0.789, 0.500, std=0.45, label='test')
1509
    ('test', '3.0', '0.8', '0.5', '0.5')
1510
    >>> formatting_numbers(3., 0.07, 0.008, std=0.03, label='test')
1511
    ('test', '3.000', '0.07', '0.008', '0.03')
1512
    >>> formatting_numbers(3240., 0.2, 0.4, std=0.3)
1513
    ('3240.0', '0.2', '0.4', '0.3')
1514
    >>> formatting_numbers(3240., 230, 420, std=330)
1515
    ('3240', '230', '420', '330')
1516
    >>> formatting_numbers(0, 0.008, 0.04, std=0.03)
1517
    ('0.000', '0.008', '0.040', '0.030')
1518
    >>> formatting_numbers(-55, 0.008, 0.04, std=0.03)
1519
    ('-55.000', '0.008', '0.04', '0.03')
1520
    """
1521
    str_std = ""
1✔
1522
    out = []
1✔
1523
    if label is not None:
1✔
1524
        out.append(label)
1✔
1525
    power10 = min(int(floor(np.log10(np.abs(error_high)))), int(floor(np.log10(np.abs(error_low)))))
1✔
1526
    if np.isclose(0.0, float("%.*f" % (abs(power10), value))):
1✔
1527
        str_value = "%.*f" % (abs(power10), 0)
1✔
1528
        str_error_high = "%.*f" % (abs(power10), error_high)
1✔
1529
        str_error_low = "%.*f" % (abs(power10), error_low)
1✔
1530
        if std is not None:
1✔
1531
            str_std = "%.*f" % (abs(power10), std)
1✔
1532
    elif power10 > 0:
1✔
1533
        str_value = f"{value:.0f}"
1✔
1534
        str_error_high = f"{error_high:.0f}"
1✔
1535
        str_error_low = f"{error_low:.0f}"
1✔
1536
        if std is not None:
1✔
1537
            str_std = f"{std:.0f}"
1✔
1538
    else:
1539
        if int(floor(np.log10(np.abs(error_high)))) == int(floor(np.log10(np.abs(error_low)))):
1✔
1540
            str_value = "%.*f" % (abs(power10), value)
1✔
1541
            str_error_high = f"{error_high:.1g}"
1✔
1542
            str_error_low = f"{error_low:.1g}"
1✔
1543
            if std is not None:
1✔
1544
                str_std = f"{std:.1g}"
1✔
1545
        elif int(floor(np.log10(np.abs(error_high)))) > int(floor(np.log10(np.abs(error_low)))):
1✔
1546
            str_value = "%.*f" % (abs(power10), value)
1✔
1547
            str_error_high = f"{error_high:.2g}"
1✔
1548
            str_error_low = f"{error_low:.1g}"
1✔
1549
            if std is not None:
1✔
1550
                str_std = f"{std:.2g}"
1✔
1551
        else:
1552
            str_value = "%.*f" % (abs(power10), value)
1✔
1553
            str_error_high = f"{error_high:.1g}"
1✔
1554
            str_error_low = f"{error_low:.2g}"
1✔
1555
            if std is not None:
1✔
1556
                str_std = f"{std:.2g}"
1✔
1557
    out += [str_value, str_error_high]
1✔
1558
    # if not np.isclose(error_high, error_low):
1559
    out += [str_error_low]
1✔
1560
    if std is not None:
1✔
1561
        out += [str_std]
1✔
1562
    out = tuple(out)
1✔
1563
    return out
1✔
1564

1565

1566
def pixel_rotation(x, y, theta, x0=0, y0=0):
1✔
1567
    """Rotate a 2D vector (x,y) of an angle theta clockwise.
1568

1569
    Parameters
1570
    ----------
1571
    x: float
1572
        x coordinate
1573
    y: float
1574
        y coordinate
1575
    theta: float
1576
        angle in radians
1577
    x0: float, optional
1578
        x position of the center of rotation (default: 0)
1579
    y0: float, optional
1580
        y position of the center of rotation (default: 0)
1581

1582
    Returns
1583
    -------
1584
    u: float
1585
        rotated x coordinate
1586
    v: float
1587
        rotated y coordinate
1588

1589
    Examples
1590
    --------
1591
    >>> pixel_rotation(0, 0, 45)
1592
    (0.0, 0.0)
1593
    >>> u, v = pixel_rotation(1, 0, np.pi/4)
1594

1595
    .. doctest::
1596
        :hide:
1597

1598
        >>> assert np.isclose(u, 1/np.sqrt(2))
1599
        >>> assert np.isclose(v, -1/np.sqrt(2))
1600
        >>> u, v = pixel_rotation(1, 2, -np.pi/2, x0=1, y0=0)
1601
        >>> assert np.isclose(u, -2)
1602
        >>> assert np.isclose(v, 0)
1603
    """
1604
    u = np.cos(theta) * (x - x0) + np.sin(theta) * (y - y0)
1✔
1605
    v = -np.sin(theta) * (x - x0) + np.cos(theta) * (y - y0)
1✔
1606
    return u, v
1✔
1607

1608

1609
def detect_peaks(image):
1✔
1610
    """
1611
    Takes an image and detect the peaks using the local maximum filter.
1612
    Returns a boolean mask of the peaks (i.e. 1 when
1613
    the pixel's value is the neighborhood maximum, 0 otherwise).
1614
    Only positive peaks are detected (take absolute value or negative value of the
1615
    image to detect the negative ones).
1616

1617
    Parameters
1618
    ----------
1619
    image: array_like
1620
        The image 2D array.
1621

1622
    Returns
1623
    -------
1624
    detected_peaks: array_like
1625
        Boolean maskof the peaks.
1626

1627
    Examples
1628
    --------
1629
    >>> im = np.zeros((50,50))
1630
    >>> im[4,6] = 2
1631
    >>> im[10,20] = -3
1632
    >>> im[49,49] = 1
1633
    >>> detected_peaks = detect_peaks(im)
1634

1635
    .. doctest::
1636
        :hide:
1637

1638
        >>> assert detected_peaks[4,6]
1639
        >>> assert not detected_peaks[10,20]
1640
        >>> assert detected_peaks[49,49]
1641
    """
1642

1643
    # define an 8-connected neighborhood
1644
    neighborhood = generate_binary_structure(2, 2)
1✔
1645

1646
    # apply the local maximum filter; all pixel of maximal value
1647
    # in their neighborhood are set to 1
1648
    local_max = maximum_filter(image, footprint=neighborhood) == image
1✔
1649
    # local_max is a mask that contains the peaks we are
1650
    # looking for, but also the background.
1651
    # In order to isolate the peaks we must remove the background from the mask.
1652

1653
    # we create the mask of the background
1654
    background = (image == 0)
1✔
1655

1656
    # a little technicality: we must erode the background in order to
1657
    # successfully subtract it form local_max, otherwise a line will
1658
    # appear along the background border (artifact of the local maximum filter)
1659
    eroded_background = binary_erosion(background, structure=neighborhood, border_value=50)
1✔
1660

1661
    # we obtain the final mask, containing only peaks,
1662
    # by removing the background from the local_max mask (xor operation)
1663
    detected_peaks = local_max ^ eroded_background
1✔
1664

1665
    return detected_peaks
1✔
1666

1667

1668
def clean_target_spikes(data, saturation):
1✔
1669
    saturated_pixels = np.where(data > saturation)
×
1670
    data[saturated_pixels] = saturation
×
1671
    NY, NX = data.shape
×
1672
    delta = len(saturated_pixels[0])
×
1673
    while delta > 0:
×
1674
        delta = len(saturated_pixels[0])
×
1675
        grady, gradx = np.gradient(data)
×
1676
        for iy in range(1, NY - 1):
×
1677
            for ix in range(1, NX - 1):
×
1678
                # if grady[iy,ix]  > 0.8*np.max(grady) :
1679
                #    data[iy,ix] = data[iy-1,ix]
1680
                # if grady[iy,ix]  < 0.8*np.min(grady) :
1681
                #    data[iy,ix] = data[iy+1,ix]
1682
                if gradx[iy, ix] > 0.8 * np.max(gradx):
×
1683
                    data[iy, ix] = data[iy, ix - 1]
×
1684
                if gradx[iy, ix] < 0.8 * np.min(gradx):
×
1685
                    data[iy, ix] = data[iy, ix + 1]
×
1686
        saturated_pixels = np.where(data >= saturation)
×
1687
        delta = delta - len(saturated_pixels[0])
×
1688
    return data
×
1689

1690

1691
def plot_image_simple(ax, data, scale="lin", title="", units="Image units", cmap=None,
1✔
1692
                      target_pixcoords=None, vmin=None, vmax=None, aspect=None, cax=None):
1693
    """Simple function to plot a spectrum with error bars and labels.
1694

1695
    Parameters
1696
    ----------
1697
    ax: Axes
1698
        Axes instance to make the plot
1699
    data: array_like
1700
        The image data 2D array.
1701
    scale: str
1702
        Scaling of the image (choose between: lin, log or log10, symlog) (default: lin)
1703
    title: str
1704
        Title of the image (default: "")
1705
    units: str
1706
        Units of the image to be written in the color bar label (default: "Image units")
1707
    cmap: colormap
1708
        Color map label (default: None)
1709
    target_pixcoords: array_like, optional
1710
        2D array giving the (x,y) coordinates of the targets on the image: add a scatter plot (default: None)
1711
    vmin: float
1712
        Minimum value of the image (default: None)
1713
    vmax: float
1714
        Maximum value of the image (default: None)
1715
    aspect: str
1716
        Aspect keyword to be passed to imshow (default: None)
1717
    cax: Axes, optional
1718
        Color bar axes if necessary (default: None).
1719

1720
    Examples
1721
    --------
1722

1723
    .. plot::
1724
        :include-source:
1725

1726
        >>> import matplotlib.pyplot as plt
1727
        >>> from spectractor.extractor.images import Image
1728
        >>> from spectractor import parameters
1729
        >>> from spectractor.tools import plot_image_simple
1730
        >>> f, ax = plt.subplots(1,1)
1731
        >>> im = Image('tests/data/reduc_20170605_028.fits', config="./config/ctio.ini")
1732
        >>> plot_image_simple(ax, im.data, scale="symlog", units="ADU", target_pixcoords=(815,580),
1733
        ...                     title="tests/data/reduc_20170605_028.fits")
1734
        >>> if parameters.DISPLAY: plt.show()
1735
    """
1736
    if scale == "log" or scale == "log10":
1✔
1737
        # removes the zeros and negative pixels first
1738
        zeros = np.where(data <= 0)
1✔
1739
        min_noz = np.min(data[np.where(data > 0)])
1✔
1740
        data[zeros] = min_noz
1✔
1741
        # apply log
1742
        # data = np.log10(data)
1743
    if scale == "log10" or scale == "log":
1✔
1744
        norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
1✔
1745
    elif scale == "symlog":
1✔
1746
        norm = matplotlib.colors.SymLogNorm(vmin=vmin, vmax=vmax, linthresh=10, base=10)
1✔
1747
    else:
1748
        norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
1✔
1749
    im = ax.imshow(data, origin='lower', cmap=cmap, norm=norm, aspect=aspect)
1✔
1750
    ax.grid(color='silver', ls='solid')
1✔
1751
    ax.grid(True)
1✔
1752
    ax.set_xlabel(parameters.PLOT_XLABEL)
1✔
1753
    ax.set_ylabel(parameters.PLOT_YLABEL)
1✔
1754
    cb = plt.colorbar(im, ax=ax, cax=cax)
1✔
1755
    if scale == "lin":
1✔
1756
        cb.formatter.set_powerlimits((0, 0))
1✔
1757
        cb.locator = MaxNLocator(7, prune=None)
1✔
1758
        cb.update_ticks()
1✔
1759
    cb.set_label('%s (%s scale)' % (units, scale))  # ,fontsize=16)
1✔
1760
    if title != "":
1✔
1761
        ax.set_title(title)
1✔
1762
    if target_pixcoords is not None:
1✔
1763
        ax.scatter(target_pixcoords[0], target_pixcoords[1], marker='o', s=100, edgecolors='k', facecolors='none',
1✔
1764
                   label='Target', linewidth=2)
1765

1766

1767
def plot_spectrum_simple(ax, lambdas, data, data_err=None, xlim=None, color='r', linestyle='none', lw=2, label='',
1✔
1768
                         title='', units='', marker='o'):
1769
    """Simple function to plot a spectrum with error bars and labels.
1770

1771
    Parameters
1772
    ----------
1773
    ax: Axes
1774
        Axes instance to make the plot.
1775
    lambdas: array
1776
        The wavelengths array.
1777
    data: array
1778
        The spectrum data array.
1779
    data_err: array, optional
1780
        The spectrum uncertainty array (default: None).
1781
    xlim: list, optional
1782
        List of minimum and maximum abscisses (default: None).
1783
    color: str, optional
1784
        String for the color of the spectrum (default: 'r').
1785
    linestyle: str, optional
1786
        String for the linestyle of the spectrum (default: 'none').
1787
    lw: int, optional
1788
        Integer for line width (default: 2).
1789
    marker: str, optional
1790
        Character for marker style (default: 'o').
1791
    label: str, optional
1792
        String label for the plot legend (default: '').
1793
    title: str, optional
1794
        String label for the plot title (default: '').
1795
    units: str, optional
1796
        String label for the plot units (default: '').
1797

1798

1799
    Examples
1800
    --------
1801

1802
    .. plot::
1803
        :include-source:
1804

1805
        >>> import matplotlib.pyplot as plt
1806
        >>> from spectractor.extractor.spectrum import Spectrum
1807
        >>> from spectractor import parameters
1808
        >>> from spectractor.tools import plot_spectrum_simple
1809
        >>> f, ax = plt.subplots(1,1)
1810
        >>> s = Spectrum(file_name='tests/data/reduc_20170530_134_spectrum.fits')
1811
        >>> plot_spectrum_simple(ax, s.lambdas, s.data, data_err=s.err, xlim=None, color='r', label='test')
1812
        >>> if parameters.DISPLAY: plt.show()
1813
    """
1814
    xs = lambdas
1✔
1815
    if xs is None:
1✔
1816
        xs = np.arange(data.size)
×
1817
    if data_err is not None:
1✔
1818
        ax.errorbar(xs, data, yerr=data_err, color=color, marker=marker, lw=lw, label=label,
1✔
1819
                    zorder=0, markersize=2, linestyle=linestyle)
1820
    else:
1821
        ax.plot(xs, data, color=color, lw=lw, label=label, linestyle=linestyle)
×
1822
    ax.grid(True)
1✔
1823
    if xlim is None and lambdas is not None:
1✔
1824
        xlim = [parameters.LAMBDA_MIN, parameters.LAMBDA_MAX]
1✔
1825
    ax.set_xlim(xlim)
1✔
1826
    try:
1✔
1827
        ax.set_ylim(0., np.nanmax(data[np.logical_and(xs > xlim[0], xs < xlim[1])]) * 1.2)
1✔
1828
    except ValueError:
×
1829
        pass
×
1830
    if lambdas is not None:
1✔
1831
        ax.set_xlabel(r'$\lambda$ [nm]')
1✔
1832
    else:
1833
        ax.set_xlabel('X [pixels]')
×
1834
    if units != '':
1✔
1835
        ax.set_ylabel(f'Flux [{units}]')
1✔
1836
    else:
1837
        ax.set_ylabel(f'Flux')
1✔
1838
    if title != '':
1✔
1839
        ax.set_title(title)
1✔
1840

1841

1842
def plot_compass_simple(ax, parallactic_angle=None, arrow_size=0.1, origin=[0.15, 0.15]):
1✔
1843
    """Plot small (N,W) compass, and optionally zenith direction.
1844

1845
    Parameters
1846
    ----------
1847
    ax: Axes
1848
        Axes instance to make the plot.
1849
    parallactic_angle: float, optional
1850
        Value is the parallactic angle with respect to North eastward and plot the zenith direction (default: None).
1851
    arrow_size: float, optional
1852
        Length of the arrow as a fraction of axe sizes (default: 0.1)
1853
    origin: array_like, optional
1854
        (x0, y0) position of the compass as axes fraction (default: [0.15, 0.15]).
1855

1856
    Examples
1857
    --------
1858

1859
    >>> from spectractor.extractor.images import Image
1860
    >>> from spectractor import parameters
1861
    >>> from spectractor.tools import plot_image_simple, plot_compass_simple
1862
    >>> f, ax = plt.subplots(1,1)
1863
    >>> im = Image('tests/data/reduc_20170605_028.fits', config="./config/ctio.ini")
1864
    >>> plot_image_simple(ax, im.data, scale="symlog", units="ADU", target_pixcoords=(750,700),
1865
    ...                   title='tests/data/reduc_20170530_134.fits')
1866
    >>> plot_compass_simple(ax, im.parallactic_angle)
1867
    >>> if parameters.DISPLAY: plt.show()
1868

1869
    """
1870
    # North arrow
1871
    N_arrow = [0, arrow_size]
1✔
1872
    N_xy = np.asarray(flip_and_rotate_radec_vector_to_xy_vector(N_arrow[0], N_arrow[1],
1✔
1873
                                                                camera_angle=parameters.OBS_CAMERA_ROTATION,
1874
                                                                flip_ra_sign=parameters.OBS_CAMERA_RA_FLIP_SIGN,
1875
                                                                flip_dec_sign=parameters.OBS_CAMERA_DEC_FLIP_SIGN))
1876
    ax.annotate("N", xy=origin, xycoords='axes fraction', xytext=N_xy + origin, textcoords='axes fraction',
1✔
1877
                arrowprops=dict(arrowstyle="<|-", fc="yellow", ec="yellow"), color="yellow",
1878
                horizontalalignment='center', verticalalignment='center')
1879
    # West arrow
1880
    W_arrow = [arrow_size, 0]
1✔
1881
    W_xy = np.asarray(flip_and_rotate_radec_vector_to_xy_vector(W_arrow[0], W_arrow[1],
1✔
1882
                                                                camera_angle=parameters.OBS_CAMERA_ROTATION,
1883
                                                                flip_ra_sign=parameters.OBS_CAMERA_RA_FLIP_SIGN,
1884
                                                                flip_dec_sign=parameters.OBS_CAMERA_DEC_FLIP_SIGN))
1885
    ax.annotate("W", xy=origin, xycoords='axes fraction', xytext=W_xy + origin, textcoords='axes fraction',
1✔
1886
                arrowprops=dict(arrowstyle="<|-", fc="yellow", ec="yellow"), color="yellow",
1887
                horizontalalignment='center', verticalalignment='center')
1888
    # Central dot
1889
    xmin, xmax = ax.get_xlim()
1✔
1890
    ymin, ymax = ax.get_ylim()
1✔
1891
    ax.scatter(origin[0] * xmax, origin[1] * ymax, color="yellow", s=20)
1✔
1892
    # Zenith direction
1893
    if parallactic_angle is not None:
1✔
1894
        p_arrow = [0, arrow_size]  # angle with respect to North in RADEC counterclockwise
1✔
1895
        angle = parameters.OBS_CAMERA_ROTATION + parameters.OBS_CAMERA_RA_FLIP_SIGN * parallactic_angle
1✔
1896
        p_xy = np.asarray(flip_and_rotate_radec_vector_to_xy_vector(p_arrow[0], p_arrow[1],
1✔
1897
                                                                    camera_angle=angle,
1898
                                                                    flip_ra_sign=parameters.OBS_CAMERA_RA_FLIP_SIGN,
1899
                                                                    flip_dec_sign=parameters.OBS_CAMERA_DEC_FLIP_SIGN))
1900
        ax.annotate("Z", xy=origin, xycoords='axes fraction', xytext=p_xy + origin, textcoords='axes fraction',
1✔
1901
                    arrowprops=dict(arrowstyle="<|-", fc="lightgreen", ec="lightgreen"), color="lightgreen",
1902
                    horizontalalignment='center', verticalalignment='center')
1903

1904

1905
def load_fits(file_name, hdu_index=0):
1✔
1906
    """Generic function to load a FITS file.
1907

1908
    Parameters
1909
    ----------
1910
    file_name: str
1911
        The FITS file name.
1912
    hdu_index: int, str, optional
1913
        The HDU index in the file (default: 0).
1914

1915
    Returns
1916
    -------
1917
    header: fits.Header
1918
        Header of the FITS file.
1919
    data: np.array
1920
        The data array.
1921

1922
    Examples
1923
    --------
1924
    >>> header, data = load_fits("./tests/data/reduc_20170530_134.fits")
1925
    >>> header["DATE-OBS"]
1926
    '2017-05-31T02:53:52.356'
1927
    >>> data.shape
1928
    (2048, 2048)
1929

1930
    """
1931
    hdu_list = fits.open(file_name)
1✔
1932
    header = hdu_list[hdu_index].header
1✔
1933
    data = hdu_list[hdu_index].data
1✔
1934
    hdu_list.close()  # need to free allocation for file description
1✔
1935
    return header, data
1✔
1936

1937

1938
def save_fits(file_name, header, data, overwrite=False):
1✔
1939
    """Generic function to save a FITS file.
1940

1941
    Parameters
1942
    ----------
1943
    file_name: str
1944
        The FITS file name.
1945
    header: fits.Header
1946
        Header of the FITS file.
1947
    data: np.array
1948
        The data array.
1949
    overwrite: bool, optional
1950
        If True and the file already exists, it is overwritten (default: False).
1951

1952
    Examples
1953
    --------
1954

1955
    >>> header, data = load_fits("./tests/data/reduc_20170530_134.fits")
1956
    >>> save_fits("./outputs/save_fits_test.fits", header, data, overwrite=True)
1957
    >>> assert os.path.isfile("./outputs/save_fits_test.fits")
1958

1959
    .. doctest:
1960
        :hide:
1961

1962
        >>> os.remove("./outputs/save_fits_test.fits")
1963

1964
    """
1965
    hdu = fits.PrimaryHDU()
1✔
1966
    hdu.header = header
1✔
1967
    hdu.data = data
1✔
1968
    output_directory = '/'.join(file_name.split('/')[:-1])
1✔
1969
    ensure_dir(output_directory)
1✔
1970
    hdu.writeto(file_name, overwrite=overwrite)
1✔
1971

1972

1973
def dichotomie(f, a, b, epsilon):
1✔
1974
    """
1975
    Dichotomie method to find a function root.
1976

1977
    Parameters
1978
    ----------
1979
    f: callable
1980
        The function
1981
    a: float
1982
        Left bound to the expected root
1983
    b: float
1984
        Right bound to the expected root
1985
    epsilon: float
1986
        Precision
1987

1988
    Returns
1989
    -------
1990
    root: float
1991
        The root of the function.
1992

1993
    Examples
1994
    --------
1995

1996
    Search for the Gaussian FWHM:
1997

1998
    >>> p = [1,0,1]
1999
    >>> xx = np.arange(-10,10,0.1)
2000
    >>> PSF = gauss(xx, *p)
2001
    >>> def eq(x):
2002
    ...     return np.interp(x, xx, PSF) - 0.5
2003
    >>> root = dichotomie(eq, 0, 10, 1e-6)
2004
    >>> assert np.isclose(2*root, 2.355*p[2], 1e-3)
2005
    """
2006
    x = 0.5 * (a + b)
1✔
2007
    N = 1
1✔
2008
    while b - a > epsilon and N < 100:
1✔
2009
        x = 0.5 * (a + b)
1✔
2010
        if f(x) * f(a) > 0:
1✔
2011
            a = x
1✔
2012
        else:
2013
            b = x
1✔
2014
        N += 1
1✔
2015
    return x
1✔
2016

2017

2018
def wavelength_to_rgb(wavelength, gamma=0.8):
1✔
2019
    """ taken from http://www.noah.org/wiki/Wavelength_to_RGB_in_Python
2020
    This converts a given wavelength of light to an
2021
    approximate RGB color value. The wavelength must be given
2022
    in nanometers in the range from 380 nm through 750 nm
2023
    (789 THz through 400 THz).
2024

2025
    Based on code by Dan Bruton
2026
    http://www.physics.sfasu.edu/astro/color/spectra.html
2027
    Additionally alpha value set to 0.5 outside range
2028
    """
2029
    wavelength = float(wavelength)
1✔
2030
    if 380 <= wavelength <= 750:
1✔
2031
        A = 1.
1✔
2032
    else:
2033
        A = 0.5
1✔
2034
    if wavelength < 380:
1✔
2035
        wavelength = 380.
1✔
2036
    if wavelength > 750:
1✔
2037
        wavelength = 750.
1✔
2038
    if 380 <= wavelength <= 440:
1✔
2039
        attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380)
1✔
2040
        R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma
1✔
2041
        G = 0.0
1✔
2042
        B = (1.0 * attenuation) ** gamma
1✔
2043
    elif 440 <= wavelength <= 490:
1✔
2044
        R = 0.0
1✔
2045
        G = ((wavelength - 440) / (490 - 440)) ** gamma
1✔
2046
        B = 1.0
1✔
2047
    elif 490 <= wavelength <= 510:
1✔
2048
        R = 0.0
1✔
2049
        G = 1.0
1✔
2050
        B = (-(wavelength - 510) / (510 - 490)) ** gamma
1✔
2051
    elif 510 <= wavelength <= 580:
1✔
2052
        R = ((wavelength - 510) / (580 - 510)) ** gamma
1✔
2053
        G = 1.0
1✔
2054
        B = 0.0
1✔
2055
    elif 580 <= wavelength <= 645:
1✔
2056
        R = 1.0
1✔
2057
        G = (-(wavelength - 645) / (645 - 580)) ** gamma
1✔
2058
        B = 0.0
1✔
2059
    elif 645 <= wavelength <= 750:
1✔
2060
        attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645)
1✔
2061
        R = (1.0 * attenuation) ** gamma
1✔
2062
        G = 0.0
1✔
2063
        B = 0.0
1✔
2064
    else:
2065
        R = 0.0
×
2066
        G = 0.0
×
2067
        B = 0.0
×
2068
    return R, G, B, A
1✔
2069

2070

2071
def from_lambda_to_colormap(lambdas):
1✔
2072
    """Convert an array of wavelength in nm into a color map.
2073

2074
    Parameters
2075
    ----------
2076
    lambdas: array_like
2077
        Wavelength array in nm.
2078

2079
    Returns
2080
    -------
2081
    spectral_map: matplotlib.colors.LinearSegmentedColormap
2082
        Color map.
2083

2084
    Examples
2085
    --------
2086
    >>> lambdas = np.arange(300, 1000, 10)
2087
    >>> spec = from_lambda_to_colormap(lambdas)
2088
    >>> plt.scatter(lambdas, np.zeros(lambdas.size), cmap=spec, c=lambdas)  #doctest: +ELLIPSIS
2089
    <matplotlib.collections.PathCollection object at ...>
2090
    >>> plt.grid()
2091
    >>> plt.xlabel("Wavelength [nm]")  #doctest: +ELLIPSIS
2092
    Text(..., 'Wavelength [nm]')
2093
    >>> plt.show()
2094

2095
    ..plot::
2096

2097
        import numpy as np
2098
        import matplotlib.pyplot as plt
2099
        from spectractor.tools import from_lambda_to_colormap
2100
        lambdas = np.arange(300, 1000, 10)
2101
        spec = from_lambda_to_colormap(lambdas)
2102
        plt.scatter(lambdas, np.zeros(lambdas.size), cmap=spec, c=lambdas)
2103
        plt.xlabel("Wavelength [nm]")
2104
        plt.grid()
2105
        plt.show()
2106

2107
    """
2108
    colorlist = [wavelength_to_rgb(lbda) for lbda in lambdas]
1✔
2109
    spectralmap = matplotlib.colors.LinearSegmentedColormap.from_list("spectrum", colorlist)
1✔
2110
    return spectralmap
1✔
2111

2112

2113
def rebin(arr, new_shape, FLAG_MAKESUM=True):
1✔
2114
    """Rebin and reshape a numpy array.
2115

2116
    Parameters
2117
    ----------
2118
    arr: np.array
2119
        Numpy array to be reshaped.
2120
    new_shape: array_like
2121
        New shape of the array.
2122

2123
    Returns
2124
    -------
2125
    arr_rebinned: np.array
2126
        Rebinned array.
2127

2128
    Examples
2129
    --------
2130
    >>> a = 4 * np.ones((10, 10))
2131
    >>> b = rebin(a, (5, 5))
2132
    >>> b
2133
    array([[4., 4., 4., 4., 4.],
2134
           [4., 4., 4., 4., 4.],
2135
           [4., 4., 4., 4., 4.],
2136
           [4., 4., 4., 4., 4.],
2137
           [4., 4., 4., 4., 4.]])
2138
    """
2139

2140

2141

2142
    if np.any(new_shape * parameters.CCD_REBIN != arr.shape):
1✔
2143
        shape_cropped = new_shape * parameters.CCD_REBIN
1✔
2144
        margins = np.asarray(arr.shape) - shape_cropped
1✔
2145
        arr = arr[:-margins[0], :-margins[1]]
1✔
2146
    shape = (new_shape[0], arr.shape[0] // new_shape[0],
1✔
2147
             new_shape[1], arr.shape[1] // new_shape[1])
2148

2149
    if FLAG_MAKESUM:
1✔
2150
        # SDC : conservation of energy
2151
        return arr.reshape(shape).sum(-1).sum(1)
1✔
2152

2153
    else:
2154
        # SDC : sure of conservation of energy
2155
        return arr.reshape(shape).mean(-1).mean(1)
×
2156

2157

2158
def set_wcs_output_directory(file_name, output_directory=""):
1✔
2159
    """Returns the WCS output directory corresponding to the analyzed image. The name of the directory is
2160
    the anme of the image with the suffix _wcs.
2161

2162
    Parameters
2163
    ----------
2164
    file_name: str
2165
        File name of the image.
2166
    output_directory: str, optional
2167
        If not set, the main output directory is the one the image,
2168
        otherwise the specified directory is taken (default: "").
2169

2170
    Returns
2171
    -------
2172
    output: str
2173
        The name of the output directory
2174

2175
    Examples
2176
    --------
2177
    >>> set_wcs_output_directory("image.fits", output_directory="")
2178
    'image_wcs'
2179
    >>> set_wcs_output_directory("image.png", output_directory="outputs")
2180
    'outputs/image_wcs'
2181

2182
    """
2183
    outdir = os.path.dirname(file_name)
1✔
2184
    if output_directory != "":
1✔
2185
        outdir = output_directory
1✔
2186
    output_directory = os.path.join(outdir, os.path.splitext(os.path.basename(file_name))[0]) + "_wcs"
1✔
2187
    return output_directory
1✔
2188

2189

2190
def set_wcs_tag(file_name):
1✔
2191
    """Returns the WCS tag name associated to the analyzed image: the file name without the extension.
2192

2193
    Parameters
2194
    ----------
2195
    file_name: str
2196
        File name of the image.
2197

2198
    Returns
2199
    -------
2200
    tag: str
2201
        The tag.
2202

2203
    Examples
2204
    --------
2205
    >>> set_wcs_tag("image.fits")
2206
    'image'
2207

2208
    """
2209
    tag = os.path.splitext(os.path.basename(file_name))[0]
1✔
2210
    return tag
1✔
2211

2212

2213
def set_wcs_file_name(file_name, output_directory=""):
1✔
2214
    """Returns the WCS file name associated to the analyzed image, placed in the output directory.
2215
    The extension is .wcs.
2216

2217
    Parameters
2218
    ----------
2219
    file_name: str
2220
        File name of the image.
2221
    output_directory: str, optional
2222
        If not set, the main output directory is the one the image,
2223
        otherwise the specified directory is taken (default: "").
2224

2225
    Returns
2226
    -------
2227
    wcs_file_name: str
2228
        The WCS file name.
2229

2230
    Examples
2231
    --------
2232
    >>> set_wcs_file_name("image.fits", output_directory="")
2233
    'image_wcs/image.wcs'
2234
    >>> set_wcs_file_name("image.png", output_directory="outputs")
2235
    'outputs/image_wcs/image.wcs'
2236

2237
    """
2238
    output_directory = set_wcs_output_directory(file_name, output_directory=output_directory)
1✔
2239
    tag = set_wcs_tag(file_name)
1✔
2240
    return os.path.join(output_directory, tag + '.wcs')
1✔
2241

2242

2243
def set_sources_file_name(file_name, output_directory=""):
1✔
2244
    """Returns the file name containing the deteted sources associated to the analyzed image,
2245
    placed in the output directory. The suffix is _source.fits.
2246

2247
    Parameters
2248
    ----------
2249
    file_name: str
2250
        File name of the image.
2251
    output_directory: str, optional
2252
        If not set, the main output directory is the one the image,
2253
        otherwise the specified directory is taken (default: "").
2254

2255
    Returns
2256
    -------
2257
    sources_file_name: str
2258
        The detected sources file name.
2259

2260
    Examples
2261
    --------
2262
    >>> set_sources_file_name("image.fits", output_directory="")
2263
    'image_wcs/image.axy'
2264
    >>> set_sources_file_name("image.png", output_directory="outputs")
2265
    'outputs/image_wcs/image.axy'
2266

2267
    """
2268
    output_directory = set_wcs_output_directory(file_name, output_directory=output_directory)
1✔
2269
    tag = set_wcs_tag(file_name)
1✔
2270
    return os.path.join(output_directory, f"{tag}.axy")
1✔
2271

2272

2273
def set_gaia_catalog_file_name(file_name, output_directory=""):
1✔
2274
    """Returns the file name containing the Gaia catalog associated to the analyzed image,
2275
    placed in the output directory. The suffix is _gaia.ecsv.
2276

2277
    Parameters
2278
    ----------
2279
    file_name: str
2280
        File name of the image.
2281
    output_directory: str, optional
2282
        If not set, the main output directory is the one the image,
2283
        otherwise the specified directory is taken (default: "").
2284

2285
    Returns
2286
    -------
2287
    sources_file_name: str
2288
        The Gaia catalog file name.
2289

2290
    Examples
2291
    --------
2292
    >>> set_gaia_catalog_file_name("image.fits", output_directory="")
2293
    'image_wcs/image_gaia.ecsv'
2294
    >>> set_gaia_catalog_file_name("image.png", output_directory="outputs")
2295
    'outputs/image_wcs/image_gaia.ecsv'
2296

2297
    """
2298
    output_directory = set_wcs_output_directory(file_name, output_directory=output_directory)
1✔
2299
    tag = set_wcs_tag(file_name)
1✔
2300
    return os.path.join(output_directory, f"{tag}_gaia.ecsv")
1✔
2301

2302

2303
def load_wcs_from_file(file_name):
1✔
2304
    """Open the WCS FITS file and returns a WCS astropy object.
2305

2306
    Parameters
2307
    ----------
2308
    file_name: str
2309
        File name of the WCS FITS file.
2310

2311
    Returns
2312
    -------
2313
    wcs: WCS
2314
        WCS Astropy object.
2315

2316
    """
2317
    # Load the FITS hdulist using astropy.io.fits
2318
    hdulist = fits.open(file_name)
1✔
2319
    # Parse the WCS keywords in the primary HDU
2320
    with warnings.catch_warnings():
1✔
2321
        warnings.filterwarnings("ignore")
1✔
2322
        wcs = WCS.WCS(hdulist[0].header, fix=False)
1✔
2323
    return wcs
1✔
2324

2325

2326
def imgslice(slicespec):
1✔
2327
    """
2328
    Utility function: convert a FITS slice specification (1-based)
2329
    into the corresponding numpy array slice spec (0-based as python does, xy swapped).
2330

2331
    Parameters
2332
    ----------
2333
    slicespec: str
2334
        FITS slice specification with the format [xmin:xmax,ymin:ymax]
2335

2336
    Returns
2337
    -------
2338
    slice: slice
2339
        Slice object to be injected in a np.array for instance.
2340

2341
    Examples
2342
    --------
2343
    >>> imgslice('[11:522,1:2002]')
2344
    (slice(0, 2002, None), slice(10, 522, None))
2345
    """
2346

2347
    parts = slicespec.replace('[', '').replace(']', '').split(',')
1✔
2348
    xbegin, xend = [int(i) for i in parts[0].split(':')]
1✔
2349
    ybegin, yend = [int(i) for i in parts[1].split(':')]
1✔
2350
    xbegin -= 1
1✔
2351
    ybegin -= 1
1✔
2352
    return np.s_[ybegin:yend, xbegin:xend]
1✔
2353

2354

2355
def compute_correlation_matrix(cov):
1✔
2356
    rho = np.zeros_like(cov)
1✔
2357
    for i in range(cov.shape[0]):
1✔
2358
        for j in range(cov.shape[1]):
1✔
2359
            rho[i, j] = cov[i, j] / np.sqrt(cov[i, i] * cov[j, j])
1✔
2360
    return rho
1✔
2361

2362

2363
def plot_correlation_matrix_simple(ax, rho, axis_names=None, ipar=None):
1✔
2364
    if ipar is None:
1✔
2365
        ipar = np.arange(rho.shape[0]).astype(int)
1✔
2366
    im = plt.imshow(rho[ipar[:, None], ipar], interpolation="nearest", cmap='bwr', vmin=-1, vmax=1)
1✔
2367
    ax.set_title("Correlation matrix")
1✔
2368
    if axis_names is not None:
1✔
2369
        names = [axis_names[ip] for ip in ipar]
1✔
2370
        plt.xticks(np.arange(ipar.size), names, rotation='vertical', fontsize=15)
1✔
2371
        plt.yticks(np.arange(ipar.size), names, fontsize=15)
1✔
2372
    cbar = plt.colorbar(im)
1✔
2373
    cbar.ax.tick_params(labelsize=15)
1✔
2374
    plt.gcf().tight_layout()
1✔
2375

2376

2377
def resolution_operator(cov, Q, reg):
1✔
2378
    N = cov.shape[0]
×
2379
    return np.eye(N) - reg * cov @ Q
×
2380

2381

2382
def flip_and_rotate_radec_vector_to_xy_vector(ra, dec, camera_angle=0, flip_ra_sign=1, flip_dec_sign=1):
1✔
2383
    """Flip and rotate the vectors in pixels along (RA,DEC) directions to (x, y) image coordinates.
2384
    The parity transformations are applied first, then rotation.
2385

2386
    Parameters
2387
    ----------
2388
    ra: array_like
2389
        Vector coordinates along RA direction.
2390
    dec: array_like
2391
        Vector coordinates along DEC direction.
2392
    camera_angle: float
2393
        Angle of the camera between y axis and the North Celestial Pole counterclockwise, or equivalently between
2394
        the x axis and the West direction counterclokwise. Units are degrees. (default: 0).
2395
    flip_ra_sign: -1, 1, optional
2396
        Flip RA axis is value is -1 (default: 1).
2397
    flip_dec_sign: -1, 1, optional
2398
        Flip DEC axis is value is -1 (default: 1).
2399

2400
    Returns
2401
    -------
2402
    x: array_like
2403
       Vector coordinates along the x direction.
2404
    y: array_like
2405
       Vector coordinates along the y direction.
2406

2407
    Examples
2408
    --------
2409

2410
    >>> from spectractor import parameters
2411
    >>> parameters.OBS_CAMERA_ROTATION = 180
2412
    >>> parameters.OBS_CAMERA_DEC_FLIP_SIGN = 1
2413
    >>> parameters.OBS_CAMERA_RA_FLIP_SIGN = 1
2414

2415
    North vector
2416

2417
    >>> N_ra, N_dec = [0, 1]
2418

2419
    Compute North direction in (x, y) frame
2420

2421
    >>> flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 0, flip_ra_sign=1, flip_dec_sign=1)
2422
    (0.0, 1.0)
2423
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 180, flip_ra_sign=1, flip_dec_sign=1)
2424
    '-0.0, -1.0'
2425
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 90, flip_ra_sign=1, flip_dec_sign=1)
2426
    '-1.0, 0.0'
2427
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 90, flip_ra_sign=1, flip_dec_sign=-1)
2428
    '1.0, -0.0'
2429
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 90, flip_ra_sign=-1, flip_dec_sign=-1)
2430
    '1.0, -0.0'
2431
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 0, flip_ra_sign=1, flip_dec_sign=-1)
2432
    '0.0, -1.0'
2433
    >>> "%.1f, %.1f" % flip_and_rotate_radec_vector_to_xy_vector(N_ra, N_dec, 0, flip_ra_sign=-1, flip_dec_sign=1)
2434
    '0.0, 1.0'
2435

2436
    """
2437
    flip = np.array([[flip_ra_sign, 0], [0, flip_dec_sign]], dtype=float)
1✔
2438
    a = - camera_angle * np.pi / 180
1✔
2439
    # minus sign as rotation matrix is apply on the right on the adr vector
2440
    rotation = np.array([[np.cos(a), -np.sin(a)], [np.sin(a), np.cos(a)]], dtype=float)
1✔
2441
    transformation = flip @ rotation
1✔
2442
    x, y = (np.asarray([ra, dec]).T @ transformation).T
1✔
2443
    return x, y
1✔
2444

2445

2446
def get_uvspec_binary():
1✔
2447
    """Get the path to the libradtran uvspec binary if available.
2448

2449
    Returns
2450
    -------
2451
    uvspec_binary : `str`
2452
        Path to the uvspec binary if available, else ``None``.
2453
    """
2454
    return shutil.which('uvspec')
1✔
2455

2456

2457
def uvspec_available():
1✔
2458
    """Check if the uvspec binary is available.
2459

2460
    Returns
2461
    -------
2462
    is_available : `bool`
2463
        Is the binary available?
2464
    """
2465
    return get_uvspec_binary() is not None
1✔
2466

2467

2468
if __name__ == "__main__":
1✔
2469
    import doctest
1✔
2470

2471
    doctest.testmod()
1✔
2472

2473

2474
def iraf_source_detection(data_wo_bkg, sigma=3.0, fwhm=3.0, threshold_std_factor=5, mask=None):
1✔
2475
    """Function to detect point-like sources in a data array.
2476

2477
    This function use the photutils IRAFStarFinder module to search for sources in an image. This finder
2478
    is better than DAOStarFinder for the astrometry of isolated sources but less good for photometry.
2479

2480
    Parameters
2481
    ----------
2482
    data_wo_bkg: array_like
2483
        The image data array. It works better if the background was subtracted before.
2484
    sigma: float
2485
        Standard deviation value for sigma clipping function before finding sources (default: 3.0).
2486
    fwhm: float
2487
        Full width half maximum for the source detection algorithm (default: 3.0).
2488
    threshold_std_factor: float
2489
        Only sources with a flux above this value times the RMS of the images are kept (default: 5).
2490
    mask: array_like, optional
2491
        Boolean array to mask image pixels (default: None).
2492

2493
    Returns
2494
    -------
2495
    sources: Table
2496
        Astropy table containing the source centroids and fluxes, ordered by decreasing magnitudes.
2497

2498
    Examples
2499
    --------
2500

2501
    >>> N = 100
2502
    >>> data = np.ones((N, N))
2503
    >>> yy, xx = np.mgrid[:N, :N]
2504
    >>> x_center, y_center = 20, 30
2505
    >>> data += 10*np.exp(-((x_center-xx)**2+(y_center-yy)**2)/10)
2506
    >>> sources = iraf_source_detection(data)
2507
    >>> print(float(sources["xcentroid"]), float(sources["ycentroid"]))
2508
    20.0 30.0
2509

2510
    .. doctest:
2511
        :hide:
2512

2513
        >>> assert len(sources) == 1
2514
        >>> assert sources["xcentroid"] == x_center
2515
        >>> assert sources["ycentroid"] == y_center
2516

2517
    .. plot:
2518

2519
        from spectractor.tools import plot_image_simple
2520
        from spectractor.astrometry import source_detection
2521
        import numpy as np
2522
        import matplotlib.pyplot as plt
2523

2524
        N = 100
2525
        data = np.ones((N, N))
2526
        yy, xx = np.mgrid[:N, :N]
2527
        x_center, y_center = 20, 30
2528
        data += 10*np.exp(-((x_center-xx)**2+(y_center-yy)**2)/10)
2529
        sources = iraf_source_detection(data)
2530
        fig = plt.figure(figsize=(6,5))
2531
        plot_image_simple(plt.gca(), data, target_pixcoords=(sources["xcentroid"], sources["ycentroid"]))
2532
        fig.tight_layout()
2533
        plt.show()
2534

2535
    """
2536
    mean, median, std = sigma_clipped_stats(data_wo_bkg, sigma=sigma)
1✔
2537
    #fwhm = 5
2538
    #threshold_std_factor = 3
2539
    if mask is None:
1✔
2540
        mask = np.zeros(data_wo_bkg.shape, dtype=bool)
1✔
2541
    # daofind = DAOStarFinder(fwhm=fwhm, threshold=threshold_std_factor * std, exclude_border=True)
2542
    # sources = daofind(data_wo_bkg - median, mask=mask)
2543
    iraffind = IRAFStarFinder(fwhm=fwhm, threshold=threshold_std_factor * std, exclude_border=True)
1✔
2544
    sources = iraffind(data_wo_bkg - median, mask=mask)
1✔
2545
    for col in sources.colnames:
1✔
2546
        sources[col].info.format = '%.8g'  # for consistent table output
1✔
2547
    sources.sort('mag')
1✔
2548
    if parameters.DEBUG:
1✔
2549
        positions = np.array((sources['xcentroid'], sources['ycentroid']))
1✔
2550
        plot_image_simple(plt.gca(), data_wo_bkg, scale="symlog", target_pixcoords=positions)
1✔
2551
        if parameters.DISPLAY:
1✔
2552
            plt.show()
×
2553
        if parameters.PdfPages:
1✔
2554
            parameters.PdfPages.savefig()
×
2555

2556
    return sources
1✔
2557

2558

2559
class NumpyArrayEncoder(json.JSONEncoder):
1✔
2560
    def default(self, obj):
1✔
2561
        if isinstance(obj, np.ndarray):
1✔
2562
            return obj.tolist()
1✔
2563
        return json.JSONEncoder.default(self, obj)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc