• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 8240860511

11 Mar 2024 11:02PM UTC coverage: 49.76% (-0.1%) from 49.878%
8240860511

push

github

aymgal
Add preliminary support for evaluating composable model functions over posterior samples

35 of 76 new or added lines in 1 file covered. (46.05%)

34 existing lines in 1 file now uncovered.

1450 of 2914 relevant lines covered (49.76%)

0.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

43.48
/coolest/api/composable_models.py
1
__author__ = 'aymgal'
1✔
2

3
import os
1✔
4
import numpy as np
1✔
5
import math
1✔
6
import logging
1✔
7
from scipy import signal
1✔
8
import pandas as pd
1✔
9
from functools import partial
1✔
10

11
from coolest.api import util
1✔
12

13

14
# logging settings
15
logging.getLogger().setLevel(logging.WARNING)
1✔
16

17

18
class BaseComposableModel(object):
1✔
19
    """Given a COOLEST object, evaluates a selection of mass or light profiles.
1✔
20
    This class serves as parent for more specific classes and should not be 
21
    instantiated by the user.
22

23
    Parameters
24
    ----------
25
    model_type : str
26
        Either 'light_model' or 'mass_model'
27
    coolest_object : COOLEST
28
        COOLEST instance
29
    coolest_directory : str, optional
30
        Directory which contains the COOLEST template, by default None
31
    entity_selection : list, optional
32
        List of indices of the lensing entities to consider; If None, 
33
        selects the first entity which has a model of type model_type, by default None
34
    profile_selection : list, optional
35
        List of either lists of indices, or 'all', for selecting which (mass or light) profile 
36
        of a given lensing entity to consider. If None, selects all the 
37
        profiles of within the corresponding entity, by default None
38

39
    Raises
40
    ------
41
    ValueError
42
        No valid entity found or no profiles found.
43
    """
44
        
45
    _chain_key = "chain_file_name"
1✔
46
    _supported_eval_modes = ('point', 'posterior')
1✔
47

48
    def __init__(self, model_type, 
1✔
49
                 coolest_object, coolest_directory=None, 
50
                 load_posterior_samples=False,
51
                 entity_selection=None, profile_selection=None):
52
        if entity_selection is None:
1✔
53
            # finds the first entity that has a 'model_type' profile
54
            entity_selection = None
×
55
            for i, entity in enumerate(coolest_object.lensing_entities):
×
56
                if model_type == 'light_model' \
×
57
                    and entity.type == 'galaxy' \
58
                    and len(entity.light_model) > 0:
59
                    entity_selection = [i]
×
60
                    break
×
61
                elif model_type == 'mass_model' \
×
62
                    and len(entity.mass_model) > 0:
63
                    entity_selection = [i]
×
64
                    break
×
65
            if entity_selection is None:
×
66
                raise ValueError("No lensing entity with light profiles have been found")
×
67
            else:
NEW
68
                logging.warning(f"Found valid profile for lensing entity (index {i}) for model type '{model_type}'")
×
69
        if profile_selection is None:
1✔
70
            profile_selection = 'all'
×
71
        entities = coolest_object.lensing_entities
1✔
72
        self.directory = coolest_directory
1✔
73
        self._load_samples, self._csv_path = False, None
1✔
74
        if load_posterior_samples:
1✔
NEW
75
            metadata = coolest_object.meta
×
NEW
76
            if self._chain_key not in metadata:
×
NEW
77
                logging.warning(f"Metadata key '{self._chain_key}' is missing "
×
78
                                f"from COOLEST template, hence no posterior samples "
79
                                f"will be loaded.")
80
            else:
NEW
81
                self._load_samples = True
×
NEW
82
                self._csv_path = os.path.join(self.directory, metadata[self._chain_key])
×
83
        self.profile_list, self.param_list, self.post_param_list, self.info_list \
1✔
84
            = self.get_profiles_and_params(model_type, entities, 
85
                                           entity_selection, profile_selection)
86
        self.num_profiles = len(self.profile_list)
1✔
87
        if self.num_profiles == 0:
1✔
88
            raise ValueError("No profile has been selected!")
×
89

90
    def get_profiles_and_params(self, model_type, entities, 
1✔
91
                                entity_selection, profile_selection):
92
        profile_list = []
1✔
93
        param_list, post_param_list = [], []
1✔
94
        info_list = []
1✔
95
        for i, entity in enumerate(entities):
1✔
96
            if self._selected(i, entity_selection):
1✔
97
                if model_type == 'light_model' and entity.type == 'external_shear':
1✔
98
                    raise ValueError(f"External shear (entity index {i}) has no light model")
×
99
                for j, profile in enumerate(getattr(entity, model_type)):
1✔
100
                    if self._selected(j, profile_selection):
1✔
101
                        if 'Grid' in profile.type:
1✔
NEW
102
                            if self.directory is None:
×
103
                                raise ValueError("The directory in which the COOLEST file is located "
×
104
                                                 "must be provided for loading FITS files.")
NEW
105
                            params, fixed_params = self._get_grid_params(profile, self.directory)
×
NEW
106
                            profile_list.append(self._get_api_profile(model_type, profile, *fixed_params))
×
NEW
107
                            post_params = None  # TODO: support samples for grid parameters
×
108
                        else:
109
                            params, post_params = self._get_regular_params(
1✔
110
                                profile, samples_file_path=self._csv_path
111
                            )
112
                            profile_list.append(self._get_api_profile(model_type, profile))
1✔
113
                        param_list.append(params)
1✔
114
                        post_param_list.append(post_params)
1✔
115
                        info_list.append((entity.name, entity.redshift))
1✔
116
        post_param_list = self._reorganize_post_list(post_param_list)
1✔
117
        return profile_list, param_list, post_param_list, info_list
1✔
118

119
    def estimate_center(self):
1✔
120
        # TODO: improve this (for now simply considers the first profile that has a center)
121
        for profile, params in zip(self.profile_list, self.param_list):
1✔
122
            if 'center_x' in params:
1✔
123
                center_x = params['center_x']
1✔
124
                center_y = params['center_y']
1✔
125
                logging.info(f"Picked center from profile '{profile.type}'")
1✔
126
                return center_x, center_y
1✔
127
        raise ValueError("Could not estimate a center from the composed model")
×
128

129
    @staticmethod
1✔
130
    def _get_api_profile(model_type, profile_in, *extra_profile_args):
1✔
131
        """
132
        Takes as input a light profile from the template submodule
133
        and instantites the corresponding profile from the API submodule
134
        """
135
        if model_type == 'light_model':
1✔
136
            from coolest.api.profiles import light
1✔
137
            ProfileClass = getattr(light, profile_in.type)
1✔
138
        elif model_type == 'mass_model':
1✔
139
            from coolest.api.profiles import mass
1✔
140
            ProfileClass = getattr(mass, profile_in.type)
1✔
141
        return ProfileClass(*extra_profile_args)
1✔
142

143
    @staticmethod
1✔
144
    def _get_regular_params(profile_in, samples_file_path=None):
1✔
145
        parameters = {}  # best-fit values
1✔
146
        samples = {} if samples_file_path else None  # posterior samples
1✔
147
        for name, param in profile_in.parameters.items():
1✔
148
            parameters[name] = param.point_estimate.value
1✔
149
            if samples is not None:
1✔
150
                # read just the column corresponding to the parameter ID
NEW
151
                column = pd.read_csv(
×
152
                    samples_file_path, 
153
                    usecols=[param.id], 
154
                    delimiter=',',
155
                )
NEW
156
                samples[name] = np.array(column[param.id])
×
157
        return parameters, samples
1✔
158

159
    @staticmethod
1✔
160
    def _get_grid_params(profile_in, fits_dir):
1✔
NEW
161
        param_in = profile_in.parameters['pixels']
×
162
        if profile_in.type == 'PixelatedRegularGrid':
×
NEW
163
            data = param_in.get_pixels(directory=fits_dir)
×
164
            parameters = {'pixels': data}
×
NEW
165
            fov_x = param_in.field_of_view_x
×
NEW
166
            fov_y = param_in.field_of_view_y
×
NEW
167
            npix_x = param_in.num_pix_x
×
NEW
168
            npix_y = param_in.num_pix_y
×
UNCOV
169
            fixed_parameters = (fov_x, fov_y, npix_x, npix_y)
×
170

171
        elif profile_in.type == 'IrregularGrid':
×
NEW
172
            x, y, z = param_in.get_xyz(directory=fits_dir)
×
173
            parameters = {'x': x, 'y': y, 'z': z}
×
NEW
174
            fov_x = param_in.field_of_view_x
×
NEW
175
            fov_y = param_in.field_of_view_y
×
NEW
176
            npix = param_in.num_pix
×
177
            fixed_parameters = (fov_x, fov_y, npix)
×
178
        return parameters, fixed_parameters
×
179
    
180
    @staticmethod
1✔
181
    def _reorganize_post_list(param_list_of_samples):
1✔
182
        """
183
        Takes as input the samples grouped at the leaves of the nested container structure,
184
        and returns a list of items each organized as self.param_list
185
        """
186
        num_profiles = len(param_list_of_samples)
1✔
187
        profile_0 = param_list_of_samples[0]
1✔
188
        if profile_0 is None:  # happens when no samples have been loaded
1✔
189
            return None
1✔
NEW
190
        num_samples = len(profile_0[list(profile_0.keys())[0]])
×
NEW
191
        samples_of_param_list = [
×
192
            [{} for _ in range(num_profiles)] for _ in range(num_samples)
193
        ]
NEW
194
        for i in range(num_samples):
×
NEW
195
            for k in range(num_profiles):
×
NEW
196
                for key in param_list_of_samples[k].keys():
×
NEW
197
                    samples_of_param_list[i][k][key] = param_list_of_samples[k][key][i]
×
NEW
198
        return samples_of_param_list
×
199

200
    @staticmethod
1✔
201
    def _selected(index, selection):
1✔
202
        if isinstance(selection, str) and selection.lower() == 'all':
1✔
203
            return True
1✔
204
        elif isinstance(selection, (list, tuple, np.ndarray)) and index in selection:
1✔
205
            return True
1✔
206
        elif isinstance(selection, (int, float)) and int(selection) == index:
1✔
207
            return True
×
208
        return False
1✔
209

210
    def _check_eval_mode(self, mode):
1✔
NEW
211
        if mode not in self._supported_eval_modes:
×
NEW
212
            raise NotImplementedError(
×
213
                f"Only evaluation modes "
214
                f"{self._supported_eval_modes} are supported "
215
                f"(received '{mode}')."
216
        )
NEW
217
        if mode == 'posterior' and not self._load_samples:
×
NEW
218
            raise ValueError(f"Selected evaluation mode '{mode}' "
×
219
                             f"but samples have not been loaded.")
220

221

222
class ComposableLightModel(BaseComposableModel):
1✔
223
    """Given a COOLEST object, evaluates a selection of entity and their light profiles.
1✔
224

225
    Parameters
226
    ----------
227
    coolest_object : COOLEST
228
        COOLEST instance
229
    coolest_directory : str, optional
230
        Directory which contains the COOLEST template, by default None
231
    entity_selection : list, optional
232
        List of indices of the lensing entities to consider; If None, 
233
        selects the first entity that has a light model, by default None
234
    profile_selection : list, optional
235
        List of either lists of indices, or 'all', for selecting which light profile 
236
        of a given lensing entity to consider. If None, selects all the 
237
        profiles of within the corresponding entity, by default None
238

239
    Raises
240
    ------
241
    ValueError
242
        No valid entity found or no profiles found.
243
    """
244

245
    def __init__(self, coolest_object, coolest_directory=None, **kwargs_selection):
1✔
246
        super().__init__('light_model', coolest_object, 
1✔
247
                         coolest_directory=coolest_directory,
248
                         **kwargs_selection)
249
        pixel_size = coolest_object.instrument.pixel_size
1✔
250
        if pixel_size is None:
1✔
251
            self.pixel_area = 1.
×
252
        else:
253
            self.pixel_area = pixel_size**2
1✔
254

255
    def surface_brightness(self, return_extra=False):
1✔
256
        """Returns the surface brightness as stored in the COOLEST file"""
257
        if self.num_profiles > 1:
×
258
            logging.warning("When more than a single light profile has been selected, "
×
259
                            "the method `surface_brightness()` only considers the first profile")
260
        profile = self.profile_list[0]
×
261
        values = profile.surface_brightness(**self.param_list[0])
×
262
        if return_extra:
×
263
            extent = profile.get_extent()
×
264
            coordinates = profile.get_coordinates()
×
265
            return values, extent, coordinates
×
266
        return values
×
267

268
    def evaluate_surface_brightness(self, x, y):
1✔
269
        """Evaluates the surface brightness at given coordinates"""
270
        image = np.zeros_like(x)
1✔
271
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
1✔
272
            flux_k = profile.evaluate_surface_brightness(x, y, **params)
1✔
273
            if profile.units == 'per_ang':
1✔
274
                flux_k *= self.pixel_area
1✔
275
            image += flux_k
1✔
276
        return image
1✔
277

278

279
class ComposableMassModel(BaseComposableModel):
1✔
280
    """Given a COOLEST object, evaluates a selection of entity and their mass profiles.
1✔
281

282
    Parameters
283
    ----------
284
    coolest_object : COOLEST
285
        COOLEST instance
286
    coolest_directory : str, optional
287
        Directory which contains the COOLEST template, by default None
288
    entity_selection : list, optional
289
        List of indices of the lensing entities to consider; If None, 
290
        selects the first entity that has a mass model, by default None
291
    profile_selection : list, optional
292
        List of either lists of indices, or 'all', for selecting which mass profile 
293
        of a given lensing entity to consider. If None, selects all the 
294
        profiles of within the corresponding entity, by default None
295

296
    Raises
297
    ------
298
    ValueError
299
        No valid entity found or no profiles found.
300
    """
301

302
    def __init__(self, coolest_object, coolest_directory=None, 
1✔
303
                 load_posterior_samples=False,
304
                 **kwargs_selection):
305
        super().__init__('mass_model', coolest_object, 
1✔
306
                         coolest_directory=coolest_directory,
307
                         load_posterior_samples=load_posterior_samples,
308
                         **kwargs_selection)
309

310
    def evaluate_potential(self, x, y, mode='point_estimate'):
1✔
311
        """Evaluates the lensing potential field at given coordinates"""
NEW
312
        self._check_eval_mode(mode)
×
NEW
313
        if mode == 'point':
×
NEW
314
            return self._eval_pot_point(x, y, self.param_list)
×
NEW
315
        return self._eval_pot_posterior(x, y, self.post_param_list)
×
316

317
    def _eval_pot_point(self, x, y, param_list):
1✔
318
        psi = np.zeros_like(x)
×
NEW
319
        for k, profile in enumerate(self.profile_list):
×
NEW
320
            psi += profile.potential(x, y, **param_list[k])
×
321
        return psi
×
322
    
323
    def _eval_pot_posterior(self, x, y, post_param_list):
1✔
NEW
324
        mapped = map(partial(self._eval_pot_point, x, y), post_param_list)
×
NEW
325
        return np.array(list(mapped))
×
326
    
327
    def evaluate_deflection(self, x, y):
1✔
328
        """Evaluates the lensing deflection field at given coordinates"""
UNCOV
329
        alpha_x, alpha_y = np.zeros_like(x), np.zeros_like(x)
×
UNCOV
330
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
UNCOV
331
            a_x, a_y = profile.deflection(x, y, **params)
×
UNCOV
332
            alpha_x += a_x
×
UNCOV
333
            alpha_y += a_y
×
UNCOV
334
        return alpha_x, alpha_y
×
335

336
    def evaluate_convergence(self, x, y):
1✔
337
        """Evaluates the lensing convergence (i.e., 2D mass density) at given coordinates"""
338
        kappa = np.zeros_like(x)
1✔
339
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
1✔
340
            kappa += profile.convergence(x, y, **params)
1✔
341
        return kappa
1✔
342

343
    def evaluate_magnification(self, x, y):
1✔
344
        """Evaluates the lensing magnification at given coordinates"""
345
        H_xx_sum = np.zeros_like(x)
×
346
        H_xy_sum = np.zeros_like(x)
×
347
        H_yx_sum = np.zeros_like(x)
×
348
        H_yy_sum = np.zeros_like(x)
×
349
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
350
            H_xx, H_xy, H_yx, H_yy = profile.hessian(x, y, **params)
×
UNCOV
351
            H_xx_sum += H_xx
×
UNCOV
352
            H_xy_sum += H_xy
×
UNCOV
353
            H_yx_sum += H_yx
×
354
            H_yy_sum += H_yy
×
355
        det_A = (1 - H_xx_sum) * (1 - H_yy_sum) - H_xy_sum*H_yx_sum
×
356
        mu = 1. / det_A
×
UNCOV
357
        return mu
×
358

359
    def ray_shooting(self, x, y):
1✔
360
        """evaluates the lens equation beta = theta - alpha(theta)"""
UNCOV
361
        alpha_x, alpha_y = self.evaluate_deflection(x, y)
×
UNCOV
362
        x_rs, y_rs = x - alpha_x, y - alpha_y
×
UNCOV
363
        return x_rs, y_rs
×
364

365

366
class ComposableLensModel(object):
1✔
367
    """Given a COOLEST object, evaluates a selection of entity and 
1✔
368
    their mass and light profiles, typically to construct an image of the lens.
369

370
    Parameters
371
    ----------
372
    coolest_object : COOLEST
373
        COOLEST instance
374
    coolest_directory : str, optional
375
        Directory which contains the COOLEST template, by default None
376
    entity_selection : list, optional
377
        List of indices of the lensing entities to consider; If None, 
378
        selects the first entity that has a light/mass model, by default None
379
    profile_selection : list, optional
380
        List of either lists of indices, or 'all', for selecting which light/mass profile 
381
        of a given lensing entity to consider. If None, selects all the 
382
        profiles of within the corresponding entity, by default None
383

384
    Raises
385
    ------
386
    ValueError
387
        No valid entity found or no profiles found.
388
    """
389

390
    def __init__(self, coolest_object, coolest_directory=None, 
1✔
391
                 kwargs_selection_source=None, kwargs_selection_lens_mass=None):
392
        self.coolest = coolest_object
×
UNCOV
393
        self.coord_obs = util.get_coordinates(self.coolest)
×
UNCOV
394
        self.directory = coolest_directory
×
395
        if kwargs_selection_source is None:
×
UNCOV
396
            kwargs_selection_source = {}
×
UNCOV
397
        if kwargs_selection_lens_mass is None:
×
UNCOV
398
            kwargs_selection_lens_mass = {}
×
UNCOV
399
        self.lens_mass = ComposableMassModel(coolest_object, 
×
400
                                             coolest_directory,
401
                                             **kwargs_selection_lens_mass)
402
        self.source = ComposableLightModel(coolest_object, 
×
403
                                          coolest_directory,
404
                                          **kwargs_selection_source)
405

406
    def model_image(self, supersampling=5, convolved=True, super_convolution=True):
1✔
407
        """generates an image of the lens based on the selected model components"""
408
        obs = self.coolest.observation
×
409
        psf = self.coolest.instrument.psf
×
410
        if convolved is True and psf.type == 'PixelatedPSF':
×
411
            scale_factor = obs.pixels.pixel_size / psf.pixels.pixel_size
×
412
            supersampling_conv = int(round(scale_factor))
×
413
            if not math.isclose(scale_factor, supersampling_conv):
×
414
                raise ValueError(f"PSF supersampling ({scale_factor}) not close to an integer?")
×
415
            if supersampling_conv < 1:
×
416
                raise ValueError("PSF pixel size smaller than data pixel size")
×
417
        if supersampling < 1:
×
418
            raise ValueError("Supersampling must be >= 1")
×
419
        if convolved is True and supersampling_conv > supersampling:
×
420
            supersampling = supersampling_conv
×
421
            logging.warning(f"Supersampling adapted to the PSF pixel size ({supersampling})")
×
422
        coord_eval = self.coord_obs.create_new_coordinates(pixel_scale_factor=1./supersampling)
×
423
        x, y = coord_eval.pixel_coordinates
×
424
        image = self.evaluate_lensed_surface_brightness(x, y)
×
425
        if convolved is True:
×
UNCOV
426
            if psf.type != 'PixelatedPSF':
×
427
                raise NotImplementedError
×
428
            kernel = psf.pixels.get_pixels(directory=self.directory)
×
429
            kernel_sum = kernel.sum()
×
UNCOV
430
            if not math.isclose(kernel_sum, 1., abs_tol=1e-3):
×
431
                kernel /= kernel_sum
×
UNCOV
432
                logging.warning(f"PSF kernel is not normalized (sum={kernel_sum}), "
×
433
                                f"so it has been normalized before convolution")
434
            if np.isnan(image).any():
×
UNCOV
435
                np.nan_to_num(image, copy=False, nan=0., posinf=None, neginf=None)
×
UNCOV
436
                logging.warning("Found NaN values in image prior to convolution; "
×
437
                                "they have been replaced by zeros.")
438
            if super_convolution and supersampling_conv == supersampling:
×
439
                # first convolve then downscale
440
                image = signal.fftconvolve(image, kernel, mode='same')
×
441
                image = util.downsampling(image, factor=supersampling)
×
442
            else:
443
                # first downscale then convolve
UNCOV
444
                image = util.downsampling(image, factor=supersampling)
×
445
                image = signal.fftconvolve(image, kernel, mode='same')
×
446
        elif supersampling > 1:
×
447
            image = util.downsampling(image, factor=supersampling)
×
448
        return image, self.coord_obs
×
449

450
    def model_residuals(self, mask=None, **model_image_kwargs):
1✔
451
        """computes the normalized residuals map as (data - model) / sigma"""
452
        model, _ = self.model_image(**model_image_kwargs)
×
453
        data = self.coolest.observation.pixels.get_pixels(directory=self.directory)
×
UNCOV
454
        noise = self.coolest.observation.noise
×
UNCOV
455
        if noise.type != 'NoiseMap':
×
UNCOV
456
            raise NotImplementedError
×
UNCOV
457
        sigma = noise.noise_map.get_pixels(directory=self.directory)
×
458
        if mask is None:
×
UNCOV
459
            mask = np.ones_like(model)
×
460
        return ((data - model) / sigma) * mask, self.coord_obs
×
461

462
    def evaluate_lensed_surface_brightness(self, x, y):
1✔
463
        """Evaluates the surface brightness of a lensed source at given coordinates"""
464
        # ray-shooting
465
        x_rs, y_rs = self.ray_shooting(x, y)
×
466
        # evaluates at ray-shooted coordinates
UNCOV
467
        lensed_image = self.source.evaluate_surface_brightness(x_rs, y_rs)
×
UNCOV
468
        return lensed_image
×
469

470
    def ray_shooting(self, x, y):
1✔
471
        """evaluates the lens equation beta = theta - alpha(theta)"""
UNCOV
472
        return self.lens_mass.ray_shooting(x, y)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc