• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

icecube / flarestack / 4536597083

pending completion
4536597083

Pull #268

github

GitHub
Merge 0d13f2846 into e550ab59c
Pull Request #268: Compliance with black v23

26 of 26 new or added lines in 11 files covered. (100.0%)

4405 of 5755 relevant lines covered (76.54%)

2.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.61
/flarestack/core/angular_error_modifier.py
1
import logging
3✔
2
import numpy as np
3✔
3
import os
3✔
4
from flarestack.core.energy_pdf import EnergyPDF
3✔
5
from flarestack.shared import (
3✔
6
    min_angular_err,
7
    base_floor_quantile,
8
    floor_pickle,
9
    pull_pickle,
10
    weighted_quantile,
11
)
12
from flarestack.utils.dynamic_pull_correction import (
3✔
13
    create_quantile_floor_0d,
14
    create_quantile_floor_0d_e,
15
    create_quantile_floor_1d,
16
    create_quantile_floor_1d_e,
17
    create_pull_0d_e,
18
    create_pull_1d,
19
    create_pull_1d_e,
20
    create_pull_2d,
21
    create_pull_2d_e,
22
)
23
import pickle as Pickle
3✔
24
from scipy.interpolate import interp1d, RectBivariateSpline
3✔
25
from flarestack.utils.make_SoB_splines import (
3✔
26
    get_gamma_support_points,
27
    get_gamma_precision,
28
    _around,
29
)
30
import numexpr
3✔
31
import inspect
3✔
32

33
logger = logging.getLogger(__name__)
3✔
34

35

36
class BaseFloorClass(object):
3✔
37
    subclasses = {}
3✔
38

39
    def __init__(self, floor_dict):
3✔
40
        self.floor_dict = floor_dict
3✔
41
        self.season = floor_dict["season"]
3✔
42
        self.pickle_name = floor_pickle(floor_dict)
3✔
43

44
    @classmethod
3✔
45
    def register_subclass(cls, floor_name):
3✔
46
        """Adds a new subclass of BaseFloorClass, with class name equal to
47
        "floor_name".
48
        """
49

50
        def decorator(subclass):
3✔
51
            cls.subclasses[floor_name] = subclass
3✔
52
            return subclass
3✔
53

54
        return decorator
3✔
55

56
    @classmethod
3✔
57
    def create(cls, floor_dict):
3✔
58
        floor_name = floor_dict["floor_name"]
3✔
59

60
        if floor_name not in cls.subclasses:
3✔
61
            raise ValueError("Bad floor name {}".format(floor_name))
×
62

63
        return cls.subclasses[floor_name](floor_dict)
3✔
64

65
    def floor(self, data):
3✔
66
        return np.array([0.0 for _ in data])
×
67

68
    def apply_floor(self, data):
3✔
69
        mask = data["raw_sigma"] < self.floor(data)
3✔
70
        new_data = data.copy()
3✔
71
        new_data["sigma"][mask] = np.sqrt(
3✔
72
            self.floor(data[mask].copy()) ** 2.0 + data["raw_sigma"][mask].copy() ** 2.0
73
        )
74
        return new_data
3✔
75

76
    def apply_dynamic(self, data):
3✔
77
        return data
×
78

79
    def apply_static(self, data):
3✔
80
        return data
×
81

82

83
@BaseFloorClass.register_subclass("no_floor")
3✔
84
class NoFloor(BaseFloorClass):
3✔
85
    pass
3✔
86

87

88
class BaseStaticFloor(BaseFloorClass):
3✔
89
    """Class that enables the application of a static floor. Rewrites the
90
    apply_static method to update the "sigma" field AND update the "raw
91
    sigma" field. This means that the floor is applied once, and then is
92
    permanently in effect. It will run faster because it will not be
93
    reapplied in each iteration.
94
    """
95

96
    def apply_static(self, data):
3✔
97
        data = self.apply_floor(data)
3✔
98
        data["raw_sigma"] = data["sigma"]
3✔
99
        return data
3✔
100

101

102
class BaseDynamicFloorClass(BaseFloorClass):
3✔
103
    """Class that enables the application of a dynamic floor. Rewrites the
104
    apply_dynamic method to update only the "sigma" field". This means that
105
    the floor is applied for each iteration. It will run slower because it will
106
    be reapplied in each iteration.
107
    """
108

109
    def apply_dynamic(self, data):
3✔
110
        return self.floor(data)
×
111

112

113
@BaseFloorClass.register_subclass("static_floor")
3✔
114
class StaticFloor(BaseStaticFloor):
3✔
115
    def __init__(self, floor_dict):
3✔
116
        BaseFloorClass.__init__(self, floor_dict)
3✔
117

118
        try:
3✔
119
            self.min_error = np.deg2rad(floor_dict["min_error_deg"])
3✔
120
        except KeyError:
3✔
121
            self.min_error = min_angular_err
3✔
122

123
        logger.debug(
3✔
124
            "Applying an angular error floor of {0} degrees".format(
125
                np.degrees(self.min_error)
126
            )
127
        )
128

129
    def floor(self, data):
3✔
130
        return np.array([self.min_error for _ in data])
3✔
131

132

133
class BaseQuantileFloor(BaseFloorClass):
3✔
134
    def __init__(self, floor_dict):
3✔
135
        try:
×
136
            self.floor_quantile = floor_dict["floor_quantile"]
×
137
        except KeyError:
×
138
            self.floor_quantile = base_floor_quantile
×
139
            floor_dict["floor_quantile"] = self.floor_quantile
×
140

141
        BaseFloorClass.__init__(self, floor_dict)
×
142

143
        logger.debug(
×
144
            "Applying an angular error floor using quantile {0}".format(
145
                self.floor_quantile
146
            )
147
        )
148

149
        if not os.path.isfile(self.pickle_name):
×
150
            self.create_pickle()
×
151
        else:
152
            logger.debug("Loading from".format(self.pickle_name))
×
153

154
        with open(self.pickle_name, "r") as f:
×
155
            pickled_data = Pickle.load(f)
×
156

157
        self.f = self.create_function(pickled_data)
×
158

159
    def create_pickle(self):
3✔
160
        pass
×
161

162
    def create_function(self, pickled_array):
3✔
163
        pass
×
164

165

166
@BaseFloorClass.register_subclass("quantile_floor_0d")
3✔
167
class QuantileFloor0D(BaseQuantileFloor, BaseStaticFloor):
3✔
168
    def create_pickle(self):
3✔
169
        create_quantile_floor_0d(self.floor_dict)
×
170

171
    def create_function(self, pickled_array):
3✔
172
        return lambda data, params: np.array([pickled_array for _ in data])
×
173

174

175
@BaseFloorClass.register_subclass("quantile_floor_0d_e")
3✔
176
class QuantileFloorEParam0D(BaseQuantileFloor, BaseDynamicFloorClass):
3✔
177
    def create_pickle(self):
3✔
178
        create_quantile_floor_0d_e(self.floor_dict)
×
179

180
    def create_function(self, pickled_array):
3✔
181
        func = interp1d(pickled_array[0], pickled_array[1])
×
182
        return lambda data, params: np.array([func(params) for _ in data])
×
183

184

185
@BaseFloorClass.register_subclass("quantile_floor_1d")
3✔
186
class QuantileFloor1D(BaseQuantileFloor, BaseStaticFloor):
3✔
187
    def create_pickle(self):
3✔
188
        create_quantile_floor_1d(self.floor_dict)
×
189

190
    def create_function(self, pickled_array):
3✔
191
        func = interp1d(pickled_array[0], pickled_array[1])
×
192
        return lambda data, params: func(data["logE"])
×
193

194

195
@BaseFloorClass.register_subclass("quantile_floor_1d_e")
3✔
196
class QuantileFloor1D(BaseQuantileFloor, BaseDynamicFloorClass):
3✔
197
    def create_pickle(self):
3✔
198
        create_quantile_floor_1d_e(self.floor_dict)
×
199

200
    def create_function(self, pickled_array):
3✔
201
        func = RectBivariateSpline(
×
202
            pickled_array[0],
203
            pickled_array[1],
204
            np.log(pickled_array[2]),
205
            kx=1,
206
            ky=1,
207
            s=0,
208
        )
209
        return lambda data, params: np.array(
×
210
            [np.exp(func(x["logE"], params[0])[0]) for x in data]
211
        ).T
212

213

214
class BaseAngularErrorModifier(object):
3✔
215
    subclasses = {}
3✔
216

217
    def __init__(self, pull_dict):
3✔
218
        self.season = pull_dict["season"]
3✔
219
        self.floor = BaseFloorClass.create(pull_dict)
3✔
220
        self.pull_dict = pull_dict
3✔
221
        self.pull_name = pull_pickle(pull_dict)
3✔
222

223
        # precision in gamma
224
        self.precision = pull_dict.get("gamma_precision", "flarestack")
3✔
225

226
    @classmethod
3✔
227
    def register_subclass(cls, aem_name):
3✔
228
        """Adds a new subclass of BaseAngularErrorModifier,
229
        with class name equal to aem_name.
230

231
        :param aem_name: AngularErrorModifier name
232
        :return: AngularErrorModifier object
233
        """
234

235
        def decorator(subclass):
3✔
236
            cls.subclasses[aem_name] = subclass
3✔
237
            return subclass
3✔
238

239
        return decorator
3✔
240

241
    @classmethod
3✔
242
    def create(
3✔
243
        cls,
244
        season,
245
        e_pdf_dict,
246
        floor_name="static_floor",
247
        aem_name="no_modifier",
248
        **kwargs
249
    ):
250
        pull_dict = dict()
3✔
251
        pull_dict["season"] = season
3✔
252
        pull_dict["e_pdf_dict"] = e_pdf_dict
3✔
253
        pull_dict["floor_name"] = floor_name
3✔
254
        pull_dict["aem_name"] = aem_name
3✔
255
        pull_dict.update(kwargs)
3✔
256

257
        if aem_name not in cls.subclasses:
3✔
258
            raise ValueError("Bad pull name {}".format(aem_name))
×
259

260
        return cls.subclasses[aem_name](pull_dict)
3✔
261

262
    def pull_correct(self, data, params):
3✔
263
        return data
×
264

265
    def pull_correct_static(self, data):
3✔
266
        data = self.floor.apply_static(data)
3✔
267
        return data
3✔
268

269
    def pull_correct_dynamic(self, data, params):
3✔
270
        data = self.floor.apply_dynamic(data, params)
×
271
        return data
×
272

273
    def create_spatial_cache(self, cut_data, SoB_pdf):
3✔
274
        if len(inspect.getfullargspec(SoB_pdf)[0]) == 2:
3✔
275
            SoB = dict()
3✔
276
            for gamma in get_gamma_support_points(precision=self.precision):
3✔
277
                SoB[gamma] = np.log(SoB_pdf(cut_data, gamma))
3✔
278
        else:
279
            SoB = SoB_pdf(cut_data)
3✔
280
        return SoB
3✔
281

282
    def estimate_spatial(self, gamma, spatial_cache):
3✔
283
        if isinstance(spatial_cache, dict):
3✔
284
            return self.estimate_spatial_dynamic(gamma, spatial_cache)
3✔
285
        else:
286
            return spatial_cache
3✔
287

288
    def estimate_spatial_dynamic(self, gamma, spatial_cache):
3✔
289
        """Quickly estimates the value of pull for Gamma.
290
        Uses pre-calculated values for first and second derivatives.
291
        Uses a Taylor series to estimate S(gamma), unless pull has already
292
        been calculated for a given gamma.
293

294
        :param gamma: Spectral Index
295
        :param spatial_cache: Median Pull cache
296
        :return: Estimated value for S(gamma)
297
        """
298
        if gamma in list(spatial_cache.keys()):
3✔
299
            val = np.exp(spatial_cache[gamma])
3✔
300
            # val = spatial_cache[gamma]
301
        else:
302
            g1 = _around(gamma, self.precision)
3✔
303
            dg = get_gamma_precision(self.precision)
3✔
304

305
            g0 = _around(g1 - dg, self.precision)
3✔
306
            g2 = _around(g1 + dg, self.precision)
3✔
307

308
            # Uses Numexpr to quickly estimate S(gamma)
309

310
            S0 = spatial_cache[g0]
3✔
311
            S1 = spatial_cache[g1]
3✔
312
            S2 = spatial_cache[g2]
3✔
313

314
            val = numexpr.evaluate(
3✔
315
                "exp((S0 - 2.*S1 + S2) / (2. * dg**2) * (gamma - g1)**2"
316
                + " + (S2 -S0) / (2. * dg) * (gamma - g1) + S1)"
317
            )
318
            # val = numexpr.evaluate(
319
            #     "((S0 - 2.*S1 + S2) / (2. * dg**2) * (gamma - g1)**2" + \
320
            #     " + (S2 -S0) / (2. * dg) * (gamma - g1) + S1)"
321
            # )
322

323
        return val
3✔
324

325

326
@BaseAngularErrorModifier.register_subclass("no_pull")
3✔
327
class NoPull(BaseAngularErrorModifier):
3✔
328
    pass
3✔
329

330

331
class BaseMedianAngularErrorModifier(BaseAngularErrorModifier):
3✔
332
    def __init__(self, pull_dict):
3✔
333
        BaseAngularErrorModifier.__init__(self, pull_dict)
×
334

335
        if not os.path.isfile(self.pull_name):
×
336
            self.create_pickle()
×
337
        else:
338
            logger.debug("Loading from".format(self.pull_name))
×
339

340
        with open(self.pull_name, "r") as f:
×
341
            self.pickled_data = Pickle.load(f)
×
342

343
    def pull_correct(self, f, data):
3✔
344
        data["sigma"] = np.exp(f(data)) * data["raw_sigma"]
×
345
        return data
×
346

347
    def create_pickle(self):
3✔
348
        pass
×
349

350
    def create_static(self):
3✔
351
        return lambda data: np.array([1.0 for _ in data])
×
352

353
    def create_dynamic(self, pickled_array):
3✔
354
        return lambda data: np.array([1.0 for _ in data])
×
355

356

357
class StaticMedianPullCorrector(BaseMedianAngularErrorModifier):
3✔
358
    def __init__(self, pull_dict):
3✔
359
        BaseMedianAngularErrorModifier.__init__(self, pull_dict)
×
360

361
        self.static_f = self.create_static()
×
362

363
    def pull_correct_static(self, data):
3✔
364
        data = self.floor.apply_static(data)
×
365
        data = self.pull_correct(self.static_f, data)
×
366

367
        data["raw_sigma"] = data["sigma"]
×
368

369
        return data
×
370

371

372
class DynamicMedianPullCorrector(BaseMedianAngularErrorModifier):
3✔
373
    def __init__(self, pull_dict):
3✔
374
        BaseAngularErrorModifier.__init__(self, pull_dict)
×
375

376
    def estimate_spatial(self, gamma, spatial_cache):
3✔
377
        return self.estimate_spatial_dynamic(gamma, spatial_cache)
×
378

379
    def pull_correct_dynamic(self, data, param):
3✔
380
        data = self.floor.apply_dynamic(data)
×
381
        f = self.create_dynamic(self.pickled_data[param])
×
382
        data["sigma"] = np.exp(f(data)) * data["raw_sigma"]
×
383
        return data
×
384

385
    def create_spatial_cache(self, cut_data, SoB_pdf):
3✔
386
        """Evaluates the median pull values for all coincidentdata. For each
387
        value of gamma in self.gamma_support_points, calculates
388
        the Log(Signal/Background) values for the coincident data. Then saves
389
        each weight array to a dictionary.
390

391
        :param cut_data: Subset of the data containing only coincident events
392
        :return: Dictionary containing SoB values for each event for each
393
        gamma value.
394
        """
395

396
        spatial_cache = dict()
×
397

398
        for key in sorted(self.pickled_data.keys()):
×
399
            cut_data = self.pull_correct_dynamic(cut_data, key)
×
400

401
            # If gamma is needed to evaluate spatial PDF (say because you
402
            # have overlapping PDFs and you need to know the weights,
403
            # then you pass the key. Otherwise just evaluate as normal.
404

405
            if len(inspect.getargspec(SoB_pdf)[0]) == 2:
×
406
                SoB = SoB_pdf(cut_data, key)
×
407
            else:
408
                SoB = SoB_pdf(cut_data)
×
409

410
            spatial_cache[key] = np.log(SoB)
×
411

412
        return spatial_cache
×
413

414

415
@BaseAngularErrorModifier.register_subclass("median_0d_e")
3✔
416
class MedianPullEParam0D(DynamicMedianPullCorrector):
3✔
417
    def create_pickle(self):
3✔
418
        create_pull_0d_e(self.pull_dict)
×
419

420
    def create_dynamic(self, pickled_array):
3✔
421
        return lambda data: np.array([pickled_array for _ in data])
×
422

423

424
@BaseAngularErrorModifier.register_subclass("median_1d")
3✔
425
class MedianPull1D(StaticMedianPullCorrector):
3✔
426
    def create_pickle(self):
3✔
427
        create_pull_1d(self.pull_dict)
×
428

429
    def create_static(self):
3✔
430
        func = interp1d(self.pickled_data[0], self.pickled_data[1])
×
431
        return lambda data: func(data["logE"])
×
432

433

434
@BaseAngularErrorModifier.register_subclass("median_1d_e")
3✔
435
class MedianPullEParam1D(DynamicMedianPullCorrector):
3✔
436
    def create_pickle(self):
3✔
437
        create_pull_1d_e(self.pull_dict)
×
438

439
    def create_dynamic(self, pickled_array):
3✔
440
        func = interp1d(pickled_array[0], pickled_array[1])
×
441
        return lambda data: func(data["logE"])
×
442

443

444
@BaseAngularErrorModifier.register_subclass("median_2d")
3✔
445
class MedianPull2D(StaticMedianPullCorrector):
3✔
446
    def create_pickle(self):
3✔
447
        create_pull_2d(self.pull_dict)
×
448

449
    def create_static(self):
3✔
450
        func = RectBivariateSpline(
×
451
            self.pickled_data[0], self.pickled_data[1], self.pickled_data[2]
452
        )
453

454
        return lambda data: [func(x["logE"], x["sinDec"])[0][0] for x in data]
×
455
        # return lambda data:
456

457

458
@BaseAngularErrorModifier.register_subclass("median_2d_e")
3✔
459
class MedianPullEParam2D(DynamicMedianPullCorrector):
3✔
460
    def create_pickle(self):
3✔
461
        create_pull_2d_e(self.pull_dict)
×
462

463
    def create_dynamic(self, pickled_array):
3✔
464
        func = RectBivariateSpline(pickled_array[0], pickled_array[1], pickled_array[2])
×
465

466
        return lambda data: func(data["logE"], data["sinDec"])
×
467

468

469
if __name__ == "__main__":
470
    from flarestack.data.icecube.ps_tracks.ps_v002_p01 import IC86_1_dict
471
    from flarestack.analyses.angular_error_floor.plot_bias import (
472
        get_data,
473
        weighted_quantile,
474
    )
475
    from scipy.stats import norm
476

477
    print(norm.cdf(1.0))
478

479
    def symmetric_gauss(sigma):
480
        return 1 - 2 * norm.sf(sigma)
481

482
    def gauss_2d(sigma):
483
        # return symmetric_gauss(sigma) ** 2
484
        return symmetric_gauss(sigma)
485

486
    print(symmetric_gauss(1.0))
487
    print(gauss_2d(1.0))
488
    print(gauss_2d(1.177))
489

490
    e_pdf_dict = {"Name": "Power Law", "Gamma": 3.0}
491

492
    e_pdf = EnergyPDF.create(e_pdf_dict)
493

494
    pc = BaseAngularErrorModifier.create(
495
        IC86_1_dict, e_pdf_dict, "no_floor", "median_2d"
496
    )
497
    mc, x, y = get_data(IC86_1_dict)[:10]
498

499
    pulls = x / y
500

501
    weights = e_pdf.weight_mc(mc)
502

503
    median_pull = weighted_quantile(pulls, 0.5, weights)
504

505
    def med_pull(data):
506
        y = np.degrees(data["sigma"])
507
        pulls = x / y
508
        med = weighted_quantile(pulls, 0.5, weights)
509
        return med
510

511
    print(mc["sigma"][:5])
512

513
    mc = pc.pull_correct_static(mc)
514

515
    print(mc["sigma"][:5])
516

517
    print(med_pull(mc))
518

519
    print(median_pull)
520

521
# @BaseAngularErrorModifier.register_subclass('static_pull_corrector')
522
# class Static1DPullCorrector(BaseAngularErrorModifier):
523
#
524
#     def __init__(self, season, e_pdf_dict, **kwargs):
525
#         BaseAngularErrorModifier.__init__(self, season, e_pdf_dict)
526

527

528
# x = BaseFloorClass.create(IC86_1_dict, e_pdf_dict, "quantile_floor_1d_e",
529
#                           floor_quantile=0.1)
530
# for gamma in np.linspace(1.0, 4.0, 4):
531
#     print data_loader(IC86_1_dict["exp_path"])["logE"][:8]
532
#     print np.degrees(x.f(data_loader(IC86_1_dict["exp_path"])[:8], [gamma]))
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc