• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 6010555315

29 Aug 2023 09:38AM UTC coverage: 54.118%. Remained the same
6010555315

push

github

web-flow
Update plotting corner routine to support samples with lists

4 of 4 new or added lines in 1 file covered. (100.0%)

1347 of 2489 relevant lines covered (54.12%)

0.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots,chains,MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18
import pandas as pd
×
19

20

21
# matplotlib global settings
22
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
23

24
# logging settings
25
logging.getLogger().setLevel(logging.INFO)
×
26

27

28
class ModelPlotter(object):
×
29
    """Create pyplot panels from a lens model stored in the COOLEST format.
30

31
    Parameters
32
    ----------
33
    coolest_object : COOLEST
34
        COOLEST instance
35
    coolest_directory : str, optional
36
        Directory which contains the COOLEST template, by default None
37
    color_bad_values : str, optional
38
        Color assigned to NaN values (typically negative values in log-scale), 
39
        by default '#111111' (dark gray)
40
    """
41

42
    def __init__(self, coolest_object, coolest_directory=None, 
×
43
                 color_bad_values='#111111'):
44
        self.coolest = coolest_object
×
45
        self._directory = coolest_directory
×
46

47
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
48
        self.cmap_flux.set_bad(color_bad_values)
×
49

50
        self.cmap_mag = plt.get_cmap('twilight_shifted')
×
51
        self.cmap_conv = plt.get_cmap('cividis')
×
52
        self.cmap_res = plt.get_cmap('RdBu_r')
×
53

54
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
55
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
56
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
57

58
    def plot_data_image(self, ax, title=None, norm=None, cmap=None, 
×
59
                        neg_values_as_bad=False, add_colorbar=True):
60
        """plt.imshow panel with the data image"""
61
        if cmap is None:
×
62
            cmap = self.cmap_flux
×
63
        coordinates = util.get_coordinates(self.coolest)
×
64
        extent = coordinates.plt_extent
×
65
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
66
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
67
                                cmap=cmap,
68
                                neg_values_as_bad=neg_values_as_bad, 
69
                                norm=norm)
70
        if add_colorbar:
×
71
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
72
            cb.set_label("flux")
×
73
        if title is not None:
×
74
            ax.set_title(title)
×
75
        return image
×
76

77
    def plot_surface_brightness(self, ax, title=None, coordinates=None, 
×
78
                                extent_irreg=None, norm=None, cmap=None, neg_values_as_bad=True,
79
                                plot_points_irreg=False, add_colorbar=True, kwargs_light=None):
80
        """plt.imshow panel showing the surface brightness of the (unlensed)
81
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
82
        if kwargs_light is None:
×
83
            kwargs_light = {}
×
84
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
85
        if cmap is None:
×
86
            cmap = self.cmap_flux
×
87
        if coordinates is not None:
×
88
            x, y = coordinates.pixel_coordinates
×
89
            image = light_model.evaluate_surface_brightness(x, y)
×
90
            extent = coordinates.plt_extent
×
91
            ax, im = self._plot_regular_grid(ax, image, extent=extent, cmap=cmap,
×
92
                                             neg_values_as_bad=neg_values_as_bad, 
93
                                             norm=norm)
94
        else:
95
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
96
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
97
                image = values
×
98
                ax, im = self._plot_regular_grid(ax, image, extent=extent_model, 
×
99
                                        cmap=cmap, 
100
                                        neg_values_as_bad=neg_values_as_bad,
101
                                        norm=norm)
102
            else:
103
                points = values
×
104
                if extent_irreg is None:
×
105
                    extent_irreg = extent_model
×
106
                ax, im = self._plot_irregular_grid(ax, points, extent_irreg, norm=norm, cmap=cmap, 
×
107
                                                   neg_values_as_bad=neg_values_as_bad,
108
                                                   plot_points=plot_points_irreg)
109
                image = None
×
110
        if add_colorbar:
×
111
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
112
            cb.set_label("flux")
×
113
        if title is not None:
×
114
            ax.set_title(title)
×
115
        return image, coordinates
×
116

117
    def plot_model_image(self, ax, supersampling=5, convolved=False, title=None,
×
118
                         norm=None, cmap=None, neg_values_as_bad=False,
119
                         kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None):
120
        """plt.imshow panel showing the surface brightness of the (lensed)
121
        selected lensing entities (see ComposableLensModel docstring)
122
        """
123
        if cmap is None:
×
124
            cmap = self.cmap_flux
×
125
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
126
                                         kwargs_selection_source=kwargs_source,
127
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
128
        image, coordinates = lens_model.model_image(supersampling=supersampling, 
×
129
                                                    convolved=convolved)
130
        extent = coordinates.plt_extent
×
131
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
132
                                cmap=cmap,
133
                                neg_values_as_bad=neg_values_as_bad, 
134
                                norm=norm)
135
        if add_colorbar:
×
136
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
137
            cb.set_label("flux")
×
138
        if title is not None:
×
139
            ax.set_title(title)
×
140
        return image
×
141

142
    def plot_model_residuals(self, ax, supersampling=5, mask=None, title=None,
×
143
                             norm=None, cmap=None, add_chi2_label=False, chi2_fontsize=12,
144
                             kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None):
145
        """plt.imshow panel showing the normalized model residuals image"""
146
        if cmap is None:
×
147
            cmap = self.cmap_res
×
148
        if norm is None:
×
149
            norm = Normalize(-6, 6)
×
150
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
151
                                         kwargs_selection_source=kwargs_source,
152
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
153
        image, coordinates = lens_model.model_residuals(supersampling=supersampling, 
×
154
                                                        mask=mask)
155
        extent = coordinates.plt_extent
×
156
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
157
                                cmap=cmap,
158
                                neg_values_as_bad=False, 
159
                                norm=norm)
160
        if add_colorbar:
×
161
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
162
            cb.set_label("(data $-$ model) / noise")
×
163
        if add_chi2_label is True:
×
164
            num_constraints = np.size(image) if mask is None else np.sum(mask)
×
165
            red_chi2 = np.sum(image**2) / num_constraints
×
166
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
167
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
168
                    bbox={'color': 'white', 'alpha': 0.6})
169
        if title is not None:
×
170
            ax.set_title(title)
×
171
        return image
×
172

173
    def plot_convergence(self, ax, title=None,
×
174
                         norm=None, cmap=None, neg_values_as_bad=False,
175
                         add_colorbar=True, kwargs_lens_mass=None):
176
        """plt.imshow panel showing the 2D convergence map associated to the
177
        selected lensing entities (see ComposableMassModel docstring)
178
        """
179
        if kwargs_lens_mass is None:
×
180
            kwargs_lens_mass = {}
×
181
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
182
                                         **kwargs_lens_mass)
183
        if cmap is None:
×
184
            cmap = self.cmap_conv
×
185
        coordinates = util.get_coordinates(self.coolest)
×
186
        extent = coordinates.plt_extent
×
187
        x, y = coordinates.pixel_coordinates
×
188
        image = mass_model.evaluate_convergence(x, y)
×
189
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
190
                                cmap=cmap,
191
                                neg_values_as_bad=neg_values_as_bad, 
192
                                norm=norm)
193
        if add_colorbar:
×
194
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
195
            cb.set_label(r"$\kappa$")
×
196
        if title is not None:
×
197
            ax.set_title(title)
×
198
        return image
×
199

200
    def plot_magnification(self, ax, title=None,
×
201
                          norm=None, cmap=None, neg_values_as_bad=False,
202
                          add_colorbar=True, kwargs_lens_mass=None):
203
        """plt.imshow panel showing the 2D magnification map associated to the
204
        selected lensing entities (see ComposableMassModel docstring)
205
        """
206
        if kwargs_lens_mass is None:
×
207
            kwargs_lens_mass = {}
×
208
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
209
                                         **kwargs_lens_mass)
210
        if cmap is None:
×
211
            cmap = self.cmap_mag
×
212
        if norm is None:
×
213
            norm = Normalize(-10, 10)
×
214
        coordinates = util.get_coordinates(self.coolest)
×
215
        extent = coordinates.plt_extent
×
216
        x, y = coordinates.pixel_coordinates
×
217
        image = mass_model.evaluate_magnification(x, y)
×
218
        ax, im = self._plot_regular_grid(ax, image, extent=extent, 
×
219
                                cmap=cmap,
220
                                neg_values_as_bad=neg_values_as_bad, 
221
                                norm=norm)
222
        if add_colorbar:
×
223
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
224
            cb.set_label(r"$\mu$")
×
225
        if title is not None:
×
226
            ax.set_title(title)
×
227
        return image
×
228
        
229
    @staticmethod
×
230
    def _plot_regular_grid(ax, image_, neg_values_as_bad=True, **imshow_kwargs):
×
231
        if neg_values_as_bad:
×
232
            image = np.copy(image_)
×
233
            image[image < 0] = np.nan
×
234
        else:
235
            image = image_
×
236
        if neg_values_as_bad:
×
237
            image[image < 0] = np.nan
×
238
        im = ax.imshow(image, **imshow_kwargs)
×
239
        im.set_rasterized(True)
×
240
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
241
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
242
        return ax, im
×
243

244
    @staticmethod
×
245
    def _plot_irregular_grid(ax, points, extent, neg_values_as_bad=True,
×
246
                             norm=None, cmap=None, plot_points=False):
247
        x, y, z = points
×
248
        im = plut.plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
249
                              norm=norm, cmap=cmap, zorder=1)
250
        ax.set_xlim(extent[0], extent[1])
×
251
        ax.set_ylim(extent[2], extent[3])
×
252
        ax.set_aspect('equal', 'box')
×
253
        ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
254
        ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
255
        if plot_points:
×
256
            ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
257
        return ax, im
×
258

259

260
class MultiModelPlotter(object):
×
261
    """Wrapper around a set of ModelPlotter instances to produce panels that
262
    consistently compare different models, evaluated on the same
263
    coordinates systems.
264

265
    Parameters
266
    ----------
267
    coolest_objects : list
268
        List of COOLEST instances
269
    coolest_directories : list, optional
270
        List of directories corresponding to each COOLEST instance, by default None
271
    kwargs_plotter : dict, optional
272
        Additional keyword arguments passed to ModelPlotter
273
    """
274

275
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
276
        self.num_models = len(coolest_objects)
×
277
        if coolest_directories is None:
×
278
            coolest_directories = self.num_models * [None]
×
279
        self.plotter_list = []
×
280
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
281
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
282
                                                  **kwargs_plotter))
283

284
    def plot_surface_brightness(self, axes, global_title="surf. brightness", titles=None, **kwargs):
×
285
        return self._plot_light_multi('plot_surface_brightness', global_title, axes, titles=titles, **kwargs)
×
286

287
    def plot_data_image(self, axes, global_title="data", titles=None, **kwargs):
×
288
        return self._plot_data_multi(global_title, axes, titles=titles, **kwargs)
×
289

290
    def plot_model_image(self, axes, global_title="model", titles=None, **kwargs):
×
291
        return self._plot_lens_model_multi('plot_model_image', global_title, axes, titles=titles, **kwargs)
×
292

293
    def plot_model_residuals(self, axes, global_title="residuals", titles=None, **kwargs):
×
294
        return self._plot_lens_model_multi('plot_model_residuals', global_title, axes, titles=titles, **kwargs)
×
295

296
    def plot_convergence(self, axes, global_title="convergence", titles=None, **kwargs):
×
297
        return self._plot_lens_model_multi('plot_convergence', global_title, axes, titles=titles, **kwargs)
×
298

299
    def plot_magnification(self, axes, titles=None, **kwargs):
×
300
        return self._plot_lens_model_multi('plot_magnification', "magnification", axes, titles=titles, **kwargs)
×
301

302
    def _plot_light_multi(self, method_name, global_title, axes, titles=None, **kwargs):
×
303
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
304
        if titles is None:
×
305
            titles = self.num_models * [global_title]
×
306
        kwargs_ = copy.deepcopy(kwargs)
×
307
        image_list = []
×
308
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
309
            if ax is None:
×
310
                continue
×
311
            if 'kwargs_light' in kwargs:
×
312
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
313
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
314
            image_list.append(image)
×
315
        return image_list
×
316

317
    def _plot_mass_multi(self, method_name, global_title, axes, titles=None, **kwargs):
×
318
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
319
        if titles is None:
×
320
            titles = self.num_models * [global_title]
×
321
        kwargs_ = copy.deepcopy(kwargs)
×
322
        image_list = []
×
323
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
324
            if ax is None:
×
325
                continue
×
326
            if 'kwargs_lens_mass' in kwargs:
×
327
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
328
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
329
            image_list.append(image)
×
330
        return image_list
×
331

332
    def _plot_lens_model_multi(self, method_name, global_title, axes, titles=None, kwargs_select=None, **kwargs):
×
333
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
334
        if titles is None:
×
335
            titles = self.num_models * [global_title]
×
336
        kwargs_ = copy.deepcopy(kwargs)
×
337
        image_list = []
×
338
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
339
            if ax is None:
×
340
                continue
×
341
            if 'kwargs_source' in kwargs:
×
342
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
343
            if 'kwargs_lens_mass' in kwargs:
×
344
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
345
            image = getattr(plotter, method_name)(ax, title=titles[i], **kwargs_)
×
346
            image_list.append(image)
×
347
        return image_list
×
348

349
    def _plot_data_multi(self, global_title, axes, titles=None, **kwargs):
×
350
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
351
        if titles is None:
×
352
            titles = self.num_models * [global_title]
×
353
        image_list = []
×
354
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
355
            if ax is None:
×
356
                continue
×
357
            image = getattr(plotter, 'plot_data_image')(ax, title=titles[i], **kwargs)
×
358
            image_list.append(image)
×
359
        return image_list
×
360

361

362
class Comparison_analytical(object):
×
363
    """Handles plot of analytical models in a comparative way
364

365
    Parameters
366
    ----------
367
    coolest_file_list : list
368
        List of paths to COOLEST templates
369
    nickname_file_list : list
370
        List of shorter names related to each COOLEST model
371
    posterior_bool_list : list
372
        List of bool to toggle errorbars on point-estimate values
373
    """
374

375
    def __init__(self,coolest_file_list, nickname_file_list, posterior_bool_list):
×
376
        self.file_names = nickname_file_list
×
377
        self.posterior_bool_list = posterior_bool_list
×
378
        self.param_lens, self.param_source = util.read_json_param(coolest_file_list,
×
379
                                                                  self.file_names, 
380
                                                                  lens_light=False)
381

382
    def plotting_routine(self,param_dict,idx_file=0):
×
383
        """
384
        plot the parameters
385

386
        INPUT
387
        -----
388
        param_dict: dict, organized dictonnary with all parameters results of the different files
389
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
390
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
391
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
392
         sersic and shapelets parameters when available)
393
        """
394

395
        #find the numer of parameters to plot and define a nice looking figure
396
        number_param = len(param_dict[self.file_names[idx_file]])
×
397
        unused_figs = []
×
398
        if number_param <= 4:
×
399
            print('so few parameters not implemented yet')
×
400
        else:
401
            if number_param % 4 == 0:
×
402
                num_lines = int(number_param / 4.)
×
403
            else:
404
                num_lines = int(number_param / 4.) + 1
×
405

406
                for idx in range(3):
×
407
                    if (number_param + idx) % 4 != 0:
×
408
                        unused_figs.append(-idx - 1)
×
409
                    else:
410
                        break
×
411

412
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
413
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
414
        #may find a better way to define markers but right now, it is sufficient
415

416
        for j, file_name in enumerate(self.file_names):
×
417
            i = 0
×
418
            result = param_dict[file_name]
×
419
            for key in result.keys():
×
420
                idx_line = int(i / 4.)
×
421
                idx_col = i % 4
×
422
                p = result[key]
×
423
                m = markers[j]
×
424
                if self.posterior_bool_list[j]:
×
425
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
426
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
427
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
428
                    #                 i+=1
429
                    #                 continue
430

431
                    #trick to plot correct error bars if close to the +180/-180 edge
432
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
433
                        if p['percentile_16th'] > p['median']:
×
434
                            p['percentile_16th'] -= 180.
×
435
                        if p['percentile_84th'] < p['median']:
×
436
                            p['percentile_84th'] += 180.
×
437
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
438
                                                                    [p['percentile_84th'] - p['median']]],
439
                                                   marker=m, ls='', label=file_name)
440
                else:
441
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
442

443
                if j == 0:
×
444
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
445
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
446
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
447
                i += 1
×
448

449
        ax[0, 0].legend()
×
450
        for idx in unused_figs:
×
451
            ax[-1, idx].axis('off')
×
452
        plt.tight_layout()
×
453
        plt.show()
×
454
        return f,ax
×
455
    
456
    def plot_source(self,idx_file=0):
×
457
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
458
        return f,ax
×
459
    
460
    def plot_lens(self,idx_file=0):
×
461
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
462
        return f,ax
×
463

464

465

466

467

468
def plot_corner(parameter_id_list,chain_objs,chain_dirs,chain_names=None,point_estimate_objs=None,point_estimate_dirs=None,point_estimate_names=None,colors=None,labels=None,subplot_size=1,mc_samples_kwargs=None,filled_contours=True):
×
469
    """
470
    Adding this as just a function for the moment.
471
    Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
472

473
    Parameters
474
    ----------
475
    parameter_id_list : array
476
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
477
    chain_objs : array
478
        A list of coolest objects that have a chain file associated to them.
479
    chain_dirs : array
480
        A list of paths matching the coolest files in 'chain_objs'.
481
    chain_names : array, optional
482
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
483
    point_estimate_objs : array, optional
484
        A list of coolest objects that will be used as point estimates.
485
    point_estimate_dirs : array
486
        A list of paths matching the coolest files in 'point_estimate_objs'.
487
    point_estimate_names : array, optional
488
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
489
    labels : dict, optional
490
        A dictionary matching the parameter_id_list entries to some human-readable labels.
491
    
492

493
    Returns
494
    -------
495
    An image
496
    """
497

498
    chains.print_load_details = False # Just to silence messages
×
499
    parameter_id_set = set(parameter_id_list)
×
500
    Npars = len(parameter_id_list)
×
501
    
502
    # Set the chain names
503
    if chain_names is None:
×
504
        chain_names = ["chain "+str(i) for i in range(0,len(chain_objs))]
×
505
    
506
    # Get the values of the point_estimates
507
    point_estimates = []
×
508
    if point_estimate_objs is not None:
×
509
        for coolest_obj in point_estimate_objs:
×
510
            values = []
×
511
            for par in parameter_id_list:
×
512
                param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
513
                val = param.point_estimate.value
×
514
                if val is None:
×
515
                    values.append(None)
×
516
                else:
517
                    values.append(val)
×
518
            point_estimates.append(values)
×
519

520

521
            
522
    mcsamples = []
×
523
    for i in range(0,len(chain_objs)):
×
524
        chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
525

526
        # Each chain file can have a different number of free parameters
527
        f = open(chain_file)
×
528
        header = f.readline()
×
529
        f.close()
×
530

531
        if ';' in header:
×
532
            raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
×
533

534
        chain_file_headers = header.split(',')
×
535
        num_cols = len(chain_file_headers)
×
536
        chain_file_headers.pop() # Remove the last column name that is the probability weights
×
537
        chain_file_headers_set = set(chain_file_headers)
×
538
        
539
        # Check that the given parameters are a subset of those in the chain file
540
        assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
541

542
        # Set the labels for the parameters in the chain file
543
        par_labels = []
×
544
        if labels is None:
×
545
            labels = {}
×
546
        for par_id in parameter_id_list:
×
547
            if labels.get(par_id, None) is None:
×
548
                param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
×
549
                par_labels.append(param.latex_str.strip('$'))
×
550
            else:
551
                par_labels.append(labels[par_id])
×
552
                    
553
        # Read parameter values and probability weights
554
        # TODO: handle samples that are given as a list / array (e.g. using `converters`)
555
        column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
×
556
        columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
×
557
        # samples = np.loadtxt(chain_file, usecols=columns_to_read, skiprows=1, delimiter=',', comments=None)
558
        samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
×
559
        
560
        #sample_par_values = samples[:, :-1]
561

562
        # Re-order columnds to match parameter_id_list and par_labels
563
        sample_par_values = np.array(samples[parameter_id_list])
×
564

565
        # Clean-up the probability weights
566
        mypost = np.array(samples['probability_weights'])
×
567
        min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
568
        sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
×
569
        #sample_prob_weight = mypost
570

571
        # Create MCSamples object
572
        mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
×
573
        mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
574
        mcsamples.append(mysample)
×
575

576

577
        
578
    # Make the plot
579
    image = plots.getSubplotPlotter(subplot_size=subplot_size)    
×
580
    image.triangle_plot(mcsamples,
×
581
                        params=parameter_id_list,
582
                        legend_labels=chain_names,
583
                        filled=filled_contours,
584
                        colors=colors,
585
                        line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors], 
586
                        contour_colors=colors)
587

588

589
    my_linestyles = ['solid','dotted','dashed','dashdot']
×
590
    my_markers    = ['s','^','o','star']
×
591

592
    for k in range(0,len(point_estimates)):
×
593
        # Add vertical and horizontal lines
594
        for i in range(0,Npars):
×
595
            val = point_estimates[k][i]
×
596
            if val is not None:
×
597
                for ax in image.subplots[i:,i]:
×
598
                    ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
599
                for ax in image.subplots[i,:i]:
×
600
                    ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
×
601

602
        # Add points
603
        for i in range(0,Npars):
×
604
            val_x = point_estimates[k][i]
×
605
            for j in range(i+1,Npars):
×
606
                val_y = point_estimates[k][j]
×
607
                if val_x is not None and val_y is not None:
×
608
                    image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
×
609
                else:
610
                    pass    
×
611

612

613
    # Set default ranges for angles
614
    for i in range(0,len(parameter_id_list)):
×
615
        dum = parameter_id_list[i].split('-')
×
616
        name = dum[-1]
×
617
        if name in ['phi','phi_ext']:
×
618
            xlim = image.subplots[i,i].get_xlim()
×
619
            #print(xlim)
620
        
621
            if xlim[0] < -90:
×
622
                for ax in image.subplots[i:,i]:
×
623
                    ax.set_xlim(left=-90)
×
624
                for ax in image.subplots[i,:i]:
×
625
                    ax.set_ylim(bottom=-90)
×
626
            if xlim[1] > 90:
×
627
                for ax in image.subplots[i:,i]:
×
628
                    ax.set_xlim(right=90)
×
629
                for ax in image.subplots[i,:i]:
×
630
                    ax.set_ylim(top=90)
×
631

632
            
633
    return image
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc