• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 5802599031

pending completion
5802599031

push

github

gvernard
Fixed a mistake from the previous commit.

5 of 5 new or added lines in 1 file covered. (100.0%)

1347 of 2489 relevant lines covered (54.12%)

0.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots,chains,MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18

19
# matplotlib global settings
20
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
21

22
# logging settings
23
logging.getLogger().setLevel(logging.INFO)
×
24

25

26
class ModelPlotter(object):
×
27
    """Create pyplot panels from a lens model stored in the COOLEST format.
28

29
    Parameters
30
    ----------
31
    coolest_object : COOLEST
32
        COOLEST instance
33
    coolest_directory : str, optional
34
        Directory which contains the COOLEST template, by default None
35
    color_bad_values : str, optional
36
        Color assigned to NaN values (typically negative values in log-scale), 
37
        by default '#111111' (dark gray)
38
    """
39

40
    def __init__(self, coolest_object, coolest_directory=None, 
×
41
                 color_bad_values='#111111'):
42
        self.coolest = coolest_object
×
43
        self._directory = coolest_directory
×
44

45
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
46
        self.cmap_flux.set_bad(color_bad_values)
×
47

48
        self.cmap_mag = plt.get_cmap('twilight_shifted')
×
49
        self.cmap_conv = plt.get_cmap('cividis')
×
50
        self.cmap_res = plt.get_cmap('RdBu_r')
×
51

52
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
53
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
54
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
55

56
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, 
×
57
                        neg_values_as_bad=False, add_colorbar=True):
58
        """plt.imshow panel with the data image"""
59
        if cmap is None:
×
60
            cmap = self.cmap_flux
×
61
        coordinates = util.get_coordinates(self.coolest)
×
62
        extent = coordinates.plt_extent
×
63
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
64
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
65
                                cmap=cmap,
66
                                neg_values_as_bad=neg_values_as_bad, 
67
                                norm=norm)
68
        if add_colorbar:
×
69
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
70
            cb.set_label("flux")
×
71
        if title is not None:
×
72
            ax.set_title(title)
×
73
        return image
×
74

75
    def plot_surface_brightness(self, ax, title=None, coordinates=None, 
×
76
                                extent_irreg=None, norm=None, cmap=None, neg_values_as_bad=True,
77
                                plot_points_irreg=False, add_colorbar=True, kwargs_light=None):
78
        """plt.imshow panel showing the surface brightness of the (unlensed)
79
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
80
        if kwargs_light is None:
×
81
            kwargs_light = {}
×
82
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
83
        if cmap is None:
×
84
            cmap = self.cmap_flux
×
85
        if coordinates is not None:
×
86
            x, y = coordinates.pixel_coordinates
×
87
            image = light_model.evaluate_surface_brightness(x, y)
×
88
            extent = coordinates.plt_extent
×
89
            ax, im = self._plot_regular_grid(ax, image, extent=extent, cmap=cmap,
×
90
                                             neg_values_as_bad=neg_values_as_bad, 
91
                                             norm=norm)
92
        else:
93
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
94
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
95
                image = values
×
96
                ax, im = self._plot_regular_grid(ax, image, extent=extent_model, 
×
97
                                        cmap=cmap, 
98
                                        neg_values_as_bad=neg_values_as_bad,
99
                                        norm=norm)
100
            else:
101
                points = values
×
102
                if extent_irreg is None:
×
103
                    extent_irreg = extent_model
×
104
                ax, im = self._plot_irregular_grid(ax, points, extent_irreg, norm=norm, cmap=cmap, 
×
105
                                                   neg_values_as_bad=neg_values_as_bad,
106
                                                   plot_points=plot_points_irreg)
107
                image = None
×
108
        if add_colorbar:
×
109
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
110
            cb.set_label("flux")
×
111
        if title is not None:
×
112
            ax.set_title(title)
×
113
        return image, coordinates
×
114

115
    def plot_model_image(self, ax, supersampling=5, convolved=False, title=None,
×
116
                         norm=None, cmap=None, neg_values_as_bad=False,
117
                         kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None):
118
        """plt.imshow panel showing the surface brightness of the (lensed)
119
        selected lensing entities (see ComposableLensModel docstring)
120
        """
121
        if cmap is None:
×
122
            cmap = self.cmap_flux
×
123
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
124
                                         kwargs_selection_source=kwargs_source,
125
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
126
        image, coordinates = lens_model.model_image(supersampling=supersampling, 
×
127
                                                    convolved=convolved)
128
        extent = coordinates.plt_extent
×
129
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
130
                                cmap=cmap,
131
                                neg_values_as_bad=neg_values_as_bad, 
132
                                norm=norm)
133
        if add_colorbar:
×
134
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
135
            cb.set_label("flux")
×
136
        if title is not None:
×
137
            ax.set_title(title)
×
138
        return image
×
139

140
    def plot_model_residuals(self, ax, supersampling=5, mask=None, title=None,
×
141
                             norm=None, cmap=None, add_chi2_label=False, chi2_fontsize=12,
142
                             kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None):
143
        """plt.imshow panel showing the normalized model residuals image"""
144
        if cmap is None:
×
145
            cmap = self.cmap_res
×
146
        if norm is None:
×
147
            norm = Normalize(-6, 6)
×
148
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
149
                                         kwargs_selection_source=kwargs_source,
150
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
151
        image, coordinates = lens_model.model_residuals(supersampling=supersampling, 
×
152
                                                        mask=mask)
153
        extent = coordinates.plt_extent
×
154
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
155
                                cmap=cmap,
156
                                neg_values_as_bad=False, 
157
                                norm=norm)
158
        if add_colorbar:
×
159
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
160
            cb.set_label("(data $-$ model) / noise")
×
161
        if add_chi2_label is True:
×
162
            num_constraints = np.size(image) if mask is None else np.sum(mask)
×
163
            red_chi2 = np.sum(image**2) / num_constraints
×
164
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
165
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
166
                    bbox={'color': 'white', 'alpha': 0.6})
167
        if title is not None:
×
168
            ax.set_title(title)
×
169
        return image
×
170

171
    def plot_convergence(self, ax, title=None,
×
172
                         norm=None, cmap=None, neg_values_as_bad=False,
173
                         add_colorbar=True, kwargs_lens_mass=None):
174
        """plt.imshow panel showing the 2D convergence map associated to the
175
        selected lensing entities (see ComposableMassModel docstring)
176
        """
177
        if kwargs_lens_mass is None:
×
178
            kwargs_lens_mass = {}
×
179
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
180
                                         **kwargs_lens_mass)
181
        if cmap is None:
×
182
            cmap = self.cmap_conv
×
183
        coordinates = util.get_coordinates(self.coolest)
×
184
        extent = coordinates.plt_extent
×
185
        x, y = coordinates.pixel_coordinates
×
186
        image = mass_model.evaluate_convergence(x, y)
×
187
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
188
                                cmap=cmap,
189
                                neg_values_as_bad=neg_values_as_bad, 
190
                                norm=norm)
191
        if add_colorbar:
×
192
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
193
            cb.set_label(r"$\kappa$")
×
194
        if title is not None:
×
195
            ax.set_title(title)
×
196
        return image
×
197

198
    def plot_magnification(self, ax, title=None,
×
199
                          norm=None, cmap=None, neg_values_as_bad=False,
200
                          add_colorbar=True, kwargs_lens_mass=None):
201
        """plt.imshow panel showing the 2D magnification map associated to the
202
        selected lensing entities (see ComposableMassModel docstring)
203
        """
204
        if kwargs_lens_mass is None:
×
205
            kwargs_lens_mass = {}
×
206
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
207
                                         **kwargs_lens_mass)
208
        if cmap is None:
×
209
            cmap = self.cmap_mag
×
210
        if norm is None:
×
211
            norm = Normalize(-10, 10)
×
212
        coordinates = util.get_coordinates(self.coolest)
×
213
        extent = coordinates.plt_extent
×
214
        x, y = coordinates.pixel_coordinates
×
215
        image = mass_model.evaluate_magnification(x, y)
×
216
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
217
                                cmap=cmap,
218
                                neg_values_as_bad=neg_values_as_bad, 
219
                                norm=norm)
220
        if add_colorbar:
×
221
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
222
            cb.set_label(r"$\mu$")
×
223
        if title is not None:
×
224
            ax.set_title(title)
×
225
        return image
×
226
        
227
    @staticmethod
×
228
    def _plot_regular_grid(ax, image_, neg_values_as_bad=True, **imshow_kwargs):
×
229
        if neg_values_as_bad:
×
230
            image = np.copy(image_)
×
231
            image[image < 0] = np.nan
×
232
        else:
233
            image = image_
×
234
        if neg_values_as_bad:
×
235
            image[image < 0] = np.nan
×
236
        im = ax.imshow(image, **imshow_kwargs)
×
237
        im.set_rasterized(True)
×
238
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
239
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
240
        return ax, im
×
241

242
    @staticmethod
×
243
    def _plot_irregular_grid(ax, points, extent, neg_values_as_bad=True,
×
244
                             norm=None, cmap=None, plot_points=False):
245
        x, y, z = points
×
246
        im = plut.plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
247
                              norm=norm, cmap=cmap, zorder=1)
248
        ax.set_xlim(extent[0], extent[1])
×
249
        ax.set_ylim(extent[2], extent[3])
×
250
        ax.set_aspect('equal', 'box')
×
251
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
252
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
253
        if plot_points:
×
254
            ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
255
        return ax, im
×
256

257

258
class MultiModelPlotter(object):
×
259
    """Wrapper around a set of ModelPlotter instances to produce panels that
260
    consistently compare different models, evaluated on the same
261
    coordinates systems.
262

263
    Parameters
264
    ----------
265
    coolest_objects : list
266
        List of COOLEST instances
267
    coolest_directories : list, optional
268
        List of directories corresponding to each COOLEST instance, by default None
269
    kwargs_plotter : dict, optional
270
        Additional keyword arguments passed to ModelPlotter
271
    """
272

273
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
274
        self.num_models = len(coolest_objects)
×
275
        if coolest_directories is None:
×
276
            coolest_directories = self.num_models * [None]
×
277
        self.plotter_list = []
×
278
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
279
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
280
                                                  **kwargs_plotter))
281

282
    def plot_surface_brightness(self, axes, global_title="surf. brightness", titles=None, **kwargs):
×
283
        return self._plot_light_multi('plot_surface_brightness', global_title, axes, titles=titles, **kwargs)
×
284

285
    def plot_data_image(self, axes, global_title="data", titles=None, **kwargs):
×
286
        return self._plot_data_multi(global_title, axes, titles=titles, **kwargs)
×
287

288
    def plot_model_image(self, axes, global_title="model", titles=None, **kwargs):
×
289
        return self._plot_lens_model_multi('plot_model_image', global_title, axes, titles=titles, **kwargs)
×
290

291
    def plot_model_residuals(self, axes, global_title="residuals", titles=None, **kwargs):
×
292
        return self._plot_lens_model_multi('plot_model_residuals', global_title, axes, titles=titles, **kwargs)
×
293

294
    def plot_convergence(self, axes, global_title="convergence", titles=None, **kwargs):
×
295
        return self._plot_lens_model_multi('plot_convergence', global_title, axes, titles=titles, **kwargs)
×
296

297
    def plot_magnification(self, axes, titles=None, **kwargs):
×
298
        return self._plot_lens_model_multi('plot_magnification', "magnification", axes, titles=titles, **kwargs)
×
299

300
    def _plot_light_multi(self, method_name, global_title, axes, titles=None, **kwargs):
×
301
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
302
        if titles is None:
×
303
            titles = self.num_models * [global_title]
×
304
        kwargs_ = copy.deepcopy(kwargs)
×
305
        image_list = []
×
306
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
307
            if ax is None:
×
308
                continue
×
309
            if 'kwargs_light' in kwargs:
×
310
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
311
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
312
            image_list.append(image)
×
313
        return image_list
×
314

315
    def _plot_mass_multi(self, method_name, global_title, axes, titles=None, **kwargs):
×
316
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
317
        if titles is None:
×
318
            titles = self.num_models * [global_title]
×
319
        kwargs_ = copy.deepcopy(kwargs)
×
320
        image_list = []
×
321
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
322
            if ax is None:
×
323
                continue
×
324
            if 'kwargs_lens_mass' in kwargs:
×
325
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
326
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
327
            image_list.append(image)
×
328
        return image_list
×
329

330
    def _plot_lens_model_multi(self, method_name, global_title, axes, titles=None, kwargs_select=None, **kwargs):
×
331
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
332
        if titles is None:
×
333
            titles = self.num_models * [global_title]
×
334
        kwargs_ = copy.deepcopy(kwargs)
×
335
        image_list = []
×
336
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
337
            if ax is None:
×
338
                continue
×
339
            if 'kwargs_source' in kwargs:
×
340
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
341
            if 'kwargs_lens_mass' in kwargs:
×
342
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
343
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
344
            image_list.append(image)
×
345
        return image_list
×
346

347
    def _plot_data_multi(self, global_title, axes, titles=None, **kwargs):
×
348
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
349
        if titles is None:
×
350
            titles = self.num_models * [global_title]
×
351
        image_list = []
×
352
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
353
            if ax is None:
×
354
                continue
×
355
            image = getattr(plotter, 'plot_data_image')(ax, title=titles[i], **kwargs)
×
356
            image_list.append(image)
×
357
        return image_list
×
358

359

360
class Comparison_analytical(object):
×
361
    """Handles plot of analytical models in a comparative way
362

363
    Parameters
364
    ----------
365
    coolest_file_list : list
366
        List of paths to COOLEST templates
367
    nickname_file_list : list
368
        List of shorter names related to each COOLEST model
369
    posterior_bool_list : list
370
        List of bool to toggle errorbars on point-estimate values
371
    """
372

373
    def __init__(self,coolest_file_list, nickname_file_list, posterior_bool_list):
×
374
        self.file_names = nickname_file_list
×
375
        self.posterior_bool_list = posterior_bool_list
×
376
        self.param_lens, self.param_source = util.read_json_param(coolest_file_list,
×
377
                                                                  self.file_names, 
378
                                                                  lens_light=False)
379

380
    def plotting_routine(self,param_dict,idx_file=0):
×
381
        """
382
        plot the parameters
383

384
        INPUT
385
        -----
386
        param_dict: dict, organized dictonnary with all parameters results of the different files
387
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
388
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
389
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
390
         sersic and shapelets parameters when available)
391
        """
392

393
        #find the numer of parameters to plot and define a nice looking figure
394
        number_param = len(param_dict[self.file_names[idx_file]])
×
395
        unused_figs = []
×
396
        if number_param <= 4:
×
397
            print('so few parameters not implemented yet')
×
398
        else:
399
            if number_param % 4 == 0:
×
400
                num_lines = int(number_param / 4.)
×
401
            else:
402
                num_lines = int(number_param / 4.) + 1
×
403

404
                for idx in range(3):
×
405
                    if (number_param + idx) % 4 != 0:
×
406
                        unused_figs.append(-idx - 1)
×
407
                    else:
408
                        break
×
409

410
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
411
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
412
        #may find a better way to define markers but right now, it is sufficient
413

414
        for j, file_name in enumerate(self.file_names):
×
415
            i = 0
×
416
            result = param_dict[file_name]
×
417
            for key in result.keys():
×
418
                idx_line = int(i / 4.)
×
419
                idx_col = i % 4
×
420
                p = result[key]
×
421
                m = markers[j]
×
422
                if self.posterior_bool_list[j]:
×
423
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
424
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
425
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
426
                    #                 i+=1
427
                    #                 continue
428

429
                    #trick to plot correct error bars if close to the +180/-180 edge
430
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
431
                        if p['percentile_16th'] > p['median']:
×
432
                            p['percentile_16th'] -= 180.
×
433
                        if p['percentile_84th'] < p['median']:
×
434
                            p['percentile_84th'] += 180.
×
435
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
436
                                                                    [p['percentile_84th'] - p['median']]],
437
                                                   marker=m, ls='', label=file_name)
438
                else:
439
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
440

441
                if j == 0:
×
442
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
443
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
444
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
445
                i += 1
×
446

447
        ax[0, 0].legend()
×
448
        for idx in unused_figs:
×
449
            ax[-1, idx].axis('off')
×
450
        plt.tight_layout()
×
451
        plt.show()
×
452
        return f,ax
×
453
    
454
    def plot_source(self,idx_file=0):
×
455
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
456
        return f,ax
×
457
    
458
    def plot_lens(self,idx_file=0):
×
459
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
460
        return f,ax
×
461

462

463

464

465

466
def plot_corner(parameter_id_list,chain_objs,chain_dirs,chain_names=None,point_estimate_objs=None,point_estimate_dirs=None,point_estimate_names=None,colors=None,labels=None,subplot_size=1,mc_samples_kwargs=None,filled_contours=True):
×
467
    """
468
    Adding this as just a function for the moment.
469
    Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
470

471
    Parameters
472
    ----------
473
    parameter_id_list : array
474
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
475
    chain_objs : array
476
        A list of coolest objects that have a chain file associated to them.
477
    chain_dirs : array
478
        A list of paths matching the coolest files in 'chain_objs'.
479
    chain_names : array, optional
480
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
481
    point_estimate_objs : array, optional
482
        A list of coolest objects that will be used as point estimates.
483
    point_estimate_dirs : array
484
        A list of paths matching the coolest files in 'point_estimate_objs'.
485
    point_estimate_names : array, optional
486
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
487
    labels : dict, optional
488
        A dictionary matching the parameter_id_list entries to some human-readable labels.
489
    
490

491
    Returns
492
    -------
493
    An image
494
    """
495

496
    chains.print_load_details = False # Just to silence messages
×
497
    parameter_id_set = set(parameter_id_list)
×
498
    Npars = len(parameter_id_list)
×
499
    
500
    # Set the chain names
501
    if chain_names is None:
×
502
        chain_names = ["chain "+str(i) for i in range(0,len(chain_objs))]
×
503
    
504
    # Get the values of the point_estimates
505
    point_estimates = []
×
506
    if point_estimate_objs is not None:
×
507
        for coolest_obj in point_estimate_objs:
×
508
            values = []
×
509
            for par in parameter_id_list:
×
510
                param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
511
                val = param.point_estimate.value
×
512
                if val is None:
×
513
                    values.append(None)
×
514
                else:
515
                    values.append(val)
×
516
            point_estimates.append(values)
×
517

518

519
            
520
    mcsamples = []
×
521
    for i in range(0,len(chain_objs)):
×
522
        chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
523

524
        # Each chain file can have a different number of free parameters
525
        f = open(chain_file)
×
526
        header = f.readline()
×
527
        f.close()
×
528

529
        if ';' in header:
×
530
            raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
×
531

532
        chain_file_headers = header.split(',')
×
533
        num_cols = len(chain_file_headers)
×
534
        chain_file_headers.pop() # Remove the last column name that is the probability weights
×
535
        chain_file_headers_set = set(chain_file_headers)
×
536
        
537
        # Check that the given parameters are a subset of those in the chain file
538
        assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
539

540
        # Set the labels for the parameters in the chain file
541
        par_labels = []
×
542
        if labels is None:
×
543
            labels = {}
×
544
        for par_id in parameter_id_list:
×
545
            if labels.get(par_id, None) is None:
×
546
                param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
×
547
                par_labels.append(param.latex_str.strip('$'))
×
548
            else:
549
                par_labels.append(labels[par_id])
×
550
                    
551
        # Read parameter values and probability weights
552
        # TODO: handle samples that are given as a list / array (e.g. using `converters`)
553
        column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
×
554
        columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
×
555
        samples = np.loadtxt(chain_file, usecols=columns_to_read, skiprows=1, delimiter=',', comments=None)
×
556
        sample_par_values = samples[:, :-1]
×
557

558
        # Re-order columnds to match parameter_id_list and par_labels
559
        sample_par_values = sample_par_values[:, column_indices]
×
560

561
        # Clean-up the probability weights
562
        mypost = samples[:,-1]
×
563
        min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
564
        sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
×
565
        #sample_prob_weight = mypost
566

567
        # Create MCSamples object
568
        mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
×
569
        mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
570
        mcsamples.append(mysample)
×
571

572

573
        
574
    # Make the plot
575
    image = plots.getSubplotPlotter(subplot_size=subplot_size)    
×
576
    image.triangle_plot(mcsamples,
×
577
                        params=parameter_id_list,
578
                        legend_labels=chain_names,
579
                        filled=filled_contours,
580
                        colors=colors,
581
                        line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors], 
582
                        contour_colors=colors)
583

584

585
    my_linestyles = ['solid','dotted','dashed','dashdot']
×
586
    my_markers    = ['s','^','o','star']
×
587

588
    for k in range(0,len(point_estimates)):
×
589
        # Add vertical and horizontal lines
590
        for i in range(0,Npars):
×
591
            val = point_estimates[k][i]
×
592
            if val is not None:
×
593
                for ax in image.subplots[i:,i]:
×
594
                    ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
595
                for ax in image.subplots[i,:i]:
×
596
                    ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
597

598
        # Add points
599
        for i in range(0,Npars):
×
600
            val_x = point_estimates[k][i]
×
601
            for j in range(i+1,Npars):
×
602
                val_y = point_estimates[k][j]
×
603
                if val_x is not None and val_y is not None:
×
604
                    image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
×
605
                else:
606
                    pass    
×
607

608

609
    # Set default ranges for angles
610
    for i in range(0,len(parameter_id_list)):
×
611
        dum = parameter_id_list[i].split('-')
×
612
        name = dum[-1]
×
613
        if name in ['phi','phi_ext']:
×
614
            xlim = image.subplots[i,i].get_xlim()
×
615
            #print(xlim)
616
        
617
            if xlim[0] < -90:
×
618
                for ax in image.subplots[i:,i]:
×
619
                    ax.set_xlim(left=-90)
×
620
                for ax in image.subplots[i,:i]:
×
621
                    ax.set_ylim(bottom=-90)
×
622
            if xlim[1] > 90:
×
623
                for ax in image.subplots[i:,i]:
×
624
                    ax.set_xlim(right=90)
×
625
                for ax in image.subplots[i,:i]:
×
626
                    ax.set_ylim(top=90)
×
627

628
            
629
    return image
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc