• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 22070395133

16 Feb 2026 04:25PM UTC coverage: 45.463% (+0.05%) from 45.411%
22070395133

Pull #76

github

web-flow
Merge 226ee2e61 into 52f3a30ef
Pull Request #76: Add support for lens light component and fully-defined pixelated lens model, various other improvements to the analysis and plottine engines

96 of 234 new or added lines in 11 files covered. (41.03%)

5 existing lines in 3 files now uncovered.

1498 of 3295 relevant lines covered (45.46%)

0.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots, chains, MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18
import pandas as pd
×
19

20

21
# matplotlib global settings
22
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
23

24
# logging settings
25
logging.getLogger().setLevel(logging.INFO)
×
26

27
# TODO: separate ParametersPlotter from ModelPlotter to avoid dependencies on getdist
28

29
__all__ = [
×
30
    'ModelPlotter',
31
    'MultiModelPlotter',
32
    'ParametersPlotter',
33
]
34

35
class ModelPlotter(object):
×
36
    """Create pyplot panels from a lens model stored in the COOLEST format.
37

38
    Parameters
39
    ----------
40
    coolest_object : COOLEST
41
        COOLEST instance
42
    coolest_directory : str, optional
43
        Directory which contains the COOLEST template, by default None
44
    color_bad_values : str, optional
45
        Color assigned to NaN values (typically negative values in log-scale), 
46
        by default '#111111' (dark gray)
47
    """
48

49
    def __init__(self, coolest_object, coolest_directory=None, 
×
50
                 color_bad_values='#222222'):
51
        self.coolest = coolest_object
×
52
        self._directory = coolest_directory
×
53

54
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
55
        self.cmap_flux.set_bad(color_bad_values)
×
56

57
        self.cmap_mag = plt.get_cmap('viridis')
×
58
        self.cmap_conv = plt.get_cmap('cividis')
×
59
        self.cmap_res = plt.get_cmap('RdBu_r')
×
60

61
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
62
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
63
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
64

65
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, xylim=None,
×
66
                        neg_values_as_bad=False, add_colorbar=True, 
67
                        add_scalebar=True, scalebar_size=1):
68
        """plt.imshow panel with the data image"""
69
        if cmap is None:
×
70
            cmap = self.cmap_flux
×
71
        coordinates = util.get_coordinates(self.coolest)
×
72
        extent = coordinates.plt_extent
×
73
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
74
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
75
                                cmap=cmap, norm=norm,
76
                                neg_values_as_bad=neg_values_as_bad, 
77
                                xylim=xylim)
78
        if add_colorbar:
×
79
            cb = plut.nice_colorbar(im, ax=ax)
×
80
            cb.set_label("flux")
×
81
        if add_scalebar:
×
82
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
83
        return image
×
84

85
    def plot_surface_brightness(self, ax, title = None, coordinates=None,
×
86
                                extent_irreg=None, norm=None, cmap=None, 
87
                                xylim=None, neg_values_as_bad=True,
88
                                plot_points_irreg=False, add_colorbar=True, 
89
                                add_scalebar=False, scalebar_size=0.4,
90
                                kwargs_light=None,
91
                                plot_caustics=None, caustics_color='white', caustics_alpha=0.5,
92
                                coordinates_lens=None, kwargs_lens_mass=None):
93
        """plt.imshow panel showing the surface brightness of the (unlensed)
94
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
95
        if extent_irreg is not None:
×
96
            raise ValueError("`extent_irreg` is deprecated; use `xylim` instead.")
×
97
        if kwargs_light is None:
×
98
            kwargs_light = {}
×
99
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
100
        if plot_caustics:
×
101
            if kwargs_lens_mass is None:
×
102
                raise ValueError("`kwargs_lens_mass` must be provided to compute caustics")
×
103
            if coordinates_lens is None:
×
104
                coordinates_lens = util.get_coordinates(self.coolest).create_new_coordinates(pixel_scale_factor=0.1)
×
105
            # NOTE: here we assume that `kwargs_light` is for the source!
106
            mass_model = ComposableMassModel(self.coolest, self._directory, **kwargs_lens_mass)
×
107
            _, caustics = util.find_all_lens_lines(coordinates_lens, mass_model)
×
108
        if cmap is None:
×
109
            cmap = self.cmap_flux
×
110
        if coordinates is not None:
×
111
            x, y = coordinates.pixel_coordinates
×
112
            image = light_model.evaluate_surface_brightness(x, y)
×
113
            extent = coordinates.plt_extent
×
114
            ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, cmap=cmap,
×
115
                                             neg_values_as_bad=neg_values_as_bad, 
116
                                             norm=norm, xylim=xylim)
117
        else:
118
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
119
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
120
                image = values
×
121
                ax, im = plut.plot_regular_grid(ax, title, image, extent=extent_model, 
×
122
                                        cmap=cmap, 
123
                                        neg_values_as_bad=neg_values_as_bad,
124
                                        norm=norm, xylim=xylim)
125
            else:
126
                points = values
×
127
                if xylim is None:
×
128
                    xylim = extent_model
×
129
                ax, im = plut.plot_irregular_grid(ax, title, points, xylim, norm=norm, cmap=cmap, 
×
130
                                                   neg_values_as_bad=neg_values_as_bad,
131
                                                   plot_points=plot_points_irreg)
132
                image = None
×
133
        if plot_caustics:
×
134
            for caustic in caustics:
×
135
                ax.plot(caustic[0], caustic[1], lw=1, color=caustics_color, alpha=caustics_alpha)
×
136
        if add_colorbar:
×
137
            cb = plut.nice_colorbar(im, ax=ax)
×
138
            cb.set_label("flux")
×
139
        if add_scalebar:
×
140
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
141
        return image, coordinates
×
142

NEW
143
    def plot_model_image(self, ax, title=None,
×
144
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
145
                         add_colorbar=True, add_scalebar=True, scalebar_size=1, 
146
                         auto_selection=False,
147
                         kwargs_lens_mass=None, 
148
                         kwargs_lens_light=None, 
149
                         kwargs_source=None,
150
                         **model_image_kwargs):
151
        """plt.imshow panel showing the surface brightness of the (lensed)
152
        selected lensing entities (see ComposableLensModel docstring)
153
        """
154
        if cmap is None:
×
155
            cmap = self.cmap_flux
×
NEW
156
        lens_model = ComposableLensModel(
×
157
            self.coolest, self._directory,
158
            auto_selection=auto_selection,
159
            kwargs_selection_source=kwargs_source,
160
            kwargs_selection_lens_mass=kwargs_lens_mass,
161
            kwargs_selection_lens_light=kwargs_lens_light
162
        )
163
        image, coordinates = lens_model.model_image(**model_image_kwargs)
×
164
        extent = coordinates.plt_extent
×
165
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
166
                                cmap=cmap,
167
                                neg_values_as_bad=neg_values_as_bad, 
168
                                norm=norm, xylim=xylim)
169
        if add_colorbar:
×
170
            cb = plut.nice_colorbar(im, ax=ax)
×
171
            cb.set_label("flux")
×
172
        if add_scalebar:
×
173
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
174
        return image
×
175

176
    def plot_model_residuals(self, ax, title = None, mask=None,
×
177
                             norm=None, cmap=None, xylim=None, add_chi2_label=False, chi2_fontsize=12,
178
                             kwargs_source=None, 
179
                             kwargs_lens_mass=None,
180
                             kwargs_lens_light=None,
181
                             add_colorbar=True, add_scalebar=True, scalebar_size=1,
182
                             **model_image_kwargs):
183
        """plt.imshow panel showing the normalized model residuals image"""
184
        if cmap is None:
×
185
            cmap = self.cmap_res
×
186
        if norm is None:
×
187
            norm = Normalize(-6, 6)
×
188
        ll_mask = self._get_likelihood_mask(mask)
×
189
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
190
                                         kwargs_selection_source=kwargs_source,
191
                                         kwargs_selection_lens_mass=kwargs_lens_mass,
192
                                         kwargs_selection_lens_light=kwargs_lens_light)
193
        image, coordinates = lens_model.model_residuals(mask=ll_mask, **model_image_kwargs)
×
194
        extent = coordinates.plt_extent
×
195
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
196
                                cmap=cmap,
197
                                neg_values_as_bad=False, 
198
                                norm=norm, xylim=xylim)
199
        if add_colorbar:
×
200
            cb = plut.nice_colorbar(im, ax=ax)
×
201
            cb.set_label("(data $-$ model) / noise")
×
202
        if add_scalebar:
×
203
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
204
        if add_chi2_label is True:
×
205
            num_constraints = np.size(image) if ll_mask is None else np.sum(ll_mask)
×
206
            red_chi2 = np.sum(image**2) / num_constraints
×
207
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
208
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
209
                    bbox={'color': 'white', 'alpha': 0.6})
210
        return image
×
211

212
    def plot_convergence(self, ax, title = None, coordinates=None,
×
213
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
214
                         add_colorbar=True, 
215
                         add_scalebar=True, scalebar_size=1, 
216
                         kwargs_lens_mass=None):
217
        """plt.imshow panel showing the 2D convergence map associated to the
218
        selected lensing entities (see ComposableMassModel docstring)
219
        """
220
        if kwargs_lens_mass is None:
×
221
            kwargs_lens_mass = {}
×
222
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
223
                                         **kwargs_lens_mass)
224
        if cmap is None:
×
225
            cmap = self.cmap_conv
×
226
        if coordinates is None:
×
227
            coordinates = util.get_coordinates(self.coolest)
×
228
        extent = coordinates.plt_extent
×
229
        x, y = coordinates.pixel_coordinates
×
230
        image = mass_model.evaluate_convergence(x, y)
×
231
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
232
                                cmap=cmap,
233
                                neg_values_as_bad=neg_values_as_bad, 
234
                                norm=norm, xylim=xylim)
235
        if add_colorbar:
×
236
            cb = plut.nice_colorbar(im, ax=ax)
×
237
            cb.set_label(r"$\kappa$")
×
238
        if add_scalebar:
×
239
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
240
        return image
×
241
    
242
    def plot_convergence_diff(
×
243
            self, ax, reference_map, title = None, relative_error=True,    
244
            norm=None, cmap=None, xylim=None, coordinates=None,
245
            add_colorbar=True, add_scalebar=True, scalebar_size=1, 
246
            kwargs_lens_mass=None,
247
            plot_crit_lines=False, crit_lines_color='black', crit_lines_alpha=0.5):
248
        """plt.imshow panel showing the 2D convergence map associated to the
249
        selected lensing entities (see ComposableMassModel docstring)
250
        """
251
        if kwargs_lens_mass is None:
×
252
            kwargs_lens_mass = {}
×
253
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
254
                                         **kwargs_lens_mass)
255
        if cmap is None:
×
256
            cmap = self.cmap_res
×
257
        if norm is None:
×
258
            norm = Normalize(-1, 1)
×
259
        if coordinates is None:
×
260
            coordinates = util.get_coordinates(self.coolest)
×
261
        if plot_crit_lines:
×
262
            critical_lines, _ = util.find_all_lens_lines(coordinates, mass_model)
×
263
        extent = coordinates.plt_extent
×
264
        x, y = coordinates.pixel_coordinates
×
265
        image = mass_model.evaluate_convergence(x, y)
×
266
        if relative_error is True:
×
267
            diff = (reference_map - image) / reference_map
×
268
        else:
269
            diff = reference_map - image
×
270
        ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent, 
×
271
                                cmap=cmap, 
272
                                norm=norm, xylim=xylim)
273
        if plot_crit_lines:
×
274
            for cline in critical_lines:
×
275
                ax.plot(cline[0], cline[1], lw=1, color=crit_lines_color, alpha=crit_lines_alpha)
×
276
        if add_colorbar:
×
277
            cb = plut.nice_colorbar(im, ax=ax)
×
278
            cb.set_label(r"$\kappa$")
×
279
        if add_scalebar:
×
280
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
281
        return image
×
282

283
    def plot_magnification(self, ax, title = None,
×
284
                          norm=None, cmap=None, xylim=None,
285
                          add_colorbar=True, add_scalebar=True, scalebar_size=1, 
286
                          coordinates=None, kwargs_lens_mass=None):
287
        """plt.imshow panel showing the 2D magnification map associated to the
288
        selected lensing entities (see ComposableMassModel docstring)
289
        """
290
        if kwargs_lens_mass is None:
×
291
            kwargs_lens_mass = {}
×
292
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
293
                                         **kwargs_lens_mass)
294
        if cmap is None:
×
295
            cmap = self.cmap_mag
×
296
        if norm is None:
×
297
            norm = Normalize(-10, 10)
×
298
        if coordinates is None:
×
299
            coordinates = util.get_coordinates(self.coolest)
×
300
        x, y = coordinates.pixel_coordinates
×
301
        extent = coordinates.plt_extent
×
302
        image = mass_model.evaluate_magnification(x, y)
×
303
        ax, im = plut.plot_regular_grid(ax, title, image, extent=extent, 
×
304
                                cmap=cmap, 
305
                                norm=norm, xylim=xylim)
306
        if add_colorbar:
×
307
            cb = plut.nice_colorbar(im, ax=ax)
×
308
            cb.set_label(r"$\mu$")
×
309
        if add_scalebar:
×
310
            plut.scale_bar(ax, scalebar_size, color='white', loc='lower right')
×
311
        return image
×
312

313
    def plot_magnification_diff(
×
314
            self, ax, reference_map, title = None, relative_error=True,
315
            norm=None, cmap=None, xylim=None,
316
            add_colorbar=True, add_scalebar=True, scalebar_size=1, 
317
            coordinates=None, kwargs_lens_mass=None):
318
        """plt.imshow panel showing the (absolute or relative) 
319
        difference between 2D magnification maps
320
        """
321
        if kwargs_lens_mass is None:
×
322
            kwargs_lens_mass = {}
×
323
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
324
                                        **kwargs_lens_mass)
325
        if cmap is None:
×
326
            cmap = self.cmap_res
×
327
        if norm is None:
×
328
            norm = Normalize(-1, 1)
×
329
        if coordinates is None:
×
330
            coordinates = util.get_coordinates(self.coolest)
×
331
        x, y = coordinates.pixel_coordinates
×
332
        extent = coordinates.plt_extent
×
333
        image = mass_model.evaluate_magnification(x, y)
×
334
        if relative_error is True:
×
335
            diff = (reference_map - image) / reference_map
×
336
        else:
337
            diff = reference_map - image
×
338
        ax, im = plut.plot_regular_grid(ax, title, diff, extent=extent, 
×
339
                                cmap=cmap,
340
                                norm=norm, xylim=xylim)
341
        if add_colorbar:
×
342
            cb = plut.nice_colorbar(im, ax=ax)
×
343
            cb.set_label(r"$\mu$")
×
344
        if add_scalebar:
×
345
            plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
×
346
        return image
×
347

348
    def _get_likelihood_mask(self, user_mask):
×
349
        # TODO: 
350
        if self.coolest.likelihoods is None:
×
351
            return None
×
352
        try:
×
353
            img_ll_idx = self.coolest.likelihoods.index('ImagingDataLikelihood')
×
354
        except ValueError:
×
355
            return None
×
356
        img_ll = self.coolest.likelihoods[img_ll_idx]
×
357
        mask = img_ll.get_mask_pixels(directory=self._directory)
×
358
        if mask is None:  # then we use the user-provided mask
×
359
            mask = user_mask
×
360
        return mask
×
361

362

363
class MultiModelPlotter(object):
×
364
    """Wrapper around a set of ModelPlotter instances to produce panels that
365
    consistently compare different models, evaluated on the same
366
    coordinates systems.
367

368
    Parameters
369
    ----------
370
    coolest_objects : list
371
        List of COOLEST instances
372
    coolest_directories : list, optional
373
        List of directories corresponding to each COOLEST instance, by default None
374
    kwargs_plotter : dict, optional
375
        Additional keyword arguments passed to ModelPlotter
376
    """
377

378
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
379
        self.num_models = len(coolest_objects)
×
380
        if coolest_directories is None:
×
381
            coolest_directories = self.num_models * [None]
×
382
        self.plotter_list = []
×
383
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
384
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
385
                                                  **kwargs_plotter))
386

387
    def plot_surface_brightness(self, axes, **kwargs):
×
388
        return self._plot_light_multi('plot_surface_brightness',axes, **kwargs)
×
389

390
    def plot_data_image(self, axes, **kwargs):
×
391
        return self._plot_data_multi(axes, **kwargs)
×
392

393
    def plot_model_image(self, axes, **kwargs):
×
394
        return self._plot_lens_model_multi('plot_model_image', axes, **kwargs)
×
395

396
    def plot_model_residuals(self, axes, **kwargs):
×
397
        return self._plot_lens_model_multi('plot_model_residuals', axes, **kwargs)
×
398

399
    def plot_convergence(self, axes, **kwargs):
×
400
        return self._plot_lens_model_multi('plot_convergence', axes, **kwargs)
×
401

402
    def plot_magnification(self, axes, **kwargs):
×
403
        return self._plot_lens_model_multi('plot_magnification', axes, **kwargs)
×
404

405
    def plot_convergence_diff(self, axes, *args, **kwargs):
×
406
        return self._plot_lens_model_multi('plot_convergence_diff', axes, *args, **kwargs)
×
407

408
    def plot_magnification_diff(self, axes, *args, **kwargs):
×
409
        return self._plot_lens_model_multi('plot_magnification_diff', axes, *args, **kwargs)
×
410

411
    def _plot_light_multi(self, method_name, axes, **kwargs):
×
412
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
413
        kwargs_ = copy.deepcopy(kwargs)
×
414
        if 'titles' in kwargs_:
×
415
                del kwargs_['titles']
×
416
        image_list = []
×
417
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
418
            if ax is None:
×
419
                continue
×
420
            if 'kwargs_light' in kwargs:
×
421
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
422
            if 'kwargs_lens_mass' in kwargs:  # used for over-plotting caustics
×
423
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
424
            if 'titles' in kwargs:
×
425
                title = kwargs['titles'][i]
×
426
            image = getattr(plotter, method_name)(ax, title, **kwargs_)
×
427
            image_list.append(image)
×
428
        return image_list
×
429

430
    def _plot_mass_multi(self, method_name, axes, **kwargs):
×
431
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
432
        kwargs_ = copy.deepcopy(kwargs)
×
433
        if 'titles' in kwargs_:
×
434
                del kwargs_['titles']
×
435
        image_list = []
×
436
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
437
            if ax is None:
×
438
                continue
×
439
            if 'kwargs_lens_mass' in kwargs:
×
440
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
441
            if 'titles' in kwargs:
×
442
                title = kwargs['titles'][i]
×
443
            image = getattr(plotter, method_name)(ax, title, **kwargs_)
×
444
            image_list.append(image)
×
445
        return image_list
×
446

447
    def _plot_lens_model_multi(self, method_name, axes, *args, **kwargs):
×
448
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
449
        kwargs_ = copy.deepcopy(kwargs)
×
450
        if 'titles' in kwargs_:
×
451
                del kwargs_['titles']
×
452
        image_list = []
×
453
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
454
            if ax is None:
×
455
                continue
×
456
            if 'kwargs_source' in kwargs:
×
457
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
458
            if 'kwargs_lens_mass' in kwargs:
×
459
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
NEW
460
            if 'kwargs_lens_light' in kwargs:
×
NEW
461
                kwargs_['kwargs_lens_light'] = {k: v[i] for k, v in kwargs['kwargs_lens_light'].items()}
×
462
            if 'titles' in kwargs:
×
463
                title = kwargs['titles'][i]
×
464
            image = getattr(plotter, method_name)(ax, title, *args, **kwargs_)
×
465
            image_list.append(image)
×
466
        return image_list
×
467

468
    def _plot_data_multi(self, axes, **kwargs):
×
469
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
470
        kwargs_ = copy.deepcopy(kwargs)
×
471
        if 'titles' in kwargs_:
×
472
                del kwargs_['titles']
×
473
        image_list = []
×
474
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
475
            if ax is None:
×
476
                continue
×
477
            if 'titles' in kwargs:
×
478
                title = kwargs['titles'][i]
×
479
            image = getattr(plotter, 'plot_data_image')(ax, title, **kwargs_)
×
480
            image_list.append(image)
×
481
        return image_list
×
482

483

484
class ParametersPlotter(object):
×
485
    """Handles plot of analytical models in a comparative way
486

487
    Parameters
488
    ----------
489
    parameter_id_list : array
490
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
491
    coolest_objects : array
492
        A list of coolest objects that have a chain file associated to them.
493
    coolest_directories : array
494
        A list of paths matching the coolest files in 'chain_objs'.
495
    coolest_names : array, optional
496
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
497
    ref_coolest_objects : array, optional
498
        A list of coolest objects that will be used as point estimates.
499
    ref_coolest_directories : array
500
        A list of paths matching the coolest files in 'point_estimate_objs'.
501
    ref_coolest_names : array, optional
502
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
503
    posterior_bool_list : list, optional
504
        List of bool to toggle errorbars on point-estimate values
505
    colors : list, optional
506
        List of pyplot color names to associate to each coolest model.
507
    linestyles : list, optional
508
        List of pyplot linesyles to associate to each coolest model.
509
    add_multivariate_margin_samples : bool, optional
510
        If True, will append to the list of compared models
511
        a new chain that is resampled from the multi-variate normal distribution,
512
        where its covariance matrix is computed from the marginalization of
513
        all samples from all models. By default False. 
514
    num_samples_per_model_margin : int, optional
515
        Number of samples to (randomly) draw from each model samples to concatenate
516
        before estimating the multi-variate normal marginalization.
517
    """
518

519
    np.random.seed(598237)  # fix the random seed for reproducibility
×
520
    
521
    def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
×
522
                 ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
523
                 posterior_bool_list=None, colors=None, linestyles=None,
524
                 add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
525
        self.parameter_id_list = parameter_id_list
×
526
        self.coolest_objects = coolest_objects
×
527
        self.coolest_directories = coolest_directories
×
528
        if coolest_names is None:
×
529
            coolest_names = ["Model "+str(i) for i in range(len(coolest_objects))]
×
530
        self.coolest_names = coolest_names
×
531
        self.ref_coolest_objects = ref_coolest_objects
×
532
        self.ref_coolest_directories = ref_coolest_directories
×
533
        self.ref_coolest_names = ref_coolest_names
×
534
        self.ref_file_names = ref_coolest_names
×
535

536
        self.num_models = len(self.coolest_objects)
×
537
        self.num_params = len(self.parameter_id_list)
×
538
        if colors is None:
×
539
            colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
×
540
        self.colors = colors
×
541
        if linestyles is None:
×
542
            linestyles = ['-']*self.num_models
×
543
        self.linestyles = linestyles
×
544
        self.ref_linestyles = ['--', ':', '-.', '-']
×
545
        self.ref_markers = ['s', '^', 'o', '*']
×
546

547
        self._add_margin_samples = add_multivariate_margin_samples
×
548
        self._ns_per_model_margin = num_samples_per_model_margin
×
549
        self._color_margin = 'black'
×
550
        self._label_margin = "Combined"
×
551

552
        # self.posterior_bool_list = posterior_bool_list
553
        # self.param_lens, self.param_source = util.split_lens_source_params(
554
        #     self.coolest_objects, self.coolest_names, lens_light=False)
555

556
    def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
×
557
                     add_multivariate_margin_samples=False):
558
        """Initializes the getdist plotter.
559

560
        Parameters
561
        ----------
562
        shift_sample_list : dict
563
            Dictionary keyed by parameter ID to apply a uniform additive shift to
564
            all samples of that parameters posterior distribution.
565
        settings_mcsamples : dict, optional
566
            Keyword arguments passed as the `settings` argument of getdist.MCSamples, by default None
567

568
        Raises
569
        ------
570
        ValueError
571
            If the csv file containing samples is is not coma (,) separated.
572
        """
573
        chains.print_load_details = False # Just to silence messages
×
574
        parameter_id_set = set(self.parameter_id_list)
×
575

576
        if shift_sample_list is None:
×
577
            shift_sample_list = [None]*self.num_models
×
578
        
579
        # Get the values of the point_estimates
580
        point_estimates = []
×
581
        if self.ref_coolest_objects is not None:
×
582
            for coolest_obj in self.ref_coolest_objects:
×
583
                values = []
×
584
                for par in self.parameter_id_list:
×
585
                    param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
586
                    val = param.point_estimate.value
×
587
                    if val is None:
×
588
                        values.append(None)
×
589
                    else:
590
                        values.append(val)
×
591
                point_estimates.append(values)
×
592

593
        mcsamples = []
×
594
        samples_margin, weights_margin = None, None
×
595
        mysample_margin = None
×
596
        for i in range(self.num_models):
×
597
            chain_file = os.path.join(self.coolest_directories[i],self.coolest_objects[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
598

599
            # Each chain file can have a different number of free parameters
600
            f = open(chain_file)
×
601
            header = f.readline()
×
602
            f.close()
×
603

604
            if ';' in header:
×
605
                raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
×
606

607
            chain_file_headers = header.split(',')
×
608
            num_cols = len(chain_file_headers)
×
609
            chain_file_headers.pop() # Remove the last column name that is the probability weights
×
610
            chain_file_headers_set = set(chain_file_headers)
×
611
            
612
            # Check that the given parameters are a subset of those in the chain file
613
            assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
614

615
            # Set the labels for the parameters in the chain file
616
            labels = []
×
617
            for par_id in self.parameter_id_list:
×
618
                param = self.coolest_objects[i].lensing_entities.get_parameter_from_id(par_id)
×
619
                labels.append(param.latex_str.strip('$'))
×
620

621
            # Read parameter values and probability weights
622
            column_indices = [chain_file_headers.index(par_id) for par_id in self.parameter_id_list]
×
623
            columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
×
624
            samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
×
625
        
626
            # Re-order columns to match self.parameter_id_list and labels
627
            sample_par_values = np.array(samples[self.parameter_id_list])
×
628

629
            # If needed, shift samples by a constant
630
            if shift_sample_list[i] is not None:
×
631
                for param_id, value in shift_sample_list[i].items():
×
632
                    sample_par_values[:, self.parameter_id_list.index(param_id)] += value
×
633
                    logging.info(f"posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
×
634
                                 f"has been shifted by {value}.")
635

636
            # Clean-up the probability weights
637
            mypost = np.array(samples['probability_weights'])
×
638
            min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
639
            sample_prob_weight = np.where(mypost<min_non_zero, min_non_zero, mypost)
×
640
            #sample_prob_weight = mypost
641

642
            # Create MCSamples object
643
            mysample = MCSamples(samples=sample_par_values, names=self.parameter_id_list,
×
644
                                 labels=labels, settings=settings_mcsamples)
645
            mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
646
            mcsamples.append(mysample)
×
647

648
            # if required, aggregate the samples in a "marginalized" posterior
649
            if self._add_margin_samples:
×
650
                if i == 0:
×
651
                    mysample_margin = copy.deepcopy(mysample)
×
652
                else:
653
                    # combine the sample such that the probability mass of each set of samples is the same
654
                    mysample_margin = mysample_margin.getCombinedSamplesWithSamples(mysample, sample_weights=(1, 1))
×
655
        
656
        if self._add_margin_samples:
×
657
            mcsamples.append(mysample_margin)
×
658

659
        self._mcsamples = mcsamples
×
660
        self.ref_values = point_estimates
×
661
        self.ref_values_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
×
662

663
    def get_mcsamples_getdist(self, with_margin=False):
×
664
        if not self._add_margin_samples or with_margin:
×
665
            return self._mcsamples
×
666
        else:
667
            return self._mcsamples[:-1]
×
668
    
669
    def get_margin_mcsamples_getdist(self):
×
670
        if not self._add_margin_samples:
×
671
            return None
×
672
        else:
673
            return self._mcsamples[-1]
×
674
    
675
    def plot_triangle_getdist(self, filled_contours=True, angles_range=None, 
×
676
                              linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
677
                              marker_linewidth=2, marker_size=15, 
678
                              axes_labelsize=None, legend_fontsize=None,
679
                              **subplot_kwargs):
680
        """Corner array of subplots using getdist.triangle_plot method.
681

682
        Parameters
683
        ----------
684
        subplot_size : int, optional
685
            Size of the getdist plot, by default 1
686
        filled_contours : bool, optional
687
            Wether or not to fill the 2D contours, by default True
688
        angles_range : _type_, optional
689
            Restrict the range of angle (containing 'phi' in their name) parameters, by default None
690
        linewidth_hist : int, optional
691
            Line width for 1D histograms, by default 2
692
        linewidth_cont : int, optional
693
            Line width for 2D contours, by default 1
694
        marker_size : int, optional
695
            Size of the reference (scatter) markers on 2D contours plots, by default 15
696

697
        Returns
698
        -------
699
        GetDistPlotter
700
            Instance of GetDistPlotter corresponding to the figure
701
        """
702
        line_args, contour_lws, contour_ls, colors, legend_labels \
×
703
            = self._prepare_getdist_plot(linewidth_hist, 
704
                                         lw_cont=linewidth_cont, 
705
                                         lw_margin=linewidth_margin)
706
        
707
        filled_contours = [filled_contours]*len(self._mcsamples)
×
708
        alphas = [1]*len(self._mcsamples)
×
709
        if self._add_margin_samples:
×
710
            filled_contours[-1] = True
×
711
            # alphas[-1] = 0.7
712
    
713
        # Make the plot
714
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
715
        if legend_fontsize is not None:
×
716
            g.settings.legend_fontsize = legend_fontsize 
×
717
        if axes_labelsize is not None:
×
718
            g.settings.axes_labelsize = axes_labelsize 
×
719
        g.triangle_plot(
×
720
            self._mcsamples,
721
            params=self.parameter_id_list,
722
            legend_labels=legend_labels,
723
            filled=filled_contours,
724
            colors=colors,
725
            line_args=line_args,   # TODO: issue that linewidth settings in line_args are being overwritten by contour_lws
726
            contour_colors=self.colors,
727
            contour_lws=contour_lws,
728
            contour_ls=contour_ls,
729
            alphas=alphas,
730
        )
731
        
732
        # Add marker lines and points
733
        for k in range(0, len(self.ref_values)):
×
734
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], 
×
735
                                lw=marker_linewidth)
736
            for i in range(0,self.num_params):
×
737
                val_x = self.ref_values[k][i]
×
738
                for j in range(i+1,self.num_params):
×
739
                    val_y = self.ref_values[k][j]
×
740
                    if val_x is not None and val_y is not None:
×
741
                        g.subplots[j,i].scatter(val_x, val_y, s=marker_size, facecolors='black',
×
742
                                                color='black', marker=self.ref_markers[k])
743

744

745
        # Set default ranges for angles
746
        if angles_range is None:
×
747
            angles_range = (-90, 90)
×
748
        for i in range(0, len(self.parameter_id_list)):
×
749
            dum = self.parameter_id_list[i].split('-')
×
750
            name = dum[-1]
×
751
            if name in ['phi','phi_ext']:
×
752
                xlim = g.subplots[i,i].get_xlim()
×
753
                #print(xlim)
754
            
755
                if xlim[0] < -90:
×
756
                    for ax in g.subplots[i:,i]:
×
757
                        ax.set_xlim(left=angles_range[0])
×
758
                    for ax in g.subplots[i,:i]:
×
759
                        ax.set_ylim(bottom=angles_range[0])
×
760
                if xlim[1] > 90:
×
761
                    for ax in g.subplots[i:,i]:
×
762
                        ax.set_xlim(right=angles_range[1])
×
763
                    for ax in g.subplots[i,:i]:
×
764
                        ax.set_ylim(top=angles_range[1])
×
765
        return g
×
766
    
767
    def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, 
×
768
                               legend_ncol=None, legend_fontsize=None, 
769
                               filled_contours=True, linewidth=1,
770
                               marker_size=15, axes_labelsize=None, **subplot_kwargs):
771
        """Array of (2D contours) subplots using getdist.rectangle_plot method.
772

773
        Parameters
774
        ----------
775
        subplot_size : int, optional
776
            Size of the getdist plot, by default 1
777
        filled_contours : bool, optional
778
            Wether or not to fill the 2D contours, by default True
779
        linewidth : int, optional
780
            Line width for 2D contours, by default 1
781
        marker_size : int, optional
782
            Size of the reference (scatter) markers on 2D contours plots, by default 15
783
        legend_ncol : number of columns in the legend
784

785
        Returns
786
        -------
787
        GetDistPlotter
788
            Instance of GetDistPlotter corresponding to the figure
789
        """
790
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
791
        
792
        if legend_ncol is None:
×
793
            legend_ncol = 3
×
794
        # Make the plot
795
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
796
        if legend_fontsize is not None:
×
797
            g.settings.legend_fontsize = legend_fontsize
×
798
        if axes_labelsize is not None:
×
799
            g.settings.axes_labelsize = axes_labelsize
×
800
        g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
×
801
                         filled=filled_contours,
802
                         colors=colors,
803
                         legend_ncol=legend_ncol,
804
                         legend_labels=legend_labels,
805
                         line_args=line_args, 
806
                         contour_colors=self.colors)
807
        for k in range(len(self.ref_values)):
×
808
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
809
            for j, key_x in enumerate(x_param_ids):
×
810
                val_x = self.ref_values_markers[k][key_x]
×
811
                for i, key_y in enumerate(y_param_ids):
×
812
                    val_y = self.ref_values_markers[k][key_y]
×
813
                    if val_x is not None and val_y is not None:
×
814
                        g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
×
815
        return g
×
816
    
817
    def plot_1d_getdist(self, num_columns=None, legend_ncol=None, 
×
818
                        legend_fontsize=None, axes_labelsize=None,
819
                        linewidth=1, **subplot_kwargs):
820
        """Array of 1D histogram subplots using getdist.plots_1d method.
821

822
        Parameters
823
        ----------
824
        subplot_size : int, optional
825
            Size of the getdist plot, by default 1
826
        linewidth : int, optional
827
            Line width for 2D contours, by default 1
828
        marker_size : int, optional
829
            Size of the reference (scatter) markers on 2D contours plots, by default 15
830
        legend_ncol : int, optional
831
            number of columns in the legend
832
        num_columns : int, optional
833
            number of columns of the subplot array
834

835
        Returns
836
        -------
837
        GetDistPlotter
838
            Instance of GetDistPlotter corresponding to the figure
839
        """
840
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
841

842
        if num_columns is None:
×
843
            num_columns = self.num_models//2+1
×
844
        if legend_ncol is None:
×
845
            legend_ncol = 3
×
846
        # Make the plot
847
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
848
        if legend_fontsize is not None:
×
849
            g.settings.legend_fontsize = legend_fontsize
×
850
        if axes_labelsize is not None:
×
851
            g.settings.axes_labelsize = axes_labelsize
×
852
        g.plots_1d(self._mcsamples,
×
853
                   params=self.parameter_id_list,
854
                   legend_labels=legend_labels,
855
                   colors=colors,
856
                   share_y=True,
857
                   line_args=line_args,
858
                   nx=num_columns, legend_ncol=legend_ncol,
859
        )
860
        for k in range(len(self.ref_values)):
×
861
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
862
        # for k in range(0, len(self.ref_values)):
863
        #     # Add vertical and horizontal lines
864
        #     for i in range(0, self.num_params):
865
        #         val = self.ref_values[k][i]
866
        #         ax = g.subplots.flatten()[i]
867
        #         if val is not None:
868
        #             ax.axvline(val, color='black', ls=self.ref_linestyles[k], alpha=1.0, lw=1)
869
        return g
×
870

871
    def plot_source(self, idx_file=0):
×
872
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
873
        return f,ax
×
874
    
875
    def plot_lens(self, idx_file=0):
×
876
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
877
        return f,ax
×
878

879
    def plotting_routine(self, param_dict, idx_file=0):
×
880
        """
881
        plot the parameters
882

883
        INPUT
884
        -----
885
        param_dict: dict, organized dictonnary with all parameters results of the different files
886
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
887
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
888
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
889
         sersic and shapelets parameters when available)
890
        """
891

892
        #find the numer of parameters to plot and define a nice looking figure
893
        number_param = len(param_dict[self.file_names[idx_file]])
×
894
        unused_figs = []
×
895
        if number_param <= 4:
×
896
            print('so few parameters not implemented yet')
×
897
        else:
898
            if number_param % 4 == 0:
×
899
                num_lines = int(number_param / 4.)
×
900
            else:
901
                num_lines = int(number_param / 4.) + 1
×
902

903
                for idx in range(3):
×
904
                    if (number_param + idx) % 4 != 0:
×
905
                        unused_figs.append(-idx - 1)
×
906
                    else:
907
                        break
×
908

909
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
910
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
911
        #may find a better way to define markers but right now, it is sufficient
912

913
        for j, file_name in enumerate(self.file_names):
×
914
            i = 0
×
915
            result = param_dict[file_name]
×
916
            for key in result.keys():
×
917
                idx_line = int(i / 4.)
×
918
                idx_col = i % 4
×
919
                p = result[key]
×
920
                m = markers[j]
×
921
                if self.posterior_bool_list[j]:
×
922
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
923
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
924
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
925
                    #                 i+=1
926
                    #                 continue
927

928
                    #trick to plot correct error bars if close to the +180/-180 edge
929
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
930
                        if p['percentile_16th'] > p['median']:
×
931
                            p['percentile_16th'] -= 180.
×
932
                        if p['percentile_84th'] < p['median']:
×
933
                            p['percentile_84th'] += 180.
×
934
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
935
                                                                    [p['percentile_84th'] - p['median']]],
936
                                                   marker=m, ls='', label=file_name)
937
                else:
938
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
939

940
                if j == 0:
×
941
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
942
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
943
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
944
                i += 1
×
945

946
        ax[0, 0].legend()
×
947
        for idx in unused_figs:
×
948
            ax[-1, idx].axis('off')
×
949
        plt.tight_layout()
×
950
        plt.show()
×
951
        return f, ax
×
952

953
    def _prepare_getdist_plot(self, lw, lw_cont=None, lw_margin=None):
×
954
        if lw_margin is None:
×
955
            lw_margin = lw + 2
×
956
        line_args = [{'ls': ls, 'lw': lw, 'color': c} for ls, c in zip(self.linestyles, self.colors)]
×
957
        lw_conts = [lw_cont]*self.num_models
×
958
        ls_conts = self.linestyles
×
959
        legend_labels = copy.deepcopy(self.coolest_names)
×
960
        colors = copy.deepcopy(self.colors)
×
961
        if self._add_margin_samples:
×
962
            line_args.append({'ls': '-.', 'lw': lw_margin, 'alpha': 0.8, 'color': self._color_margin})
×
963
            ls_conts.append('-.')
×
964
            if lw_cont is not None: lw_conts.append(lw_margin)
×
965
            legend_labels.append(self._label_margin)
×
966
            colors.append(self._color_margin)
×
967
        return line_args, lw_conts, ls_conts, colors, legend_labels
×
968

969
# def plot_corner(parameter_id_list, 
970
#                 chain_objs, chain_dirs, chain_names=None, 
971
#                 point_estimate_objs=None, point_estimate_dirs=None, point_estimate_names=None, 
972
#                 colors=None, labels=None, subplot_size=1, mc_samples_kwargs=None, 
973
#                 filled_contours=True, angles_range=None, shift_sample_list=None):
974
#     """
975
#     Adding this as just a function for the moment.
976
#     Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
977

978
#     Parameters
979
#     ----------
980
#     parameter_id_list : array
981
#         A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
982
#     chain_objs : array
983
#         A list of coolest objects that have a chain file associated to them.
984
#     chain_dirs : array
985
#         A list of paths matching the coolest files in 'chain_objs'.
986
#     chain_names : array, optional
987
#         A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
988
#     point_estimate_objs : array, optional
989
#         A list of coolest objects that will be used as point estimates.
990
#     point_estimate_dirs : array
991
#         A list of paths matching the coolest files in 'point_estimate_objs'.
992
#     point_estimate_names : array, optional
993
#         A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
994
#     labels : dict, optional
995
#         A dictionary matching the parameter_id_list entries to some human-readable labels.
996

997
#     Returns
998
#     -------
999
#     An image
1000
#     """
1001

1002
#     chains.print_load_details = False # Just to silence messages
1003
#     parameter_id_set = set(parameter_id_list)
1004
#     Npars = len(parameter_id_list)
1005
#     Nobjs = len(chain_objs)
1006
    
1007
#     # Set the chain names
1008
#     if chain_names is None:
1009
#         chain_names = ["chain "+str(i) for i in range(Nobjs)]
1010
    
1011
#     if shift_sample_list is None:
1012
#         shift_sample_list = [None]*Nobjs
1013
    
1014
#     # Get the values of the point_estimates
1015
#     point_estimates = []
1016
#     if point_estimate_objs is not None:
1017
#         for coolest_obj in point_estimate_objs:
1018
#             values = []
1019
#             for par in parameter_id_list:
1020
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par)
1021
#                 val = param.point_estimate.value
1022
#                 if val is None:
1023
#                     values.append(None)
1024
#                 else:
1025
#                     values.append(val)
1026
#             point_estimates.append(values)
1027

1028

1029
            
1030
#     mcsamples = []
1031
#     for i in range(Nobjs):
1032
#         chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
1033

1034
#         # Each chain file can have a different number of free parameters
1035
#         f = open(chain_file)
1036
#         header = f.readline()
1037
#         f.close()
1038

1039
#         if ';' in header:
1040
#             raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
1041

1042
#         chain_file_headers = header.split(',')
1043
#         num_cols = len(chain_file_headers)
1044
#         chain_file_headers.pop() # Remove the last column name that is the probability weights
1045
#         chain_file_headers_set = set(chain_file_headers)
1046
        
1047
#         # Check that the given parameters are a subset of those in the chain file
1048
#         assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
1049

1050
#         # Set the labels for the parameters in the chain file
1051
#         par_labels = []
1052
#         if labels is None:
1053
#             labels = {}
1054
#         for par_id in parameter_id_list:
1055
#             if labels.get(par_id, None) is None:
1056
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
1057
#                 par_labels.append(param.latex_str.strip('$'))
1058
#             else:
1059
#                 par_labels.append(labels[par_id])
1060
                    
1061
#         # Read parameter values and probability weights
1062
#         column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
1063
#         columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
1064
#         samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
1065
    
1066
#         # Re-order columnds to match parameter_id_list and par_labels
1067
#         sample_par_values = np.array(samples[parameter_id_list])
1068

1069
#         # If needed, shift samples by a constant
1070
#         if shift_sample_list[i] is not None:
1071
#             for param_id, value in shift_sample_list[i].items():
1072
#                 sample_par_values[:, parameter_id_list.index(param_id)] += value
1073
#                 print(f"INFO: posterior for parameter '{param_id}' from model '{chain_names[i]}' "
1074
#                       f"has been shifted by {value}.")
1075

1076
#         # Clean-up the probability weights
1077
#         mypost = np.array(samples['probability_weights'])
1078
#         min_non_zero = np.min(mypost[np.nonzero(mypost)])
1079
#         sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
1080
#         #sample_prob_weight = mypost
1081

1082
#         # Create MCSamples object
1083
#         mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
1084
#         mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
1085
#         mcsamples.append(mysample)
1086

1087

1088
        
1089
#     # Make the plot
1090
#     image = plots.getSubplotPlotter(subplot_size=subplot_size)    
1091
#     image.triangle_plot(mcsamples,
1092
#                         params=parameter_id_list,
1093
#                         legend_labels=chain_names,
1094
#                         filled=filled_contours,
1095
#                         colors=colors,
1096
#                         line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors], 
1097
#                         contour_colors=colors)
1098

1099

1100
#     my_linestyles = ['solid','dotted','dashed','dashdot']
1101
#     my_markers    = ['s','^','o','star']
1102

1103
#     for k in range(0,len(point_estimates)):
1104
#         # Add vertical and horizontal lines
1105
#         for i in range(0,Npars):
1106
#             val = point_estimates[k][i]
1107
#             if val is not None:
1108
#                 for ax in image.subplots[i:,i]:
1109
#                     ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1110
#                 for ax in image.subplots[i,:i]:
1111
#                     ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1112

1113
#         # Add points
1114
#         for i in range(0,Npars):
1115
#             val_x = point_estimates[k][i]
1116
#             for j in range(i+1,Npars):
1117
#                 val_y = point_estimates[k][j]
1118
#                 if val_x is not None and val_y is not None:
1119
#                     image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
1120
#                 else:
1121
#                     pass    
1122

1123

1124
#     # Set default ranges for angles
1125
#     if angles_range is None:
1126
#         angles_range = (-90, 90)
1127
#     for i in range(0,len(parameter_id_list)):
1128
#         dum = parameter_id_list[i].split('-')
1129
#         name = dum[-1]
1130
#         if name in ['phi','phi_ext']:
1131
#             xlim = image.subplots[i,i].get_xlim()
1132
#             #print(xlim)
1133
        
1134
#             if xlim[0] < -90:
1135
#                 for ax in image.subplots[i:,i]:
1136
#                     ax.set_xlim(left=angles_range[0])
1137
#                 for ax in image.subplots[i,:i]:
1138
#                     ax.set_ylim(bottom=angles_range[0])
1139
#             if xlim[1] > 90:
1140
#                 for ax in image.subplots[i:,i]:
1141
#                     ax.set_xlim(right=angles_range[1])
1142
#                 for ax in image.subplots[i,:i]:
1143
#                     ax.set_ylim(top=angles_range[1])
1144

1145
            
1146
#     return image
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc