• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 5707586311

pending completion
5707586311

push

github

gvernard
Merge branch 'joss' of github.com:aymgal/COOLEST into joss

75 of 75 new or added lines in 2 files covered. (100.0%)

1350 of 2471 relevant lines covered (54.63%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots,chains,MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18

19
# matplotlib global settings
20
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
21

22
# logging settings
23
logging.getLogger().setLevel(logging.INFO)
×
24

25

26
class ModelPlotter(object):
×
27
    """Create pyplot panels from a lens model stored in the COOLEST format.
28

29
    Parameters
30
    ----------
31
    coolest_object : COOLEST
32
        COOLEST instance
33
    coolest_directory : str, optional
34
        Directory which contains the COOLEST template, by default None
35
    color_bad_values : str, optional
36
        Color assigned to NaN values (typically negative values in log-scale), 
37
        by default '#111111' (dark gray)
38
    """
39

40
    def __init__(self, coolest_object, coolest_directory=None, 
×
41
                 color_bad_values='#111111'):
42
        self.coolest = coolest_object
×
43
        self._directory = coolest_directory
×
44

45
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
46
        self.cmap_flux.set_bad(color_bad_values)
×
47

48
        self.cmap_mag = plt.get_cmap('twilight_shifted')
×
49
        self.cmap_conv = plt.get_cmap('cividis')
×
50
        self.cmap_res = plt.get_cmap('RdBu_r')
×
51

52
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
53
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
54
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
55

56
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, neg_values_as_bad=False):
×
57
        """plt.imshow panel with the data image"""
58
        if cmap is None:
×
59
            cmap = self.cmap_flux
×
60
        coordinates = util.get_coordinates(self.coolest)
×
61
        extent = coordinates.plt_extent
×
62
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
63
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
64
                                cmap=cmap,
65
                                neg_values_as_bad=neg_values_as_bad, 
66
                                norm=norm)
67
        cb.set_label("flux")
×
68
        if title is not None:
×
69
            ax.set_title(title)
×
70
        return image
×
71

72
    def plot_surface_brightness(self, ax, title=None, coordinates=None, 
×
73
                                extent=None, norm=None, cmap=None, neg_values_as_bad=True,
74
                                plot_points_irreg=False, kwargs_light=None):
75
        """plt.imshow panel showing the surface brightness of the (unlensed)
76
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
77
        if kwargs_light is None:
×
78
            kwargs_light = {}
×
79
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
80
        if cmap is None:
×
81
            cmap = self.cmap_flux
×
82
        if coordinates is not None:
×
83
            x, y = coordinates.pixel_coordinates
×
84
            image = light_model.evaluate_surface_brightness(x, y)
×
85
            extent = coordinates.plt_extent
×
86
            ax, cb = self._plot_regular_grid(ax, image, extent=extent, cmap=cmap,
×
87
                                             neg_values_as_bad=neg_values_as_bad, 
88
                                             norm=norm)
89
        else:
90
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
91
            if extent is None:
×
92
                extent = extent_model
×
93
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
94
                image = values
×
95
                ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
96
                                        cmap=cmap, 
97
                                        neg_values_as_bad=neg_values_as_bad,
98
                                        norm=norm)
99
            else:
100
                points = values
×
101
                ax, cb = self._plot_irregular_grid(ax, points, extent, norm=norm, cmap=cmap, 
×
102
                                                   neg_values_as_bad=neg_values_as_bad,
103
                                                   plot_points=plot_points_irreg)
104
                image = None
×
105
        cb.set_label("flux")
×
106
        if title is not None:
×
107
            ax.set_title(title)
×
108
        return image, coordinates
×
109

110
    def plot_model_image(self, ax, supersampling=5, convolved=False, title=None,
×
111
                         norm=None, cmap=None, neg_values_as_bad=False,
112
                         kwargs_source=None, kwargs_lens_mass=None):
113
        """plt.imshow panel showing the surface brightness of the (lensed)
114
        selected lensing entities (see ComposableLensModel docstring)
115
        """
116
        if cmap is None:
×
117
            cmap = self.cmap_flux
×
118
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
119
                                         kwargs_selection_source=kwargs_source,
120
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
121
        image, coordinates = lens_model.model_image(supersampling=supersampling, 
×
122
                                                    convolved=convolved)
123
        extent = coordinates.plt_extent
×
124
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
125
                                cmap=cmap,
126
                                neg_values_as_bad=neg_values_as_bad, 
127
                                norm=norm)
128
        cb.set_label("flux")
×
129
        if title is not None:
×
130
            ax.set_title(title)
×
131
        return image
×
132

133
    def plot_model_residuals(self, ax, supersampling=5, mask=None, title=None,
×
134
                             norm=None, cmap=None, add_chi2_label=False, chi2_fontsize=12,
135
                             kwargs_source=None, kwargs_lens_mass=None):
136
        """plt.imshow panel showing the normalized model residuals image"""
137
        if cmap is None:
×
138
            cmap = self.cmap_res
×
139
        if norm is None:
×
140
            norm = Normalize(-6, 6)
×
141
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
142
                                         kwargs_selection_source=kwargs_source,
143
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
144
        image, coordinates = lens_model.model_residuals(supersampling=supersampling, 
×
145
                                                        mask=mask)
146
        extent = coordinates.plt_extent
×
147
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
148
                                cmap=cmap,
149
                                neg_values_as_bad=False, 
150
                                norm=norm)
151
        cb.set_label("(data $-$ model) / noise")
×
152
        if add_chi2_label is True:
×
153
            num_constraints = np.size(image) if mask is None else np.sum(mask)
×
154
            red_chi2 = np.sum(image**2) / num_constraints
×
155
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
156
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
157
                    bbox={'color': 'white', 'alpha': 0.6})
158
        if title is not None:
×
159
            ax.set_title(title)
×
160
        return image
×
161

162
    def plot_convergence(self, ax, title=None,
×
163
                         norm=None, cmap=None, neg_values_as_bad=False,
164
                         kwargs_lens_mass=None):
165
        """plt.imshow panel showing the 2D convergence map associated to the
166
        selected lensing entities (see ComposableMassModel docstring)
167
        """
168
        if kwargs_lens_mass is None:
×
169
            kwargs_lens_mass = {}
×
170
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
171
                                         **kwargs_lens_mass)
172
        if cmap is None:
×
173
            cmap = self.cmap_conv
×
174
        coordinates = util.get_coordinates(self.coolest)
×
175
        extent = coordinates.plt_extent
×
176
        x, y = coordinates.pixel_coordinates
×
177
        image = mass_model.evaluate_convergence(x, y)
×
178
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
179
                                cmap=cmap,
180
                                neg_values_as_bad=neg_values_as_bad, 
181
                                norm=norm)
182
        cb.set_label(r"$\kappa$")
×
183
        if title is not None:
×
184
            ax.set_title(title)
×
185
        return image
×
186

187
    def plot_magnification(self, ax, title=None,
×
188
                          norm=None, cmap=None, neg_values_as_bad=False,
189
                          kwargs_lens_mass=None):
190
        """plt.imshow panel showing the 2D magnification map associated to the
191
        selected lensing entities (see ComposableMassModel docstring)
192
        """
193
        if kwargs_lens_mass is None:
×
194
            kwargs_lens_mass = {}
×
195
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
196
                                         **kwargs_lens_mass)
197
        if cmap is None:
×
198
            cmap = self.cmap_mag
×
199
        if norm is None:
×
200
            norm = Normalize(-10, 10)
×
201
        coordinates = util.get_coordinates(self.coolest)
×
202
        extent = coordinates.plt_extent
×
203
        x, y = coordinates.pixel_coordinates
×
204
        image = mass_model.evaluate_magnification(x, y)
×
205
        ax, cb = self._plot_regular_grid(ax, image, extent=extent, 
×
206
                                cmap=cmap,
207
                                neg_values_as_bad=neg_values_as_bad, 
208
                                norm=norm)
209
        cb.set_label(r"$\mu$")
×
210
        if title is not None:
×
211
            ax.set_title(title)
×
212
        return image
×
213
        
214
    @staticmethod
×
215
    def _plot_regular_grid(ax, image_, neg_values_as_bad=True, **imshow_kwargs):
×
216
        if neg_values_as_bad:
×
217
            image = np.copy(image_)
×
218
            image[image < 0] = np.nan
×
219
        else:
220
            image = image_
×
221
        if neg_values_as_bad:
×
222
            image[image < 0] = np.nan
×
223
        im = ax.imshow(image, **imshow_kwargs)
×
224
        im.set_rasterized(True)
×
225
        cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
226
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
227
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
228
        return ax, cb
×
229

230
    @staticmethod
×
231
    def _plot_irregular_grid(ax, points, extent, neg_values_as_bad=True,
×
232
                             norm=None, cmap=None, plot_points=False):
233
        x, y, z = points
×
234
        m = plut.plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
235
                              norm=norm, cmap=cmap, zorder=1)
236
        ax.set_xlim(extent[0], extent[1])
×
237
        ax.set_ylim(extent[2], extent[3])
×
238
        ax.set_aspect('equal', 'box')
×
239
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
240
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
241
        cb = plut.nice_colorbar(m, ax=ax, max_nbins=4)
×
242
        if plot_points:
×
243
            ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
244
        return ax, cb
×
245

246

247
class MultiModelPlotter(object):
×
248
    """Wrapper around a set of ModelPlotter instances to produce panels that
249
    consistently compare different models, evaluated on the same
250
    coordinates systems.
251

252
    Parameters
253
    ----------
254
    coolest_objects : list
255
        List of COOLEST instances
256
    coolest_directories : list, optional
257
        List of directories corresponding to each COOLEST instance, by default None
258
    kwargs_plotter : dict, optional
259
        Additional keyword arguments passed to ModelPlotter
260
    """
261

262
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
263
        self.num_models = len(coolest_objects)
×
264
        if coolest_directories is None:
×
265
            coolest_directories = self.num_models * [None]
×
266
        self.plotter_list = []
×
267
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
268
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
269
                                                  **kwargs_plotter))
270

271
    def plot_surface_brightness(self, axes, titles=None, **kwargs):
×
272
        return self._plot_light_multi('plot_surface_brightness', "surf. brightness", axes, titles=titles, **kwargs)
×
273

274
    def plot_data_image(self, axes, titles=None, **kwargs):
×
275
        return self._plot_data_multi("data", axes, titles=titles, **kwargs)
×
276

277
    def plot_model_image(self, axes, titles=None, **kwargs):
×
278
        return self._plot_lens_model_multi('plot_model_image', "model", axes, titles=titles, **kwargs)
×
279

280
    def plot_model_residuals(self, axes, titles=None, **kwargs):
×
281
        return self._plot_lens_model_multi('plot_model_residuals', "residuals", axes, titles=titles, **kwargs)
×
282

283
    def plot_convergence(self, axes, titles=None, **kwargs):
×
284
        return self._plot_lens_model_multi('plot_convergence', "convergence", axes, titles=titles, **kwargs)
×
285

286
    def plot_magnification(self, axes, titles=None, **kwargs):
×
287
        return self._plot_lens_model_multi('plot_magnification', "magnification", axes, titles=titles, **kwargs)
×
288

289
    def _plot_light_multi(self, method_name, default_title, axes, titles=None, **kwargs):
×
290
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
291
        if titles is None:
×
292
            titles = self.num_models * [default_title]
×
293
        kwargs_ = copy.deepcopy(kwargs)
×
294
        image_list = []
×
295
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
296
            if ax is None:
×
297
                continue
×
298
            if 'kwargs_light' in kwargs:
×
299
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
300
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
301
            image_list.append(image)
×
302
        return image_list
×
303

304
    def _plot_mass_multi(self, method_name, default_title, axes, titles=None, **kwargs):
×
305
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
306
        if titles is None:
×
307
            titles = self.num_models * [default_title]
×
308
        kwargs_ = copy.deepcopy(kwargs)
×
309
        image_list = []
×
310
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
311
            if ax is None:
×
312
                continue
×
313
            if 'kwargs_lens_mass' in kwargs:
×
314
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
315
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
316
            image_list.append(image)
×
317
        return image_list
×
318

319
    def _plot_lens_model_multi(self, method_name, default_title, axes, titles=None, kwargs_select=None, **kwargs):
×
320
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
321
        if titles is None:
×
322
            titles = self.num_models * [default_title]
×
323
        kwargs_ = copy.deepcopy(kwargs)
×
324
        image_list = []
×
325
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
326
            if ax is None:
×
327
                continue
×
328
            if 'kwargs_source' in kwargs:
×
329
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
330
            if 'kwargs_lens_mass' in kwargs:
×
331
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
332
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
333
            image_list.append(image)
×
334
        return image_list
×
335

336
    def _plot_data_multi(self, default_title, axes, titles=None, **kwargs):
×
337
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
338
        if titles is None:
×
339
            titles = self.num_models * [default_title]
×
340
        image_list = []
×
341
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
342
            if ax is None:
×
343
                continue
×
344
            image = getattr(plotter, 'plot_data_image')(ax, title=titles[i], **kwargs)
×
345
            image_list.append(image)
×
346
        return image_list
×
347

348

349
class Comparison_analytical(object):
×
350
    """Handles plot of analytical models in a comparative way
351

352
    Parameters
353
    ----------
354
    coolest_file_list : list
355
        List of paths to COOLEST templates
356
    nickname_file_list : list
357
        List of shorter names related to each COOLEST model
358
    posterior_bool_list : list
359
        List of bool to toggle errorbars on point-estimate values
360
    """
361

362
    def __init__(self,coolest_file_list, nickname_file_list, posterior_bool_list):
×
363
        self.file_names = nickname_file_list
×
364
        self.posterior_bool_list = posterior_bool_list
×
365
        self.param_lens, self.param_source = util.read_json_param(coolest_file_list,
×
366
                                                                  self.file_names, 
367
                                                                  lens_light=False)
368

369
    def plotting_routine(self,param_dict,idx_file=0):
×
370
        """
371
        plot the parameters
372

373
        INPUT
374
        -----
375
        param_dict: dict, organized dictonnary with all parameters results of the different files
376
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
377
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
378
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
379
         sersic and shapelets parameters when available)
380
        """
381

382
        #find the numer of parameters to plot and define a nice looking figure
383
        number_param = len(param_dict[self.file_names[idx_file]])
×
384
        unused_figs = []
×
385
        if number_param <= 4:
×
386
            print('so few parameters not implemented yet')
×
387
        else:
388
            if number_param % 4 == 0:
×
389
                num_lines = int(number_param / 4.)
×
390
            else:
391
                num_lines = int(number_param / 4.) + 1
×
392

393
                for idx in range(3):
×
394
                    if (number_param + idx) % 4 != 0:
×
395
                        unused_figs.append(-idx - 1)
×
396
                    else:
397
                        break
×
398

399
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
400
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
401
        #may find a better way to define markers but right now, it is sufficient
402

403
        for j, file_name in enumerate(self.file_names):
×
404
            i = 0
×
405
            result = param_dict[file_name]
×
406
            for key in result.keys():
×
407
                idx_line = int(i / 4.)
×
408
                idx_col = i % 4
×
409
                p = result[key]
×
410
                m = markers[j]
×
411
                if self.posterior_bool_list[j]:
×
412
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
413
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
414
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
415
                    #                 i+=1
416
                    #                 continue
417

418
                    #trick to plot correct error bars if close to the +180/-180 edge
419
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
420
                        if p['percentile_16th'] > p['median']:
×
421
                            p['percentile_16th'] -= 180.
×
422
                        if p['percentile_84th'] < p['median']:
×
423
                            p['percentile_84th'] += 180.
×
424
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
425
                                                                    [p['percentile_84th'] - p['median']]],
426
                                                   marker=m, ls='', label=file_name)
427
                else:
428
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
429

430
                if j == 0:
×
431
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
432
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
433
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
434
                i += 1
×
435

436
        ax[0, 0].legend()
×
437
        for idx in unused_figs:
×
438
            ax[-1, idx].axis('off')
×
439
        plt.tight_layout()
×
440
        plt.show()
×
441
        return f,ax
×
442
    
443
    def plot_source(self,idx_file=0):
×
444
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
445
        return f,ax
×
446
    
447
    def plot_lens(self,idx_file=0):
×
448
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
449
        return f,ax
×
450

451

452

453

454

455
def plot_corner(parameter_id_list,chain_objs,chain_dirs,chain_names=None,point_estimate_objs=None,point_estimate_dirs=None,point_estimate_names=None,colors=None,labels=None,mc_samples_kwargs=None):
×
456
    """
457
    Adding this as just a function for the moment.
458
    Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
459

460
    Parameters
461
    ----------
462
    parameter_id_list : array
463
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
464
    chain_objs : array
465
        A list of coolest objects that have a chain file associated to them.
466
    chain_dirs : array
467
        A list of paths matching the coolest files in 'chain_objs'.
468
    chain_names : array, optional
469
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
470
    point_estimate_objs : array, optional
471
        A list of coolest objects that will be used as point estimates.
472
    point_estimate_dirs : array
473
        A list of paths matching the coolest files in 'point_estimate_objs'.
474
    point_estimate_names : array, optional
475
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
476
    labels : dict, optional
477
        A dictionary matching the parameter_id_list entries to some human-readable labels.
478
    
479

480
    Returns
481
    -------
482
    An image
483
    """
484

485
    chains.print_load_details = False # Just to silence messages
×
486
    parameter_id_set = set(parameter_id_list)
×
487
    Npars = len(parameter_id_list)
×
488
    
489
    # Get the chain file headers from the first object in the list
490
    chain_file = os.path.join(chain_dirs[0],chain_objs[0].meta["chain_file_name"])
×
491
    
492

493
    # Set the chain names
494
    if chain_names is None:
×
495
        chain_names = ["chain "+str(i) for i in range(0,len(chain_objs))]
×
496
    
497

498
    # Get the values of the point_estimates
499
    point_estimates = []
×
500
    if point_estimate_objs is not None:
×
501
        for coolest_obj in point_estimate_objs:
×
502
            values = []
×
503
            for par in parameter_id_list:
×
504
                param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
505
                val = param.point_estimate.value
×
506
                if val is None:
×
507
                    values.append(None)
×
508
                else:
509
                    values.append(val)
×
510
            point_estimates.append(values)
×
511

512
            
513
    mcsamples = []
×
514
    for i in range(0,len(chain_objs)):
×
515
        chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
516

517
        # Each chain file can have a different number of free parameters
518
        f = open(chain_file)
×
519
        header = f.readline()
×
520
        f.close()
×
521
        chain_file_headers = header.split(',')
×
522
        chain_file_headers.pop() # Remove the last column name that is the probability weights
×
523
        chain_file_headers_set = set(chain_file_headers)
×
524
        
525
        # Check that the given parameters are a subset of those in the chain file
526
        assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
527

528
        # Set the labels for the parameters in the chain file
529
        par_labels = []
×
530
        if labels is None:
×
531
            for par_id in chain_file_headers:
×
532
                if par_id in parameter_id_list:
×
533
                    param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
×
534
                    par_labels.append(param.latex_str.strip('$'))
×
535
                else:
536
                    par_labels.append(par_id)
×
537
        else:
538
            label_keys = list(labels.keys())
×
539
            for par_id in chain_file_headers:
×
540
                if par_id in label_keys:
×
541
                    par_labels.append(labels[par_id])
×
542
                else:
543
                    param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
×
544
                    if param:
×
545
                        par_labels.append(param.latex_str.strip('$'))
×
546
                    else:
547
                        par_labels.append(par_id)
×
548
                    
549
        # Read parameter values and probability weights
550
        samples = np.loadtxt(chain_file,skiprows=1,delimiter=',')
×
551
        sample_par_values = samples[:,:-1]
×
552
        
553
        # Clean-up the probability weights
554
        mypost = samples[:,-1]
×
555
        min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
556
        sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
×
557
        #sample_prob_weight = mypost
558

559
        # Create MCSamples object
560
        mysample = MCSamples(samples=sample_par_values,names=chain_file_headers,labels=par_labels,settings=mc_samples_kwargs)
×
561
        mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
562
        mcsamples.append(mysample)
×
563

564

565
        
566
    # Make the plot
567
    image = plots.getSubplotPlotter(subplot_size=1)    
×
568
    image.triangle_plot(mcsamples,params=parameter_id_list,legend_labels=chain_names,filled=True,colors=colors)
×
569

570
    my_linestyles = ['solid','dotted','dashed','dashdot']
×
571
    my_markers    = ['s','^','o','star']
×
572

573
    for k in range(0,len(point_estimates)):
×
574
        # Add vertical and horizontal lines
575
        for i in range(0,Npars):
×
576
            val = point_estimates[k][i]
×
577
            if val is not None:
×
578
                for ax in image.subplots[i:,i]:
×
579
                    ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
580
                for ax in image.subplots[i,:i]:
×
581
                    ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
582

583
        # Add points
584
        for i in range(0,Npars):
×
585
            val_x = point_estimates[k][i]
×
586
            for j in range(i+1,Npars):
×
587
                val_y = point_estimates[k][j]
×
588
                if val_x is not None and val_y is not None:
×
589
                    image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
×
590
                else:
591
                    pass    
×
592

593
                
594
    return image
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc