• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 4960685499

pending completion
4960685499

Pull #34

github

GitHub
Merge 7b7c1d7ba into d1de71ffa
Pull Request #34: Preparation for JOSS submission

184 of 184 new or added lines in 28 files covered. (100.0%)

1071 of 2324 relevant lines covered (46.08%)

0.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

12.92
/coolest/api/util.py
1
__author__ = 'aymgal'
1✔
2

3

4
import os
1✔
5
import numpy as np
1✔
6
# from astropy.coordinates import SkyCoord
7

8
from coolest.template.json import JSONSerializer
1✔
9

10

11
def convert_image_to_data_units(image, pixel_size, mag_tot, mag_zero_point):
1✔
12
    """
13
    Rescale an image so that it has units of electrons per second (e/s),
14
    which is the default data units in COOLEST.
15
    :param pixel_size: pixel size (in arcsec) of the image
16
    :param image: input image (whatever units)
17
    :param mag_tot: target total magnitude, integrated over the whole image
18
    :param mag_zero_point: magnitude zero point of the observation (magnitude that corresponds to 1 e/s)
19
    """
20
    pixel_area = pixel_size**2
×
21
    flux_tot = np.sum(image) * pixel_area
×
22
    image_unit_flux = image / flux_tot
×
23
    delta_mag = mag_tot - mag_zero_point
×
24
    flux_unit_mag = 10 ** ( - delta_mag / 2.5 )
×
25
    return image_unit_flux * flux_unit_mag
×
26

27

28
def get_coolest_object(file_path, verbose=False, **kwargs_serializer):
1✔
29
    if not os.path.isabs(file_path):
×
30
        file_path = os.path.abspath(file_path)
×
31
    serializer = JSONSerializer(file_path, **kwargs_serializer)
×
32
    return serializer.load(verbose=verbose)
×
33

34

35
def get_coordinates(coolest_object, offset_x=0., offset_y=0.):
1✔
36
    from coolest.api.coordinates import Coordinates  # prevents circular import errors
×
37
    nx, ny = coolest_object.observation.pixels.shape
×
38
    pix_scl = coolest_object.instrument.pixel_size
×
39
    half_size_x, half_size_y = nx * pix_scl / 2., ny * pix_scl / 2.
×
40
    x_at_ij_0  = - half_size_x + pix_scl / 2.  # position of x=0 with respect to bottom left pixel
×
41
    y_at_ij_0 = - half_size_y + pix_scl / 2.  # position of y=0 with respect to bottom left pixel
×
42
    matrix_pix2ang = pix_scl * np.eye(2)  # transformation matrix pixel <-> angle
×
43
    return Coordinates(nx, ny, matrix_ij_to_xy=matrix_pix2ang,
×
44
                       x_at_ij_0=x_at_ij_0 + offset_x, 
45
                       y_at_ij_0=y_at_ij_0 + offset_y)
46

47

48
def get_coordinates_from_regular_grid(field_of_view_x, field_of_view_y, num_pix_x, num_pix_y):
1✔
49
    from coolest.api.coordinates import Coordinates  # prevents circular import errors
1✔
50
    pix_scl_x = np.abs(field_of_view_x[0] - field_of_view_x[1]) / num_pix_x
1✔
51
    pix_scl_y = np.abs(field_of_view_y[0] - field_of_view_y[1]) / num_pix_y
1✔
52
    matrix_pix2ang = np.array([[pix_scl_x, 0.], [0., pix_scl_y]])
1✔
53
    x_at_ij_0  = field_of_view_x[0] + pix_scl_x / 2.
1✔
54
    y_at_ij_0 = field_of_view_y[0] + pix_scl_y / 2.
1✔
55
    return Coordinates(
1✔
56
        num_pix_x, num_pix_y, matrix_ij_to_xy=matrix_pix2ang,
57
        x_at_ij_0=x_at_ij_0, y_at_ij_0=y_at_ij_0,
58
    )
59

60

61
def get_coordinates_set(coolest_file_list, reference=0):
1✔
62
    coordinates_list = []
×
63
    for coolest_file in coolest_file_list:
×
64

65
        # TODO: compute correct offsets when each file has
66
        # obs = self.coolest.observation
67
        # sky_coord = SkyCoord(obs.ra, obs.dec, frame='icrs')
68
        # ra, dec = sky_coord.to_string(style='hmsdms').split(' ')
69

70
        coordinates = get_coordinates(coolest_file)
×
71
        coordinates_list.append(coordinates)
×
72
    return coordinates_list
×
73

74

75
def array2image(array, nx=0, ny=0):
1✔
76
    """Convert a 1d array into a 2d array.
77

78
    Note: this only works when length of array is a perfect square, or else if
79
    nx and ny are provided
80

81
    :param array: image values
82
    :type array: array of size n**2
83
    :returns:  2d array
84
    :raises: AttributeError, KeyError
85
    """
86
    if nx == 0 or ny == 0:
1✔
87
        # Avoid turning n into a JAX-traced object with jax.numpy.sqrt
88
        n = int(math.sqrt(len(array)))
×
89
        if n**2 != len(array):
×
90
            err_msg = f"Input array size {len(array)} is not a perfect square."
×
91
            raise ValueError(err_msg)
×
92
        nx, ny = n, n
×
93
    image = array.reshape(int(nx), int(ny))
1✔
94
    return image
1✔
95

96

97
def image2array(image):
1✔
98
    """Convert a 2d array into a 1d array.
99

100
    :param array: image values
101
    :type array: array of size (n,n)
102
    :returns:  1d array
103
    :raises: AttributeError, KeyError
104
    """
105
    # nx, ny = image.shape  # find the size of the array
106
    # imgh = np.reshape(image, nx * ny)  # change the shape to be 1d
107
    # return imgh
108
    return image.ravel()
×
109

110

111
def downsampling(image, factor=1):
1✔
112
    if factor < 1:
×
113
        raise ValueError(f"Downscaling factor must be > 1")
×
114
    if factor == 1:
×
115
        return image
×
116
    f = int(factor)
×
117
    nx, ny = np.shape(image)
×
118
    if int(nx/f) == nx/f and int(ny/f) == ny/f:
×
119
        down = image.reshape([int(nx/f), f, int(ny/f), f]).mean(3).mean(1)
×
120
        return down
×
121
    else:
122
        raise ValueError(f"Downscaling factor {factor} is not possible with shape ({nx}, {ny})")
×
123

124

125
def read_json_param(file_list, file_names, lens_light=False):
1✔
126
    """
127
    Read several json files already containing a model with the results of this model fitting
128

129
    INPUT
130
    -----
131
    file_list: list, list of path or names of the file to read
132
    file_names: list, list of shorter names to distinguish the files
133
    lens_light: bool, if True, computes the lens light kwargs as well (not yet implemented)
134

135
    OUTPUT
136
    ------
137
     param_all_lens, param_all_source: organized dictionnaries readable by plotting function
138
    """
139

140
    param_all_lens = {}
×
141
    param_all_source = {}
×
142
    for idx_file, file_name in enumerate(file_list):
×
143

144
        print(file_names[idx_file])
×
145
        decoder = JSONSerializer(file_name, indent=2)
×
146
        lens_coolest = decoder.load()
×
147

148
        if lens_coolest.mode == 'MAP':
×
149
            print('LENS COOLEST : ', lens_coolest.mode)
×
150
        else:
151
            print('LENS COOLEST IS NOT MAP, BUT IS ', lens_coolest.mode)
×
152

153
        lensing_entities_list = lens_coolest.lensing_entities
×
154

155
        param_lens = {}
×
156
        param_source = {}
×
157

158
        if lensing_entities_list is not None:
×
159

160
            creation_lens_source_light = True
×
161

162
            idx_lens = 0
×
163
            idx_lens_light = 0
×
164
            idx_source = 0
×
165
            idx_ps = 0
×
166

167
            min_red = 0
×
168
            max_red = 5
×
169
            creation_red_list = True
×
170
            red_list = []
×
171
            MultiPlane = False
×
172
            for lensing_entity in lensing_entities_list:
×
173
                red_list.append(lensing_entity.redshift)
×
174
            min_red = np.min(red_list)
×
175
            max_red = np.max(red_list)
×
176

177
            for lensing_entity in lensing_entities_list:
×
178
                if lensing_entity.type == "galaxy":
×
179
                    galac = lensing_entity
×
180

181
                    if galac.redshift > min_red:
×
182
                        # SOURCE OF LIGHT
183
                        light_list = galac.light_model
×
184
                        for light in light_list:
×
185

186
                            if light.type == 'Sersic':
×
187
                                read_sersic(light, param_source)
×
188
                                idx_source += 1
×
189
                            else:
190
                                print('Light Type ', light.type, ' not yet implemented.')
×
191

192
                    if galac.redshift < max_red:
×
193
                        # LENSING GALAXY
194
                        if galac.redshift > min_red:
×
195
                            MultiPlane = True
×
196
                            print('Multiplane lensing to consider.')
×
197
                        mass_list = galac.mass_model
×
198
                        for mass in mass_list:
×
199

200
                            if mass.type == 'PEMD':
×
201
                                read_pemd(mass, param_lens)
×
202
                                idx_lens += 1
×
203
                            elif mass.type == 'SIE':
×
204
                                read_sie(mass, param_lens)
×
205
                                idx_lens += 1
×
206
                            else:
207
                                print('Mass Type ', mass.type, ' not yet implemented.')
×
208

209
                    if (galac.redshift <= min_red) and (galac.redshift >= max_red):
×
210
                        print('REDSHIFT ', galac.redshift, ' is not in the range ]', min_red, ',', max_red, '[')
×
211

212
                elif lensing_entity.type == "external_shear":
×
213
                    shear_list = lensing_entity.mass_model
×
214
                    for shear_idx in shear_list:
×
215

216
                        if shear_idx.type == 'ExternalShear':
×
217
                            read_shear(shear_idx, param_lens)
×
218
                            idx_lens += 1
×
219
                        else:
220
                            print("type of Shear ", shear_idx.type, " not implemented")
×
221
                else:
222
                    print("Lensing entity of type ", lensing_enity.type, " is unknown.")
×
223

224
        param_all_lens[file_names[idx_file]] = param_lens
×
225
        param_all_source[file_names[idx_file]] = param_source
×
226

227
    return param_all_lens, param_all_source
×
228

229

230
def read_shear(mass, param={}, prefix='SHEAR_0_'):
1✔
231
    """
232
    Reads the parameters of a coolest.template.classes.profiles.mass.ExternalShear object
233

234
    INPUT
235
    -----
236
    mass : coolest.template.classes.profiles.mass.ExternalShear object
237
    param : dict, already existing dictionnary with ordered parameters readable by plotting function
238
    prefix : str, prefix to use in saving parameters names
239

240
    OUTPUT
241
    ------
242
    param : updated param
243
    """
244

245
    for mass_name, mass_param in mass.parameters.items():
×
246
        p = getattr(mass_param.point_estimate, 'value')
×
247
        p_84 = getattr(mass_param.posterior_stats, 'percentile_84th')
×
248
        p_16 = getattr(mass_param.posterior_stats, 'percentile_16th')
×
249
        p_med = getattr(mass_param.posterior_stats, 'median')
×
250
        p_mean = getattr(mass_param.posterior_stats, 'mean')
×
251
        latex_name = getattr(mass_param, 'latex_str')
×
252
        if mass_name == 'gamma_ext':
×
253
            param[prefix + 'gamma_ext'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
254
                                           'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
255
        elif mass_name == 'phi_ext':
×
256
            param[prefix + 'phi_ext'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
257
                                         'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
258
        else:
259
            print(shear_name, " not known")
×
260
    print('\t Shear correctly added')
×
261

262
    return param
×
263

264

265
def read_pemd(mass, param={}, prefix='PEMD_0_'):
1✔
266
    """
267
    Reads the parameters of a coolest.template.classes.profiles.mass.PEMD object
268

269
    INPUT
270
    -----
271
    mass : coolest.template.classes.profiles.mass.PEMD object
272
    param : dict, already existing dictionnary with ordered parameters readable by plotting function
273
    prefix : str, prefix to use in saving parameters names
274

275
    OUTPUT
276
    ------
277
    param : updated param
278
    """
279

280
    for mass_name, mass_param in mass.parameters.items():
×
281
        p = getattr(mass_param.point_estimate, 'value')
×
282
        p_84 = getattr(mass_param.posterior_stats, 'percentile_84th')
×
283
        p_16 = getattr(mass_param.posterior_stats, 'percentile_16th')
×
284
        p_med = getattr(mass_param.posterior_stats, 'median')
×
285
        p_mean = getattr(mass_param.posterior_stats, 'mean')
×
286
        latex_name = getattr(mass_param, 'latex_str')
×
287
        if mass_name == 'theta_E':
×
288
            param[prefix + 'theta_E'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
289
                                         'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
290
        elif mass_name == 'q':
×
291
            param[prefix + 'q'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
292
                                   'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
293
        elif mass_name == 'phi':
×
294
            param[prefix + 'phi'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
295
                                     'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
296
        elif mass_name == 'center_x':
×
297
            param[prefix + 'cx'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
298
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
299
        elif mass_name == 'center_y':
×
300
            param[prefix + 'cy'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
301
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
302
        elif mass_name == 'gamma':
×
303
            param[prefix + 'gamma'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
304
                                       'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
305
        else:
306
            print(mass_name, " not known")
×
307

308
    print('\t PEMD correctly added')
×
309

310
    return param
×
311

312

313
def read_sie(mass, param={}, prefix='SIE_0_'):
1✔
314
    """
315
    Reads the parameters of a coolest.template.classes.profiles.mass.SIE object
316

317
    INPUT
318
    -----
319
    mass : coolest.template.classes.profiles.mass.SIE object
320
    param : dict, already existing dictionnary with ordered parameters readable by plotting function
321
    prefix : str, prefix to use in saving parameters names
322

323
    OUTPUT
324
    ------
325
    param : updated param
326
    """
327

328
    for mass_name, mass_param in mass.parameters.items():
×
329
        p = getattr(mass_param.point_estimate, 'value')
×
330
        p_84 = getattr(mass_param.posterior_stats, 'percentile_84th')
×
331
        p_16 = getattr(mass_param.posterior_stats, 'percentile_16th')
×
332
        p_med = getattr(mass_param.posterior_stats, 'median')
×
333
        p_mean = getattr(mass_param.posterior_stats, 'mean')
×
334
        latex_name = getattr(mass_param, 'latex_str')
×
335
        if mass_name == 'theta_E':
×
336
            param[prefix + 'theta_E'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
337
                                         'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
338
        elif mass_name == 'q':
×
339
            param[prefix + 'q'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
340
                                   'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
341
        elif mass_name == 'phi':
×
342
            param[prefix + 'phi'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
343
                                     'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
344
        elif mass_name == 'center_x':
×
345
            param[prefix + 'cx'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
346
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
347
        elif mass_name == 'center_y':
×
348
            param[prefix + 'cy'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
349
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
350
        else:
351
            print(mass_name, " not known")
×
352

353
    print('\t SIE correctly added')
×
354

355
    return param
×
356

357

358
def read_sersic(light, param={}, prefix='Sersic_0_'):
1✔
359
    """
360
    Reads the parameters of a coolest.template.classes.profiles.light.Sersic object
361

362
    INPUT
363
    -----
364
    mass : coolest.template.classes.profiles.light.Sersic object
365
    param : dict, already existing dictionnary with ordered parameters readable by plotting function
366
    prefix : str, prefix to use in saving parameters names
367

368
    OUTPUT
369
    ------
370
    param : updated param
371
    """
372

373
    for light_name, light_param in light.parameters.items():
×
374
        p = getattr(light_param.point_estimate, 'value')
×
375
        p_84 = getattr(light_param.posterior_stats, 'percentile_84th')
×
376
        p_16 = getattr(light_param.posterior_stats, 'percentile_16th')
×
377
        p_med = getattr(light_param.posterior_stats, 'median')
×
378
        p_mean = getattr(light_param.posterior_stats, 'mean')
×
379
        latex_name = getattr(light_param, 'latex_str')
×
380
        if light_name == 'I_eff':
×
381
            param[prefix + 'A'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
382
                                   'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
383
        elif light_name == 'n':
×
384
            param[prefix + 'n_sersic'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
385
                                          'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
386
        elif light_name == 'theta_eff':
×
387
            param[prefix + 'R_sersic'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
388
                                          'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
389
        elif light_name == 'q':
×
390
            param[prefix + 'q'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
391
                                   'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
392
        elif light_name == 'phi':
×
393
            param[prefix + 'phi'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
394
                                     'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
395
        elif light_name == 'center_x':
×
396
            param[prefix + 'cx'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
397
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
398
        elif light_name == 'center_y':
×
399
            param[prefix + 'cy'] = {'point_estimate': p, 'percentile_84th': p_84, 'percentile_16th': p_16,
×
400
                                    'median': p_med, 'mean': p_mean, 'latex_str': latex_name}
401
        else:
402
            print(light_name, " not known")
×
403

404
    print('\t Sersic correctly added')
×
405

406
    return param
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc