• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 27645366736

16 Jun 2026 08:16PM UTC coverage: 43.458% (-2.0%) from 45.463%
27645366736

Pull #75

github

web-flow
Merge 79a70e2d3 into 9ef83563c
Pull Request #75: Adding multi-source plane lens functionality

31 of 242 new or added lines in 7 files covered. (12.81%)

6 existing lines in 4 files now uncovered.

1528 of 3516 relevant lines covered (43.46%)

0.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots, chains, MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18
import pandas as pd
×
19

20

21
# matplotlib global settings
22
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
23

24
# logging settings
25
logging.getLogger().setLevel(logging.INFO)
×
26

27
# TODO: separate ParametersPlotter from ModelPlotter to avoid dependencies on getdist
28

29
__all__ = [
×
30
    'ModelPlotter',
31
    'MultiModelPlotter',
32
    'ParametersPlotter',
33
]
34

35
class ModelPlotter(object):
×
36
    """Create pyplot panels from a lens model stored in the COOLEST format.
37

38
    Parameters
39
    ----------
40
    coolest_object : COOLEST
41
        COOLEST instance
42
    coolest_directory : str, optional
43
        Directory which contains the COOLEST template, by default None
44
    color_bad_values : str, optional
45
        Color assigned to NaN values (typically negative values in log-scale), 
46
        by default '#111111' (dark gray)
47
    """
48

49
    def __init__(self, coolest_object, coolest_directory=None, 
×
50
                 color_bad_values='#222222'):
51
        self.coolest = coolest_object
×
52
        self._directory = coolest_directory
×
53

54
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
55
        self.cmap_flux.set_bad(color_bad_values)
×
56

57
        self.cmap_mag = plt.get_cmap('viridis')
×
58
        self.cmap_conv = plt.get_cmap('cividis')
×
59
        self.cmap_res = plt.get_cmap('RdBu_r')
×
60

61
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
62
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
63
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
64

65
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, xylim=None,
×
66
                        neg_values_as_bad=False, add_colorbar=True, 
67
                        add_scalebar=True, scalebar_size=1):
68
        """plt.imshow panel with the data image"""
69
        if cmap is None:
×
70
            cmap = self.cmap_flux
×
71
        coordinates = util.get_coordinates(self.coolest)
×
72
        extent = coordinates.plt_extent
×
73
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
74
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
75
                                cmap=cmap, norm=norm,
76
                                neg_values_as_bad=neg_values_as_bad, 
77
                                xylim=xylim)
78
        if add_colorbar:
×
79
            cb = plut.nice_colorbar(im, ax=ax)
×
80
            cb.set_label("flux")
×
81
        if add_scalebar:
×
82
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
83
        return image
×
84

85
    def plot_surface_brightness(self, ax, title = None, coordinates=None,
×
86
                                extent_irreg=None, norm=None, cmap=None, 
87
                                xylim=None, neg_values_as_bad=True,
88
                                plot_points_irreg=False, add_colorbar=True, 
89
                                add_scalebar=False, scalebar_size=0.4,
90
                                kwargs_light=None,
91
                                plot_caustics=None, caustics_color='white', caustics_alpha=0.5,
92
                                coordinates_lens=None, kwargs_lens_mass=None):
93
        """plt.imshow panel showing the surface brightness of the (unlensed)
94
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
95
        if extent_irreg is not None:
×
96
            raise ValueError("`extent_irreg` is deprecated; use `xylim` instead.")
×
97
        
98

99
        if kwargs_light is None:
×
NEW
100
            lens_model = ComposableLensModel(self.coolest, self._directory)
×
NEW
101
            light_model = lens_model.source
×
102
        else:
NEW
103
            light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
104
        
105
        if plot_caustics:
×
106
            if kwargs_lens_mass is None:
×
107
                raise ValueError("`kwargs_lens_mass` must be provided to compute caustics")
×
108
            if coordinates_lens is None:
×
109
                coordinates_lens = util.get_coordinates(self.coolest).create_new_coordinates(pixel_scale_factor=0.1)
×
110
            # NOTE: here we assume that `kwargs_light` is for the source!
111
            mass_model = ComposableMassModel(self.coolest, self._directory, **kwargs_lens_mass)
×
112
            _, caustics = util.find_all_lens_lines(coordinates_lens, mass_model)
×
113
        if cmap is None:
×
114
            cmap = self.cmap_flux
×
115
        if coordinates is not None:
×
116
            x, y = coordinates.pixel_coordinates
×
NEW
117
            if isinstance(light_model, list):
×
NEW
118
                image = np.sum(np.array([lm.evaluate_surface_brightness(x, y) for lm in light_model]),axis = 0)
×
119
            else:
NEW
120
                image = light_model.evaluate_surface_brightness(x, y)
×
121
            extent = coordinates.plt_extent
×
122
            ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, cmap=cmap,
×
123
                                             neg_values_as_bad=neg_values_as_bad, 
124
                                             norm=norm, xylim=xylim)
125
        else:
126
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
127
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
128
                image = values
×
129
                ax, im = plut.plot_regular_grid(ax, title, image, extent=extent_model, 
×
130
                                        cmap=cmap, 
131
                                        neg_values_as_bad=neg_values_as_bad,
132
                                        norm=norm, xylim=xylim)
133
            else:
134
                points = values
×
135
                if xylim is None:
×
136
                    xylim = extent_model
×
137
                ax, im = plut.plot_irregular_grid(ax, title, points, xylim, norm=norm, cmap=cmap, 
×
138
                                                   neg_values_as_bad=neg_values_as_bad,
139
                                                   plot_points=plot_points_irreg)
140
                image = None
×
141
        if plot_caustics:
×
142
            for caustic in caustics:
×
143
                ax.plot(caustic[0], caustic[1], lw=1, color=caustics_color, alpha=caustics_alpha)
×
144
        if add_colorbar:
×
145
            cb = plut.nice_colorbar(im, ax=ax)
×
146
            cb.set_label("flux")
×
147
        if add_scalebar:
×
148
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
149
        return image, coordinates
×
150

151
    def plot_model_image(self, ax, title=None,
×
152
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
153
                         add_colorbar=True, add_scalebar=True, scalebar_size=1, 
154
                         auto_selection=False,
155
                         kwargs_lens_mass=None, 
156
                         kwargs_lens_light=None, 
157
                         kwargs_source=None,
158
                         **model_image_kwargs):
159
        """plt.imshow panel showing the surface brightness of the (lensed)
160
        selected lensing entities (see ComposableLensModel docstring)
161
        """
162
        if cmap is None:
×
163
            cmap = self.cmap_flux
×
164
        lens_model = ComposableLensModel(
×
165
            self.coolest, self._directory,
166
            auto_selection=auto_selection,
167
            kwargs_selection_source=kwargs_source,
168
            kwargs_selection_lens_mass=kwargs_lens_mass,
169
            kwargs_selection_lens_light=kwargs_lens_light
170
        )
171
        image, coordinates = lens_model.model_image(**model_image_kwargs)
×
172
        extent = coordinates.plt_extent
×
173
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
174
                                cmap=cmap,
175
                                neg_values_as_bad=neg_values_as_bad, 
176
                                norm=norm, xylim=xylim)
177
        if add_colorbar:
×
178
            cb = plut.nice_colorbar(im, ax=ax)
×
179
            cb.set_label("flux")
×
180
        if add_scalebar:
×
181
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
182
        return image
×
183

184
    def plot_model_residuals(self, ax, title = None, mask=None,
×
185
                             norm=None, cmap=None, xylim=None, add_chi2_label=False, chi2_fontsize=12,
186
                             kwargs_source=None, 
187
                             kwargs_lens_mass=None,
188
                             kwargs_lens_light=None,
189
                             add_colorbar=True, add_scalebar=True, scalebar_size=1,
190
                             **model_image_kwargs):
191
        """plt.imshow panel showing the normalized model residuals image"""
192
        if cmap is None:
×
193
            cmap = self.cmap_res
×
194
        if norm is None:
×
195
            norm = Normalize(-6, 6)
×
196
        ll_mask = self._get_likelihood_mask(mask)
×
197
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
198
                                         kwargs_selection_source=kwargs_source,
199
                                         kwargs_selection_lens_mass=kwargs_lens_mass,
200
                                         kwargs_selection_lens_light=kwargs_lens_light)
201
        image, coordinates = lens_model.model_residuals(mask=ll_mask, **model_image_kwargs)
×
202
        extent = coordinates.plt_extent
×
203
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
204
                                cmap=cmap,
205
                                neg_values_as_bad=False, 
206
                                norm=norm, xylim=xylim)
207
        if add_colorbar:
×
208
            cb = plut.nice_colorbar(im, ax=ax)
×
209
            cb.set_label("(data $-$ model) / noise")
×
210
        if add_scalebar:
×
211
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
212
        if add_chi2_label is True:
×
213
            num_constraints = np.size(image) if ll_mask is None else np.sum(ll_mask)
×
214
            red_chi2 = np.sum(image**2) / num_constraints
×
215
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
216
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
217
                    bbox={'color': 'white', 'alpha': 0.6})
218
        return image
×
219

220
    def plot_convergence(self, ax, title = None, coordinates=None,
×
221
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
222
                         add_colorbar=True, 
223
                         add_scalebar=True, scalebar_size=1, 
224
                         kwargs_lens_mass=None):
225
        """plt.imshow panel showing the 2D convergence map associated to the
226
        selected lensing entities (see ComposableMassModel docstring)
227
        """
228
        if kwargs_lens_mass is None:
×
229
            kwargs_lens_mass = {}
×
230
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
231
                                         **kwargs_lens_mass)
232
        if cmap is None:
×
233
            cmap = self.cmap_conv
×
234
        if coordinates is None:
×
235
            coordinates = util.get_coordinates(self.coolest)
×
236
        extent = coordinates.plt_extent
×
237
        x, y = coordinates.pixel_coordinates
×
238
        image = mass_model.evaluate_convergence(x, y)
×
239
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
240
                                cmap=cmap,
241
                                neg_values_as_bad=neg_values_as_bad, 
242
                                norm=norm, xylim=xylim)
243
        if add_colorbar:
×
244
            cb = plut.nice_colorbar(im, ax=ax)
×
245
            cb.set_label(r"$\kappa$")
×
246
        if add_scalebar:
×
247
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
248
        return image
×
249
    
250
    def plot_convergence_diff(
×
251
            self, ax, reference_map, title = None, relative_error=True,    
252
            norm=None, cmap=None, xylim=None, coordinates=None,
253
            add_colorbar=True, add_scalebar=True, scalebar_size=1, 
254
            kwargs_lens_mass=None,
255
            plot_crit_lines=False, crit_lines_color='black', crit_lines_alpha=0.5):
256
        """plt.imshow panel showing the 2D convergence map associated to the
257
        selected lensing entities (see ComposableMassModel docstring)
258
        """
259
        if kwargs_lens_mass is None:
×
260
            kwargs_lens_mass = {}
×
261
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
262
                                         **kwargs_lens_mass)
263
        if cmap is None:
×
264
            cmap = self.cmap_res
×
265
        if norm is None:
×
266
            norm = Normalize(-1, 1)
×
267
        if coordinates is None:
×
268
            coordinates = util.get_coordinates(self.coolest)
×
269
        if plot_crit_lines:
×
270
            critical_lines, _ = util.find_all_lens_lines(coordinates, mass_model)
×
271
        extent = coordinates.plt_extent
×
272
        x, y = coordinates.pixel_coordinates
×
273
        image = mass_model.evaluate_convergence(x, y)
×
274
        if relative_error is True:
×
275
            diff = (reference_map - image) / reference_map
×
276
        else:
277
            diff = reference_map - image
×
278
        ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent, 
×
279
                                cmap=cmap, 
280
                                norm=norm, xylim=xylim)
281
        if plot_crit_lines:
×
282
            for cline in critical_lines:
×
283
                ax.plot(cline[0], cline[1], lw=1, color=crit_lines_color, alpha=crit_lines_alpha)
×
284
        if add_colorbar:
×
285
            cb = plut.nice_colorbar(im, ax=ax)
×
286
            cb.set_label(r"$\kappa$")
×
287
        if add_scalebar:
×
288
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
289
        return image
×
290

291
    def plot_magnification(self, ax, title = None,
×
292
                          norm=None, cmap=None, xylim=None,
293
                          add_colorbar=True, add_scalebar=True, scalebar_size=1, 
294
                          coordinates=None, kwargs_lens_mass=None):
295
        """plt.imshow panel showing the 2D magnification map associated to the
296
        selected lensing entities (see ComposableMassModel docstring)
297
        """
298
        if kwargs_lens_mass is None:
×
299
            kwargs_lens_mass = {}
×
300
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
301
                                         **kwargs_lens_mass)
302
        if cmap is None:
×
303
            cmap = self.cmap_mag
×
304
        if norm is None:
×
305
            norm = Normalize(-10, 10)
×
306
        if coordinates is None:
×
307
            coordinates = util.get_coordinates(self.coolest)
×
308
        x, y = coordinates.pixel_coordinates
×
309
        extent = coordinates.plt_extent
×
310
        image = mass_model.evaluate_magnification(x, y)
×
311
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
312
                                cmap=cmap, 
313
                                norm=norm, xylim=xylim)
314
        if add_colorbar:
×
315
            cb = plut.nice_colorbar(im, ax=ax)
×
316
            cb.set_label(r"$\mu$")
×
317
        if add_scalebar:
×
318
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
319
        return image
×
320

321
    def plot_magnification_diff(
×
322
            self, ax, reference_map, title = None, relative_error=True,
323
            norm=None, cmap=None, xylim=None,
324
            add_colorbar=True, add_scalebar=True, scalebar_size=1, 
325
            coordinates=None, kwargs_lens_mass=None):
326
        """plt.imshow panel showing the (absolute or relative) 
327
        difference between 2D magnification maps
328
        """
329
        if kwargs_lens_mass is None:
×
330
            kwargs_lens_mass = {}
×
331
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
332
                                        **kwargs_lens_mass)
333
        if cmap is None:
×
334
            cmap = self.cmap_res
×
335
        if norm is None:
×
336
            norm = Normalize(-1, 1)
×
337
        if coordinates is None:
×
338
            coordinates = util.get_coordinates(self.coolest)
×
339
        x, y = coordinates.pixel_coordinates
×
340
        extent = coordinates.plt_extent
×
341
        image = mass_model.evaluate_magnification(x, y)
×
342
        if relative_error is True:
×
343
            diff = (reference_map - image) / reference_map
×
344
        else:
345
            diff = reference_map - image
×
346
        ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent, 
×
347
                                cmap=cmap,
348
                                norm=norm, xylim=xylim)
349
        if add_colorbar:
×
350
            cb = plut.nice_colorbar(im, ax=ax)
×
351
            cb.set_label(r"$\mu$")
×
352
        if add_scalebar:
×
353
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
354
        return image
×
355

356
    def _get_likelihood_mask(self, user_mask):
×
357
        # TODO: 
358
        if self.coolest.likelihoods is None:
×
359
            return None
×
360
        try:
×
361
            img_ll_idx = self.coolest.likelihoods.index('ImagingDataLikelihood')
×
362
        except ValueError:
×
363
            return None
×
364
        img_ll = self.coolest.likelihoods[img_ll_idx]
×
365
        mask = img_ll.get_mask_pixels(directory=self._directory)
×
366
        if mask is None:  # then we use the user-provided mask
×
367
            mask = user_mask
×
368
        return mask
×
369

370

371
class MultiModelPlotter(object):
×
372
    """Wrapper around a set of ModelPlotter instances to produce panels that
373
    consistently compare different models, evaluated on the same
374
    coordinates systems.
375

376
    Parameters
377
    ----------
378
    coolest_objects : list
379
        List of COOLEST instances
380
    coolest_directories : list, optional
381
        List of directories corresponding to each COOLEST instance, by default None
382
    kwargs_plotter : dict, optional
383
        Additional keyword arguments passed to ModelPlotter
384
    """
385

386
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
387
        self.num_models = len(coolest_objects)
×
388
        if coolest_directories is None:
×
389
            coolest_directories = self.num_models * [None]
×
390
        self.plotter_list = []
×
391
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
392
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
393
                                                  **kwargs_plotter))
394

395
    def plot_surface_brightness(self, axes, **kwargs):
×
396
        return self._plot_light_multi('plot_surface_brightness',axes, **kwargs)
×
397

398
    def plot_data_image(self, axes, **kwargs):
×
399
        return self._plot_data_multi(axes, **kwargs)
×
400

401
    def plot_model_image(self, axes, **kwargs):
×
402
        return self._plot_lens_model_multi('plot_model_image', axes, **kwargs)
×
403

404
    def plot_model_residuals(self, axes, **kwargs):
×
405
        return self._plot_lens_model_multi('plot_model_residuals', axes, **kwargs)
×
406

407
    def plot_convergence(self, axes, **kwargs):
×
408
        return self._plot_lens_model_multi('plot_convergence', axes, **kwargs)
×
409

410
    def plot_magnification(self, axes, **kwargs):
×
411
        return self._plot_lens_model_multi('plot_magnification', axes, **kwargs)
×
412

413
    def plot_convergence_diff(self, axes, *args, **kwargs):
×
414
        return self._plot_lens_model_multi('plot_convergence_diff', axes, *args, **kwargs)
×
415

416
    def plot_magnification_diff(self, axes, *args, **kwargs):
×
417
        return self._plot_lens_model_multi('plot_magnification_diff', axes, *args, **kwargs)
×
418

419
    def _plot_light_multi(self, method_name, axes, **kwargs):
×
420
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
421
        kwargs_ = copy.deepcopy(kwargs)
×
422
        if 'titles' in kwargs_:
×
423
                del kwargs_['titles']
×
424
        image_list = []
×
425
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
426
            if ax is None:
×
427
                continue
×
428
            if 'kwargs_light' in kwargs:
×
429
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
430
            if 'kwargs_lens_mass' in kwargs:  # used for over-plotting caustics
×
431
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
432
            if 'titles' in kwargs:
×
433
                title = kwargs['titles'][i]
×
434
            image = getattr(plotter, method_name)(ax, title, **kwargs_)
×
435
            image_list.append(image)
×
436
        return image_list
×
437

438
    def _plot_mass_multi(self, method_name, axes, **kwargs):
×
439
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
440
        kwargs_ = copy.deepcopy(kwargs)
×
441
        if 'titles' in kwargs_:
×
442
                del kwargs_['titles']
×
443
        image_list = []
×
444
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
445
            if ax is None:
×
446
                continue
×
447
            if 'kwargs_lens_mass' in kwargs:
×
448
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
449
            if 'titles' in kwargs:
×
450
                title = kwargs['titles'][i]
×
451
            image = getattr(plotter, method_name)(ax, title, **kwargs_)
×
452
            image_list.append(image)
×
453
        return image_list
×
454

455
    def _plot_lens_model_multi(self, method_name, axes, *args, **kwargs):
×
456
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
457
        kwargs_ = copy.deepcopy(kwargs)
×
458
        if 'titles' in kwargs_:
×
459
                del kwargs_['titles']
×
460
        image_list = []
×
461
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
462
            if ax is None:
×
463
                continue
×
464
            if 'kwargs_source' in kwargs:
×
465
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
466
            if 'kwargs_lens_mass' in kwargs:
×
467
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
468
            if 'kwargs_lens_light' in kwargs:
×
469
                kwargs_['kwargs_lens_light'] = {k: v[i] for k, v in kwargs['kwargs_lens_light'].items()}
×
470
            if 'titles' in kwargs:
×
471
                title = kwargs['titles'][i]
×
472
            image = getattr(plotter, method_name)(ax, title, *args, **kwargs_)
×
473
            image_list.append(image)
×
474
        return image_list
×
475

476
    def _plot_data_multi(self, axes, **kwargs):
×
477
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
478
        kwargs_ = copy.deepcopy(kwargs)
×
479
        if 'titles' in kwargs_:
×
480
                del kwargs_['titles']
×
481
        image_list = []
×
482
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
483
            if ax is None:
×
484
                continue
×
485
            if 'titles' in kwargs:
×
486
                title = kwargs['titles'][i]
×
487
            image = getattr(plotter, 'plot_data_image')(ax, title, **kwargs_)
×
488
            image_list.append(image)
×
489
        return image_list
×
490

491

492
class ParametersPlotter(object):
×
493
    """Handles plot of analytical models in a comparative way
494

495
    Parameters
496
    ----------
497
    parameter_id_list : array
498
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
499
    coolest_objects : array
500
        A list of coolest objects that have a chain file associated to them.
501
    coolest_directories : array
502
        A list of paths matching the coolest files in 'chain_objs'.
503
    coolest_names : array, optional
504
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
505
    ref_coolest_objects : array, optional
506
        A list of coolest objects that will be used as point estimates.
507
    ref_coolest_directories : array
508
        A list of paths matching the coolest files in 'point_estimate_objs'.
509
    ref_coolest_names : array, optional
510
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
511
    posterior_bool_list : list, optional
512
        List of bool to toggle errorbars on point-estimate values
513
    colors : list, optional
514
        List of pyplot color names to associate to each coolest model.
515
    linestyles : list, optional
516
        List of pyplot linesyles to associate to each coolest model.
517
    add_multivariate_margin_samples : bool, optional
518
        If True, will append to the list of compared models
519
        a new chain that is resampled from the multi-variate normal distribution,
520
        where its covariance matrix is computed from the marginalization of
521
        all samples from all models. By default False. 
522
    num_samples_per_model_margin : int, optional
523
        Number of samples to (randomly) draw from each model samples to concatenate
524
        before estimating the multi-variate normal marginalization.
525
    """
526

527
    np.random.seed(598237)  # fix the random seed for reproducibility
×
528
    
529
    def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
×
530
                 ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
531
                 posterior_bool_list=None, colors=None, linestyles=None,
532
                 add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
533
        self.parameter_id_list = parameter_id_list
×
534
        self.coolest_objects = coolest_objects
×
535
        self.coolest_directories = coolest_directories
×
536
        if coolest_names is None:
×
537
            coolest_names = ["Model "+str(i) for i in range(len(coolest_objects))]
×
538
        self.coolest_names = coolest_names
×
539
        self.ref_coolest_objects = ref_coolest_objects
×
540
        self.ref_coolest_directories = ref_coolest_directories
×
541
        self.ref_coolest_names = ref_coolest_names
×
542
        self.ref_file_names = ref_coolest_names
×
543

544
        self.num_models = len(self.coolest_objects)
×
545
        self.num_params = len(self.parameter_id_list)
×
546
        if colors is None:
×
547
            colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
×
548
        self.colors = colors
×
549
        if linestyles is None:
×
550
            linestyles = ['-']*self.num_models
×
551
        self.linestyles = linestyles
×
552
        self.ref_linestyles = ['--', ':', '-.', '-']
×
553
        self.ref_markers = ['s', '^', 'o', '*']
×
554

555
        self._add_margin_samples = add_multivariate_margin_samples
×
556
        self._ns_per_model_margin = num_samples_per_model_margin
×
557
        self._color_margin = 'black'
×
558
        self._label_margin = "Combined"
×
559

560
        # self.posterior_bool_list = posterior_bool_list
561
        # self.param_lens, self.param_source = util.split_lens_source_params(
562
        #     self.coolest_objects, self.coolest_names, lens_light=False)
563

564
    def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
×
565
                     add_multivariate_margin_samples=False):
566
        """Initializes the getdist plotter.
567

568
        Parameters
569
        ----------
570
        shift_sample_list : dict
571
            Dictionary keyed by parameter ID to apply a uniform additive shift to
572
            all samples of that parameters posterior distribution.
573
        settings_mcsamples : dict, optional
574
            Keyword arguments passed as the `settings` argument of getdist.MCSamples, by default None
575

576
        Raises
577
        ------
578
        ValueError
579
            If the csv file containing samples is is not coma (,) separated.
580
        """
581
        chains.print_load_details = False # Just to silence messages
×
582
        parameter_id_set = set(self.parameter_id_list)
×
583

584
        if shift_sample_list is None:
×
585
            shift_sample_list = [None]*self.num_models
×
586
        
587
        # Get the values of the point_estimates
588
        point_estimates = []
×
589
        if self.ref_coolest_objects is not None:
×
590
            for coolest_obj in self.ref_coolest_objects:
×
591
                values = []
×
592
                for par in self.parameter_id_list:
×
593
                    param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
NEW
594
                    if par[:4] == 'beta':
×
NEW
595
                        param = coolest_obj.multiplane_betas.get_using_param_id(par).beta
×
596
                    val = param.point_estimate.value
×
597
                    if val is None:
×
598
                        values.append(None)
×
599
                    else:
600
                        values.append(val)
×
601
                point_estimates.append(values)
×
602

603
        mcsamples = []
×
604
        samples_margin, weights_margin = None, None
×
605
        mysample_margin = None
×
606
        for i in range(self.num_models):
×
607
            chain_file = os.path.join(self.coolest_directories[i],self.coolest_objects[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
608

609
            # Each chain file can have a different number of free parameters
610
            f = open(chain_file)
×
611
            header = f.readline()
×
612
            f.close()
×
613

614
            if ';' in header:
×
615
                raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
×
616

617
            chain_file_headers = header.split(',')
×
618
            num_cols = len(chain_file_headers)
×
619
            chain_file_headers.pop() # Remove the last column name that is the probability weights
×
620
            chain_file_headers_set = set(chain_file_headers)
×
621
            
622
            # Check that the given parameters are a subset of those in the chain file
623
            assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
624

625
            # Set the labels for the parameters in the chain file
626
            labels = []
×
627
            for par_id in self.parameter_id_list:
×
NEW
628
                if par_id[:4] == 'beta':
×
NEW
629
                    param = self.coolest_objects[i].multiplane_betas.get_using_param_id(par_id).beta
×
630
                else:
NEW
631
                    param = self.coolest_objects[i].lensing_entities.get_parameter_from_id(par_id)
×
UNCOV
632
                labels.append(param.latex_str.strip('$'))
×
633

634
            # Read parameter values and probability weights
635
            column_indices = [chain_file_headers.index(par_id) for par_id in self.parameter_id_list]
×
636
            columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
×
637
            samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
×
638
        
639
            # Re-order columns to match self.parameter_id_list and labels
640
            sample_par_values = np.array(samples[self.parameter_id_list])
×
641

642
            # If needed, shift samples by a constant
643
            if shift_sample_list[i] is not None:
×
644
                for param_id, value in shift_sample_list[i].items():
×
645
                    sample_par_values[:, self.parameter_id_list.index(param_id)] += value
×
646
                    logging.info(f"posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
×
647
                                 f"has been shifted by {value}.")
648

649
            # Clean-up the probability weights
650
            mypost = np.array(samples['probability_weights'])
×
651
            min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
652
            sample_prob_weight = np.where(mypost<min_non_zero, min_non_zero, mypost)
×
653
            #sample_prob_weight = mypost
654

655
            # Create MCSamples object
656
            mysample = MCSamples(samples=sample_par_values, names=self.parameter_id_list,
×
657
                                 labels=labels, settings=settings_mcsamples)
658
            mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
659
            mcsamples.append(mysample)
×
660

661
            # if required, aggregate the samples in a "marginalized" posterior
662
            if self._add_margin_samples:
×
663
                if i == 0:
×
664
                    mysample_margin = copy.deepcopy(mysample)
×
665
                else:
666
                    # combine the sample such that the probability mass of each set of samples is the same
667
                    mysample_margin = mysample_margin.getCombinedSamplesWithSamples(mysample, sample_weights=(1, 1))
×
668
        
669
        if self._add_margin_samples:
×
670
            mcsamples.append(mysample_margin)
×
671

672
        self._mcsamples = mcsamples
×
673
        self.ref_values = point_estimates
×
674
        self.ref_values_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
×
675

676
    def get_mcsamples_getdist(self, with_margin=False):
×
677
        if not self._add_margin_samples or with_margin:
×
678
            return self._mcsamples
×
679
        else:
680
            return self._mcsamples[:-1]
×
681
    
682
    def get_margin_mcsamples_getdist(self):
×
683
        if not self._add_margin_samples:
×
684
            return None
×
685
        else:
686
            return self._mcsamples[-1]
×
687
    
688
    def plot_triangle_getdist(self, filled_contours=True, angles_range=None, 
×
689
                              linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
690
                              marker_linewidth=2, marker_size=15, 
691
                              axes_labelsize=None, legend_fontsize=None,
692
                              **subplot_kwargs):
693
        """Corner array of subplots using getdist.triangle_plot method.
694

695
        Parameters
696
        ----------
697
        subplot_size : int, optional
698
            Size of the getdist plot, by default 1
699
        filled_contours : bool, optional
700
            Wether or not to fill the 2D contours, by default True
701
        angles_range : _type_, optional
702
            Restrict the range of angle (containing 'phi' in their name) parameters, by default None
703
        linewidth_hist : int, optional
704
            Line width for 1D histograms, by default 2
705
        linewidth_cont : int, optional
706
            Line width for 2D contours, by default 1
707
        marker_size : int, optional
708
            Size of the reference (scatter) markers on 2D contours plots, by default 15
709

710
        Returns
711
        -------
712
        GetDistPlotter
713
            Instance of GetDistPlotter corresponding to the figure
714
        """
715
        line_args, contour_lws, contour_ls, colors, legend_labels \
×
716
            = self._prepare_getdist_plot(linewidth_hist, 
717
                                         lw_cont=linewidth_cont, 
718
                                         lw_margin=linewidth_margin)
719
        
720
        filled_contours = [filled_contours]*len(self._mcsamples)
×
721
        alphas = [1]*len(self._mcsamples)
×
722
        if self._add_margin_samples:
×
723
            filled_contours[-1] = True
×
724
            # alphas[-1] = 0.7
725
    
726
        # Make the plot
727
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
728
        if legend_fontsize is not None:
×
729
            g.settings.legend_fontsize = legend_fontsize 
×
730
        if axes_labelsize is not None:
×
731
            g.settings.axes_labelsize = axes_labelsize 
×
732
        g.triangle_plot(
×
733
            self._mcsamples,
734
            params=self.parameter_id_list,
735
            legend_labels=legend_labels,
736
            filled=filled_contours,
737
            colors=colors,
738
            line_args=line_args,   # TODO: issue that linewidth settings in line_args are being overwritten by contour_lws
739
            contour_colors=self.colors,
740
            contour_lws=contour_lws,
741
            contour_ls=contour_ls,
742
            alphas=alphas,
743
        )
744
        
745
        # Add marker lines and points
746
        for k in range(0, len(self.ref_values)):
×
747
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], 
×
748
                                lw=marker_linewidth)
749
            for i in range(0,self.num_params):
×
750
                val_x = self.ref_values[k][i]
×
751
                for j in range(i+1,self.num_params):
×
752
                    val_y = self.ref_values[k][j]
×
753
                    if val_x is not None and val_y is not None:
×
754
                        g.subplots[j,i].scatter(val_x, val_y, s=marker_size, facecolors='black',
×
755
                                                color='black', marker=self.ref_markers[k])
756

757

758
        # Set default ranges for angles
759
        if angles_range is None:
×
760
            angles_range = (-90, 90)
×
761
        for i in range(0, len(self.parameter_id_list)):
×
762
            dum = self.parameter_id_list[i].split('-')
×
763
            name = dum[-1]
×
764
            if name in ['phi','phi_ext']:
×
765
                xlim = g.subplots[i,i].get_xlim()
×
766
                #print(xlim)
767
            
768
                if xlim[0] < -90:
×
769
                    for ax in g.subplots[i:,i]:
×
770
                        ax.set_xlim(left=angles_range[0])
×
771
                    for ax in g.subplots[i,:i]:
×
772
                        ax.set_ylim(bottom=angles_range[0])
×
773
                if xlim[1] > 90:
×
774
                    for ax in g.subplots[i:,i]:
×
775
                        ax.set_xlim(right=angles_range[1])
×
776
                    for ax in g.subplots[i,:i]:
×
777
                        ax.set_ylim(top=angles_range[1])
×
778
        return g
×
779
    
780
    def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, 
×
781
                               legend_ncol=None, legend_fontsize=None, 
782
                               filled_contours=True, linewidth=1,
783
                               marker_size=15, axes_labelsize=None, **subplot_kwargs):
784
        """Array of (2D contours) subplots using getdist.rectangle_plot method.
785

786
        Parameters
787
        ----------
788
        subplot_size : int, optional
789
            Size of the getdist plot, by default 1
790
        filled_contours : bool, optional
791
            Wether or not to fill the 2D contours, by default True
792
        linewidth : int, optional
793
            Line width for 2D contours, by default 1
794
        marker_size : int, optional
795
            Size of the reference (scatter) markers on 2D contours plots, by default 15
796
        legend_ncol : number of columns in the legend
797

798
        Returns
799
        -------
800
        GetDistPlotter
801
            Instance of GetDistPlotter corresponding to the figure
802
        """
803
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
804
        
805
        if legend_ncol is None:
×
806
            legend_ncol = 3
×
807
        # Make the plot
808
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
809
        if legend_fontsize is not None:
×
810
            g.settings.legend_fontsize = legend_fontsize
×
811
        if axes_labelsize is not None:
×
812
            g.settings.axes_labelsize = axes_labelsize
×
813
        g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
×
814
                         filled=filled_contours,
815
                         colors=colors,
816
                         legend_ncol=legend_ncol,
817
                         legend_labels=legend_labels,
818
                         line_args=line_args, 
819
                         contour_colors=self.colors)
820
        for k in range(len(self.ref_values)):
×
821
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
822
            for j, key_x in enumerate(x_param_ids):
×
823
                val_x = self.ref_values_markers[k][key_x]
×
824
                for i, key_y in enumerate(y_param_ids):
×
825
                    val_y = self.ref_values_markers[k][key_y]
×
826
                    if val_x is not None and val_y is not None:
×
827
                        g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
×
828
        return g
×
829
    
830
    def plot_1d_getdist(self, num_columns=None, legend_ncol=None, 
×
831
                        legend_fontsize=None, axes_labelsize=None,
832
                        linewidth=1, **subplot_kwargs):
833
        """Array of 1D histogram subplots using getdist.plots_1d method.
834

835
        Parameters
836
        ----------
837
        subplot_size : int, optional
838
            Size of the getdist plot, by default 1
839
        linewidth : int, optional
840
            Line width for 2D contours, by default 1
841
        marker_size : int, optional
842
            Size of the reference (scatter) markers on 2D contours plots, by default 15
843
        legend_ncol : int, optional
844
            number of columns in the legend
845
        num_columns : int, optional
846
            number of columns of the subplot array
847

848
        Returns
849
        -------
850
        GetDistPlotter
851
            Instance of GetDistPlotter corresponding to the figure
852
        """
853
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
854

855
        if num_columns is None:
×
856
            num_columns = self.num_models//2+1
×
857
        if legend_ncol is None:
×
858
            legend_ncol = 3
×
859
        # Make the plot
860
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
861
        if legend_fontsize is not None:
×
862
            g.settings.legend_fontsize = legend_fontsize
×
863
        if axes_labelsize is not None:
×
864
            g.settings.axes_labelsize = axes_labelsize
×
865
        g.plots_1d(self._mcsamples,
×
866
                   params=self.parameter_id_list,
867
                   legend_labels=legend_labels,
868
                   colors=colors,
869
                   share_y=True,
870
                   line_args=line_args,
871
                   nx=num_columns, legend_ncol=legend_ncol,
872
        )
873
        for k in range(len(self.ref_values)):
×
874
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
875
        # for k in range(0, len(self.ref_values)):
876
        #     # Add vertical and horizontal lines
877
        #     for i in range(0, self.num_params):
878
        #         val = self.ref_values[k][i]
879
        #         ax = g.subplots.flatten()[i]
880
        #         if val is not None:
881
        #             ax.axvline(val, color='black', ls=self.ref_linestyles[k], alpha=1.0, lw=1)
882
        return g
×
883

884
    def plot_source(self, idx_file=0):
×
885
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
886
        return f,ax
×
887
    
888
    def plot_lens(self, idx_file=0):
×
889
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
890
        return f,ax
×
891

892
    def plotting_routine(self, param_dict, idx_file=0):
×
893
        """
894
        plot the parameters
895

896
        INPUT
897
        -----
898
        param_dict: dict, organized dictonnary with all parameters results of the different files
899
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
900
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
901
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
902
         sersic and shapelets parameters when available)
903
        """
904

905
        #find the numer of parameters to plot and define a nice looking figure
906
        number_param = len(param_dict[self.file_names[idx_file]])
×
907
        unused_figs = []
×
908
        if number_param <= 4:
×
909
            print('so few parameters not implemented yet')
×
910
        else:
911
            if number_param % 4 == 0:
×
912
                num_lines = int(number_param / 4.)
×
913
            else:
914
                num_lines = int(number_param / 4.) + 1
×
915

916
                for idx in range(3):
×
917
                    if (number_param + idx) % 4 != 0:
×
918
                        unused_figs.append(-idx - 1)
×
919
                    else:
920
                        break
×
921

922
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
923
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
924
        #may find a better way to define markers but right now, it is sufficient
925

926
        for j, file_name in enumerate(self.file_names):
×
927
            i = 0
×
928
            result = param_dict[file_name]
×
929
            for key in result.keys():
×
930
                idx_line = int(i / 4.)
×
931
                idx_col = i % 4
×
932
                p = result[key]
×
933
                m = markers[j]
×
934
                if self.posterior_bool_list[j]:
×
935
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
936
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
937
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
938
                    #                 i+=1
939
                    #                 continue
940

941
                    #trick to plot correct error bars if close to the +180/-180 edge
942
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
943
                        if p['percentile_16th'] > p['median']:
×
944
                            p['percentile_16th'] -= 180.
×
945
                        if p['percentile_84th'] < p['median']:
×
946
                            p['percentile_84th'] += 180.
×
947
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
948
                                                                    [p['percentile_84th'] - p['median']]],
949
                                                   marker=m, ls='', label=file_name)
950
                else:
951
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
952

953
                if j == 0:
×
954
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
955
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
956
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
957
                i += 1
×
958

959
        ax[0, 0].legend()
×
960
        for idx in unused_figs:
×
961
            ax[-1, idx].axis('off')
×
962
        plt.tight_layout()
×
963
        plt.show()
×
964
        return f, ax
×
965

966
    def _prepare_getdist_plot(self, lw, lw_cont=None, lw_margin=None):
×
967
        if lw_margin is None:
×
968
            lw_margin = lw + 2
×
969
        line_args = [{'ls': ls, 'lw': lw, 'color': c} for ls, c in zip(self.linestyles, self.colors)]
×
970
        lw_conts = [lw_cont]*self.num_models
×
971
        ls_conts = self.linestyles
×
972
        legend_labels = copy.deepcopy(self.coolest_names)
×
973
        colors = copy.deepcopy(self.colors)
×
974
        if self._add_margin_samples:
×
975
            line_args.append({'ls': '-.', 'lw': lw_margin, 'alpha': 0.8, 'color': self._color_margin})
×
976
            ls_conts.append('-.')
×
977
            if lw_cont is not None: lw_conts.append(lw_margin)
×
978
            legend_labels.append(self._label_margin)
×
979
            colors.append(self._color_margin)
×
980
        return line_args, lw_conts, ls_conts, colors, legend_labels
×
981

982
# def plot_corner(parameter_id_list, 
983
#                 chain_objs, chain_dirs, chain_names=None, 
984
#                 point_estimate_objs=None, point_estimate_dirs=None, point_estimate_names=None, 
985
#                 colors=None, labels=None, subplot_size=1, mc_samples_kwargs=None, 
986
#                 filled_contours=True, angles_range=None, shift_sample_list=None):
987
#     """
988
#     Adding this as just a function for the moment.
989
#     Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
990

991
#     Parameters
992
#     ----------
993
#     parameter_id_list : array
994
#         A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
995
#     chain_objs : array
996
#         A list of coolest objects that have a chain file associated to them.
997
#     chain_dirs : array
998
#         A list of paths matching the coolest files in 'chain_objs'.
999
#     chain_names : array, optional
1000
#         A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
1001
#     point_estimate_objs : array, optional
1002
#         A list of coolest objects that will be used as point estimates.
1003
#     point_estimate_dirs : array
1004
#         A list of paths matching the coolest files in 'point_estimate_objs'.
1005
#     point_estimate_names : array, optional
1006
#         A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
1007
#     labels : dict, optional
1008
#         A dictionary matching the parameter_id_list entries to some human-readable labels.
1009

1010
#     Returns
1011
#     -------
1012
#     An image
1013
#     """
1014

1015
#     chains.print_load_details = False # Just to silence messages
1016
#     parameter_id_set = set(parameter_id_list)
1017
#     Npars = len(parameter_id_list)
1018
#     Nobjs = len(chain_objs)
1019
    
1020
#     # Set the chain names
1021
#     if chain_names is None:
1022
#         chain_names = ["chain "+str(i) for i in range(Nobjs)]
1023
    
1024
#     if shift_sample_list is None:
1025
#         shift_sample_list = [None]*Nobjs
1026
    
1027
#     # Get the values of the point_estimates
1028
#     point_estimates = []
1029
#     if point_estimate_objs is not None:
1030
#         for coolest_obj in point_estimate_objs:
1031
#             values = []
1032
#             for par in parameter_id_list:
1033
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par)
1034
#                 val = param.point_estimate.value
1035
#                 if val is None:
1036
#                     values.append(None)
1037
#                 else:
1038
#                     values.append(val)
1039
#             point_estimates.append(values)
1040

1041

1042
            
1043
#     mcsamples = []
1044
#     for i in range(Nobjs):
1045
#         chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
1046

1047
#         # Each chain file can have a different number of free parameters
1048
#         f = open(chain_file)
1049
#         header = f.readline()
1050
#         f.close()
1051

1052
#         if ';' in header:
1053
#             raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
1054

1055
#         chain_file_headers = header.split(',')
1056
#         num_cols = len(chain_file_headers)
1057
#         chain_file_headers.pop() # Remove the last column name that is the probability weights
1058
#         chain_file_headers_set = set(chain_file_headers)
1059
        
1060
#         # Check that the given parameters are a subset of those in the chain file
1061
#         assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
1062

1063
#         # Set the labels for the parameters in the chain file
1064
#         par_labels = []
1065
#         if labels is None:
1066
#             labels = {}
1067
#         for par_id in parameter_id_list:
1068
#             if labels.get(par_id, None) is None:
1069
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
1070
#                 par_labels.append(param.latex_str.strip('$'))
1071
#             else:
1072
#                 par_labels.append(labels[par_id])
1073
                    
1074
#         # Read parameter values and probability weights
1075
#         column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
1076
#         columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
1077
#         samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
1078
    
1079
#         # Re-order columnds to match parameter_id_list and par_labels
1080
#         sample_par_values = np.array(samples[parameter_id_list])
1081

1082
#         # If needed, shift samples by a constant
1083
#         if shift_sample_list[i] is not None:
1084
#             for param_id, value in shift_sample_list[i].items():
1085
#                 sample_par_values[:, parameter_id_list.index(param_id)] += value
1086
#                 print(f"INFO: posterior for parameter '{param_id}' from model '{chain_names[i]}' "
1087
#                       f"has been shifted by {value}.")
1088

1089
#         # Clean-up the probability weights
1090
#         mypost = np.array(samples['probability_weights'])
1091
#         min_non_zero = np.min(mypost[np.nonzero(mypost)])
1092
#         sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
1093
#         #sample_prob_weight = mypost
1094

1095
#         # Create MCSamples object
1096
#         mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
1097
#         mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
1098
#         mcsamples.append(mysample)
1099

1100

1101
        
1102
#     # Make the plot
1103
#     image = plots.getSubplotPlotter(subplot_size=subplot_size)    
1104
#     image.triangle_plot(mcsamples,
1105
#                         params=parameter_id_list,
1106
#                         legend_labels=chain_names,
1107
#                         filled=filled_contours,
1108
#                         colors=colors,
1109
#                         line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors], 
1110
#                         contour_colors=colors)
1111

1112

1113
#     my_linestyles = ['solid','dotted','dashed','dashdot']
1114
#     my_markers    = ['s','^','o','star']
1115

1116
#     for k in range(0,len(point_estimates)):
1117
#         # Add vertical and horizontal lines
1118
#         for i in range(0,Npars):
1119
#             val = point_estimates[k][i]
1120
#             if val is not None:
1121
#                 for ax in image.subplots[i:,i]:
1122
#                     ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1123
#                 for ax in image.subplots[i,:i]:
1124
#                     ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1125

1126
#         # Add points
1127
#         for i in range(0,Npars):
1128
#             val_x = point_estimates[k][i]
1129
#             for j in range(i+1,Npars):
1130
#                 val_y = point_estimates[k][j]
1131
#                 if val_x is not None and val_y is not None:
1132
#                     image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
1133
#                 else:
1134
#                     pass    
1135

1136

1137
#     # Set default ranges for angles
1138
#     if angles_range is None:
1139
#         angles_range = (-90, 90)
1140
#     for i in range(0,len(parameter_id_list)):
1141
#         dum = parameter_id_list[i].split('-')
1142
#         name = dum[-1]
1143
#         if name in ['phi','phi_ext']:
1144
#             xlim = image.subplots[i,i].get_xlim()
1145
#             #print(xlim)
1146
        
1147
#             if xlim[0] < -90:
1148
#                 for ax in image.subplots[i:,i]:
1149
#                     ax.set_xlim(left=angles_range[0])
1150
#                 for ax in image.subplots[i,:i]:
1151
#                     ax.set_ylim(bottom=angles_range[0])
1152
#             if xlim[1] > 90:
1153
#                 for ax in image.subplots[i:,i]:
1154
#                     ax.set_xlim(right=angles_range[1])
1155
#                 for ax in image.subplots[i,:i]:
1156
#                     ax.set_ylim(top=angles_range[1])
1157

1158
            
1159
#     return image
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc