• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 4960685499

pending completion
4960685499

Pull #34

github

GitHub
Merge 7b7c1d7ba into d1de71ffa
Pull Request #34: Preparation for JOSS submission

184 of 184 new or added lines in 28 files covered. (100.0%)

1071 of 2324 relevant lines covered (46.08%)

0.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/composable_models.py
1
__author__ = 'aymgal'
×
2

3

4
import numpy as np
×
5
import math
×
6
import logging
×
7
from scipy import signal
×
8

9
from coolest.api import util
×
10

11

12
# logging settings
13
logging.getLogger().setLevel(logging.INFO)
×
14

15

16
class BaseComposableModel(object):
×
17
    """Given a COOLEST object, evaluates a selection of mass or light profiles.
18
    This class serves as parent for more specific classes and should not be 
19
    instantiated by the user.
20

21
    Parameters
22
    ----------
23
    model_type : str
24
        Either 'light_model' or 'mass_model'
25
    coolest_object : COOLEST
26
        COOLEST instance
27
    coolest_directory : str, optional
28
        Directory which contains the COOLEST template, by default None
29
    entity_selection : list, optional
30
        List of indices of the lensing entities to consider; If None, 
31
        selects the first entity which has a model of type model_type, by default None
32
    profile_selection : list, optional
33
        List of either lists of indices, or 'all', for selecting which (mass or light) profile 
34
        of a given lensing entity to consider. If None, selects all the 
35
        profiles of within the corresponding entity, by default None
36

37
    Raises
38
    ------
39
    ValueError
40
        No valid entity found or no profiles found.
41
    """
42

43
    def __init__(self, model_type, coolest_object, coolest_directory=None, 
×
44
                 entity_selection=None, profile_selection=None):
45
        if entity_selection is None:
×
46
            # finds the first entity that has a 'model_type' profile
47
            entity_selection = None
×
48
            for i, entity in enumerate(coolest_object.lensing_entities):
×
49
                if model_type == 'light_model' \
×
50
                    and entity.type == 'galaxy' \
51
                    and len(entity.light_model) > 0:
52
                    entity_selection = [i]
×
53
                    break
×
54
                elif model_type == 'mass_model' \
×
55
                    and len(entity.mass_model) > 0:
56
                    entity_selection = [i]
×
57
                    break
×
58
            if entity_selection is None:
×
59
                raise ValueError("No lensing entity with light profiles have been found")
×
60
            else:
61
                logging.info(f"Found valid profile for lensing entity (index {i}) for model type '{model_type}'")
×
62
        if profile_selection is None:
×
63
            profile_selection = 'all'
×
64
        entities = coolest_object.lensing_entities
×
65
        self.directory = coolest_directory
×
66
        self.profile_list, self.param_list, self.info_list \
×
67
            = self.select_profiles(model_type, entities, 
68
                                   entity_selection, profile_selection,
69
                                   coolest_directory)
70
        self.num_profiles = len(self.profile_list)
×
71
        if self.num_profiles == 0:
×
72
            raise ValueError("No profile has been selected!")
×
73

74
    def select_profiles(self, model_type, entities, 
×
75
                        entity_selection, profile_selection, 
76
                        coolest_directory):
77
        profile_list = []
×
78
        param_list = []
×
79
        info_list = []
×
80
        for i, entity in enumerate(entities):
×
81
            if self._selected(i, entity_selection):
×
82
                if model_type == 'light_model' and entity.type == 'external_shear':
×
83
                    raise ValueError(f"External shear (entity index {i}) has no light model")
×
84
                for j, profile in enumerate(getattr(entity, model_type)):
×
85
                    if self._selected(j, profile_selection):
×
86
                        if 'Grid' in profile.type:
×
87
                            if coolest_directory is None:
×
88
                                raise ValueError("The directory in which the COOLEST file is located "
×
89
                                                 "must be provided for loading FITS files")
90
                            params, fixed = self._get_grid_params(profile, coolest_directory)
×
91
                            profile_list.append(self._get_api_profile(model_type, profile, *fixed))
×
92
                            param_list.append(params)
×
93
                        else:
94
                            profile_list.append(self._get_api_profile(model_type, profile))
×
95
                            param_list.append(self._get_point_estimates(profile))
×
96
                        info_list.append((entity.name, entity.redshift))
×
97
        return profile_list, param_list, info_list
×
98

99
    def estimate_center(self):
×
100
        # TODO: improve this (for now simply considers the first profile that has a center)
101
        for profile, params in zip(self.profile_list, self.param_list):
×
102
            if 'center_x' in params:
×
103
                center_x = params['center_x']
×
104
                center_y = params['center_y']
×
105
                logging.info(f"Picked center from profile '{profile.type}'")
×
106
                return center_x, center_y
×
107
        raise ValueError("Could not estimate a center from the composed model")
×
108

109
    @staticmethod
×
110
    def _get_api_profile(model_type, profile_in, *args):
111
        """
112
        Takes as input a light profile from the template submodule
113
        and instantites the corresponding profile from the API submodule
114
        """
115
        if model_type == 'light_model':
×
116
            from coolest.api.profiles import light
×
117
            ProfileClass = getattr(light, profile_in.type)
×
118
        elif model_type == 'mass_model':
×
119
            from coolest.api.profiles import mass
×
120
            ProfileClass = getattr(mass, profile_in.type)
×
121
        return ProfileClass(*args)
×
122

123
    @staticmethod
×
124
    def _get_point_estimates(profile_in):
125
        parameters = {}
×
126
        for name, param in profile_in.parameters.items():
×
127
            parameters[name] = param.point_estimate.value
×
128
        return parameters
×
129

130
    @staticmethod
×
131
    def _get_grid_params(profile_in, fits_dir):
132
        if profile_in.type == 'PixelatedRegularGrid':
×
133
            data = profile_in.parameters['pixels'].get_pixels(directory=fits_dir)
×
134
            parameters = {'pixels': data}
×
135
            fov_x = profile_in.parameters['pixels'].field_of_view_x
×
136
            fov_y = profile_in.parameters['pixels'].field_of_view_y
×
137
            npix_x = profile_in.parameters['pixels'].num_pix_x
×
138
            npix_y = profile_in.parameters['pixels'].num_pix_y
×
139
            fixed_parameters = (fov_x, fov_y, npix_x, npix_y)
×
140

141
        elif profile_in.type == 'IrregularGrid':
×
142
            x, y, z = profile_in.parameters['pixels'].get_xyz(directory=fits_dir)
×
143
            parameters = {'x': x, 'y': y, 'z': z}
×
144
            fov_x = profile_in.parameters['pixels'].field_of_view_x
×
145
            fov_y = profile_in.parameters['pixels'].field_of_view_y
×
146
            npix = profile_in.parameters['pixels'].num_pix
×
147
            fixed_parameters = (fov_x, fov_y, npix)
×
148
        return parameters, fixed_parameters
×
149

150
    @staticmethod
×
151
    def _selected(index, selection):
152
        if isinstance(selection, str) and selection.lower() == 'all':
×
153
            return True
×
154
        elif isinstance(selection, (list, tuple, np.ndarray)) and index in selection:
×
155
            return True
×
156
        elif isinstance(selection, (int, float)) and int(selection) == index:
×
157
            return True
×
158
        return False
×
159

160

161
class ComposableLightModel(BaseComposableModel):
×
162
    """Given a COOLEST object, evaluates a selection of entity and their light profiles.
163

164
    Parameters
165
    ----------
166
    coolest_object : COOLEST
167
        COOLEST instance
168
    coolest_directory : str, optional
169
        Directory which contains the COOLEST template, by default None
170
    entity_selection : list, optional
171
        List of indices of the lensing entities to consider; If None, 
172
        selects the first entity that has a light model, by default None
173
    profile_selection : list, optional
174
        List of either lists of indices, or 'all', for selecting which light profile 
175
        of a given lensing entity to consider. If None, selects all the 
176
        profiles of within the corresponding entity, by default None
177

178
    Raises
179
    ------
180
    ValueError
181
        No valid entity found or no profiles found.
182
    """
183

184
    def __init__(self, coolest_object, coolest_directory=None, **kwargs_selection):
×
185
        super().__init__('light_model', coolest_object, 
×
186
                         coolest_directory=coolest_directory,
187
                         **kwargs_selection)
188
        pixel_size = coolest_object.instrument.pixel_size
×
189
        if pixel_size is None:
×
190
            self.pixel_area = 1.
×
191
        else:
192
            self.pixel_area = pixel_size**2
×
193

194
    def surface_brightness(self, return_extra=False):
×
195
        """Returns the surface brightness as stored in the COOLEST file"""
196
        if self.num_profiles > 1:
×
197
            logging.warning("When more than a single light profile has been selected, "
×
198
                            "the method `surface_brightness()` only considers the first profile")
199
        profile = self.profile_list[0]
×
200
        values = profile.surface_brightness(**self.param_list[0])
×
201
        if return_extra:
×
202
            extent = profile.get_extent()
×
203
            coordinates = profile.get_coordinates()
×
204
            return values, extent, coordinates
×
205
        return values
×
206

207
    def evaluate_surface_brightness(self, x, y):
×
208
        """Evaluates the surface brightness at given coordinates"""
209
        image = np.zeros_like(x)
×
210
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
211
            flux_k = profile.evaluate_surface_brightness(x, y, **params)
×
212
            if profile.units == 'flux_per_ang':
×
213
                flux_k *= self.pixel_area
×
214
            image += flux_k
×
215
        return image
×
216

217

218
class ComposableMassModel(BaseComposableModel):
×
219
    """Given a COOLEST object, evaluates a selection of entity and their mass profiles.
220

221
    Parameters
222
    ----------
223
    coolest_object : COOLEST
224
        COOLEST instance
225
    coolest_directory : str, optional
226
        Directory which contains the COOLEST template, by default None
227
    entity_selection : list, optional
228
        List of indices of the lensing entities to consider; If None, 
229
        selects the first entity that has a mass model, by default None
230
    profile_selection : list, optional
231
        List of either lists of indices, or 'all', for selecting which mass profile 
232
        of a given lensing entity to consider. If None, selects all the 
233
        profiles of within the corresponding entity, by default None
234

235
    Raises
236
    ------
237
    ValueError
238
        No valid entity found or no profiles found.
239
    """
240

241
    def __init__(self, coolest_object, coolest_directory=None, **kwargs_selection):
×
242
        super().__init__('mass_model', coolest_object, 
×
243
                         coolest_directory=coolest_directory,
244
                         **kwargs_selection)
245

246
    def evaluate_deflection(self, x, y):
×
247
        """Evaluates the surface brightness at given coordinates"""
248
        alpha_x, alpha_y = np.zeros_like(x), np.zeros_like(x)
×
249
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
250
            a_x, a_y = profile.deflection(x, y, **params)
×
251
            alpha_x += a_x
×
252
            alpha_y += a_y
×
253
        return alpha_x, alpha_y
×
254

255
    def evaluate_convergence(self, x, y):
×
256
        """Evaluates the surface brightness at given coordinates"""
257
        kappa = np.zeros_like(x)
×
258
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
259
            kappa += profile.convergence(x, y, **params)
×
260
        return kappa
×
261

262
    def evaluate_magnification(self, x, y):
×
263
        """Evaluates the surface brightness at given coordinates"""
264
        H_xx_sum = np.zeros_like(x)
×
265
        H_xy_sum = np.zeros_like(x)
×
266
        H_yx_sum = np.zeros_like(x)
×
267
        H_yy_sum = np.zeros_like(x)
×
268
        for k, (profile, params) in enumerate(zip(self.profile_list, self.param_list)):
×
269
            H_xx, H_xy, H_yx, H_yy = profile.hessian(x, y, **params)
×
270
            H_xx_sum += H_xx
×
271
            H_xy_sum += H_xy
×
272
            H_yx_sum += H_yx
×
273
            H_yy_sum += H_yy
×
274
        det_A = (1 - H_xx_sum) * (1 - H_yy_sum) - H_xy_sum*H_yx_sum
×
275
        mu = 1. / det_A
×
276
        return mu
×
277

278

279
class ComposableLensModel(object):
×
280
    """Given a COOLEST object, evaluates a selection of entity and 
281
    their mass and light profiles, typically to construct an image of the lens.
282

283
    Parameters
284
    ----------
285
    coolest_object : COOLEST
286
        COOLEST instance
287
    coolest_directory : str, optional
288
        Directory which contains the COOLEST template, by default None
289
    entity_selection : list, optional
290
        List of indices of the lensing entities to consider; If None, 
291
        selects the first entity that has a light/mass model, by default None
292
    profile_selection : list, optional
293
        List of either lists of indices, or 'all', for selecting which light/mass profile 
294
        of a given lensing entity to consider. If None, selects all the 
295
        profiles of within the corresponding entity, by default None
296

297
    Raises
298
    ------
299
    ValueError
300
        No valid entity found or no profiles found.
301
    """
302

303
    def __init__(self, coolest_object, coolest_directory=None, 
×
304
                 kwargs_selection_source=None, kwargs_selection_lens_mass=None):
305
        self.coolest = coolest_object
×
306
        self.coord_obs = util.get_coordinates(self.coolest)
×
307
        self.directory = coolest_directory
×
308
        if kwargs_selection_source is None:
×
309
            kwargs_selection_source = {}
×
310
        if kwargs_selection_lens_mass is None:
×
311
            kwargs_selection_lens_mass = {}
×
312
        self.lens_mass = ComposableMassModel(coolest_object, 
×
313
                                             coolest_directory,
314
                                             **kwargs_selection_lens_mass)
315
        self.source = ComposableLightModel(coolest_object, 
×
316
                                          coolest_directory,
317
                                          **kwargs_selection_source)
318

319
    def model_image(self, supersampling=5, convolved=True):
×
320
        """generates an image of the lens based on the selected model components"""
321
        obs = self.coolest.observation
×
322
        psf = self.coolest.instrument.psf
×
323
        if convolved is True and psf.type == 'PixelatedPSF':
×
324
            scale_factor = obs.pixels.pixel_size / psf.pixels.pixel_size
×
325
            supersampling_conv = int(round(scale_factor))
×
326
            if not math.isclose(scale_factor, supersampling_conv):
×
327
                raise ValueError(f"PSF supersampling ({scale_factor}) not close to an integer?")
×
328
            if supersampling_conv < 1:
×
329
                raise ValueError("PSF pixel size smaller than data pixel size")
×
330
        if supersampling < 1:
×
331
            raise ValueError("Supersampling must be >= 1")
×
332
        if convolved is True and supersampling_conv > supersampling:
×
333
            supersampling = supersampling_conv
×
334
            logging.warning(f"Supersampling adapted to the PSF pixel size ({supersampling})")
×
335
        coord_eval = self.coord_obs.create_new_coordinates(pixel_scale_factor=1./supersampling)
×
336
        x, y = coord_eval.pixel_coordinates
×
337
        image = self.evaluate_lensed_surface_brightness(x, y)
×
338
        if convolved is True:
×
339
            if psf.type != 'PixelatedPSF':
×
340
                raise NotImplementedError
×
341
            kernel = psf.pixels.get_pixels(directory=self.directory)
×
342
            kernel_sum = kernel.sum()
×
343
            if not math.isclose(kernel_sum, 1., abs_tol=1e-3):
×
344
                logging.warning(f"PSF kernel is not normalized (sum={kernel_sum}), "
×
345
                                "so it is normalized before convolution")
346
                kernel /= kernel_sum
×
347
            if np.isnan(image).any():
×
348
                logging.warning("Found NaN values in image prior to convolution; "
×
349
                                "there are replaced by zeros before convolution")
350
                np.nan_to_num(image, copy=False, nan=0., posinf=None, neginf=None)
×
351
            if supersampling_conv == supersampling:
×
352
                # first convolve then dowsnscale 
353
                image = signal.fftconvolve(image, kernel, mode='same')
×
354
                image = util.downsampling(image, factor=supersampling)
×
355
            elif supersampling_conv == 1:
×
356
                # first dowsnscale then convolve
357
                image = util.downsampling(image, factor=supersampling)
×
358
                image = signal.fftconvolve(image, kernel, mode='same')
×
359
        elif supersampling > 1:
×
360
            image = util.downsampling(image, factor=supersampling)
×
361
        return image, self.coord_obs
×
362

363
    def model_residuals(self, supersampling=5, mask=None):
×
364
        """computes the normalized residuals map as (data - model) / sigma"""
365
        model, _ = self.model_image(supersampling=supersampling, 
×
366
                                    convolved=True)
367
        data = self.coolest.observation.pixels.get_pixels(directory=self.directory)
×
368
        noise = self.coolest.observation.noise
×
369
        if noise.type != 'NoiseMap':
×
370
            raise NotImplementedError
×
371
        sigma = noise.noise_map.get_pixels(directory=self.directory)
×
372
        if mask is None:
×
373
            mask = np.ones_like(model)
×
374
        return ((data - model) / sigma) * mask, self.coord_obs
×
375

376
    def evaluate_lensed_surface_brightness(self, x, y):
×
377
        """Evaluates the surface brightness of a lensed source at given coordinates"""
378
        # ray-shooting
379
        x_rs, y_rs = self.ray_shooting(x, y)
×
380
        # evaluates at ray-shooted coordinates
381
        lensed_image = self.source.evaluate_surface_brightness(x_rs, y_rs)
×
382
        return lensed_image
×
383

384
    def ray_shooting(self, x, y):
×
385
        """evaluates the lens equation beta = theta - alpha(theta)"""
386
        alpha_x, alpha_y = self.lens_mass.evaluate_deflection(x, y)
×
387
        x_rs, y_rs = x - alpha_x, y - alpha_y
×
388
        return x_rs, y_rs
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc