• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 4960685499

pending completion
4960685499

Pull #34

github

GitHub
Merge 7b7c1d7ba into d1de71ffa
Pull Request #34: Preparation for JOSS submission

184 of 184 new or added lines in 28 files covered. (100.0%)

1071 of 2324 relevant lines covered (46.08%)

0.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv'
×
2

3

4
import copy
×
5
import logging
×
6
import numpy as np
×
7
import matplotlib.pyplot as plt
×
8
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
9
from matplotlib.colors import ListedColormap
×
10

11
from coolest.api.analysis import Analysis
×
12
from coolest.api.composable_models import *
×
13
from coolest.api import util
×
14
from coolest.api import plot_util as plut
×
15

16
# matplotlib global settings
17
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
18

19
# logging settings
20
logging.getLogger().setLevel(logging.INFO)
×
21

22

23
class ModelPlotter(object):
×
24
    """Create pyplot panels from a lens model stored in the COOLEST format.
25

26
    Parameters
27
    ----------
28
    coolest_object : COOLEST
29
        COOLEST instance
30
    coolest_directory : str, optional
31
        Directory which contains the COOLEST template, by default None
32
    color_bad_values : str, optional
33
        Color assigned to NaN values (typically negative values in log-scale), 
34
        by default '#111111' (dark gray)
35
    """
36

37
    def __init__(self, coolest_object, coolest_directory=None, 
×
38
                 color_bad_values='#111111'):
39
        self.coolest = coolest_object
×
40
        self._directory = coolest_directory
×
41

42
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
43
        self.cmap_flux.set_bad(color_bad_values)
×
44

45
        self.cmap_mag = plt.get_cmap('twilight_shifted')
×
46
        self.cmap_conv = plt.get_cmap('cividis')
×
47
        self.cmap_res = plt.get_cmap('RdBu_r')
×
48

49
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
50
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
51
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
52

53
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, neg_values_as_bad=False):
×
54
        """plt.imshow panel with the data image"""
55
        if cmap is None:
×
56
            cmap = self.cmap_flux
×
57
        coordinates = util.get_coordinates(self.coolest)
×
58
        extent = coordinates.plt_extent
×
59
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
60
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
61
                                cmap=cmap,
62
                                neg_values_as_bad=neg_values_as_bad, 
63
                                norm=norm)
64
        cb.set_label("flux")
×
65
        if title is not None:
×
66
            ax.set_title(title)
×
67
        return image
×
68

69
    def plot_surface_brightness(self, ax, title=None, coordinates=None, 
×
70
                                extent=None, norm=None, cmap=None, neg_values_as_bad=True,
71
                                plot_points_irreg=False, kwargs_light=None):
72
        """plt.imshow panel showing the surface brightness of the (unlensed)
73
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
74
        if kwargs_light is None:
×
75
            kwargs_light = {}
×
76
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
77
        if cmap is None:
×
78
            cmap = self.cmap_flux
×
79
        if coordinates is not None:
×
80
            x, y = coordinates.pixel_coordinates
×
81
            image = light_model.evaluate_surface_brightness(x, y)
×
82
            extent = coordinates.plt_extent
×
83
            ax, cb = self._plot_regular_grid(ax, image, extent=extent, cmap=cmap,
×
84
                                             neg_values_as_bad=neg_values_as_bad, 
85
                                             norm=norm)
86
        else:
87
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
88
            if extent is None:
×
89
                extent = extent_model
×
90
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
91
                image = values
×
92
                ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
93
                                        cmap=cmap, 
94
                                        neg_values_as_bad=neg_values_as_bad,
95
                                        norm=norm)
96
            else:
97
                points = values
×
98
                ax, cb = self._plot_irregular_grid(ax, points, extent, norm=norm, cmap=cmap, 
×
99
                                                   neg_values_as_bad=neg_values_as_bad,
100
                                                   plot_points=plot_points_irreg)
101
                image = None
×
102
        cb.set_label("flux")
×
103
        if title is not None:
×
104
            ax.set_title(title)
×
105
        return image, coordinates
×
106

107
    def plot_model_image(self, ax, supersampling=5, convolved=False, title=None,
×
108
                         norm=None, cmap=None, neg_values_as_bad=False,
109
                         kwargs_source=None, kwargs_lens_mass=None):
110
        """plt.imshow panel showing the surface brightness of the (lensed)
111
        selected lensing entities (see ComposableLensModel docstring)
112
        """
113
        if cmap is None:
×
114
            cmap = self.cmap_flux
×
115
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
116
                                         kwargs_selection_source=kwargs_source,
117
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
118
        image, coordinates = lens_model.model_image(supersampling=supersampling, 
×
119
                                                    convolved=convolved)
120
        extent = coordinates.plt_extent
×
121
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
122
                                cmap=cmap,
123
                                neg_values_as_bad=neg_values_as_bad, 
124
                                norm=norm)
125
        cb.set_label("flux")
×
126
        if title is not None:
×
127
            ax.set_title(title)
×
128
        return image
×
129

130
    def plot_model_residuals(self, ax, supersampling=5, mask=None, title=None,
×
131
                             norm=None, cmap=None, add_chi2_label=False, chi2_fontsize=12,
132
                             kwargs_source=None, kwargs_lens_mass=None):
133
        """plt.imshow panel showing the normalized model residuals image"""
134
        if cmap is None:
×
135
            cmap = self.cmap_res
×
136
        if norm is None:
×
137
            norm = Normalize(-6, 6)
×
138
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
139
                                         kwargs_selection_source=kwargs_source,
140
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
141
        image, coordinates = lens_model.model_residuals(supersampling=supersampling, 
×
142
                                                        mask=mask)
143
        extent = coordinates.plt_extent
×
144
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
145
                                cmap=cmap,
146
                                neg_values_as_bad=False, 
147
                                norm=norm)
148
        cb.set_label("(data $-$ model) / noise")
×
149
        if add_chi2_label is True:
×
150
            num_constraints = np.size(image) if mask is None else np.sum(mask)
×
151
            red_chi2 = np.sum(image**2) / num_constraints
×
152
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
153
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
154
                    bbox={'color': 'white', 'alpha': 0.6})
155
        if title is not None:
×
156
            ax.set_title(title)
×
157
        return image
×
158

159
    def plot_convergence(self, ax, title=None,
×
160
                         norm=None, cmap=None, neg_values_as_bad=False,
161
                         kwargs_lens_mass=None):
162
        """plt.imshow panel showing the 2D convergence map associated to the
163
        selected lensing entities (see ComposableMassModel docstring)
164
        """
165
        if kwargs_lens_mass is None:
×
166
            kwargs_lens_mass = {}
×
167
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
168
                                         **kwargs_lens_mass)
169
        if cmap is None:
×
170
            cmap = self.cmap_conv
×
171
        coordinates = util.get_coordinates(self.coolest)
×
172
        extent = coordinates.plt_extent
×
173
        x, y = coordinates.pixel_coordinates
×
174
        image = mass_model.evaluate_convergence(x, y)
×
175
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
176
                                cmap=cmap,
177
                                neg_values_as_bad=neg_values_as_bad, 
178
                                norm=norm)
179
        cb.set_label(r"$\kappa$")
×
180
        if title is not None:
×
181
            ax.set_title(title)
×
182
        return image
×
183

184
    def plot_magnification(self, ax, title=None,
×
185
                          norm=None, cmap=None, neg_values_as_bad=False,
186
                          kwargs_lens_mass=None):
187
        """plt.imshow panel showing the 2D magnification map associated to the
188
        selected lensing entities (see ComposableMassModel docstring)
189
        """
190
        if kwargs_lens_mass is None:
×
191
            kwargs_lens_mass = {}
×
192
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
193
                                         **kwargs_lens_mass)
194
        if cmap is None:
×
195
            cmap = self.cmap_mag
×
196
        if norm is None:
×
197
            norm = Normalize(-10, 10)
×
198
        coordinates = util.get_coordinates(self.coolest)
×
199
        extent = coordinates.plt_extent
×
200
        x, y = coordinates.pixel_coordinates
×
201
        image = mass_model.evaluate_magnification(x, y)
×
202
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
203
                                cmap=cmap,
204
                                neg_values_as_bad=neg_values_as_bad, 
205
                                norm=norm)
206
        cb.set_label(r"$\mu$")
×
207
        if title is not None:
×
208
            ax.set_title(title)
×
209
        return image
×
210
        
211
    @staticmethod
×
212
    def _plot_regular_grid(ax, image_, neg_values_as_bad=True, **imshow_kwargs):
×
213
        if neg_values_as_bad:
×
214
            image = np.copy(image_)
×
215
            image[image < 0] = np.nan
×
216
        else:
217
            image = image_
×
218
        if neg_values_as_bad:
×
219
            image[image < 0] = np.nan
×
220
        im = ax.imshow(image, **imshow_kwargs)
×
221
        im.set_rasterized(True)
×
222
        cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
223
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
224
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
225
        return ax, cb
×
226

227
    @staticmethod
×
228
    def _plot_irregular_grid(ax, points, extent, neg_values_as_bad=True,
×
229
                             norm=None, cmap=None, plot_points=False):
230
        x, y, z = points
×
231
        m = plut.plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
232
                              norm=norm, cmap=cmap, zorder=1)
233
        ax.set_xlim(extent[0], extent[1])
×
234
        ax.set_ylim(extent[2], extent[3])
×
235
        ax.set_aspect('equal', 'box')
×
236
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
237
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
238
        cb = plut.nice_colorbar(m, ax=ax, max_nbins=4)
×
239
        if plot_points:
×
240
            ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
241
        return ax, cb
×
242

243

244
class MultiModelPlotter(object):
×
245
    """Wrapper around a set of ModelPlotter instances to produce panels that
246
    consistently compare different models, evaluated on the same
247
    coordinates systems.
248

249
    Parameters
250
    ----------
251
    coolest_objects : list
252
        List of COOLEST instances
253
    coolest_directories : list, optional
254
        List of directories corresponding to each COOLEST instance, by default None
255
    kwargs_plotter : dict, optional
256
        Additional keyword arguments passed to ModelPlotter
257
    """
258

259
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
260
        self.num_models = len(coolest_objects)
×
261
        if coolest_directories is None:
×
262
            coolest_directories = self.num_models * [None]
×
263
        self.plotter_list = []
×
264
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
265
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
266
                                                  **kwargs_plotter))
267

268
    def plot_surface_brightness(self, axes, titles=None, **kwargs):
×
269
        return self._plot_light_multi('plot_surface_brightness', "surf. brightness", axes, titles=titles, **kwargs)
×
270

271
    def plot_data_image(self, axes, titles=None, **kwargs):
×
272
        return self._plot_data_multi("data", axes, titles=titles, **kwargs)
×
273

274
    def plot_model_image(self, axes, titles=None, **kwargs):
×
275
        return self._plot_lens_model_multi('plot_model_image', "model", axes, titles=titles, **kwargs)
×
276

277
    def plot_model_residuals(self, axes, titles=None, **kwargs):
×
278
        return self._plot_lens_model_multi('plot_model_residuals', "residuals", axes, titles=titles, **kwargs)
×
279

280
    def plot_convergence(self, axes, titles=None, **kwargs):
×
281
        return self._plot_lens_model_multi('plot_convergence', "convergence", axes, titles=titles, **kwargs)
×
282

283
    def plot_magnification(self, axes, titles=None, **kwargs):
×
284
        return self._plot_lens_model_multi('plot_magnification', "magnification", axes, titles=titles, **kwargs)
×
285

286
    def _plot_light_multi(self, method_name, default_title, axes, titles=None, **kwargs):
×
287
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
288
        if titles is None:
×
289
            titles = self.num_models * [default_title]
×
290
        kwargs_ = copy.deepcopy(kwargs)
×
291
        image_list = []
×
292
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
293
            if ax is None:
×
294
                continue
×
295
            if 'kwargs_light' in kwargs:
×
296
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
297
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
298
            image_list.append(image)
×
299
        return image_list
×
300

301
    def _plot_mass_multi(self, method_name, default_title, axes, titles=None, **kwargs):
×
302
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
303
        if titles is None:
×
304
            titles = self.num_models * [default_title]
×
305
        kwargs_ = copy.deepcopy(kwargs)
×
306
        image_list = []
×
307
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
308
            if ax is None:
×
309
                continue
×
310
            if 'kwargs_lens_mass' in kwargs:
×
311
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
312
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
313
            image_list.append(image)
×
314
        return image_list
×
315

316
    def _plot_lens_model_multi(self, method_name, default_title, axes, titles=None, kwargs_select=None, **kwargs):
×
317
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
318
        if titles is None:
×
319
            titles = self.num_models * [default_title]
×
320
        kwargs_ = copy.deepcopy(kwargs)
×
321
        image_list = []
×
322
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
323
            if ax is None:
×
324
                continue
×
325
            if 'kwargs_source' in kwargs:
×
326
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
327
            if 'kwargs_lens_mass' in kwargs:
×
328
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
329
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
330
            image_list.append(image)
×
331
        return image_list
×
332

333
    def _plot_data_multi(self, default_title, axes, titles=None, **kwargs):
×
334
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
335
        if titles is None:
×
336
            titles = self.num_models * [default_title]
×
337
        image_list = []
×
338
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
339
            if ax is None:
×
340
                continue
×
341
            image = getattr(plotter, 'plot_data_image')(ax, title=titles[i], **kwargs)
×
342
            image_list.append(image)
×
343
        return image_list
×
344

345

346
class Comparison_analytical(object):
×
347
    """Handles plot of analytical models in a comparative way
348

349
    Parameters
350
    ----------
351
    coolest_file_list : list
352
        List of paths to COOLEST templates
353
    nickname_file_list : list
354
        List of shorter names related to each COOLEST model
355
    posterior_bool_list : list
356
        List of bool to toggle errorbars on point-estimate values
357
    """
358

359
    def __init__(self,coolest_file_list, nickname_file_list, posterior_bool_list):
×
360
        self.file_names = nickname_file_list
×
361
        self.posterior_bool_list = posterior_bool_list
×
362
        self.param_lens, self.param_source = util.read_json_param(coolest_file_list,
×
363
                                                                  self.file_names, 
364
                                                                  lens_light=False)
365

366
    def plotting_routine(self,param_dict,idx_file=0):
×
367
        """
368
        plot the parameters
369

370
        INPUT
371
        -----
372
        param_dict: dict, organized dictonnary with all parameters results of the different files
373
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
374
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
375
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
376
         sersic and shapelets parameters when available)
377
        """
378

379
        #find the numer of parameters to plot and define a nice looking figure
380
        number_param = len(param_dict[self.file_names[idx_file]])
×
381
        unused_figs = []
×
382
        if number_param <= 4:
×
383
            print('so few parameters not implemented yet')
×
384
        else:
385
            if number_param % 4 == 0:
×
386
                num_lines = int(number_param / 4.)
×
387
            else:
388
                num_lines = int(number_param / 4.) + 1
×
389

390
                for idx in range(3):
×
391
                    if (number_param + idx) % 4 != 0:
×
392
                        unused_figs.append(-idx - 1)
×
393
                    else:
394
                        break
×
395

396
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
397
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
398
        #may find a better way to define markers but right now, it is sufficient
399

400
        for j, file_name in enumerate(self.file_names):
×
401
            i = 0
×
402
            result = param_dict[file_name]
×
403
            for key in result.keys():
×
404
                idx_line = int(i / 4.)
×
405
                idx_col = i % 4
×
406
                p = result[key]
×
407
                m = markers[j]
×
408
                if self.posterior_bool_list[j]:
×
409
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
410
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
411
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
412
                    #                 i+=1
413
                    #                 continue
414

415
                    #trick to plot correct error bars if close to the +180/-180 edge
416
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
417
                        if p['percentile_16th'] > p['median']:
×
418
                            p['percentile_16th'] -= 180.
×
419
                        if p['percentile_84th'] < p['median']:
×
420
                            p['percentile_84th'] += 180.
×
421
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
422
                                                                    [p['percentile_84th'] - p['median']]],
423
                                                   marker=m, ls='', label=file_name)
424
                else:
425
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
426

427
                if j == 0:
×
428
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
429
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
430
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
431
                i += 1
×
432

433
        ax[0, 0].legend()
×
434
        for idx in unused_figs:
×
435
            ax[-1, idx].axis('off')
×
436
        plt.tight_layout()
×
437
        plt.show()
×
438
        return f,ax
×
439
    
440
    def plot_source(self,idx_file=0):
×
441
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
442
        return f,ax
×
443
    
444
    def plot_lens(self,idx_file=0):
×
445
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
446
        return f,ax
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc