• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

iprafols / SQUEzE / 20310179563

17 Dec 2025 04:38PM UTC coverage: 65.567% (+8.1%) from 57.439%
20310179563

Pull #84

gihtub

web-flow
Merge d656bc5ca into 741eeef2c
Pull Request #84: added new peak finder, optimized code

240 of 462 branches covered (51.95%)

Branch coverage included in aggregate %.

150 of 183 new or added lines in 11 files covered. (81.97%)

4 existing lines in 4 files now uncovered.

1032 of 1478 relevant lines covered (69.82%)

2.09 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

43.51
/py/squeze/model.py
1
"""
2
    SQUEzE
3
    ======
4

5
    This file implements the class Model, that is used to store, train, and
6
    execute the quasar finding model
7
"""
8
__author__ = "Ignasi Perez-Rafols (iprafols@gmail.com)"
3✔
9

10
import os
3✔
11

12
import numpy as np
3✔
13
import pandas as pd
3✔
14
import fitsio
3✔
15

16
from squeze.error import Error
3✔
17
from squeze.random_forest_classifier import RandomForestClassifier
3✔
18
from squeze.utils import save_json, load_json
3✔
19

20

21
def find_prob(row, columns):
3✔
22
    """ Find the probability of a instance being a quasar by
23
        adding the probabilities of classes 3 and 30. If
24
        the probability for this classes are not found,
25
        then return np.nan
26

27
        Parameters
28
        ----------
29
        row : pd.Series
30
        A row in the DataFrame.
31

32
        colums: list of string
33
        The column labels of the Series.
34

35
        Returns
36
        -------
37
        The probability of the object being a quasar.
38
        This probability is the sum of the probabilities for classes
39
        3 and 30. If one of them is not available, then the probability
40
        is taken as the other one. If both are unavailable, then return
41
        np.nan
42
        """
43
    if "PROB_CLASS3" in columns and "PROB_CLASS30" in columns:
3!
44
        prob = row["PROB_CLASS3"] + row["PROB_CLASS30"]
3✔
UNCOV
45
    elif "PROB_CLASS30" in columns:
×
NEW
46
        prob = row["PROB_CLASS30"]
×
47
    elif "PROB_CLASS3" in columns:
×
NEW
48
        prob = row["PROB_CLASS3"]
×
49
    else:
50
        prob = np.nan
×
51
    return prob
3✔
52

53

54
class Model:
3✔
55
    """ Create, train and/or execute the quasar model to find quasars
56

57
    CLASS: Model
58
    PURPOSE: Create, train and/or execute the quasar model to find
59
    quasars
60
    """
61

62
    def __init__(self, config):
3✔
63
        """ Initialize class instance.
64

65
        Arguments
66
        ---------
67
        config: Config
68
        A configuration instance
69
        """
70
        self.config = config
3✔
71
        model_config = self.config.get_section("model")
3✔
72

73
        self.name = model_config.get("filename")
3✔
74
        if self.name is None:
3!
75
            message = "In section [model], variable 'filename' is required"
×
76
            raise Error(message)
×
77

78
        selected_cols = model_config.get("selected cols")
3✔
79
        if selected_cols is None:
3!
80
            message = "In section [model], variable 'selected cols' is required"
×
81
            raise Error(message)
×
82
        self.selected_cols = selected_cols.split()
3✔
83

84
        random_state = model_config.getint("random state")
3✔
85
        if random_state is None:
3!
86
            message = "In section [model], variable 'random state' is required"
×
87
            raise Error(message)
×
88

89
        clf_options = model_config.get("random forest options")
3✔
90
        if selected_cols is None:
3!
91
            message = (
×
92
                "In section [model], variable 'random forest options' is required"
93
            )
94
            raise Error(message)
×
95
        self.clf_options = load_json(os.path.expandvars(clf_options))
3✔
96

97
        # initialize random forest classifier(s)
98
        if "high" in self.clf_options.keys() and "low" in self.clf_options.keys(
3!
99
        ):
100
            self.highlow_split = True
3✔
101
            self.clf_options.get("high")["random_state"] = random_state
3✔
102
            self.clf_options.get("low")["random_state"] = random_state
3✔
103
            self.clf_high = RandomForestClassifier(
3✔
104
                **self.clf_options.get("high"))
105
            self.clf_low = RandomForestClassifier(**self.clf_options.get("low"))
3✔
106
        else:
107
            self.highlow_split = False
×
108
            self.clf_options = {"all": self.clf_options}
×
109
            self.clf_options.get("all")["random_state"] = random_state
×
110
            self.clf = RandomForestClassifier(**self.clf_options.get("all"))
×
111

112
    def __find_class(self, row, train):
3✔
113
        """ Find the class the instance belongs to.
114

115
        If train is set to True, then find the class from class_person.
116
        For quasars and galaxies add a new class if the redshift is wrong.
117
        If train is False, then find the class the instance belongs
118
        to from the highest of the computed probability.
119

120
        Arguments
121
        ---------
122
        row : pd.Series
123
        A row in the DataFrame.
124

125
        train : bool
126
        If True, then dinf the class from the truth table,
127
        otherwise find it from the computed probabilities
128

129
        Return
130
        ------
131
        The class the instance belongs to:
132
        "star": 1
133
        "quasar": 3
134
        "quasar, wrong z": 35
135
        "quasar, bal": 30
136
        "quasar, bal, wrong z": 305
137
        "galaxy": 4
138
        "galaxy, wrong z": 45
139
        """
140
        # find class from the truth table
141
        if train:
3!
142
            if row["CLASS_PERSON"] == 30 and not row["CORRECT_REDSHIFT"]:
×
143
                data_class = 305
×
144
            elif row["CLASS_PERSON"] == 3 and not row["CORRECT_REDSHIFT"]:
×
145
                data_class = 35
×
146
            elif row["CLASS_PERSON"] == 4 and not row["CORRECT_REDSHIFT"]:
×
147
                data_class = 45
×
148
            else:
149
                data_class = row["CLASS_PERSON"]
×
150

151
        # find class from the probabilities
152
        else:
153
            data_class = -1
3✔
154
            aux_prob = 0.0
3✔
155
            if self.highlow_split:
3!
156
                class_labels = self.clf_high.classes
3✔
157
            else:
158
                class_labels = self.clf.classes
×
159
            for class_label in class_labels:
3✔
160
                if row[f"PROB_CLASS{int(class_label):d}"] > aux_prob:
3✔
161
                    aux_prob = row[f"PROB_CLASS{int(class_label):d}"]
3✔
162
                    data_class = int(class_label)
3✔
163

164
        return data_class
3✔
165

166
    def save_model(self):
3✔
167
        """ Save the model"""
168

169
        if self.name.endswith(".json"):
×
170
            self.save_model_as_json()
×
171
        else:
172
            self.save_model_as_fits()
×
173
        self.save_model_config()
×
174

175
    def save_model_as_json(self):
3✔
176
        """ Save the model as a json file"""
177
        config = self.config
×
178
        del self.config
×
179
        save_json(os.path.expandvars(self.name), self)
×
180
        self.config = config
×
181

182
    def save_model_as_fits(self):
3✔
183
        """ Save the model as a fits file"""
184
        results = fitsio.FITS(self.name.replace(".json", ".fits.gz"),
×
185
                              'rw',
186
                              clobber=True)
187

188
        # Create model HDU(s) to store the classifiers
189
        if self.highlow_split:
×
190
            classifier_names = ["high", "low"]
×
191
            classifiers = [self.clf_high, self.clf_low]
×
192
        else:
193
            classifier_names = ["all"]
×
194
            classifiers = [self.clf]
×
195

196
        for classifier_name, classifier in zip(classifier_names, classifiers):
×
197
            header = [{
×
198
                "name": key,
199
                "value": value,
200
            } for key, value in self.clf_options.get(classifier_name).items()]
201
            if classifier_names != "all":
×
202
                header += [{
×
203
                    "name":
204
                        "COMMENT",
205
                    "value": ("Options passed to the classifier for"
206
                              f"{classifier_names} redshift quasars. Redshifts "
207
                              "are split at 2.1")
208
                }]
209
            else:
210
                header += [{
×
211
                    "name":
212
                        "COMMENT",
213
                    "value": ("Options passed to the classifier for"
214
                              "all redshift quasars")
215
                }]
216

217
            num_trees = classifier.num_trees
×
218
            header += [{
×
219
                "name": "N_TREES",
220
                "value": num_trees,
221
            }, {
222
                "name": "N_CAT",
223
                "value": classifier.num_categories,
224
            }]
225

226
            names = ["CLASSES"]
×
227
            cols = [classifier.classes]
×
228

229
            # create HDU
230
            results.write(cols,
×
231
                          names=names,
232
                          header=header,
233
                          extname=f"{classifier_name}INFO")
234
            del header, names, cols
×
235

236
            # append classifier trees in different HDUs
237
            for index in range(num_trees):
×
238
                names, cols = classifier.to_fits_hdu(index)
×
239
                results.write(cols,
×
240
                              names=names,
241
                              extname=f"{classifier_name}{index}")
242
                del names, cols
×
243

244
        # End of model HDU(s)
245
        results.close()
×
246

247
    def save_model_config(self):
3✔
248
        """ Save the model configuration"""
249
        if self.name.endswith(".json"):
×
250
            outname = os.path.expandvars(self.name.replace(".json", ".ini"))
×
251
        else:
252
            outname = os.path.expandvars(self.name.replace(".fits.gz", ".ini"))
×
253
        with open(outname, 'w', encoding="utf-8") as config_file:
×
254
            self.config.write(config_file)
×
255

256
    def compute_probability(self, data_frame):
3✔
257
        """ Compute the probability of a list of candidates to be quasars
258

259
            Parameters
260
            ----------
261
            data_frame : pd.DataFrame
262
            The dataframe where the probabilities will be predicted
263
            """
264

265
        if self.highlow_split:
3!
266
            # high-z split
267
            # compute probabilities for each of the classes
268
            data_frame_high = data_frame[data_frame["Z_TRY"] >= 2.1].copy()
3✔
269
            if data_frame_high.shape[0] > 0:
3!
270
                aux = data_frame_high.fillna(-9999.99)
3✔
271
                data_vector = aux[self.selected_cols[:-2]].values
3✔
272
                data_class_probs = self.clf_high.predict_proba(data_vector)
3✔
273

274
                # save the probability for each of the classes
275
                for index, class_label in enumerate(self.clf_high.classes):
3✔
276
                    data_frame_high[
3✔
277
                        f"PROB_CLASS{int(class_label):d}"] = data_class_probs[:,
278
                                                                              index]
279

280
            # low-z split
281
            # compute probabilities for each of the classes
282
            data_frame_low = data_frame[(data_frame["Z_TRY"] < 2.1)].copy()
3✔
283
            if data_frame_low.shape[0] > 0:
3!
284
                aux = data_frame_low.fillna(-9999.99)
3✔
285
                data_vector = aux[self.selected_cols[:-2]].values
3✔
286
                data_class_probs = self.clf_low.predict_proba(data_vector)
3✔
287

288
                # save the probability for each of the classes
289
                for index, class_label in enumerate(self.clf_low.classes):
3✔
290
                    data_frame_low[
3✔
291
                        f"PROB_CLASS{int(class_label):d}"] = data_class_probs[:,
292
                                                                              index]
293

294
            # non-peaks
295
            data_frame_nonpeaks = data_frame[data_frame["Z_TRY"].isna()].copy()
3✔
296
            if data_frame_nonpeaks.shape[0] > 0:
3!
297
                # save the probability for each of the classes
298
                for index, class_label in enumerate(self.clf_low.classes):
3✔
299
                    data_frame_nonpeaks[
3✔
300
                        f"PROB_CLASS{int(class_label):d}"] = np.nan
301

302
            # join datasets
303
            if (data_frame_high.shape[0] == 0 and
3!
304
                    data_frame_low.shape[0] == 0 and
305
                    data_frame_nonpeaks.shape[0] == 0):
306
                data_frame = data_frame_high
×
307
            else:
308
                data_frame = pd.concat(
3✔
309
                    [data_frame_high, data_frame_low, data_frame_nonpeaks],
310
                    sort=False)
311

312
        else:
313
            # peaks
314
            # compute probabilities for each of the classes
315
            data_frame_peaks = data_frame[data_frame["Z_TRY"] >= 0.0].copy()
×
316
            if data_frame_peaks.shape[0] > 0:
×
317
                data_vector = data_frame_peaks[self.selected_cols[:-2]].fillna(
×
318
                    -9999.99).astype(float).values
319
                data_class_probs = self.clf.predict_proba(data_vector)
×
320

321
                # save the probability for each of the classes
322
                for index, class_label in enumerate(self.clf.classes):
×
323
                    data_frame_peaks[
×
324
                        f"PROB_CLASS{int(class_label):d}"] = data_class_probs[:,
325
                                                                              index]
326

327
            # non-peaks
328
            data_frame_nonpeaks = data_frame[data_frame["Z_TRY"].isna()].copy()
×
329
            if not data_frame_nonpeaks.shape[0] == 0:
×
330
                # save the probability for each of the classes
331
                for index, class_label in enumerate(self.clf.classes):
×
332
                    data_frame_nonpeaks[
×
333
                        f"PROB_CLASS{int(class_label):d}"] = np.nan
334

335
            # join datasets
336
            if (data_frame_peaks.shape[0] == 0 and
×
337
                    data_frame_nonpeaks.shape[0] == 0):
338
                data_frame = data_frame_peaks
×
339
            else:
340
                data_frame = pd.concat([data_frame_peaks, data_frame_nonpeaks],
×
341
                                       sort=False)
342

343
        # predict class and find the probability of the candidate being a quasar
344
        data_frame["CLASS_PREDICTED"] = data_frame.apply(self.__find_class,
3✔
345
                                                         axis=1,
346
                                                         args=(False,))
347
        data_frame["PROB"] = data_frame.apply(find_prob,
3✔
348
                                              axis=1,
349
                                              args=(data_frame.columns,))
350

351
        # flag duplicated instances
352
        data_frame["DUPLICATED"] = data_frame.sort_values(
3✔
353
            ["SPECID", "PROB"],
354
            ascending=False).duplicated(subset=("SPECID",),
355
                                        keep="first").sort_index()
356

357
        return data_frame
3✔
358

359
    def train(self, data_frame):
3✔
360
        """ Train all the instances of the classifiers to estimate the probability
361
            of a candidate being a quasar
362

363
            Parameters
364
            ----------
365
            data_frame : pd.DataFrame
366
            The dataframe with which the model is trained
367
            """
368
        # train classifier
369
        if self.highlow_split:
×
370
            # high-z split
371
            data_frame_high = data_frame[data_frame["Z_TRY"] >= 2.1].fillna(
×
372
                -9999.99)
373
            data_vector = data_frame_high[self.selected_cols[:-2]].values
×
374
            data_class = data_frame_high.apply(self.__find_class,
×
375
                                               axis=1,
376
                                               args=(True,))
377
            self.clf_high.fit(data_vector, data_class)
×
378
            # low-z split
379
            data_frame_low = data_frame[(data_frame["Z_TRY"] < 2.1) & (
×
380
                data_frame["Z_TRY"] >= 0.0)].fillna(-9999.99)
381
            data_vector = data_frame_low[self.selected_cols[:-2]].values
×
382
            data_class = data_frame_low.apply(self.__find_class,
×
383
                                              axis=1,
384
                                              args=(True,))
385
            self.clf_low.fit(data_vector, data_class)
×
386

387
        else:
NEW
388
            data_frame = data_frame[(
×
389
                data_frame["Z_TRY"]
390
                >= 0.0)][self.selected_cols].fillna(-9999.99)
391
            data_vector = data_frame[self.selected_cols[:-2]].values
×
392
            data_class = data_frame.apply(self.__find_class,
×
393
                                          axis=1,
394
                                          args=(True,))
395
            self.clf.fit(data_vector, data_class)
×
396

397
    @classmethod
3✔
398
    def from_file(cls, config, filename):
3✔
399
        """ Construct model from file
400

401
        Arguments
402
        ---------
403
        config: Config
404
        A configuration instance
405

406
        filename: str
407
        The name of the json file containing the model. The corresponding
408
        configuration file (ending with ini extension) must also exist
409

410
        Return
411
        ------
412
        cls_instance: Model
413
        The loaded instance
414
        """
415
        if filename.endswith(".json"):
3!
416
            cls_instance = cls.from_json(config, filename)
3✔
417
        else:
418
            cls_instance = cls.from_fits(config, filename)
×
419
        return cls_instance
3✔
420

421
    @classmethod
3✔
422
    def from_json(cls, config, filename):
3✔
423
        """ This function deserializes a json string to correclty build the class.
424

425
        It uses the deserialization function of class SimpleSpectrum to reconstruct
426
        the instances of Spectrum. For this function to work, data should have been
427
        serialized using the serialization method specified in `save_json` function
428
        present on `utils.py`
429

430
        Arguments
431
        ---------
432
        config: Config
433
        A configuration instance
434

435
        filename: str
436
        The name of the json file containing the model. The corresponding
437
        configuration file (ending with ini extension) must also exist
438

439
        Return
440
        ------
441
        cls_instance: Model
442
        The loaded instance
443
        """
444
        cls_instance = cls(config)
3✔
445

446
        # now update the instance to the current values
447
        data = load_json(filename)
3✔
448
        if cls_instance.highlow_split:
3!
449
            cls_instance.clf_high = RandomForestClassifier.from_json(
3✔
450
                data.get("clf_high"))
451
            cls_instance.clf_low = RandomForestClassifier.from_json(
3✔
452
                data.get("clf_low"))
453
        else:
454
            cls_instance.clf = RandomForestClassifier.from_json(data.get("clf"))
×
455

456
        return cls_instance
3✔
457

458
    @classmethod
3✔
459
    def from_fits(cls, config, filename):
3✔
460
        """ This function loads the model information from a fits file.
461

462
        The expected shape for the fits file is that provided by the
463
        function save_model_as_fits.
464

465
        Arguments
466
        ---------
467
        config: Config
468
        A configuration instance
469

470
        filename: str
471
        The name of the json file containing the model. The corresponding
472
        configuration file (ending with ini extension) must also exist
473

474
        Return
475
        ------
476
        cls_instance: Model
477
        The loaded instance
478
        """
479
        cls_instance = cls(config)
×
480

481
        # now update the instance to the current values
482
        hdul = fitsio.FITS(os.path.expandvars(filename))
×
483
        if cls_instance.highlow_split:
×
484
            cls_instance.clf_high = RandomForestClassifier.from_fits_hdul(
×
485
                hdul,
486
                "high",
487
                "HIGHINFO",
488
                args=cls_instance.clf_options.get("high"))
489
            cls_instance.clf_low = RandomForestClassifier.from_fits_hdul(
×
490
                hdul,
491
                "low",
492
                "LOWINFO",
493
                args=cls_instance.clf_options.get("low"))
494
        else:
495
            cls_instance.clf = RandomForestClassifier.from_fits_hdul(
×
496
                hdul, "all", "ALLINFO", args=cls_instance.clf_options)
497

498
        hdul.close()
×
499
        return cls_instance
×
500

501

502
if __name__ == '__main__':
503
    pass
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc