• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

iprafols / stacking / 9285461052

29 May 2024 11:43AM UTC coverage: 99.669% (-0.3%) from 100.0%
9285461052

push

github

iprafols
fixed test suite

487 of 489 branches covered (99.59%)

Branch coverage included in aggregate %.

2 of 2 new or added lines in 1 file covered. (100.0%)

4 existing lines in 2 files now uncovered.

1318 of 1322 relevant lines covered (99.7%)

2.99 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.03
/stacking/normalizers/multiple_regions_normalization.py
1
""" This module define the class MultipleRegionsNornalization that normalizes
2
the spectra using multiple regions """
3
import logging
3✔
4
import multiprocessing
3✔
5
import os
3✔
6

7
from astropy.table import Table
3✔
8
import numpy as np
3✔
9
import pandas as pd
3✔
10

11
from stacking._version import __version__
3✔
12
from stacking.errors import NormalizerError
3✔
13
from stacking.normalizer import (Normalizer, defaults, accepted_options,
3✔
14
                                 required_options)
15
from stacking.spectrum import Spectrum
3✔
16
from stacking.utils import (update_accepted_options, update_default_options,
3✔
17
                            update_required_options)
18
from stacking.normalizers.multiple_regions_normalization_utils import (
3✔
19
    compute_norm_factors, save_correction_factors_ascii,
20
    save_norm_factors_ascii, save_norm_factors_fits, save_norm_intervals_ascii,
21
    select_final_normalisation_factor)
22

23
accepted_options = update_accepted_options(accepted_options, [
3✔
24
    "intervals", "load norm factors from", "log directory", "main interval",
25
    "min nrom sn", "num processors", "save format", "sigma_I"
26
])
27
required_options = update_required_options(required_options, ["log directory"])
3✔
28
defaults = update_default_options(
3✔
29
    defaults, {
30
        "intervals": "1300 - 1500, 2000 - 2600, 4400 - 4800",
31
        "main interval": 1,
32
        "min norm sn": 0.05,
33
        "save format": "fits.gz",
34
        "sigma_I": 0.05,
35
    })
36

37
ACCEPTED_SAVE_FORMATS = ["csv", "fits", "fits.gz", "txt"]
3✔
38

39

40
class MultipleRegionsNormalization(Normalizer):
3✔
41
    """This class is set to compute the normalization factors using multiple
42
    normalization regions
43

44
    Methods
45
    -------
46
    __init__
47
    __parse_config
48
    compute_normalisation_factors
49
    normalize_spectrum
50

51
    Attributes
52
    ----------
53
    correction_factors: array of float
54
    Correction factors that relate the different intervals
55

56
    intervals:  array of (float, float)
57
    Array containing the selected intervals. Each item must contain
58
    two floats signaling the starting and ending wavelength of the interval.
59
    Naturally, the starting wavelength must be smaller than the ending wavelength.
60

61
    log_directory: str
62
    Directory where log data is saved. Normalization factors will be saved there
63

64
    logger: logging.Logger
65
    Logger object
66

67
    main_interval: int
68
    Number of main normalizeation interval
69

70
    norm_factor: pd.DataFrame
71
    Pandas DataFrame with the normalization factors
72

73
    num_intervals: int
74
    Number of intervals
75

76
    save_format: str
77
    Saving format, e.g. 'csv', 'txt', 'fits' or 'fits.gz'
78

79
    sigma_i: float
80
    A correction to the weights so that pixels with very small variance do not
81
    dominate. Weights are computed as w = 1 / (sigma^2 + sigma_i^2)
82
    """
83

84
    def __init__(self, config):
3✔
85
        """ Initialize instance """
86

87
        self.logger = logging.getLogger(__name__)
3✔
88

89
        # load variables from config
90
        self.intervals = []
3✔
91
        self.log_directory = []
3✔
92
        self.main_interval = None
3✔
93
        self.num_intervals = None
3✔
94
        self.save_format = None
3✔
95
        self.sigma_i2 = None
3✔
96
        self.__parse_config(config)
3✔
97

98
        # initialize data frame to store normalization factors
99
        self.norm_factors = None
3✔
100
        self.correction_factors = np.zeros(self.num_intervals, dtype=float)
3✔
101

102
    def __parse_config(self, config):
3✔
103
        """Parse the configuration options
104

105
        Arguments
106
        ---------
107
        config: configparser.SectionProxy
108
        Parsed options to initialize class
109

110
        Raise
111
        -----
112
        NormalizerError upon missing required variables
113
        """
114
        intervals_str = config.get("intervals")
3✔
115
        if intervals_str is None:
3✔
116
            raise NormalizerError("Missing argument 'intervals' required by "
3✔
117
                                  "MultipleRegionsNormalization")
118
        try:
3✔
119
            self.intervals = np.array([
3✔
120
                (float(interval.split("-")[0]), float(interval.split("-")[1]))
121
                for interval in intervals_str.split(",")
122
            ])
123
        except (ValueError, IndexError) as error:
3✔
124
            raise NormalizerError(
3✔
125
                "Wrong format for variable 'intervals'. Expected "
126
                "'start0 - end0, start1 - end1, ..., startN - endN'"
127
                " where startX and endX are positive numbers. Found: "
128
                f"{intervals_str}") from error
129
        for interval in self.intervals:
3✔
130
            if interval[0] > interval[1]:
3✔
131
                raise NormalizerError(
3✔
132
                    f"Invalid interval found: {interval}. Starting wavelength "
133
                    "should be smaller than ending interval")
134
        self.num_intervals = len(self.intervals)
3✔
135

136
        self.load_norm_factors_from = config.get("load norm factors from")
3✔
137

138
        self.log_directory = config.get("log directory")
3✔
139
        if self.log_directory is None:
3✔
140
            raise NormalizerError(
3✔
141
                "Missing argument 'log directory' required by "
142
                "MultipleRegionsNormalization")
143

144
        self.main_interval = config.getint("main interval")
3✔
145
        if self.main_interval is None:
3✔
146
            raise NormalizerError(
3✔
147
                "Missing argument 'main interval' required by "
148
                "MultipleRegionsNormalization")
149
        if self.main_interval < 0:
3✔
150
            raise NormalizerError(
3✔
151
                "Invalid value for 'main interval'. Expected a positive integer. "
152
                f"Found: {self.main_interval}")
153
        if self.main_interval > self.num_intervals:
3✔
154
            raise NormalizerError(
3✔
155
                "Invalid value for 'main interval'. Selected interval "
156
                f"{self.main_interval} as main interval, but I only read "
157
                f"{len(self.intervals)} intervals (keep in mind the zero-based "
158
                "indexing in python)")
159

160
        self.min_nrom_sn = config.getfloat("min norm sn")
3✔
161
        if self.min_nrom_sn is None:
3✔
162
            raise NormalizerError("Missing argument 'min norm sn' required by "
3✔
163
                                  "MultipleRegionsNormalization")
164
        if self.min_nrom_sn < 0:
3!
UNCOV
165
            raise NormalizerError(
×
166
                "Invalid value for 'min norm sn'. Expected a positive number. "
167
                f"Found: {self.min_nrom_sn}")
168

169
        self.num_processors = config.getint("num processors")
3✔
170
        if self.num_processors is None:
3✔
171
            raise NormalizerError(
3✔
172
                "Missing argument 'num processors' required by "
173
                "MultipleRegionsNormalization")
174
        if self.num_processors == 0:
3✔
175
            self.num_processors = multiprocessing.cpu_count() // 2
3✔
176

177
        self.save_format = config.get("save format")
3✔
178
        if self.save_format is None:
3✔
179
            raise NormalizerError("Missing argument 'save format' required by "
3✔
180
                                  "MultipleRegionsNormalization")
181
        if self.save_format not in ACCEPTED_SAVE_FORMATS:
3✔
182
            raise NormalizerError(
3✔
183
                "Invalid save format. Accepted options are '" +
184
                " ".join(ACCEPTED_SAVE_FORMATS) +
185
                f"' Found: {self.save_format}")
186

187
        sigma_i = config.getfloat("sigma_I")
3✔
188
        if sigma_i is None:
3✔
189
            raise NormalizerError("Missing argument 'sigma_I' required by "
3✔
190
                                  "MultipleRegionsNormalization")
191
        if sigma_i < 0:
3✔
192
            raise NormalizerError(
3✔
193
                "Argument 'sigma_I' should be positive. Found "
194
                f"{sigma_i}")
195
        self.sigma_i2 = sigma_i * sigma_i
3✔
196

197
    def compute_correction_factors(self):
3✔
198
        """ Compute the correction factor that relate the differnt intervals
199

200
        Raise
201
        -----
202
        NormalizerError if any of the correction factor cannot be computed
203
        """
204
        for index in range(self.num_intervals):
3✔
205
            if index == self.main_interval:
3✔
206
                self.correction_factors[index] = 1
3✔
207
            else:
208
                aux = self.norm_factors[
3✔
209
                    ~self.norm_factors[f"norm factor {index}"].isna() & ~self.
210
                    norm_factors[f"norm factor {self.main_interval}"].isna()]
211
                if aux.shape[0] > 0:
3✔
212
                    self.correction_factors[index] = (
3✔
213
                        aux[f"norm factor {self.main_interval}"].mean() /
214
                        aux[f"norm factor {index}"].mean())
215
                else:
216
                    raise NormalizerError(
3✔
217
                        "Error computing the correction for normalisation "
218
                        f"factor interval {index}. No common measurements with "
219
                        "the main interval were found.")
220

221
    def compute_norm_factors(self, spectra):
3✔
222
        """ Compute the normalization factors
223

224
        Arguments
225
        ---------
226
        spectra: list of Spectrum
227
        The list of spectra
228
        """
229
        # load from file
230
        if self.load_norm_factors_from is not None:
3✔
231
            self.logger.progress("Found a folder to read them instead")
3✔
232
            self.norm_factors, self.correction_factors = self.load_norm_factors(
3✔
233
                self.load_norm_factors_from)
234

235
        # compute normalization factors
236
        else:
237
            # first compute individual normalisation factors
238
            arguments = [(spectrum.flux_common_grid, spectrum.ivar_common_grid,
3✔
239
                          Spectrum.common_wavelength_grid, self.num_intervals,
240
                          self.intervals, self.sigma_i2)
241
                         for spectrum in spectra]
242

243
            if self.num_processors > 1:
3✔
244
                context = multiprocessing.get_context('fork')
3✔
245
                with context.Pool(processes=self.num_processors) as pool:
3✔
246
                    norm_factors = pool.starmap(compute_norm_factors, arguments)
3✔
247
            else:
248
                norm_factors = [
3✔
249
                    compute_norm_factors(*argument) for argument in arguments
250
                ]
251

252
            # unpack them together in a dataframe
253
            self.norm_factors = pd.DataFrame(
3✔
254
                norm_factors,
255
                columns=[
256
                    f"{col_type} {index}" for index in range(self.num_intervals)
257
                    for col_type in
258
                    ["norm factor", "norm S/N", "num pixels", "total weight"]
259
                ])
260
            self.norm_factors["specid"] = [
3✔
261
                spectrum.specid for spectrum in spectra
262
            ]
263

264
            # create relations between the main normalisation factor and the secondary
265
            self.compute_correction_factors()
3✔
266

267
            # select final normalisation factor
268
            self.select_final_normalisation_factor()
3✔
269

270
    def load_norm_factors(self, folder):
3✔
271
        """Load normalilzation factors from file
272

273
        Arguments
274
        ---------
275
        folder: str
276
        Folder where the normalization files are saved.
277
        Must contain a file named normalization_factors with a valid extension
278
        (see ACCEPTED_SAVE_FORMATS)
279

280
        Return
281
        ------
282
        norm_factors: pd.DataFrame
283
        A pandas DataFrame with the read normalization_factors
284

285
        correction_factors: array of float
286
        The correction factors that relate the differnt intervals.
287
        """
288
        file_format = None
3✔
289
        filename = None
3✔
290
        for item in ACCEPTED_SAVE_FORMATS:
3✔
291
            filename = f"{os.path.expandvars(folder)}normalization_factors.{item}"
3✔
292
            if os.path.exists(filename):
3✔
293
                file_format = item
3✔
294
                break
3✔
295

296
        if file_format is None:
3✔
297
            raise NormalizerError(
3✔
298
                "Unable to find file normalization_factors.EXT in the specified "
299
                "folder, where EXT is one of '" +
300
                " ".join(ACCEPTED_SAVE_FORMATS) +
301
                f"'. Specified folder: {folder}")
302
        if file_format in ["csv", "txt"]:
3✔
303
            norm_factors = pd.read_csv(filename, delim_whitespace=True)
3✔
304

305
            correction_factors_filename = (
3✔
306
                f"{os.path.expandvars(folder)}correction_factors.{file_format}")
307
            if os.path.exists(correction_factors_filename):
3✔
308
                correction_factors = pd.read_csv(
3✔
309
                    correction_factors_filename,
310
                    delim_whitespace=True)["correction_factor"].values
311
            else:
312
                raise NormalizerError(
3✔
313
                    f"Unable to find file correction_factors.{file_format}. "
314
                    f"Specified folder: {os.path.expandvars(folder)}")
315
        elif file_format in ["fits", "fits.gz"]:
3✔
316
            norm_factors = Table.read(filename,
3✔
317
                                      format='fits',
318
                                      hdu="NORM_FACTORS").to_pandas()
319

320
            correction_factors = Table.read(filename,
3✔
321
                                            format='fits',
322
                                            hdu="CORRECTION_FACTORS").to_pandas(
323
                                            )["CORRECTION_FACTOR"].values
324
        # this should never enter unless new reading formats are not properly added
325
        else:  # pragma: no cover
326
            raise NormalizerError(
327
                f"Don't know what to do with file format {file_format}. "
328
                "This is one of the supported formats, maybe it "
329
                "was not properly coded. If you did the change yourself, check "
330
                "that you added the behaviour of the new mode to method `save_norm_factors`. "
331
                "Otherwise contact 'stacking' developpers.")
332

333
        return norm_factors, correction_factors
3✔
334

335
    def normalize_spectrum(self, spectrum):
3✔
336
        """ Set the flux as normalized flux
337

338
        Arguments
339
        ---------
340
        spectrum: Spectrum
341
        A spectrum to normalize
342

343
        Return
344
        ------
345
        spectrum: Spectrum
346
        The normalized spectrum
347
        """
348
        try:
3✔
349
            norm_factor = self.norm_factors[
3✔
350
                self.norm_factors["specid"] ==
351
                spectrum.specid]["norm factor"].values[0]
352
        except IndexError as error:
3✔
353
            raise NormalizerError(
3✔
354
                f"Failed to normalize spectrum with specid={spectrum.specid}. "
355
                "Could not find the specid in the norm_factor table. If you "
356
                "loaded the table, make sure the table is correct. Otherwise "
357
                "contact stacking developers") from error
358

359
        if norm_factor > 0.0:
3✔
360
            spectrum.set_normalized_flux(spectrum.flux_common_grid /
3✔
361
                                         norm_factor)
362
        else:
363
            spectrum.set_normalized_flux(
3✔
364
                np.zeros_like(spectrum.flux_common_grid) + np.nan)
365
        return spectrum
3✔
366

367
    def save_norm_factors(self):
3✔
368
        """ Save the normalisation factors for future reference """
369
        # norm factors loaded, do not save
370
        if self.load_norm_factors_from is not None:
3✔
371
            self.logger.progress(
3✔
372
                "Normalization factors were loaded from file. Skipping saving "
373
                "operation")
374
        else:
375
            filename = f"{self.log_directory}normalization_factors.{self.save_format}"
3✔
376

377
            # save as ascii file
378
            if self.save_format in ["csv", "txt"]:
3✔
379
                save_norm_factors_ascii(filename, self.norm_factors)
3✔
380

381
                # intervals used
382
                filename = f"{self.log_directory}normalization_intervals.{self.save_format}"
3✔
383
                save_norm_intervals_ascii(filename, self.intervals)
3✔
384

385
                # correction_factors
386
                filename = f"{self.log_directory}correction_factors.{self.save_format}"
3✔
387
                save_correction_factors_ascii(filename, self.correction_factors)
3✔
388

389
            # save as fits file
390
            elif self.save_format in ["fits", "fits.gz"]:
3✔
391
                save_norm_factors_fits(filename, self.norm_factors,
3✔
392
                                       self.intervals, self.correction_factors)
393

394
            # this should never enter unless new saving formats are not properly added
395
            else:  # pragma: no cover
396
                raise NormalizerError(
397
                    f"Don't know what to do with save format {self.save_format}. "
398
                    "This is one of the supported saving formats, maybe it "
399
                    "was not properly coded. If you did the change yourself, check "
400
                    "that you added the behaviour of the new mode to method `save_norm_factors`. "
401
                    "Otherwise contact 'stacking' developpers.")
402

403
    def select_final_normalisation_factor(self):
3✔
404
        """Select the final normalization factors"""
405
        self.norm_factors[["norm factor", "norm S/N",
3✔
406
                           "chosen interval"]] = self.norm_factors.apply(
407
                               select_final_normalisation_factor,
408
                               axis=1,
409
                               args=(self.correction_factors, self.min_nrom_sn),
410
                               result_type='expand',
411
                           )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc