• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

desihub / desispec / 19375428690

14 Nov 2025 07:29PM UTC coverage: 37.704% (-0.009%) from 37.713%
19375428690

Pull #2562

github

sbailey
more --no-build-isolation
Pull Request #2562: remove pin on pip version

12985 of 34439 relevant lines covered (37.7%)

0.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.4
/py/desispec/sky.py
1
"""
2
desispec.sky
3
============
4

5
Utility functions to compute a sky model and subtract it.
6
"""
7

8
import os
1✔
9
import numpy as np
1✔
10
from collections import OrderedDict
1✔
11

12
from desispec.resolution import Resolution
1✔
13
from desispec.linalg import cholesky_solve
1✔
14
from desispec.linalg import cholesky_invert
1✔
15
from desispec.linalg import spline_fit
1✔
16
from desiutil.log import get_logger
1✔
17
from desispec import util
1✔
18
from desiutil import stats as dustat
1✔
19
import scipy,scipy.sparse,scipy.stats,scipy.ndimage
1✔
20
from scipy.signal import fftconvolve
1✔
21
import sys
1✔
22
from desispec.fiberbitmasking import get_fiberbitmasked_frame_arrays, get_fiberbitmasked_frame
1✔
23
import scipy.ndimage
1✔
24
from desispec.maskbits import specmask, fibermask
1✔
25
from desispec.preproc import get_amp_ids,parse_sec_keyword
1✔
26
from desispec.io import findfile,read_xytraceset
1✔
27
from desispec.calibfinder import CalibFinder
1✔
28
from desispec.preproc import get_amp_ids
1✔
29
from desispec.tpcorrparam import tpcorrmodel
1✔
30
import desispec.skygradpca
1✔
31

32

33
def _model_variance(frame,cskyflux,cskyivar,skyfibers) :
1✔
34
    """look at chi2 per wavelength and increase sky variance to reach chi2/ndf=1
35
    """
36

37
    log = get_logger()
1✔
38

39

40
    tivar = util.combine_ivar(frame.ivar[skyfibers], cskyivar[skyfibers])
1✔
41

42
    # the chi2 at a given wavelength can be large because on a cosmic
43
    # and not a psf error or sky non uniformity
44
    # so we need to consider only waves for which
45
    # a reasonable sky model error can be computed
46

47
    # mean sky
48
    msky = np.mean(cskyflux,axis=0)
1✔
49
    dwave = np.mean(np.gradient(frame.wave))
1✔
50
    dskydw = np.zeros(msky.shape)
1✔
51
    dskydw[1:-1]=(msky[2:]-msky[:-2])/(frame.wave[2:]-frame.wave[:-2])
1✔
52
    dskydw = np.abs(dskydw)
1✔
53

54
    # now we consider a worst possible sky model error (20% error on flat, 0.5A )
55
    max_possible_var = 1./(tivar+(tivar==0)) + (0.2*msky)**2 + (0.5*dskydw)**2
1✔
56

57
    # exclude residuals inconsistent with this max possible variance (at 3 sigma)
58
    bad = (frame.flux[skyfibers]-cskyflux[skyfibers])**2 > 3**2*max_possible_var
1✔
59
    tivar[bad]=0
1✔
60
    ndata = np.sum(tivar>0,axis=0)
1✔
61
    ok=np.where(ndata>1)[0]
1✔
62

63
    chi2  = np.zeros(frame.wave.size)
1✔
64
    chi2[ok] = np.sum(tivar*(frame.flux[skyfibers]-cskyflux[skyfibers])**2,axis=0)[ok]/(ndata[ok]-1)
1✔
65
    chi2[ndata<=1] = 1. # default
1✔
66

67
    # now we are going to evaluate a sky model error based on this chi2,
68
    # but only around sky flux peaks (>0.1*max)
69
    tmp   = np.zeros(frame.wave.size)
1✔
70
    tmp   = (msky[1:-1]>msky[2:])*(msky[1:-1]>msky[:-2])*(msky[1:-1]>0.1*np.max(msky))
1✔
71
    peaks = np.where(tmp)[0]+1
1✔
72
    dpix  = 2 #eval error range
1✔
73
    dpix2  = 3 # scale error range (larger)
1✔
74

75
    input_skyvar = 1./(cskyivar+(cskyivar==0))
1✔
76
    skyvar = input_skyvar + 0.
1✔
77

78
    # loop on peaks
79
    for peak in peaks :
1✔
80
        b=peak-dpix
1✔
81
        e=peak+dpix+1
1✔
82
        b2=peak-dpix2
1✔
83
        e2=peak+dpix2+1
1✔
84
        mchi2  = np.mean(chi2[b:e]) # mean reduced chi2 around peak
1✔
85
        mndata = np.mean(ndata[b:e]) # mean number of fibers contributing
1✔
86

87
        # sky model variance = sigma_flat * msky  + sigma_wave * dmskydw
88
        sigma_flat=0.005 # the fiber flat error is already included in the flux ivar, but empirical evidence we need an extra term
1✔
89
        sigma_wave=0.005 # A, minimum value
1✔
90
        res2=(frame.flux[skyfibers,b:e]-cskyflux[skyfibers,b:e])**2
1✔
91
        var=1./(tivar[:,b:e]+(tivar[:,b:e]==0))
1✔
92
        nd=np.sum(tivar[:,b:e]>0)
1✔
93
        sigma_wave = np.arange(0.005, 2, 0.005)
1✔
94

95
        #- pivar has shape (nskyfibers, npix, nsigma_wave)
96
        pivar = (tivar[:, b:e, np.newaxis]>0)/((var+(sigma_flat*msky[b:e])**2)[..., np.newaxis] + ((sigma_wave[np.newaxis,:]*dskydw[b:e, np.newaxis])**2)[np.newaxis, ...])
1✔
97
        #- chi2_of_sky_fibers has shape (nskyfibers, nsigma_wave)
98
        chi2_of_sky_fibers = np.sum(pivar*res2[..., np.newaxis],axis=1)/np.sum(tivar[:,b:e]>0,axis=1)[:, np.newaxis]
1✔
99
        #- normalization from median to mean for chi2 with 3 d.o.f.
100
        norm = 0.7888
1✔
101
        #- median_chi2 has shape (nsigma_wave,)
102
        median_chi2 = np.median(chi2_of_sky_fibers, axis=0)/norm
1✔
103
        if np.any(median_chi2 <= 1):
1✔
104
            #- first sigma_wave with median_chi2 <= 1 is the peak
105
            sigma_wave_peak = sigma_wave[np.where(median_chi2 <= 1)[0][0]]
1✔
106
        else :
107
            sigma_wave_peak = 2.
×
108
        log.info("peak at {}A : sigma_wave={}".format(int(frame.wave[peak]),sigma_wave_peak))
1✔
109
        skyvar[:,b2:e2] = input_skyvar[:,b2:e2] + (sigma_flat*msky[b2:e2])**2 + (sigma_wave_peak*dskydw[b2:e2])**2
1✔
110

111
    return (cskyivar>0)/(skyvar+(skyvar==0))
1✔
112

113

114
def get_sector_masks(frame):
1✔
115
    # get sector info from metadata
116

117
    meta = frame.meta
×
118
    cfinder = CalibFinder([meta])
×
119
    amps = get_amp_ids(meta)
×
120
    log = get_logger()
×
121

122
    sectors = []
×
123
    for amp in amps:
×
124

125
        sec = parse_sec_keyword(frame.meta['CCDSEC'+amp])
×
126
        yb = sec[0].start
×
127
        ye = sec[0].stop
×
128

129
        # fit an offset as part of sky sub if OFFCOLSX or CTECOLSX in calib
130
        # to correct for CTE issues
131
        # if CTECOLSX, another correction is also applied at preproc
132
        # see also doc/cte-correction.rst
133
        for key in [ "OFFCOLS"+amp , "CTECOLS"+amp ] :
×
134
            if cfinder.haskey(key) :
×
135
                val = cfinder.value(key)
×
136
                for tmp1 in val.split(",") :
×
137
                    tmp2 = tmp1.split(":")
×
138
                    if len(tmp2) != 2 :
×
139
                        mess="cannot decode {}={}".format(key,val)
×
140
                        log.error(mess)
×
141
                        raise KeyError(mess)
×
142
                    xb = max(sec[1].start,int(tmp2[0]))
×
143
                    xe = min(sec[1].stop,int(tmp2[1]))
×
144
                    sector = [yb,ye,xb,xe]
×
145
                    sectors.append(sector)
×
146
                    log.info("Adding CCD sector in amp {} with offset: {}".format(
×
147
                        amp,sector))
148

149
    if len(sectors) == 0:
×
150
        return [], [[]]
×
151

152
    psf_filename = findfile('psf', meta["NIGHT"], meta["EXPID"],
×
153
                            meta["CAMERA"])
154
    if not os.path.isfile(psf_filename) :
×
155
        log.error("No PSF file "+psf_filename)
×
156
        raise IOError("No PSF file "+psf_filename)
×
157
    log.info("Using PSF {}".format(psf_filename))
×
158
    tset = read_xytraceset(psf_filename)
×
159
    tmp_fibers = np.arange(frame. nspec)
×
160
    tmp_x = np.zeros(frame.flux.shape, dtype=float)
×
161
    tmp_y = np.zeros(frame.flux.shape, dtype=float)
×
162
    for fiber in tmp_fibers :
×
163
        tmp_x[fiber] = tset.x_vs_wave(fiber=fiber, wavelength=frame.wave)
×
164
        tmp_y[fiber] = tset.y_vs_wave(fiber=fiber, wavelength=frame.wave)
×
165

166
    masks = []
×
167
    templates = []
×
168
    for ymin, ymax, xmin, xmax in sectors:
×
169
        mask = ((tmp_y >= ymin) & (tmp_y < ymax) &
×
170
                (tmp_x >= xmin) & (tmp_x < xmax))
171
        masks.append(mask)
×
172
        constant_template = 1.0 * mask
×
173
        linear_template = (
×
174
            np.ones(frame.flux.shape[0])[:, None] *
175
            np.arange(frame.flux.shape[1])[None, :])
176
        linear_template -= np.min(linear_template*mask, axis=1, keepdims=True)
×
177
        tempmax = np.max(linear_template*mask, axis=1, keepdims=True)
×
178
        linear_template /= (tempmax + (tempmax == 0))
×
179
        linear_template *= mask
×
180
        templates.append([constant_template, linear_template])
×
181
    return masks, templates
×
182

183
def get_sky_fibers(fibermap, override_sky_targetids=None, exclude_sky_targetids=None):
1✔
184
    """
185
    Retrieve the fiber indices of sky fibers
186

187
    Args:
188
        fibermap: Table from frame FIBERMAP HDU (frame.fibermap)
189

190
    Options:
191
        override_sky_targetids (array of int): TARGETIDs to use, overriding fibermap
192
        exclude_sky_targetids (array of int): TARGETIDs to exclude
193

194
    Returns:
195
        array of indices of sky fibers to use
196

197
    By default we rely on fibermap['OBJTYPE']=='SKY', but we can also exclude
198
    some targetids by providing a list of them through exclude_sky_targetids
199
    or by just providing all the sky targetids directly (in that case
200
    the OBJTYPE information is ignored)
201

202
    Fibers with FIBERSTATUS bit VARIABLETHRU are also excluded
203
    """
204
    log = get_logger()
1✔
205
    # Grab sky fibers on this frame
206
    if override_sky_targetids is not None:
1✔
207
        log.info('Overriding default sky fiber list using override_sky_targetids')
1✔
208
        skyfibers = np.where(np.isin(fibermap['TARGETID'], override_sky_targetids))[0]
1✔
209
        # we ignore OBJTYPEs
210
    else:
211
        oksky = (fibermap['OBJTYPE'] == 'SKY')
1✔
212
        oksky &= ((fibermap['FIBERSTATUS'] & fibermask.VARIABLETHRU) == 0)
1✔
213
        skyfibers = np.where(oksky)[0]
1✔
214
        if exclude_sky_targetids is not None:
1✔
215
            log.info('Excluding default sky fibers using exclude_sky_targetids')
1✔
216
            bads = np.isin(fibermap['TARGETID'][skyfibers], exclude_sky_targetids)
1✔
217
            skyfibers = skyfibers[~bads]
1✔
218

219
    assert np.max(skyfibers) < len(fibermap)  #- indices, not fiber numbers
1✔
220
    return skyfibers
1✔
221

222
def compute_sky_linear(
1✔
223
        flux, ivar, Rframe, sectors, skyfibers, skygradpca, fibermap,
224
        fiberflat=None,
225
        min_iterations=5, max_iterations=100, nsig_clipping=4,
226
        tpcorrparam=None):
227
    log = get_logger()
1✔
228
    nfibers, nwave = flux.shape
1✔
229
    nskygradpc = skygradpca.flux.shape[0] if skygradpca is not None else 0
1✔
230
    current_ivar = ivar.copy()
1✔
231
    chi2 = np.zeros(flux.shape)
1✔
232
    nout_tot = 0
1✔
233
    bad_skyfibers = []
1✔
234
    Rsky = Rframe[skyfibers]
1✔
235

236
    if tpcorrparam is None:
1✔
237
        skytpcorrfixed = np.ones(nfibers)
1✔
238
    else:
239
        skytpcorrfixed = tpcorrmodel(tpcorrparam,
×
240
                                     fibermap['FIBER_X'], fibermap['FIBER_Y'])
241
        skytpcorrfixed = skytpcorrfixed[skyfibers]
×
242

243
    skytpcorr = skytpcorrfixed.copy()
1✔
244
    if sectors is not None:
1✔
245
        sectors, sectemplates = sectors
1✔
246

247
    for iteration in range(max_iterations) :
1✔
248
        # the matrix A is 1/2 of the second derivative of the chi2 with respect to the parameters
249
        # A_ij = 1/2 d2(chi2)/di/dj
250
        # A_ij = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * d(model)/dj[fiber,w]
251

252
        # the vector B is 1/2 of the first derivative of the chi2 with respect to the parameters
253
        # B_i  = 1/2 d(chi2)/di
254
        # B_i  = sum_fiber sum_wave_w ivar[fiber,w] d(model)/di[fiber,w] * (flux[fiber,w]-model[fiber,w])
255

256
        # the model is model[fiber]=R[fiber]*sky
257
        # and the parameters are the unconvolved sky flux at the wavelength i
258

259
        # so, d(model)/di[fiber,w] = R[fiber][w,i]
260
        # this gives
261
        # A_ij = sum_fiber  sum_wave_w ivar[fiber,w] R[fiber][w,i] R[fiber][w,j]
262
        # A = sum_fiber ( diag(sqrt(ivar))*R[fiber] ) ( diag(sqrt(ivar))* R[fiber] )^t
263
        # A = sum_fiber sqrtwR[fiber] sqrtwR[fiber]^t
264
        # and
265
        # B = sum_fiber sum_wave_w ivar[fiber,w] R[fiber][w] * flux[fiber,w]
266
        # B = sum_fiber sum_wave_w sqrt(ivar)[fiber,w]*flux[fiber,w] sqrtwR[fiber,wave]
267

268
        #A=scipy.sparse.lil_matrix((nwave,nwave)).tocsr()
269

270
        # Julien can do A^T C^-1 A, A^T C^-1 b himself, but I like to write it
271
        # out
272
        # the model is that
273
        # frame = R*(sky spectrum + sum(PC * (a*(x-<x>) + b*(y-<y>)))) + offsets
274
        # We could consider adding a mild prior to deal with ill-conditioned
275
        # matrices.
276

277
        # note: the design matrix we set up has the following parameters:
278
        # first nwave columns: deconvolved flux at each wavelength
279
        # next nsector columns: sector offsets
280
        # next 2*nskygradpc columns: sky gradient amplitudes in x & y
281
        # direction for each PC.
282

283
        # in a separate step we also set up a 'tpcorr' model, reflecting
284
        # different throughputs of each fiber.
285

286
        # the full model is:
287
        # R(sky + amplitudes * skygradpc * dx)*tpcorr + sector
288

289
        nsector = len(sectors)
1✔
290
        nsectemplate = sum([len(x) for x in sectemplates])
1✔
291
        npar = nwave + nsectemplate + nskygradpc*2
1✔
292

293
        yy = np.zeros((nwave*nfibers))
1✔
294

295
        SD = scipy.sparse.dia_matrix((nwave*nfibers,nwave*nfibers))
1✔
296
        SD.setdiag(current_ivar.reshape(-1))
1✔
297

298
        # loop on fiber to handle resolution
299
        allrows = []
1✔
300
        allcols = []
1✔
301
        allvals = []
1✔
302
        for fiber in range(nfibers):
1✔
303
            if fiber % 10 == 0:
1✔
304
                log.info("iter %d sky fiber %d/%d"%(iteration,fiber,nfibers))
1✔
305
            R = Rsky[fiber]
1✔
306
            rows, cols, vals = scipy.sparse.find(R)
1✔
307
            allrows.append(rows+fiber*nwave)
1✔
308
            allcols.append(cols)
1✔
309
            allvals.append(vals)
1✔
310
            yy[fiber*nwave:(fiber+1)*nwave] = flux[fiber]
1✔
311
            if skygradpca is not None:
1✔
312
                dx = skygradpca.dx[skygradpca.skyfibers[fiber]]
×
313
                dy = skygradpca.dy[skygradpca.skyfibers[fiber]]
×
314
                for skygradpcind in range(nskygradpc):
×
315
                    convskygradpc = R.dot(skygradpca.deconvflux[skygradpcind])
×
316
                    allrows.append(np.arange(nwave)+fiber*nwave)
×
317
                    allcols.append(nwave + nsectemplate + skygradpcind*2 +
×
318
                                   np.zeros(nwave, dtype='i4'))
319
                    allvals.append(convskygradpc * dx)
×
320
                    allrows.append(np.arange(nwave)+fiber*nwave)
×
321
                    allcols.append(nwave + nsectemplate + skygradpcind*2 + 1 +
×
322
                                   np.zeros(nwave, dtype='i4'))
323
                    allvals.append(convskygradpc * dy)
×
324
        # boost model by throughput corrections
325
        for i in range(len(allvals)):
1✔
326
            allvals[i] *= skytpcorr[allrows[i] // nwave]
1✔
327

328
        i = 0
1✔
329
        for j, secmask in enumerate(sectors):
1✔
330
            for template in sectemplates[j]:
×
331
                rows = np.flatnonzero(secmask[skyfibers])
×
332
                cols = np.full(len(rows), nwave+i)
×
333
                if fiberflat is not None:
×
334
                    flat = (
×
335
                        fiberflat.fiberflat[skyfibers][secmask[skyfibers]].ravel())
336
                else:
337
                    flat = np.ones(rows.shape)
×
338
                vals = (template[skyfibers][secmask[skyfibers]].ravel()/
×
339
                        (flat + (flat == 0)))
340
                allrows.append(rows)
×
341
                allcols.append(cols)
×
342
                allvals.append(vals)
×
343
                i += 1
×
344

345
        design = scipy.sparse.coo_matrix(
1✔
346
            (np.concatenate(allvals),
347
             (np.concatenate(allrows), np.concatenate(allcols))),
348
            shape=(nwave*nfibers, npar))
349
        design = design.tocsr()
1✔
350

351
        A = design.T.dot(SD.dot(design))
1✔
352
        A = A.toarray()
1✔
353
        B = design.T.dot(SD.dot(yy))
1✔
354

355
        log.info("iter %d solving"%iteration)
1✔
356
        w = A.diagonal() > 0
1✔
357
        A_pos_def = A[w,:]
1✔
358
        A_pos_def = A_pos_def[:,w]
1✔
359
        param = B*0
1✔
360
        try:
1✔
361
            param[w]=cholesky_solve(A_pos_def,B[w])
1✔
362
        except:
×
363
            log.info("cholesky failed, trying svd in iteration {}".format(iteration))
×
364
            param[w]=np.linalg.lstsq(A_pos_def,B[w], rcond=None)[0]
×
365
        deconvolved_sky = param[:nwave]
1✔
366
        modeled_sky = design.dot(param).reshape(flux.shape)
1✔
367
        modeled_secoffs = (
1✔
368
            design[:, nwave:nwave + nsectemplate].dot(
369
                param[nwave:nwave + nsectemplate]))
370
        modeled_secoffs = modeled_secoffs.reshape(flux.shape)
1✔
371

372
        log.info("iter %d compute chi2"%iteration)
1✔
373

374
        medflux=np.zeros(nfibers)
1✔
375
        for fiber in range(nfibers) :
1✔
376
            # the parameters are directly the unconvolve sky flux
377
            # so we simply have to reconvolve it
378
            chi2[fiber]=current_ivar[fiber]*(flux[fiber]-modeled_sky[fiber])**2
1✔
379
            ok=(current_ivar[fiber]>0)
1✔
380
            if np.sum(ok)>0 :
1✔
381
                medflux[fiber] = np.median((flux[fiber]-modeled_sky[fiber])[ok])
1✔
382

383
        log.info("rejecting")
1✔
384

385
        # whole fiber with excess flux
386
        if np.sum(medflux!=0) > 2 : # at least 3 valid sky fibers
1✔
387
            rms_from_nmad = 1.48*np.median(np.abs(medflux[medflux!=0]))
1✔
388
            # discard fibers that are 7 sigma away
389
            badfibers=np.where(np.abs(medflux)>7*rms_from_nmad)[0]
1✔
390
            for fiber in badfibers :
1✔
391
                log.warning("discarding fiber {} with median flux = {:.2f} > 7*{:.2f}".format(skyfibers[fiber],medflux[fiber],rms_from_nmad))
×
392
                current_ivar[fiber]=0
×
393
                # set a mask bit here
394
                bad_skyfibers.append(skyfibers[fiber])
×
395
        nout_iter=0
1✔
396
        if iteration<1 :
1✔
397
            # only remove worst outlier per wave
398
            # apply rejection iteratively, only one entry per wave among fibers
399
            # find waves with outlier (fastest way)
400
            nout_per_wave=np.sum(chi2>nsig_clipping**2,axis=0)
1✔
401
            selection=np.where(nout_per_wave>0)[0]
1✔
402
            for i in selection :
1✔
403
                worst_entry=np.argmax(chi2[:,i])
×
404
                current_ivar[worst_entry,i]=0
×
405
                nout_iter += 1
×
406

407
        else :
408
            # remove all of them at once
409
            bad=(chi2>nsig_clipping**2)
1✔
410
            current_ivar *= (bad==0)
1✔
411
            nout_iter += np.sum(bad)
1✔
412

413
        if tpcorrparam is not None:
1✔
414
            # the throughput of each fiber varies, usually following
415
            # the tpcorrparam pca.  We want to find the coefficients
416
            # for these principal components.
417
            # the code here is a bit hard to track primarily because
418
            # in the iterative scheme, we have already applied some PCA correction
419
            # in the previous iteration.  Here we remove the previous PCA correction,
420
            # and then re-fit the result.  This may be equivalent to fitting directly
421
            # and then added the fit results to the existing pca coefficients, but
422
            # that wasn't the approach taken here.
423

424
            # the _current_ pca-tracked bit of the tpcorr we are using is
425
            # in tppca0.  This is the current total skytpcorr, divided by the fixed
426
            # bit that comes from the mean and the spatial within-patrol-radius model.
427

428
            tppca0 = skytpcorr[:, None]/skytpcorrfixed[:, None]
×
429
            tppcam = tpcorrparam.pca[:, skyfibers]
×
430
            # in the design matrix and flux residuals, we divide out tppca0 from
431
            # modeled_sky so that we have only the pre-PCA skies
432
            # we use the modeled_sky without the offsets since this is a throughput
433
            # effect and not an instrumental effect.
434
            sky_no_offsets = modeled_sky - modeled_secoffs
×
435
            aa = np.array([(sky_no_offsets*tppcam0[:, None]/tppca0).reshape(-1)
×
436
                           for tppcam0 in tppcam]).T
437
            fluxresid = flux - modeled_secoffs - sky_no_offsets / tppca0
×
438
            # then we solve for the PCA coefficients that best take the
439
            # pre-PCA skies to the pre-PCA sky residuals (fluxresid).
440
            skytpcorrcoeff = np.linalg.lstsq(
×
441
                aa.T.dot(current_ivar.reshape(-1)[:, None]*aa),
442
                aa.T.dot((current_ivar*fluxresid).reshape(-1)),
443
                rcond=None)[0]
444
            skytpcorr = skytpcorrfixed.copy()
×
445
            for coeff, vec in zip(skytpcorrcoeff,
×
446
                                  tpcorrparam.pca[:, skyfibers]):
447
                skytpcorr += coeff*vec
×
448

449
        nout_tot += nout_iter
1✔
450

451
        sum_chi2=float(np.sum(chi2))
1✔
452
        ndf=int(np.sum(chi2>0)-nwave)
1✔
453
        chi2pdf=0.
1✔
454
        if ndf>0 :
1✔
455
            chi2pdf=sum_chi2/ndf
1✔
456
        log.info("iter #%d chi2=%f ndf=%d chi2pdf=%f nout=%d"%(iteration,sum_chi2,ndf,chi2pdf,nout_iter))
1✔
457

458
        # at least min_iterations
459
        if (nout_iter == 0) & (iteration >= min_iterations - 1):
1✔
460
            break
1✔
461

462
    if nsectemplate > 0:
1✔
463
        log.info('sectors: %d sectors fit, values %s' %
×
464
                 (nsector, ' '.join(
465
                     [str(x) for x in param[nwave:nwave+nsectemplate]])))
466

467
    if nskygradpc > 0:
1✔
468
        log.info(('Fit with %d spatial PCs, amplitudes ' % nskygradpc) +
×
469
                 ' '.join(['%.1f' % x for x in param[nwave+nsectemplate:]]))
470

471
    log.info("compute the parameter covariance")
1✔
472
    # we may have to use a different method to compute this
473
    # covariance
474
    try :
1✔
475
        parameter_covar=cholesky_invert(A)
1✔
476
        # the above is too slow
477
        # maybe invert per block, sandwich by R
478
    except np.linalg.LinAlgError :
×
479
        log.warning("cholesky_solve_and_invert failed, switching to np.linalg.lstsq and np.linalg.pinv")
×
480
        parameter_covar = np.linalg.pinv(A)
×
481

482
    if tpcorrparam is None:
1✔
483
        skytpcorr = np.ones(len(fibermap), dtype='f4')
1✔
484
    else:
485
        skytpcorr = tpcorrmodel(tpcorrparam,
×
486
                                fibermap['FIBER_X'], fibermap['FIBER_Y'],
487
                                skytpcorrcoeff)
488

489
    unconvflux = param[:nwave].copy()
1✔
490
    skygradpcacoeff = param[
1✔
491
        nwave + nsectemplate:nwave+nsectemplate+nskygradpc*2]
492
    if skygradpca is not None:
1✔
493
        modeled_sky = desispec.skygradpca.evaluate_model(
×
494
            skygradpca, Rframe, skygradpcacoeff, mean=unconvflux)
495
    else:
496
        modeled_sky = np.zeros((len(Rframe), nwave), dtype='f8')
1✔
497
        for i in range(len(Rframe)):
1✔
498
            modeled_sky[i] = Rframe[i].dot(unconvflux)
1✔
499

500
    sector_offsets = np.zeros((len(fibermap), flux.shape[1]), dtype='f4')
1✔
501
    i = 0
1✔
502
    for j, secmask in enumerate(sectors):
1✔
503
        for sectemplate in sectemplates[j]:
×
504
            sector_offsets[secmask] += param[nwave+i] * sectemplate[secmask]
×
505
            i += 1
×
506
    if len(sectors) > 0 and fiberflat is not None:
1✔
507
        flat = fiberflat.fiberflat + (fiberflat.fiberflat == 0)
×
508
        sector_offsets /= flat
×
509

510
    modeled_sky *= skytpcorr[:, None]
1✔
511
    bad_wavelengths = ~(w[:nwave])
1✔
512
    modeled_sky += sector_offsets
1✔
513

514
    return (param, parameter_covar, modeled_sky, current_ivar, nout_tot,
1✔
515
            skytpcorr, bad_skyfibers, bad_wavelengths, sector_offsets,
516
            skygradpcacoeff)
517

518

519
def compute_sky(
1✔
520
    frame, nsig_clipping=4., max_iterations=100, model_ivar=False,
521
    add_variance=True, adjust_wavelength=False, adjust_lsf=False,
522
    only_use_skyfibers_for_adjustments=True, pcacorr=None,
523
    fit_offsets=False, fiberflat=None, skygradpca=None,
524
        min_iterations=5, tpcorrparam=None,
525
        exclude_sky_targetids=None, override_sky_targetids=None):
526
    """Compute a sky model.
527

528
    Sky[fiber,i] = R[fiber,i,j] Flux[j]
529

530
    Input flux are expected to be flatfielded!
531
    We don't check this in this routine.
532

533
    Args:
534
        frame : Frame object, which includes attributes
535
          - wave : 1D wavelength grid in Angstroms
536
          - flux : 2D flux[nspec, nwave] density
537
          - ivar : 2D inverse variance of flux
538
          - mask : 2D inverse mask flux (0=good)
539
          - resolution_data : 3D[nspec, ndiag, nwave]  (only sky fibers)
540
        nsig_clipping : [optional] sigma clipping value for outlier rejection
541
        max_iterations : int, maximum number of iterations
542
        model_ivar : replace ivar by a model to avoid bias due to correlated flux and ivar. this has a negligible effect on sims.
543
        add_variance : evaluate calibration error and add this to the sky model variance
544
        adjust_wavelength : adjust the wavelength of the sky model on sky lines to improve the sky subtraction
545
        adjust_lsf : adjust the LSF width of the sky model on sky lines to improve the sky subtraction
546
        only_use_skyfibers_for_adjustments : interpolate adjustments using sky fibers only
547
        pcacorr : SkyCorrPCA object to interpolate the wavelength or LSF adjustment from sky fibers to all fibers
548
        fit_offsets : fit offsets for regions defined in calib
549
        fiberflat : desispec.FiberFlat object used for the fit of offsets
550
        skygradpca : SkyGradPCA object to use to fit sky gradients, or None
551
        min_iterations : int, minimum number of iterations
552
        tpcorrparam : TPCorrParam object to use to fit fiber throughput
553
            variations, or None
554

555
    returns SkyModel object with attributes wave, flux, ivar, mask
556
    """
557

558
    log=get_logger()
1✔
559
    log.info("starting")
1✔
560

561
    skyfibers = get_sky_fibers(frame.fibermap, override_sky_targetids=override_sky_targetids,
1✔
562
                              exclude_sky_targetids=exclude_sky_targetids)
563

564
    #- Hack: test tile 81097 (observed 20210430/00086750) had set
565
    #- FIBERSTATUS bit UNASSIGNED for sky targets on stuck positioners.
566
    #- Undo that.
567
    if (frame.meta is not None) and ('TILEID' in frame.meta) and (frame.meta['TILEID'] == 81097):
1✔
568
        log.info('Unsetting FIBERSTATUS UNASSIGNED for tileid 81097 sky fibers')
×
569
        frame.fibermap['FIBERSTATUS'][skyfibers] &= ~1
×
570

571
    nwave=frame.nwave
1✔
572

573
    current_ivar = get_fiberbitmasked_frame_arrays(frame,bitmask='sky',ivar_framemask=True,return_mask=False)
1✔
574

575
    # checking ivar because some sky fibers have been disabled
576
    bad=(np.sum(current_ivar[skyfibers]>0,axis=1)==0)
1✔
577
    good=~bad
1✔
578

579
    if np.any(bad) :
1✔
580
        log.warning("{} sky fibers discarded (because ivar=0 or bad FIBERSTATUS), only {} left.".format(np.sum(bad),np.sum(good)))
×
581
        skyfibers = skyfibers[good]
×
582

583
    if np.sum(good)==0 :
1✔
584
        message = "no valid sky fibers"
×
585
        log.error(message)
×
586
        raise RuntimeError(message)
×
587

588
    nfibers=len(skyfibers)
1✔
589

590
    current_ivar = current_ivar[skyfibers]
1✔
591
    flux = frame.flux[skyfibers]
1✔
592

593
    input_ivar=None
1✔
594
    if model_ivar :
1✔
595
        log.info("use a model of the inverse variance to remove bias due to correlated ivar and flux")
×
596
        input_ivar=current_ivar.copy()
×
597
        median_ivar_vs_wave  = np.median(current_ivar,axis=0)
×
598
        median_ivar_vs_fiber = np.median(current_ivar,axis=1)
×
599
        median_median_ivar   = np.median(median_ivar_vs_fiber)
×
600
        for f in range(current_ivar.shape[0]) :
×
601
            threshold=0.01
×
602
            current_ivar[f] = median_ivar_vs_fiber[f]/median_median_ivar * median_ivar_vs_wave
×
603
            # keep input ivar for very low weights
604
            ii=(input_ivar[f]<=(threshold*median_ivar_vs_wave))
×
605
            #log.info("fiber {} keep {}/{} original ivars".format(f,np.sum(ii),current_ivar.shape[1]))
606
            current_ivar[f][ii] = input_ivar[f][ii]
×
607

608

609
    chi2=np.zeros(flux.shape)
1✔
610

611
    #max_iterations=2 ; log.warning("DEBUGGING LIMITING NUMBER OF ITERATIONS")
612

613
    if fit_offsets:
1✔
614
        sectors = get_sector_masks(frame)
×
615
    else:
616
        sectors = [], [[]]
1✔
617

618
    if skygradpca is not None:
1✔
619
        desispec.skygradpca.configure_for_xyr(
×
620
            skygradpca, frame.fibermap['FIBERASSIGN_X'],
621
            frame.fibermap['FIBERASSIGN_Y'],
622
            frame.R, skyfibers=skyfibers)
623

624
    res = compute_sky_linear(
1✔
625
        flux, current_ivar, frame.R, sectors, skyfibers, skygradpca,
626
        frame.fibermap,
627
        fiberflat=fiberflat, min_iterations=min_iterations,
628
        max_iterations=max_iterations, nsig_clipping=nsig_clipping,
629
        tpcorrparam=tpcorrparam)
630
    (param, parameter_covar, modeled_sky, current_ivar, nout_tot, skytpcorr,
1✔
631
     bad_skyfibers, bad_wavelengths, background, skygradpcacoeff) = res
632
    deconvolved_sky = param[:nwave]
1✔
633

634
    log.info("compute mean resolution")
1✔
635
    # we make an approximation for the variance to save CPU time
636
    # we use the average resolution of all fibers in the frame:
637
    mean_res_data=np.mean(frame.resolution_data,axis=0)
1✔
638
    Rmean = Resolution(mean_res_data)
1✔
639

640
    log.info("compute convolved sky and ivar")
1✔
641

642
    parameter_sky_covar = parameter_covar[:nwave, :nwave]
1✔
643

644
    # The parameters are directly the unconvolved sky
645
    # First convolve with average resolution :
646
    convolved_sky_covar=Rmean.dot(parameter_sky_covar).dot(Rmean.T.todense())
1✔
647

648
    # and keep only the diagonal
649
    convolved_sky_var=np.diagonal(convolved_sky_covar)
1✔
650

651
    # inverse
652
    convolved_sky_ivar=(convolved_sky_var>0)/(convolved_sky_var+(convolved_sky_var==0))
1✔
653

654
    # and simply consider it's the same for all spectra
655
    cskyivar = np.tile(convolved_sky_ivar, frame.nspec).reshape(frame.nspec, nwave)
1✔
656

657
    # remove background for line fitting; add back at end
658
    cskyflux = modeled_sky - background
1✔
659
    frame.flux -= background
1✔
660

661
    # See if we can improve the sky model by readjusting the wavelength and/or the width of sky lines
662
    dwavecoeff = None
1✔
663
    dlsfcoeff = None
1✔
664
    if adjust_wavelength or adjust_lsf :
1✔
665
        log.info("adjust the wavelength of sky spectrum on sky lines to improve sky subtraction ...")
×
666

667
        if adjust_wavelength :
×
668
            # compute derivative of sky w.r.t. wavelength
669
            dskydwave = np.gradient(cskyflux,axis=1)/np.gradient(frame.wave)
×
670
        else :
671
            dskydwave = None
×
672

673
        if adjust_lsf :
×
674
            # compute derivative of sky w.r.t. lsf width
675
            dwave = np.mean(np.gradient(frame.wave))
×
676
            dsigma_A   = 0.3 #A
×
677
            dsigma_bin = dsigma_A/dwave # consider this extra width for the PSF (sigma' = sqrt(sigma**2+dsigma**2))
×
678
            hw=int(4*dsigma_bin)+1
×
679
            x=np.arange(-hw,hw+1)
×
680
            k=np.zeros((3,x.size)) # a Gaussian kernel
×
681
            k[1]=np.exp(-x**2/dsigma_bin**2/2.)
×
682
            k/=np.sum(k)
×
683
            tmp = fftconvolve(cskyflux,k,mode="same")
×
684
            dskydlsf = (tmp-cskyflux)/dsigma_A # variation of line shape with width
×
685
        else :
686
            dskydlsf = None
×
687

688
        # detect peaks in mean sky spectrum
689
        # peaks = local maximum larger than 10% of max peak
690
        meansky = np.mean(cskyflux,axis=0)
×
691
        tmp   = (meansky[1:-1]>meansky[2:])*(meansky[1:-1]>meansky[:-2])*(meansky[1:-1]>0.1*np.max(meansky))
×
692
        peaks = np.where(tmp)[0]+1
×
693
        # remove edges
694
        peaks = peaks[(peaks>10)&(peaks<meansky.size-10)]
×
695
        peak_wave=frame.wave[peaks]
×
696

697
        log.info("Number of peaks: {}".format(peaks.size))
×
698
        if  peaks.size < 10 :
×
699
            log.info("Wavelength of peaks: {}".format(peak_wave))
×
700

701
        # define area around each sky line to adjust
702
        dwave = np.mean(np.gradient(frame.wave))
×
703
        dpix = int(3//dwave)+1
×
704

705
        # number of parameters to fit for each peak: delta_wave , delta_lsf , scale of sky , a background (to absorb source signal)
706
        nparam = 2
×
707
        if adjust_wavelength : nparam += 1
×
708
        if adjust_lsf : nparam += 1
×
709

710
        AA=np.zeros((nparam,nparam))
×
711
        BB=np.zeros((nparam))
×
712

713
        # temporary arrays with best fit parameters on peaks
714
        # for each fiber, with errors and chi2/ndf
715
        peak_scale=np.zeros((frame.nspec,peaks.size))
×
716
        peak_scale_err=np.zeros((frame.nspec,peaks.size))
×
717
        peak_dw=np.zeros((frame.nspec,peaks.size))
×
718
        peak_dw_err=np.zeros((frame.nspec,peaks.size))
×
719
        peak_dlsf=np.zeros((frame.nspec,peaks.size))
×
720
        peak_dlsf_err=np.zeros((frame.nspec,peaks.size))
×
721

722
        peak_chi2pdf=np.zeros((frame.nspec,peaks.size))
×
723

724
        # interpolated values across peaks, after selection
725
        # based on precision and chi2
726
        interpolated_sky_dwave=np.zeros(frame.flux.shape)
×
727
        interpolated_sky_dlsf=np.zeros(frame.flux.shape)
×
728

729
        # loop on fibers and then on sky spectrum peaks
730
        if only_use_skyfibers_for_adjustments :
×
731
            fibers_in_fit = skyfibers
×
732
        else :
733
            fibers_in_fit = np.arange(frame.nspec)
×
734

735
        # restrict to fibers with ivar!=0
736
        ok = np.sum(frame.ivar[fibers_in_fit],axis=1)>0
×
737
        fibers_in_fit = fibers_in_fit[ok]
×
738

739
        # loop on sky spectrum peaks, compute for all fibers simultaneously
740
        for j,peak in enumerate(peaks) :
×
741
            b = peak-dpix
×
742
            e = peak+dpix+1
×
743
            npix = e - b
×
744
            flux = frame.flux[fibers_in_fit][:,b:e]
×
745
            ivar = frame.ivar[fibers_in_fit][:,b:e]
×
746
            if b < 0 or e > frame.flux.shape[1] :
×
747
                log.warning("skip peak on edge of spectrum with b={} e={}".format(b,e))
×
748
                continue
×
749
            M = np.zeros((fibers_in_fit.size, nparam, npix))
×
750
            index = 0
×
751
            M[:, index] = np.ones(npix); index += 1
×
752
            M[:, index] = cskyflux[fibers_in_fit][:, b:e]; index += 1
×
753
            if adjust_wavelength : M[:, index] = dskydwave[fibers_in_fit][:, b:e]; index += 1
×
754
            if adjust_lsf        : M[:, index] = dskydlsf[fibers_in_fit][:, b:e]; index += 1
×
755
            # Solve (M * W * M.T) X = (M * W * flux)
756
            BB = np.einsum('ijk,ik->ij', M, ivar*flux)
×
757
            AA = np.einsum('ijk,ik,ilk->ijl', M, ivar, M)
×
758
            # solve linear system
759
            #- TODO: replace with X = np.linalg.solve(AA, BB) ?
760
            try:
×
761
                AAi=np.linalg.inv(AA)
×
762
            except np.linalg.LinAlgError as e:
×
763
                log.warning(str(e))
×
764
                continue
×
765
            # save best fit parameter and errors
766
            X = np.einsum('ijk,ik->ij', AAi, BB)
×
767
            X_err = np.sqrt(AAi*(AAi>0))
×
768
            index = 1
×
769
            peak_scale[fibers_in_fit,j] = X[:, index]
×
770
            peak_scale_err[fibers_in_fit,j] = X_err[:, index, index]
×
771
            index += 1
×
772
            if adjust_wavelength:
×
773
                peak_dw[fibers_in_fit, j] = X[:, index]
×
774
                peak_dw_err[fibers_in_fit, j] = X_err[:, index, index]
×
775
                index += 1
×
776
            if adjust_lsf:
×
777
                peak_dlsf[fibers_in_fit, j] = X[:, index]
×
778
                peak_dlsf_err[fibers_in_fit, j] = X_err[:, index, index]
×
779
                index += 1
×
780

781
            residuals = flux
×
782
            for index in range(nparam) :
×
783
                #for index in range(3) : # needed for compatibility with master (but this was a bug)
784
                residuals -= X[:,index][:, np.newaxis]*M[:,index]
×
785

786
            variance = 1.0/(ivar+(ivar==0)) + (0.05*M[:,1])**2
×
787
            peak_chi2pdf[fibers_in_fit, j] = np.sum((ivar>0)/variance*(residuals)**2, axis=1)/(npix-nparam)
×
788

789
        for i in fibers_in_fit :
×
790
            # for each fiber, select valid peaks and interpolate
791
            ok=(peak_chi2pdf[i]<2)
×
792
            if adjust_wavelength :
×
793
                ok &= (peak_dw_err[i]>0.)&(peak_dw_err[i]<0.1) # error on wavelength shift
×
794
            if adjust_lsf :
×
795
                ok &= (peak_dlsf_err[i]>0.)&(peak_dlsf_err[i]<0.3) # error on line width (quadratic, so 0.3 mean a change of width of 0.3**2/2~5%)
×
796
            # piece-wise linear interpolate across the whole spectrum between the sky line peaks
797
            # this interpolation will be used to alter the whole sky spectrum
798
            if np.sum(ok)>0 :
×
799
                if adjust_wavelength :
×
800
                    interpolated_sky_dwave[i]=np.interp(frame.wave,peak_wave[ok],peak_dw[i,ok])
×
801
                if adjust_lsf :
×
802
                    interpolated_sky_dlsf[i]=np.interp(frame.wave,peak_wave[ok],peak_dlsf[i,ok])
×
803
                line=""
×
804
                if adjust_wavelength :
×
805
                    line += " dlambda mean={:4.3f} rms={:4.3f} A".format(np.mean(interpolated_sky_dwave[i]),np.std(interpolated_sky_dwave[i]))
×
806
                if adjust_lsf :
×
807
                    line += " dlsf mean={:4.3f} rms={:4.3f} A".format(np.mean(interpolated_sky_dlsf[i]),np.std(interpolated_sky_dlsf[i]))
×
808
                log.info(line)
×
809

810
        # we ignore the interpolated_sky_scale which is too sensitive
811
        # to CCD defects or cosmic rays
812

813
        if pcacorr is None :
×
814
            if only_use_skyfibers_for_adjustments :
×
815
                goodfibers=fibers_in_fit
×
816
            else : # keep all except bright objects and interpolate over them
817
                mflux=np.median(frame.flux,axis=1)
×
818
                mmflux=np.median(mflux)
×
819
                rms=1.48*np.median(np.abs(mflux-mmflux))
×
820
                selection=(mflux<mmflux+2*rms)
×
821
                # at least 80% of good pixels
822
                ngood=np.sum((frame.ivar>0)*(frame.mask==0),axis=1)
×
823
                selection &= (ngood>0.8*frame.flux.shape[1])
×
824
                goodfibers=np.where(mflux<mmflux+2*rms)[0]
×
825
                log.info("number of good fibers=",goodfibers.size)
×
826
            allfibers=np.arange(frame.nspec)
×
827
            # the actual median filtering
828
            if adjust_wavelength :
×
829
                for j in range(interpolated_sky_dwave.shape[1]) :
×
830
                    interpolated_sky_dwave[:,j] = np.interp(np.arange(interpolated_sky_dwave.shape[0]),goodfibers,interpolated_sky_dwave[goodfibers,j])
×
831
                cskyflux += interpolated_sky_dwave*dskydwave
×
832
            if adjust_lsf : # simple interpolation over fibers
×
833
                for j in range(interpolated_sky_dlsf.shape[1]) :
×
834
                    interpolated_sky_dlsf[:,j] = np.interp(np.arange(interpolated_sky_dlsf.shape[0]),goodfibers,interpolated_sky_dlsf[goodfibers,j])
×
835
                cskyflux += interpolated_sky_dlsf*dskydlsf
×
836

837
        else :
838

839

840
            def fit_and_interpolate(delta,skyfibers,mean,components,label="") :
×
841
                mean_and_components = np.zeros((components.shape[0]+1,
×
842
                                                components.shape[1],
843
                                                components.shape[2]))
844
                mean_and_components[0]  = mean
×
845
                mean_and_components[1:] = components
×
846
                ncomp=mean_and_components.shape[0]
×
847
                log.info("Will fit a linear combination on {} components for {}".format(ncomp,label))
×
848
                AA=np.zeros((ncomp,ncomp))
×
849
                BB=np.zeros(ncomp)
×
850
                for i in range(ncomp) :
×
851
                    BB[i] = np.sum(delta[skyfibers]*mean_and_components[i][skyfibers])
×
852
                    for j in range(i,ncomp) :
×
853
                        AA[i,j] = np.sum(mean_and_components[i][skyfibers]*mean_and_components[j][skyfibers])
×
854
                        if j!=i :
×
855
                            AA[j,i]=AA[i,j]
×
856
                AAi=np.linalg.inv(AA)
×
857
                X=AAi.dot(BB)
×
858
                log.info("Best fit linear coefficients for {} = {}".format(label,list(X)))
×
859
                result = np.zeros_like(delta)
×
860
                for i in range(ncomp) :
×
861
                    result += X[i]*mean_and_components[i]
×
862
                return result, X
×
863

864

865
            # we are going to fit a linear combination of the PCA coefficients only on the sky fibers
866
            # and then apply the linear combination to all fibers
867
            log.info("Use PCA skycorr")
×
868

869
            if adjust_wavelength :
×
870
                correction, dwavecoeff = fit_and_interpolate(
×
871
                    interpolated_sky_dwave, skyfibers,
872
                    pcacorr.dwave_mean, pcacorr.dwave_eigenvectors,
873
                    label="wavelength")
874
                cskyflux  += correction*dskydwave
×
875
            if adjust_lsf :
×
876
                correction, dlsfcoeff = fit_and_interpolate(
×
877
                    interpolated_sky_dlsf,skyfibers,
878
                    pcacorr.dlsf_mean,pcacorr.dlsf_eigenvectors,label="LSF")
879
                cskyflux  += correction*dskydlsf
×
880

881

882
    # look at chi2 per wavelength and increase sky variance to reach chi2/ndf=1
883
    if skyfibers.size > 1 and add_variance :
1✔
884
        modified_cskyivar = _model_variance(frame,cskyflux,cskyivar,skyfibers)
1✔
885
    else :
886
        modified_cskyivar = cskyivar.copy()
×
887

888
    cskyflux += background
1✔
889
    frame.flux += background
1✔
890

891
    # set sky flux and ivar to zero to poorly constrained regions
892
    # and add margins to avoid expolation issues with the resolution matrix
893
    # limit to sky spectrum part of A
894
    wmask = bad_wavelengths.astype(float)
1✔
895
    # empirically, need to account for the full width of the resolution band
896
    # (realized here by applying twice the resolution)
897
    wmask = Rmean.dot(Rmean.dot(wmask))
1✔
898
    bad = np.where(wmask!=0)[0]
1✔
899
    cskyflux[:,bad]=0.
1✔
900
    modified_cskyivar[:,bad]=0.
1✔
901

902
    # minimum number of fibers at each wavelength
903
    min_number_of_fibers = min(10,max(1,skyfibers.size//2))
1✔
904
    fibers_with_signal=np.sum(current_ivar>0,axis=0)
1✔
905
    bad = (fibers_with_signal<min_number_of_fibers)
1✔
906
    # increase by 1 pixel
907
    bad[1:-1] |= bad[2:]
1✔
908
    bad[1:-1] |= bad[:-2]
1✔
909
    cskyflux[:,bad]=0.
1✔
910
    modified_cskyivar[:,bad]=0.
1✔
911

912
    mask = (modified_cskyivar==0).astype(np.uint32)
1✔
913

914
    # add mask bits for bad sky fibers
915
    bad_skyfibers = np.unique(bad_skyfibers)
1✔
916
    if bad_skyfibers.size > 0 :
1✔
917
        mask[bad_skyfibers] |= specmask.mask("BADSKY")
×
918

919
    skymodel = SkyModel(frame.wave.copy(), cskyflux, modified_cskyivar, mask,
1✔
920
                        nrej=nout_tot, stat_ivar = cskyivar,
921
                        dwavecoeff=dwavecoeff, dlsfcoeff=dlsfcoeff,
922
                        throughput_corrections_model=skytpcorr,
923
                        skygradpcacoeff=skygradpcacoeff,
924
                        skytargetid=frame.fibermap['TARGETID'][skyfibers])
925
    # keep a record of the statistical ivar for QA
926
    if adjust_wavelength :
1✔
927
        skymodel.dwave = interpolated_sky_dwave
×
928
    if adjust_lsf :
1✔
929
        skymodel.dlsf  = interpolated_sky_dlsf
×
930

931
    skymodel.throughput_corrections = calculate_throughput_corrections(
1✔
932
        frame, skymodel)
933

934
    return skymodel
1✔
935

936

937
class SkyModel(object):
1✔
938
    def __init__(self, wave, flux, ivar, mask, header=None, nrej=0,
1✔
939
                 stat_ivar=None, throughput_corrections=None,
940
                 throughput_corrections_model=None,
941
                 dwavecoeff=None, dlsfcoeff=None, skygradpcacoeff=None,
942
                 skytargetid=None):
943
        """Create SkyModel object
944

945
        Args:
946
            wave  : 1D[nwave] wavelength in Angstroms
947
            flux  : 2D[nspec, nwave] sky model to subtract
948
            ivar  : 2D[nspec, nwave] inverse variance of the sky model
949
            mask  : 2D[nspec, nwave] 0=ok or >0 if problems; 32-bit
950
            header : (optional) header from FITS file HDU0
951
            nrej : (optional) Number of rejected pixels in fit
952
            stat_ivar  : 2D[nspec, nwave] inverse variance of the statistical inverse variance
953
            throughput_corrections : 1D (optional) Residual multiplicative throughput corrections for each fiber
954
            throughput_corrections_model : 1D (optional) Model multiplicative throughput corrections for each fiber
955
            dwavecoeff : (optional) 1D[ncoeff] vector of PCA coefficients for wavelength offsets
956
            dlsfcoeff : (optional) 1D[ncoeff] vector of PCA coefficients for LSF size changes
957
            skygradpcacoeff : (optional) 1D[ncoeff] vector of gradient amplitudes for
958
                sky gradient spectra.
959
            skytargetid : (optional) 1D[nsky] vector of TARGETIDs of fibers used for sky determination
960
        All input arguments become attributes
961
        """
962
        assert wave.ndim == 1
1✔
963
        assert flux.ndim == 2
1✔
964
        assert ivar.shape == flux.shape
1✔
965
        assert mask.shape == flux.shape
1✔
966

967
        self.nspec, self.nwave = flux.shape
1✔
968
        self.wave = wave
1✔
969
        self.flux = flux
1✔
970
        self.ivar = ivar
1✔
971
        self.mask = util.mask32(mask)
1✔
972
        self.header = header
1✔
973
        self.nrej = nrej
1✔
974
        self.stat_ivar = stat_ivar
1✔
975
        self.throughput_corrections = throughput_corrections
1✔
976
        self.throughput_corrections_model = throughput_corrections_model
1✔
977
        self.dwave = None # wavelength corrections
1✔
978
        self.dlsf  = None # LSF corrections
1✔
979
        self.dwavecoeff = dwavecoeff
1✔
980
        self.dlsfcoeff = dlsfcoeff
1✔
981
        self.skygradpcacoeff = skygradpcacoeff
1✔
982
        self.skytargetid = skytargetid
1✔
983

984
    def __getitem__(self, index):
1✔
985
        """
986
        Return a subset of the fibers for this skymodel
987

988
        e.g. `stdsky = sky[stdstar_indices]`
989
        """
990
        #- convert index to 1d array to maintain dimentionality of sliced arrays
991
        if not isinstance(index, slice):
1✔
992
            index = np.atleast_1d(index)
1✔
993

994
        flux = self.flux[index]
1✔
995
        ivar = self.ivar[index]
1✔
996
        mask = self.mask[index]
1✔
997

998
        if self.stat_ivar is not None:
1✔
999
            stat_ivar = self.stat_ivar[index]
1✔
1000
        else:
1001
            stat_ivar = None
1✔
1002

1003
        if self.throughput_corrections is not None:
1✔
1004
            tcorr = self.throughput_corrections[index]
1✔
1005
        else:
1006
            tcorr = None
1✔
1007

1008
        sky2 = SkyModel(self.wave, flux, ivar, mask, header=self.header, nrej=self.nrej,
1✔
1009
                stat_ivar=stat_ivar, throughput_corrections=tcorr)
1010

1011
        sky2.dwave = self.dwave
1✔
1012
        if self.dlsf is not None:
1✔
1013
            sky2.dlsf = self.dlsf[index]
×
1014

1015
        return sky2
1✔
1016

1017

1018
def subtract_sky(frame, skymodel, apply_throughput_correction_to_lines = True, apply_throughput_correction = False, zero_ivar=True) :
1✔
1019
    """Subtract skymodel from frame, altering frame.flux, .ivar, and .mask
1020

1021
    Args:
1022
        frame : desispec.Frame object
1023
        skymodel : desispec.SkyModel object
1024

1025
    Option:
1026
        apply_throughput_correction : if True, fit for an achromatic throughput
1027
            correction.  This is to absorb variations of Focal Ratio Degradation
1028
            with fiber flexure.  This applies the residual throughput corrections
1029
            on top of the model throughput corrections already included in the sky
1030
            model.
1031

1032
        zero_ivar : if True , set ivar=0 for masked pixels
1033
    """
1034
    assert frame.nspec == skymodel.nspec
1✔
1035
    assert frame.nwave == skymodel.nwave
1✔
1036

1037
    log=get_logger()
1✔
1038
    log.info("starting with apply_throughput_correction_to_lines = {} apply_throughput_correction = {} and zero_ivar = {}".format(apply_throughput_correction_to_lines,apply_throughput_correction, zero_ivar))
1✔
1039

1040
    # Set fibermask flagged spectra to have 0 flux and variance
1041
    frame = get_fiberbitmasked_frame(frame,bitmask='sky',ivar_framemask=zero_ivar)
1✔
1042

1043
    # check same wavelength, die if not the case
1044
    if not np.allclose(frame.wave, skymodel.wave):
1✔
1045
        message = "frame and sky not on same wavelength grid"
×
1046
        log.error(message)
×
1047
        raise ValueError(message)
×
1048

1049

1050
    skymodel_flux = skymodel.flux.copy() # always use a copy to avoid overwriting model
1✔
1051

1052
    if skymodel.throughput_corrections is not None :
1✔
1053
        # a multiplicative factor + background around
1054
        # each of the bright sky lines has been fit.
1055
        # here we apply this correction to the emission lines only or to the whole
1056
        # sky spectrum
1057

1058
        if apply_throughput_correction  :
1✔
1059

1060
            skymodel_flux *= skymodel.throughput_corrections[:,None]
×
1061

1062
        elif apply_throughput_correction_to_lines :
1✔
1063

1064
            if frame.meta is not None and "CAMERA" in frame.meta and frame.meta["CAMERA"] is not None and frame.meta["CAMERA"][0].lower() == "b" :
1✔
1065
                log.info("Do not apply throughput correction to sky lines for blue cameras")
×
1066
            else :
1067
                in_cont_boolean = np.repeat(True,skymodel.wave.shape)
1✔
1068
                for line in get_sky_lines() :
1✔
1069
                    # ignore b-arm sky lines, because there is really only one significant line
1070
                    # at 5579A. without other lines, we could be erase a target emission line.
1071
                    # (this is a duplication of test on the camera ID above)
1072
                    if line < 5700 : continue
1✔
1073
                    in_cont_boolean &= np.abs(skymodel.wave-line)>2. # A
1✔
1074
                in_cont = np.where(in_cont_boolean)[0]
1✔
1075

1076
                if in_cont.size > 0 :
1✔
1077
                    # apply this correction to the sky lines only
1078
                    for fiber in range(frame.flux.shape[0]) :
1✔
1079
                        # estimate and subtract continuum for this fiber specifically
1080
                        cont = np.interp(skymodel.wave,skymodel.wave[in_cont],skymodel.flux[fiber][in_cont])
1✔
1081
                        skylines = skymodel.flux[fiber] - cont
1✔
1082
                        skylines[skylines<0]=0
1✔
1083
                        # apply correction to the sky lines only
1084
                        skymodel_flux[fiber] += (skymodel.throughput_corrections[fiber]-1.)*skylines
1✔
1085
                else :
1086
                    log.warning("Could not determine sky continuum, do not apply throughput correction on sky lines")
×
1087

1088
    frame.flux -= skymodel_flux
1✔
1089
    frame.ivar = util.combine_ivar(frame.ivar, skymodel.ivar)
1✔
1090
    frame.mask |= skymodel.mask
1✔
1091

1092
    log.info("done")
1✔
1093

1094
def get_sky_lines() :
1✔
1095
    # it's more robust to have a hardcoded set of sky lines here
1096
    # these are most of the dark sky lines at KPNO (faint lines are discarded)
1097
    # wavelength are A, in vacuum, (obviously in earth frame)
1098
    return np.array([
1✔
1099
        4359.55,5199.27,5462.38,5578.85,5891.47,5897.51,5917.04,5934.63,
1100
        5955.11,5978.80,6172.39,6204.48,6223.56,6237.36,6259.62,6266.96,
1101
        6289.21,6302.06,6308.70,6323.09,6331.63,6350.26,6358.10,6365.57,
1102
        6388.31,6467.83,6472.63,6500.48,6506.85,6524.31,6534.88,6545.95,
1103
        6555.44,6564.59,6570.69,6579.14,6606.02,6831.19,6836.16,6843.66,
1104
        6865.78,6873.06,6883.14,6891.22,6902.79,6914.52,6925.11,6941.43,
1105
        6950.99,6971.84,6980.36,7005.76,7013.36,7050.01,7242.08,7247.18,
1106
        7255.13,7278.23,7286.57,7298.03,7305.76,7318.31,7331.26,7342.93,
1107
        7360.71,7371.44,7394.23,7403.92,7431.82,7440.56,7468.66,7473.57,
1108
        7475.80,7481.74,7485.52,7495.74,7526.02,7532.80,7559.59,7573.90,
1109
        7588.21,7600.49,7620.20,7630.86,7655.22,7664.52,7694.11,7701.75,
1110
        7714.64,7718.99,7725.76,7728.26,7737.35,7752.67,7759.19,7762.17,
1111
        7775.53,7782.58,7794.17,7796.27,7810.62,7823.65,7843.43,7851.81,
1112
        7855.53,7860.06,7862.85,7870.17,7872.94,7880.89,7883.89,7892.04,
1113
        7915.75,7921.90,7923.29,7933.51,7947.89,7951.41,7966.84,7970.48,
1114
        7980.89,7982.00,7995.58,8016.34,8022.30,8028.03,8030.18,8054.15,
1115
        8064.43,8087.35,8096.04,8104.77,8141.35,8149.00,8190.40,8197.67,
1116
        8280.73,8283.97,8290.79,8298.52,8301.21,8305.09,8313.02,8320.68,
1117
        8346.77,8352.10,8355.21,8363.19,8367.11,8384.59,8401.53,8417.58,
1118
        8432.54,8436.62,8448.87,8454.61,8467.72,8477.00,8495.82,8507.19,
1119
        8523.03,8541.04,8551.10,8590.62,8599.43,8620.40,8623.91,8627.34,
1120
        8630.96,8634.50,8638.26,8642.53,8643.38,8649.38,8652.98,8657.70,
1121
        8662.46,8667.36,8672.49,8677.88,8683.25,8689.10,8695.03,8702.18,
1122
        8709.64,8762.40,8763.78,8770.00,8780.36,8793.54,8812.37,8829.35,
1123
        8831.74,8835.98,8838.89,8847.90,8852.27,8864.96,8870.02,8888.31,
1124
        8898.80,8905.60,8911.52,8922.12,8930.34,8945.87,8960.56,8973.21,
1125
        8984.22,8990.85,8994.99,9003.85,9023.44,9030.80,9040.55,9052.03,
1126
        9067.36,9095.13,9105.32,9154.58,9163.70,9219.04,9227.34,9265.26,
1127
        9288.76,9296.23,9309.49,9315.79,9320.36,9326.25,9333.54,9340.43,
1128
        9351.84,9364.59,9370.52,9378.34,9385.63,9399.67,9404.72,9422.34,
1129
        9425.23,9442.27,9450.81,9461.13,9479.60,9486.22,9493.22,9505.47,
1130
        9522.06,9532.66,9539.71,9555.14,9562.95,9570.04,9610.34,9623.30,
1131
        9656.01,9661.88,9671.38,9676.93,9684.39,9693.15,9701.98,9714.38,
1132
        9722.52,9737.51,9740.80,9748.49,9793.49,9799.16,9802.55,9810.34,
1133
        9814.71,9819.99])
1134

1135

1136
def calculate_throughput_corrections(frame,skymodel):
1✔
1137
    """
1138
    Calculate the throughput corrections for each fiber based on the skymodel.
1139

1140
    Args:
1141
        frame (Frame object): frame containing the data that may need to be corrected
1142
        skymodel (SkyModel object): skymodel object that contains the information about the sky for the given exposure/frame
1143

1144
    Output:
1145
        corrections (1D array):  1D array where the index corresponds to the fiber % 500 and the values are the multiplicative corrections that would
1146
                             be applied to the fluxes in frame.flux to correct them based on the input skymodel
1147
    """
1148
    # need to fit for a multiplicative factor of the sky model
1149
    # before subtraction
1150
    # we are going to use a set of bright sky lines,
1151
    # and fit a multiplicative factor + background around
1152
    # each of them individually, and then combine the results
1153
    # with outlier rejection in case a source emission line
1154
    # coincides with one of the sky lines.
1155

1156
    skyline=get_sky_lines()
1✔
1157

1158
    # half width of wavelength region around each sky line
1159
    # larger values give a better statistical precision
1160
    # but also a larger sensitivity to source features
1161
    # best solution on one dark night exposure obtained with
1162
    # a half width of 4A.
1163
    hw=4#A
1✔
1164
    tivar=frame.ivar
1✔
1165
    if frame.mask is not None :
1✔
1166
        tivar *= (frame.mask==0)
1✔
1167
        tivar *= (skymodel.ivar>0)
1✔
1168

1169
    # we precompute the quantities needed to fit each sky line + continuum
1170
    # the sky "line profile" is the actual sky model
1171
    # and we consider an additive constant
1172
    sw,swf,sws,sws2,swsf=[],[],[],[],[]
1✔
1173
    for line in skyline :
1✔
1174
        if line<=frame.wave[0] or line>=frame.wave[-1] : continue
1✔
1175
        ii=np.where((frame.wave>=line-hw)&(frame.wave<=line+hw))[0]
1✔
1176
        if ii.size<2 : continue
1✔
1177
        sw.append(np.sum(tivar[:,ii],axis=1))
1✔
1178
        swf.append(np.sum(tivar[:,ii]*frame.flux[:,ii],axis=1))
1✔
1179
        swsf.append(np.sum(tivar[:,ii]*frame.flux[:,ii]*skymodel.flux[:,ii],axis=1))
1✔
1180
        sws.append(np.sum(tivar[:,ii]*skymodel.flux[:,ii],axis=1))
1✔
1181
        sws2.append(np.sum(tivar[:,ii]*skymodel.flux[:,ii]**2,axis=1))
1✔
1182

1183
    log=get_logger()
1✔
1184
    nlines=len(sw)
1✔
1185
    corrections = np.ones(frame.flux.shape[0]).astype('f8')
1✔
1186
    for fiber in range(frame.flux.shape[0]) :
1✔
1187
        # we solve the 2x2 linear system for each fiber and sky line
1188
        # and save the results for each fiber
1189
        coef=[] # list of scale values
1✔
1190
        var=[] # list of variance on scale values
1✔
1191
        for line in range(nlines) :
1✔
1192
            if sw[line][fiber]<=0 : continue
1✔
1193
            A=np.array([[sw[line][fiber],sws[line][fiber]],[sws[line][fiber],sws2[line][fiber]]])
1✔
1194
            B=np.array([swf[line][fiber],swsf[line][fiber]])
1✔
1195
            try :
1✔
1196
                Ai=np.linalg.inv(A)
1✔
1197
                X=Ai.dot(B)
1✔
1198
                coef.append(X[1]) # the scale coef (marginalized over cst background)
1✔
1199
                var.append(Ai[1,1])
1✔
1200
            except :
×
1201
                pass
×
1202

1203
        if len(coef)==0 :
1✔
1204
            log.warning("cannot corr. throughput. for fiber %d"%fiber)
×
1205
            continue
×
1206

1207
        coef=np.array(coef)
1✔
1208
        var=np.array(var)
1✔
1209
        ivar=(var>0)/(var+(var==0)+0.005**2)
1✔
1210
        ivar_for_outliers=(var>0)/(var+(var==0)+0.02**2)
1✔
1211

1212
        # loop for outlier rejection
1213
        failed=False
1✔
1214
        for loop in range(50) :
1✔
1215
            a=np.sum(ivar)
1✔
1216
            if a <= 0 :
1✔
1217
                log.warning("cannot corr. throughput. ivar=0 everywhere on sky lines for fiber %d"%fiber)
×
1218
                failed=True
×
1219
                break
×
1220

1221
            mcoef=np.sum(ivar*coef)/a
1✔
1222
            mcoeferr=1/np.sqrt(a)
1✔
1223

1224
            nsig=3.
1✔
1225
            chi2=ivar_for_outliers*(coef-mcoef)**2
1✔
1226
            worst=np.argmax(chi2)
1✔
1227
            if chi2[worst]>nsig**2*np.median(chi2[chi2>0]) : # with rough scaling of errors
1✔
1228
                #log.debug("discard a bad measurement for fiber %d"%(fiber))
1229
                ivar[worst]=0
1✔
1230
                ivar_for_outliers[worst]=0
1✔
1231
            else :
1232
                break
1✔
1233

1234
        if failed :
1✔
1235
            continue
×
1236

1237

1238
        log.info("fiber #%03d throughput corr = %5.4f +- %5.4f (mean fiber flux=%f)"%(fiber,mcoef,mcoeferr,np.median(frame.flux[fiber])))
1✔
1239
        if mcoeferr>0.1 :
1✔
1240
            log.warning("throughput corr error = %5.4f > 0.1 is too large for fiber %d, do not apply correction"%(mcoeferr,fiber))
×
1241
        else :
1242
            corrections[fiber] = mcoef
1✔
1243

1244
    return corrections
1✔
1245

1246

1247
def qa_skysub(param, frame, skymodel, quick_look=False):
1✔
1248
    """Calculate QA on SkySubtraction
1249

1250
    Note: Pixels rejected in generating the SkyModel (as above), are
1251
    not rejected in the stats calculated here.  Would need to carry
1252
    along current_ivar to do so.
1253

1254
    Args:
1255
        param : dict of QA parameters : see qa_frame.init_skysub for example
1256
        frame : desispec.Frame object;  Should have been flat fielded
1257
        skymodel : desispec.SkyModel object
1258
        quick_look : bool, optional
1259
          If True, do QuickLook specific QA (or avoid some)
1260
    Returns:
1261
        qadict: dict of QA outputs
1262
          Need to record simple Python objects for yaml (str, float, int)
1263
    """
1264
    from desispec.qa import qalib
×
1265
    import copy
×
1266

1267
    log=get_logger()
×
1268

1269
    #- QAs
1270
    #- first subtract sky to get the sky subtracted frame. This is only for QA. Pipeline does it separately.
1271
    tempframe=copy.deepcopy(frame) #- make a copy so as to propagate frame unaffected so that downstream pipeline uses it.
×
1272
    subtract_sky(tempframe,skymodel) #- Note: sky subtract is done to get residuals. As part of pipeline it is done in fluxcalib stage
×
1273

1274
    # Sky residuals first
1275
    qadict = qalib.sky_resid(param, tempframe, skymodel, quick_look=quick_look)
×
1276

1277
    # Sky continuum
1278
    if not quick_look:  # Sky continuum is measured after flat fielding in QuickLook
×
1279
        channel = frame.meta['CAMERA'][0]
×
1280
        wrange1, wrange2 = param[channel.upper()+'_CONT']
×
1281
        skyfiber, contfiberlow, contfiberhigh, meancontfiber, skycont = qalib.sky_continuum(frame,wrange1,wrange2)
×
1282
        qadict["SKYFIBERID"] = skyfiber.tolist()
×
1283
        qadict["SKYCONT"] = skycont
×
1284
        qadict["SKYCONT_FIBER"] = meancontfiber
×
1285

1286
    if quick_look:  # The following can be a *large* dict
×
1287
        qadict_snr = qalib.SignalVsNoise(tempframe,param)
×
1288
        qadict.update(qadict_snr)
×
1289

1290
    return qadict
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc