• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 8296930474

15 Mar 2024 01:23PM UTC coverage: 49.288% (-0.2%) from 49.489%
8296930474

push

github

aymgal
Fix issue with no specific linestyles was passed to ParametersPlotter

0 of 17 new or added lines in 1 file covered. (0.0%)

3 existing lines in 1 file now uncovered.

1453 of 2948 relevant lines covered (49.29%)

0.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plotting.py
1
__author__ = 'aymgal', 'lynevdv', 'gvernard'
×
2

3

4
import os
×
5
import copy
×
6
import logging
×
7
import numpy as np
×
8
import matplotlib.pyplot as plt
×
9
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
10
from matplotlib.colors import ListedColormap
×
11
from getdist import plots, chains, MCSamples
×
12

13
from coolest.api.analysis import Analysis
×
14
from coolest.api.composable_models import *
×
15
from coolest.api import util
×
16
from coolest.api import plot_util as plut
×
17

18
import pandas as pd
×
19

20

21
# matplotlib global settings
22
plt.rc('image', interpolation='none', origin='lower') # imshow settings
×
23

24
# logging settings
25
logging.getLogger().setLevel(logging.INFO)
×
26

27
# TODO: separate ParametersPlotter from ModelPlotter to avoid dependencies on getdist
28

29

30
class ModelPlotter(object):
×
31
    """Create pyplot panels from a lens model stored in the COOLEST format.
32

33
    Parameters
34
    ----------
35
    coolest_object : COOLEST
36
        COOLEST instance
37
    coolest_directory : str, optional
38
        Directory which contains the COOLEST template, by default None
39
    color_bad_values : str, optional
40
        Color assigned to NaN values (typically negative values in log-scale), 
41
        by default '#111111' (dark gray)
42
    """
43

44
    def __init__(self, coolest_object, coolest_directory=None, 
×
45
                 color_bad_values='#222222'):
46
        self.coolest = coolest_object
×
47
        self._directory = coolest_directory
×
48

49
        self.cmap_flux = copy.copy(plt.get_cmap('magma'))
×
50
        self.cmap_flux.set_bad(color_bad_values)
×
51

52
        self.cmap_mag = plt.get_cmap('viridis')
×
53
        self.cmap_conv = plt.get_cmap('cividis')
×
54
        self.cmap_res = plt.get_cmap('RdBu_r')
×
55

56
        #cmap_colors = self.cmap_flux(np.linspace(0, 1, 256))
57
        #cmap_colors[0,:] = [0.15, 0.15, 0.15, 1.0]  # Set the color of the very first value to gray
58
        #self.cmap_flux_mod = ListedColormap(cmap_colors)
59

60
    def plot_data_image(self, ax, norm=None, cmap=None, xylim=None,
×
61
                        neg_values_as_bad=False, add_colorbar=True):
62
        """plt.imshow panel with the data image"""
63
        if cmap is None:
×
64
            cmap = self.cmap_flux
×
65
        coordinates = util.get_coordinates(self.coolest)
×
66
        extent = coordinates.plt_extent
×
67
        image = self.coolest.observation.pixels.get_pixels(directory=self._directory)
×
68
        ax, im = plut.plot_regular_grid(ax, image, extent=extent, 
×
69
                                cmap=cmap, norm=norm,
70
                                neg_values_as_bad=neg_values_as_bad, 
71
                                xylim=xylim)
72
        if add_colorbar:
×
73
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
74
            cb.set_label("flux")
×
75
        return image
×
76

77
    def plot_surface_brightness(self, ax, coordinates=None,
×
78
                                extent_irreg=None, norm=None, cmap=None, 
79
                                xylim=None, neg_values_as_bad=True,
80
                                plot_points_irreg=False, add_colorbar=True,
81
                                kwargs_light=None,
82
                                plot_caustics=None, caustics_color='white', caustics_alpha=0.5,
83
                                coordinates_lens=None, kwargs_lens_mass=None):
84
        """plt.imshow panel showing the surface brightness of the (unlensed)
85
        lensing entity selected via kwargs_light (see ComposableLightModel docstring)"""
86
        if extent_irreg is not None:
×
87
            raise ValueError("`extent_irreg` is deprecated; use `xylim` instead.")
×
88
        if kwargs_light is None:
×
89
            kwargs_light = {}
×
90
        light_model = ComposableLightModel(self.coolest, self._directory, **kwargs_light)
×
91
        if plot_caustics:
×
92
            if kwargs_lens_mass is None:
×
93
                raise ValueError("`kwargs_lens_mass` must be provided to compute caustics")
×
94
            if coordinates_lens is None:
×
95
                coordinates_lens = util.get_coordinates(self.coolest).create_new_coordinates(pixel_scale_factor=0.1)
×
96
            # NOTE: here we assume that `kwargs_light` is for the source!
97
            mass_model = ComposableMassModel(self.coolest, self._directory, **kwargs_lens_mass)
×
98
            _, caustics = util.find_all_lens_lines(coordinates_lens, mass_model)
×
99
        if cmap is None:
×
100
            cmap = self.cmap_flux
×
101
        if coordinates is not None:
×
102
            x, y = coordinates.pixel_coordinates
×
103
            image = light_model.evaluate_surface_brightness(x, y)
×
104
            extent = coordinates.plt_extent
×
105
            ax, im = plut.plot_regular_grid(ax, image, extent=extent, cmap=cmap,
×
106
                                             neg_values_as_bad=neg_values_as_bad, 
107
                                             norm=norm, xylim=xylim)
108
        else:
109
            values, extent_model, coordinates = light_model.surface_brightness(return_extra=True)
×
110
            if isinstance(values, np.ndarray) and len(values.shape) == 2:
×
111
                image = values
×
112
                ax, im = plut.plot_regular_grid(ax, image, extent=extent_model, 
×
113
                                        cmap=cmap, 
114
                                        neg_values_as_bad=neg_values_as_bad,
115
                                        norm=norm, xylim=xylim)
116
            else:
117
                points = values
×
118
                if xylim is None:
×
119
                    xylim = extent_model
×
120
                ax, im = plut.plot_irregular_grid(ax, points, xylim, norm=norm, cmap=cmap, 
×
121
                                                   neg_values_as_bad=neg_values_as_bad,
122
                                                   plot_points=plot_points_irreg)
123
                image = None
×
124
        if plot_caustics:
×
125
            for caustic in caustics:
×
126
                ax.plot(caustic[0], caustic[1], lw=1, color=caustics_color, alpha=caustics_alpha)
×
127
        if add_colorbar:
×
128
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
129
            cb.set_label("flux")
×
130
        return image, coordinates
×
131

132
    def plot_model_image(self, ax,
×
133
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
134
                         kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None,
135
                         **model_image_kwargs):
136
        """plt.imshow panel showing the surface brightness of the (lensed)
137
        selected lensing entities (see ComposableLensModel docstring)
138
        """
139
        if cmap is None:
×
140
            cmap = self.cmap_flux
×
141
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
142
                                         kwargs_selection_source=kwargs_source,
143
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
144
        image, coordinates = lens_model.model_image(**model_image_kwargs)
×
145
        extent = coordinates.plt_extent
×
146
        ax, im = plut.plot_regular_grid(ax, image, extent=extent, 
×
147
                                cmap=cmap,
148
                                neg_values_as_bad=neg_values_as_bad, 
149
                                norm=norm, xylim=xylim)
150
        if add_colorbar:
×
151
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
152
            cb.set_label("flux")
×
153
        return image
×
154

155
    def plot_model_residuals(self, ax, mask=None,
×
156
                             norm=None, cmap=None, xylim=None, add_chi2_label=False, chi2_fontsize=12,
157
                             kwargs_source=None, add_colorbar=True, kwargs_lens_mass=None,
158
                             **model_image_kwargs):
159
        """plt.imshow panel showing the normalized model residuals image"""
160
        if cmap is None:
×
161
            cmap = self.cmap_res
×
162
        if norm is None:
×
163
            norm = Normalize(-6, 6)
×
164
        lens_model = ComposableLensModel(self.coolest, self._directory,
×
165
                                         kwargs_selection_source=kwargs_source,
166
                                         kwargs_selection_lens_mass=kwargs_lens_mass)
167
        image, coordinates = lens_model.model_residuals(mask=mask, **model_image_kwargs)
×
168
        extent = coordinates.plt_extent
×
169
        ax, im = plut.plot_regular_grid(ax, image, extent=extent, 
×
170
                                cmap=cmap,
171
                                neg_values_as_bad=False, 
172
                                norm=norm, xylim=xylim)
173
        if add_colorbar:
×
174
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
175
            cb.set_label("(data $-$ model) / noise")
×
176
        if add_chi2_label is True:
×
177
            num_constraints = np.size(image) if mask is None else np.sum(mask)
×
178
            red_chi2 = np.sum(image**2) / num_constraints
×
179
            ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, 
×
180
                    fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
181
                    bbox={'color': 'white', 'alpha': 0.6})
182
        return image
×
183

184
    def plot_convergence(self, ax, coordinates=None,
×
185
                         norm=None, cmap=None, xylim=None, neg_values_as_bad=False,
186
                         add_colorbar=True, kwargs_lens_mass=None):
187
        """plt.imshow panel showing the 2D convergence map associated to the
188
        selected lensing entities (see ComposableMassModel docstring)
189
        """
190
        if kwargs_lens_mass is None:
×
191
            kwargs_lens_mass = {}
×
192
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
193
                                         **kwargs_lens_mass)
194
        if cmap is None:
×
195
            cmap = self.cmap_conv
×
196
        if coordinates is None:
×
197
            coordinates = util.get_coordinates(self.coolest)
×
198
        extent = coordinates.plt_extent
×
199
        x, y = coordinates.pixel_coordinates
×
200
        image = mass_model.evaluate_convergence(x, y)
×
201
        ax, im = plut.plot_regular_grid(ax, image, extent=extent, 
×
202
                                cmap=cmap,
203
                                neg_values_as_bad=neg_values_as_bad, 
204
                                norm=norm, xylim=xylim)
205
        if add_colorbar:
×
206
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
207
            cb.set_label(r"$\kappa$")
×
208
        return image
×
209
    
210
    def plot_convergence_diff(
×
211
            self, ax, reference_map, relative_error=True,    
212
            norm=None, cmap=None, xylim=None, coordinates=None,
213
            add_colorbar=True, kwargs_lens_mass=None,
214
            plot_crit_lines=False, crit_lines_color='black', crit_lines_alpha=0.5):
215
        """plt.imshow panel showing the 2D convergence map associated to the
216
        selected lensing entities (see ComposableMassModel docstring)
217
        """
218
        if kwargs_lens_mass is None:
×
219
            kwargs_lens_mass = {}
×
220
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
221
                                         **kwargs_lens_mass)
222
        if cmap is None:
×
223
            cmap = self.cmap_res
×
224
        if norm is None:
×
225
            norm = Normalize(-1, 1)
×
226
        if coordinates is None:
×
227
            coordinates = util.get_coordinates(self.coolest)
×
228
        if plot_crit_lines:
×
229
            critical_lines, _ = util.find_all_lens_lines(coordinates, mass_model)
×
230
        extent = coordinates.plt_extent
×
231
        x, y = coordinates.pixel_coordinates
×
232
        image = mass_model.evaluate_convergence(x, y)
×
233
        if relative_error is True:
×
234
            diff = (reference_map - image) / reference_map
×
235
        else:
236
            diff = reference_map - image
×
237
        ax, im = plut.plot_regular_grid(ax, diff, extent=extent, 
×
238
                                cmap=cmap, 
239
                                norm=norm, xylim=xylim)
240
        if plot_crit_lines:
×
241
            for cline in critical_lines:
×
242
                ax.plot(cline[0], cline[1], lw=1, color=crit_lines_color, alpha=crit_lines_alpha)
×
243
        if add_colorbar:
×
244
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
245
            cb.set_label(r"$\kappa$")
×
246
        return image
×
247

248
    def plot_magnification(self, ax, 
×
249
                          norm=None, cmap=None, xylim=None,
250
                          add_colorbar=True, coordinates=None, kwargs_lens_mass=None):
251
        """plt.imshow panel showing the 2D magnification map associated to the
252
        selected lensing entities (see ComposableMassModel docstring)
253
        """
254
        if kwargs_lens_mass is None:
×
255
            kwargs_lens_mass = {}
×
256
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
257
                                         **kwargs_lens_mass)
258
        if cmap is None:
×
259
            cmap = self.cmap_mag
×
260
        if norm is None:
×
261
            norm = Normalize(-10, 10)
×
262
        if coordinates is None:
×
263
            coordinates = util.get_coordinates(self.coolest)
×
264
        x, y = coordinates.pixel_coordinates
×
265
        extent = coordinates.plt_extent
×
266
        image = mass_model.evaluate_magnification(x, y)
×
267
        ax, im = plut.plot_regular_grid(ax, image, extent=extent, 
×
268
                                cmap=cmap, 
269
                                norm=norm, xylim=xylim)
270
        if add_colorbar:
×
271
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
272
            cb.set_label(r"$\mu$")
×
273
        return image
×
274

275
    def plot_magnification_diff(
×
276
            self, ax, reference_map, relative_error=True,
277
            norm=None, cmap=None, xylim=None,
278
            add_colorbar=True, coordinates=None, kwargs_lens_mass=None):
279
        """plt.imshow panel showing the (absolute or relative) 
280
        difference between 2D magnification maps
281
        """
282
        if kwargs_lens_mass is None:
×
283
            kwargs_lens_mass = {}
×
284
        mass_model = ComposableMassModel(self.coolest, self._directory,
×
285
                                        **kwargs_lens_mass)
286
        if cmap is None:
×
287
            cmap = self.cmap_res
×
288
        if norm is None:
×
289
            norm = Normalize(-1, 1)
×
290
        if coordinates is None:
×
291
            coordinates = util.get_coordinates(self.coolest)
×
292
        x, y = coordinates.pixel_coordinates
×
293
        extent = coordinates.plt_extent
×
294
        image = mass_model.evaluate_magnification(x, y)
×
295
        if relative_error is True:
×
296
            diff = (reference_map - image) / reference_map
×
297
        else:
298
            diff = reference_map - image
×
299
        ax, im = plut.plot_regular_grid(ax, diff, extent=extent, 
×
300
                                cmap=cmap,
301
                                norm=norm, xylim=xylim)
302
        if add_colorbar:
×
303
            cb = plut.nice_colorbar(im, ax=ax, max_nbins=4)
×
304
            cb.set_label(r"$\mu$")
×
305
        return image
×
306

307

308
class MultiModelPlotter(object):
×
309
    """Wrapper around a set of ModelPlotter instances to produce panels that
310
    consistently compare different models, evaluated on the same
311
    coordinates systems.
312

313
    Parameters
314
    ----------
315
    coolest_objects : list
316
        List of COOLEST instances
317
    coolest_directories : list, optional
318
        List of directories corresponding to each COOLEST instance, by default None
319
    kwargs_plotter : dict, optional
320
        Additional keyword arguments passed to ModelPlotter
321
    """
322

323
    def __init__(self, coolest_objects, coolest_directories=None, **kwargs_plotter):
×
324
        self.num_models = len(coolest_objects)
×
325
        if coolest_directories is None:
×
326
            coolest_directories = self.num_models * [None]
×
327
        self.plotter_list = []
×
328
        for coolest, c_dir in zip(coolest_objects, coolest_directories):
×
329
            self.plotter_list.append(ModelPlotter(coolest, coolest_directory=c_dir,
×
330
                                                  **kwargs_plotter))
331

332
    def plot_surface_brightness(self, axes, **kwargs):
×
333
        return self._plot_light_multi('plot_surface_brightness',axes, **kwargs)
×
334

335
    def plot_data_image(self, axes, **kwargs):
×
336
        return self._plot_data_multi(axes, **kwargs)
×
337

338
    def plot_model_image(self, axes, **kwargs):
×
339
        return self._plot_lens_model_multi('plot_model_image', axes, **kwargs)
×
340

341
    def plot_model_residuals(self, axes, **kwargs):
×
342
        return self._plot_lens_model_multi('plot_model_residuals', axes, **kwargs)
×
343

344
    def plot_convergence(self, axes, **kwargs):
×
345
        return self._plot_lens_model_multi('plot_convergence', axes, **kwargs)
×
346

347
    def plot_magnification(self, axes, **kwargs):
×
348
        return self._plot_lens_model_multi('plot_magnification', axes, **kwargs)
×
349

350
    def plot_convergence_diff(self, axes, *args, **kwargs):
×
351
        return self._plot_lens_model_multi('plot_convergence_diff', axes, *args, **kwargs)
×
352

353
    def plot_magnification_diff(self, axes, *args, **kwargs):
×
354
        return self._plot_lens_model_multi('plot_magnification_diff', axes, *args, **kwargs)
×
355

356
    def _plot_light_multi(self, method_name, axes, **kwargs):
×
357
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
358
        kwargs_ = copy.deepcopy(kwargs)
×
359
        image_list = []
×
360
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
361
            if ax is None:
×
362
                continue
×
363
            if 'kwargs_light' in kwargs:
×
364
                kwargs_['kwargs_light'] = {k: v[i] for k, v in kwargs['kwargs_light'].items()}
×
365
            if 'kwargs_lens_mass' in kwargs:  # used for over-plotting caustics
×
366
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
367
            image = getattr(plotter, method_name)(ax, **kwargs_)
×
368
            image_list.append(image)
×
369
        return image_list
×
370

371
    def _plot_mass_multi(self, method_name, axes, **kwargs):
×
372
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
373
        kwargs_ = copy.deepcopy(kwargs)
×
374
        image_list = []
×
375
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
376
            if ax is None:
×
377
                continue
×
378
            if 'kwargs_lens_mass' in kwargs:
×
379
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
380
            image = getattr(plotter, method_name)(ax, **kwargs_)
×
381
            image_list.append(image)
×
382
        return image_list
×
383

384
    def _plot_lens_model_multi(self, method_name, axes, *args, **kwargs):
×
385
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
386
        kwargs_ = copy.deepcopy(kwargs)
×
387
        image_list = []
×
388
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
389
            if ax is None:
×
390
                continue
×
391
            if 'kwargs_source' in kwargs:
×
392
                kwargs_['kwargs_source'] = {k: v[i] for k, v in kwargs['kwargs_source'].items()}
×
393
            if 'kwargs_lens_mass' in kwargs:
×
394
                kwargs_['kwargs_lens_mass'] = {k: v[i] for k, v in kwargs['kwargs_lens_mass'].items()}
×
395
            image = getattr(plotter, method_name)(ax, *args, **kwargs_)
×
396
            image_list.append(image)
×
397
        return image_list
×
398

399
    def _plot_data_multi(self, axes, **kwargs):
×
400
        assert len(axes) == self.num_models, "Inconsistent number of subplot axes"
×
401
        image_list = []
×
402
        for i, (ax, plotter) in enumerate(zip(axes, self.plotter_list)):
×
403
            if ax is None:
×
404
                continue
×
405
            image = getattr(plotter, 'plot_data_image')(ax, **kwargs)
×
406
            image_list.append(image)
×
407
        return image_list
×
408

409

410
class ParametersPlotter(object):
×
411
    """Handles plot of analytical models in a comparative way
412

413
    Parameters
414
    ----------
415
    parameter_id_list : array
416
        A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
417
    coolest_objects : array
418
        A list of coolest objects that have a chain file associated to them.
419
    coolest_directories : array
420
        A list of paths matching the coolest files in 'chain_objs'.
421
    coolest_names : array, optional
422
        A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
423
    ref_coolest_objects : array, optional
424
        A list of coolest objects that will be used as point estimates.
425
    ref_coolest_directories : array
426
        A list of paths matching the coolest files in 'point_estimate_objs'.
427
    ref_coolest_names : array, optional
428
        A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
429
    posterior_bool_list : list, optional
430
        List of bool to toggle errorbars on point-estimate values
431
    colors : list, optional
432
        List of pyplot color names to associate to each coolest model.
433
    linestyles : list, optional
434
        List of pyplot linesyles to associate to each coolest model.
435
    add_multivariate_margin_samples : bool, optional
436
        If True, will append to the list of compared models
437
        a new chain that is resampled from the multi-variate normal distribution,
438
        where its covariance matrix is computed from the marginalization of
439
        all samples from all models. By default False. 
440
    num_samples_per_model_margin : int, optional
441
        Number of samples to (randomly) draw from each model samples to concatenate
442
        before estimating the multi-variate normal marginalization.
443
    """
444

445
    np.random.seed(598237)  # fix the random seed for reproducibility
×
446
    
447
    def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
×
448
                 ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
449
                 posterior_bool_list=None, colors=None, linestyles=None,
450
                 add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
451
        self.parameter_id_list = parameter_id_list
×
452
        self.coolest_objects = coolest_objects
×
453
        self.coolest_directories = coolest_directories
×
454
        if coolest_names is None:
×
455
            coolest_names = ["Model "+str(i) for i in range(len(coolest_objects))]
×
456
        self.coolest_names = coolest_names
×
457
        self.ref_coolest_objects = ref_coolest_objects
×
458
        self.ref_coolest_directories = ref_coolest_directories
×
459
        self.ref_coolest_names = ref_coolest_names
×
460
        self.ref_file_names = ref_coolest_names
×
461

462
        self.num_models = len(self.coolest_objects)
×
463
        self.num_params = len(self.parameter_id_list)
×
464
        if colors is None:
×
465
            colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
×
466
        self.colors = colors
×
NEW
467
        if linestyles is None:
×
NEW
468
            linestyles = ['-']*self.num_models
×
469
        self.linestyles = linestyles
×
470
        self.ref_linestyles = ['--', ':', '-.', '-']
×
471
        self.ref_markers = ['s', '^', 'o', '*']
×
472

473
        self._add_margin_samples = add_multivariate_margin_samples
×
474
        self._ns_per_model_margin = num_samples_per_model_margin
×
475
        self._color_margin = 'black'
×
476
        self._label_margin = "Combined"
×
477

478
        # self.posterior_bool_list = posterior_bool_list
479
        # self.param_lens, self.param_source = util.split_lens_source_params(
480
        #     self.coolest_objects, self.coolest_names, lens_light=False)
481

482
    def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
×
483
                     add_multivariate_margin_samples=False):
484
        """Initializes the getdist plotter.
485

486
        Parameters
487
        ----------
488
        shift_sample_list : dict
489
            Dictionary keyed by parameter ID to apply a uniform additive shift to
490
            all samples of that parameters posterior distribution.
491
        settings_mcsamples : dict, optional
492
            Keyword arguments passed as the `settings` argument of getdist.MCSamples, by default None
493

494
        Raises
495
        ------
496
        ValueError
497
            If the csv file containing samples is is not coma (,) separated.
498
        """
499
        chains.print_load_details = False # Just to silence messages
×
500
        parameter_id_set = set(self.parameter_id_list)
×
501

502
        if shift_sample_list is None:
×
503
            shift_sample_list = [None]*self.num_models
×
504
        
505
        # Get the values of the point_estimates
506
        point_estimates = []
×
507
        if self.ref_coolest_objects is not None:
×
508
            for coolest_obj in self.ref_coolest_objects:
×
509
                values = []
×
510
                for par in self.parameter_id_list:
×
511
                    param = coolest_obj.lensing_entities.get_parameter_from_id(par)
×
512
                    val = param.point_estimate.value
×
513
                    if val is None:
×
514
                        values.append(None)
×
515
                    else:
516
                        values.append(val)
×
517
                point_estimates.append(values)
×
518

519
        mcsamples = []
×
520
        samples_margin, weights_margin = None, None
×
521
        mysample_margin = None
×
522
        for i in range(self.num_models):
×
523
            chain_file = os.path.join(self.coolest_directories[i],self.coolest_objects[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
×
524

525
            # Each chain file can have a different number of free parameters
526
            f = open(chain_file)
×
527
            header = f.readline()
×
528
            f.close()
×
529

530
            if ';' in header:
×
531
                raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
×
532

533
            chain_file_headers = header.split(',')
×
534
            num_cols = len(chain_file_headers)
×
535
            chain_file_headers.pop() # Remove the last column name that is the probability weights
×
536
            chain_file_headers_set = set(chain_file_headers)
×
537
            
538
            # Check that the given parameters are a subset of those in the chain file
539
            assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
×
540

541
            # Set the labels for the parameters in the chain file
542
            labels = []
×
543
            for par_id in self.parameter_id_list:
×
544
                param = self.coolest_objects[i].lensing_entities.get_parameter_from_id(par_id)
×
545
                labels.append(param.latex_str.strip('$'))
×
546

547
            # Read parameter values and probability weights
548
            column_indices = [chain_file_headers.index(par_id) for par_id in self.parameter_id_list]
×
549
            columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
×
550
            samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
×
551
        
552
            # Re-order columns to match self.parameter_id_list and labels
553
            sample_par_values = np.array(samples[self.parameter_id_list])
×
554

555
            # If needed, shift samples by a constant
556
            if shift_sample_list[i] is not None:
×
557
                for param_id, value in shift_sample_list[i].items():
×
558
                    sample_par_values[:, self.parameter_id_list.index(param_id)] += value
×
559
                    logging.info(f"posterior for parameter '{param_id}' from model '{self.coolest_names[i]}' "
×
560
                                 f"has been shifted by {value}.")
561

562
            # Clean-up the probability weights
563
            mypost = np.array(samples['probability_weights'])
×
564
            min_non_zero = np.min(mypost[np.nonzero(mypost)])
×
565
            sample_prob_weight = np.where(mypost<min_non_zero, min_non_zero, mypost)
×
566
            #sample_prob_weight = mypost
567

568
            # Create MCSamples object
569
            mysample = MCSamples(samples=sample_par_values, names=self.parameter_id_list,
×
570
                                 labels=labels, settings=settings_mcsamples)
571
            mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
×
572
            mcsamples.append(mysample)
×
573

574
            # if required, aggregate the samples in a "marginalized" posterior
575
            if self._add_margin_samples:
×
576
                if i == 0:
×
577
                    mysample_margin = copy.deepcopy(mysample)
×
578
                else:
579
                    # combine the sample such that the probability mass of each set of samples is the same
580
                    mysample_margin = mysample_margin.getCombinedSamplesWithSamples(mysample, sample_weights=(1, 1))
×
581
        
582
        if self._add_margin_samples:
×
583
            mcsamples.append(mysample_margin)
×
584

585
        self._mcsamples = mcsamples
×
586
        self.ref_values = point_estimates
×
587
        self.ref_values_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
×
588

589
    def get_mcsamples_getdist(self, with_margin=False):
×
590
        if not self._add_margin_samples or with_margin:
×
591
            return self._mcsamples
×
592
        else:
593
            return self._mcsamples[:-1]
×
594
    
595
    def get_margin_mcsamples_getdist(self):
×
596
        if not self._add_margin_samples:
×
597
            return None
×
598
        else:
599
            return self._mcsamples[-1]
×
600
    
601
    def plot_triangle_getdist(self, filled_contours=True, angles_range=None, 
×
602
                              linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
603
                              marker_linewidth=2, marker_size=15, 
604
                              axes_labelsize=None, legend_fontsize=None,
605
                              **subplot_kwargs):
606
        """Corner array of subplots using getdist.triangle_plot method.
607

608
        Parameters
609
        ----------
610
        subplot_size : int, optional
611
            Size of the getdist plot, by default 1
612
        filled_contours : bool, optional
613
            Wether or not to fill the 2D contours, by default True
614
        angles_range : _type_, optional
615
            Restrict the range of angle (containing 'phi' in their name) parameters, by default None
616
        linewidth_hist : int, optional
617
            Line width for 1D histograms, by default 2
618
        linewidth_cont : int, optional
619
            Line width for 2D contours, by default 1
620
        marker_size : int, optional
621
            Size of the reference (scatter) markers on 2D contours plots, by default 15
622

623
        Returns
624
        -------
625
        GetDistPlotter
626
            Instance of GetDistPlotter corresponding to the figure
627
        """
628
        line_args, contour_lws, contour_ls, colors, legend_labels \
×
629
            = self._prepare_getdist_plot(linewidth_hist, 
630
                                         lw_cont=linewidth_cont, 
631
                                         lw_margin=linewidth_margin)
632
        
633
        filled_contours = [filled_contours]*len(self._mcsamples)
×
634
        alphas = [1]*len(self._mcsamples)
×
635
        if self._add_margin_samples:
×
636
            filled_contours[-1] = True
×
637
            # alphas[-1] = 0.7
638
    
639
        # Make the plot
640
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
NEW
641
        if legend_fontsize is not None:
×
NEW
642
            g.settings.legend_fontsize = legend_fontsize 
×
NEW
643
        if axes_labelsize is not None:
×
NEW
644
            g.settings.axes_labelsize = axes_labelsize 
×
UNCOV
645
        g.triangle_plot(
×
646
            self._mcsamples,
647
            params=self.parameter_id_list,
648
            legend_labels=legend_labels,
649
            filled=filled_contours,
650
            colors=colors,
651
            line_args=line_args,   # TODO: issue that linewidth settings in line_args are being overwritten by contour_lws
652
            contour_colors=self.colors,
653
            contour_lws=contour_lws,
654
            contour_ls=contour_ls,
655
            alphas=alphas,
656
        )
657
        
658
        # Add marker lines and points
659
        for k in range(0, len(self.ref_values)):
×
660
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], 
×
661
                                lw=marker_linewidth)
662
            for i in range(0,self.num_params):
×
663
                val_x = self.ref_values[k][i]
×
664
                for j in range(i+1,self.num_params):
×
665
                    val_y = self.ref_values[k][j]
×
666
                    if val_x is not None and val_y is not None:
×
667
                        g.subplots[j,i].scatter(val_x, val_y, s=marker_size, facecolors='black',
×
668
                                                color='black', marker=self.ref_markers[k])
669

670

671
        # Set default ranges for angles
672
        if angles_range is None:
×
673
            angles_range = (-90, 90)
×
674
        for i in range(0, len(self.parameter_id_list)):
×
675
            dum = self.parameter_id_list[i].split('-')
×
676
            name = dum[-1]
×
677
            if name in ['phi','phi_ext']:
×
678
                xlim = g.subplots[i,i].get_xlim()
×
679
                #print(xlim)
680
            
681
                if xlim[0] < -90:
×
682
                    for ax in g.subplots[i:,i]:
×
683
                        ax.set_xlim(left=angles_range[0])
×
684
                    for ax in g.subplots[i,:i]:
×
685
                        ax.set_ylim(bottom=angles_range[0])
×
686
                if xlim[1] > 90:
×
687
                    for ax in g.subplots[i:,i]:
×
688
                        ax.set_xlim(right=angles_range[1])
×
689
                    for ax in g.subplots[i,:i]:
×
690
                        ax.set_ylim(top=angles_range[1])
×
691
        return g
×
692
    
693
    def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1, 
×
694
                               legend_ncol=None, legend_fontsize=None, 
695
                               filled_contours=True, linewidth=1,
696
                               marker_size=15, axes_labelsize=None, **subplot_kwargs):
697
        """Array of (2D contours) subplots using getdist.rectangle_plot method.
698

699
        Parameters
700
        ----------
701
        subplot_size : int, optional
702
            Size of the getdist plot, by default 1
703
        filled_contours : bool, optional
704
            Wether or not to fill the 2D contours, by default True
705
        linewidth : int, optional
706
            Line width for 2D contours, by default 1
707
        marker_size : int, optional
708
            Size of the reference (scatter) markers on 2D contours plots, by default 15
709
        legend_ncol : number of columns in the legend
710

711
        Returns
712
        -------
713
        GetDistPlotter
714
            Instance of GetDistPlotter corresponding to the figure
715
        """
716
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
717
        
718
        if legend_ncol is None:
×
719
            legend_ncol = 3
×
720
        # Make the plot
NEW
721
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
NEW
722
        if legend_fontsize is not None:
×
NEW
723
            g.settings.legend_fontsize = legend_fontsize
×
NEW
724
        if axes_labelsize is not None:
×
NEW
725
            g.settings.axes_labelsize = axes_labelsize
×
UNCOV
726
        g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
×
727
                         filled=filled_contours,
728
                         colors=colors,
729
                         legend_ncol=legend_ncol,
730
                         legend_labels=legend_labels,
731
                         line_args=line_args, 
732
                         contour_colors=self.colors)
733
        for k in range(len(self.ref_values)):
×
734
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
735
            for j, key_x in enumerate(x_param_ids):
×
736
                val_x = self.ref_values_markers[k][key_x]
×
737
                for i, key_y in enumerate(y_param_ids):
×
738
                    val_y = self.ref_values_markers[k][key_y]
×
739
                    if val_x is not None and val_y is not None:
×
740
                        g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
×
741
        return g
×
742
    
NEW
743
    def plot_1d_getdist(self, num_columns=None, legend_ncol=None, 
×
744
                        legend_fontsize=None, axes_labelsize=None,
745
                        linewidth=1, **subplot_kwargs):
746
        """Array of 1D histogram subplots using getdist.plots_1d method.
747

748
        Parameters
749
        ----------
750
        subplot_size : int, optional
751
            Size of the getdist plot, by default 1
752
        linewidth : int, optional
753
            Line width for 2D contours, by default 1
754
        marker_size : int, optional
755
            Size of the reference (scatter) markers on 2D contours plots, by default 15
756
        legend_ncol : int, optional
757
            number of columns in the legend
758
        num_columns : int, optional
759
            number of columns of the subplot array
760

761
        Returns
762
        -------
763
        GetDistPlotter
764
            Instance of GetDistPlotter corresponding to the figure
765
        """
766
        line_args, _, _, colors, legend_labels = self._prepare_getdist_plot(linewidth)
×
767

768
        if num_columns is None:
×
769
            num_columns = self.num_models//2+1
×
770
        if legend_ncol is None:
×
771
            legend_ncol = 3
×
772
        # Make the plot
NEW
773
        g = plots.get_subplot_plotter(**subplot_kwargs)
×
NEW
774
        if legend_fontsize is not None:
×
NEW
775
            g.settings.legend_fontsize = legend_fontsize
×
NEW
776
        if axes_labelsize is not None:
×
NEW
777
            g.settings.axes_labelsize = axes_labelsize
×
UNCOV
778
        g.plots_1d(self._mcsamples,
×
779
                   params=self.parameter_id_list,
780
                   legend_labels=legend_labels,
781
                   colors=colors,
782
                   share_y=True,
783
                   line_args=line_args,
784
                   nx=num_columns, legend_ncol=legend_ncol,
785
        )
786
        for k in range(len(self.ref_values)):
×
787
            g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
×
788
        # for k in range(0, len(self.ref_values)):
789
        #     # Add vertical and horizontal lines
790
        #     for i in range(0, self.num_params):
791
        #         val = self.ref_values[k][i]
792
        #         ax = g.subplots.flatten()[i]
793
        #         if val is not None:
794
        #             ax.axvline(val, color='black', ls=self.ref_linestyles[k], alpha=1.0, lw=1)
795
        return g
×
796

797
    def plot_source(self, idx_file=0):
×
798
        f,ax = self.plotting_routine(self.param_source,idx_file)
×
799
        return f,ax
×
800
    
801
    def plot_lens(self, idx_file=0):
×
802
        f,ax = self.plotting_routine(self.param_lens,idx_file)
×
803
        return f,ax
×
804

805
    def plotting_routine(self, param_dict, idx_file=0):
×
806
        """
807
        plot the parameters
808

809
        INPUT
810
        -----
811
        param_dict: dict, organized dictonnary with all parameters results of the different files
812
        idx_file: int, chooses the file on which the choice of plotted parameters will be made
813
        (not very clear: basically in file 0 you may have a sersic fit and in file 1 sersic+shapelets. If you choose
814
         idx_file=0, you will plot the sersic results of both file. If you choose idx_file=1, you will plot all the
815
         sersic and shapelets parameters when available)
816
        """
817

818
        #find the numer of parameters to plot and define a nice looking figure
819
        number_param = len(param_dict[self.file_names[idx_file]])
×
820
        unused_figs = []
×
821
        if number_param <= 4:
×
822
            print('so few parameters not implemented yet')
×
823
        else:
824
            if number_param % 4 == 0:
×
825
                num_lines = int(number_param / 4.)
×
826
            else:
827
                num_lines = int(number_param / 4.) + 1
×
828

829
                for idx in range(3):
×
830
                    if (number_param + idx) % 4 != 0:
×
831
                        unused_figs.append(-idx - 1)
×
832
                    else:
833
                        break
×
834

835
        f, ax = plt.subplots(num_lines, 4, figsize=(4 * 3.5, 2.5 * num_lines))
×
836
        markers = ['*', '.', 's', '^','<','>','v','p','P','X','D','1','2','3','4','+']
×
837
        #may find a better way to define markers but right now, it is sufficient
838

839
        for j, file_name in enumerate(self.file_names):
×
840
            i = 0
×
841
            result = param_dict[file_name]
×
842
            for key in result.keys():
×
843
                idx_line = int(i / 4.)
×
844
                idx_col = i % 4
×
845
                p = result[key]
×
846
                m = markers[j]
×
847
                if self.posterior_bool_list[j]:
×
848
                    # UNCOMMENT IF NO ERROR BARS AVAILABLE ON SHEAR
849
                    #             if (j== 1) and (key=='SHEAR_0_gamma_ext' or key == 'SHEAR_0_phi_ext'):
850
                    #                 ax[idx_line,idx_col].plot(j,p['point_estimate'],marker=m,ls='',label=file_name)
851
                    #                 i+=1
852
                    #                 continue
853

854
                    #trick to plot correct error bars if close to the +180/-180 edge
855
                    if (key == 'SHEAR_0_phi_ext' or key == 'PEMD_0_phi'):
×
856
                        if p['percentile_16th'] > p['median']:
×
857
                            p['percentile_16th'] -= 180.
×
858
                        if p['percentile_84th'] < p['median']:
×
859
                            p['percentile_84th'] += 180.
×
860
                    ax[idx_line, idx_col].errorbar(j, p['median'], [[p['median'] - p['percentile_16th']],
×
861
                                                                    [p['percentile_84th'] - p['median']]],
862
                                                   marker=m, ls='', label=file_name)
863
                else:
864
                    ax[idx_line, idx_col].plot(j, p['point_estimate'], marker=m, ls='', label=file_name)
×
865

866
                if j == 0:
×
867
                    ax[idx_line, idx_col].get_xaxis().set_visible(False)
×
868
                    ax[idx_line, idx_col].set_ylabel(p['latex_str'], fontsize=12)
×
869
                    ax[idx_line, idx_col].tick_params(axis='y', labelsize=12)
×
870
                i += 1
×
871

872
        ax[0, 0].legend()
×
873
        for idx in unused_figs:
×
874
            ax[-1, idx].axis('off')
×
875
        plt.tight_layout()
×
876
        plt.show()
×
877
        return f, ax
×
878

879
    def _prepare_getdist_plot(self, lw, lw_cont=None, lw_margin=None):
×
880
        if lw_margin is None:
×
881
            lw_margin = lw + 2
×
882
        line_args = [{'ls': ls, 'lw': lw, 'color': c} for ls, c in zip(self.linestyles, self.colors)]
×
883
        lw_conts = [lw_cont]*self.num_models
×
884
        ls_conts = ['-']*self.num_models
×
885
        legend_labels = copy.deepcopy(self.coolest_names)
×
886
        colors = copy.deepcopy(self.colors)
×
887
        if self._add_margin_samples:
×
888
            line_args.append({'ls': '-.', 'lw': lw_margin, 'alpha': 0.8, 'color': self._color_margin})
×
889
            ls_conts.append('-.')
×
890
            if lw_cont is not None: lw_conts.append(lw_margin)
×
891
            legend_labels.append(self._label_margin)
×
892
            colors.append(self._color_margin)
×
893
        return line_args, lw_conts, ls_conts, colors, legend_labels
×
894

895
# def plot_corner(parameter_id_list, 
896
#                 chain_objs, chain_dirs, chain_names=None, 
897
#                 point_estimate_objs=None, point_estimate_dirs=None, point_estimate_names=None, 
898
#                 colors=None, labels=None, subplot_size=1, mc_samples_kwargs=None, 
899
#                 filled_contours=True, angles_range=None, shift_sample_list=None):
900
#     """
901
#     Adding this as just a function for the moment.
902
#     Takes a list of COOLEST files as input, which must have a chain file associated to them, and returns a corner plot.
903

904
#     Parameters
905
#     ----------
906
#     parameter_id_list : array
907
#         A list of parameter unique ids obtained from lensing entities. Their order determines the order of the plot panels.
908
#     chain_objs : array
909
#         A list of coolest objects that have a chain file associated to them.
910
#     chain_dirs : array
911
#         A list of paths matching the coolest files in 'chain_objs'.
912
#     chain_names : array, optional
913
#         A list of labels for the coolest models in the 'chain_objs' list. Must have the same order as 'chain_objs'.
914
#     point_estimate_objs : array, optional
915
#         A list of coolest objects that will be used as point estimates.
916
#     point_estimate_dirs : array
917
#         A list of paths matching the coolest files in 'point_estimate_objs'.
918
#     point_estimate_names : array, optional
919
#         A list of labels for the models in the 'point_estimate_objs' list. Must have the same order as 'point_estimate_objs'.
920
#     labels : dict, optional
921
#         A dictionary matching the parameter_id_list entries to some human-readable labels.
922

923
#     Returns
924
#     -------
925
#     An image
926
#     """
927

928
#     chains.print_load_details = False # Just to silence messages
929
#     parameter_id_set = set(parameter_id_list)
930
#     Npars = len(parameter_id_list)
931
#     Nobjs = len(chain_objs)
932
    
933
#     # Set the chain names
934
#     if chain_names is None:
935
#         chain_names = ["chain "+str(i) for i in range(Nobjs)]
936
    
937
#     if shift_sample_list is None:
938
#         shift_sample_list = [None]*Nobjs
939
    
940
#     # Get the values of the point_estimates
941
#     point_estimates = []
942
#     if point_estimate_objs is not None:
943
#         for coolest_obj in point_estimate_objs:
944
#             values = []
945
#             for par in parameter_id_list:
946
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par)
947
#                 val = param.point_estimate.value
948
#                 if val is None:
949
#                     values.append(None)
950
#                 else:
951
#                     values.append(val)
952
#             point_estimates.append(values)
953

954

955
            
956
#     mcsamples = []
957
#     for i in range(Nobjs):
958
#         chain_file = os.path.join(chain_dirs[i],chain_objs[i].meta["chain_file_name"]) # Here get the chain file path for each coolest object
959

960
#         # Each chain file can have a different number of free parameters
961
#         f = open(chain_file)
962
#         header = f.readline()
963
#         f.close()
964

965
#         if ';' in header:
966
#             raise ValueError("Columns must be coma-separated (no semi-colon) in chain file.")
967

968
#         chain_file_headers = header.split(',')
969
#         num_cols = len(chain_file_headers)
970
#         chain_file_headers.pop() # Remove the last column name that is the probability weights
971
#         chain_file_headers_set = set(chain_file_headers)
972
        
973
#         # Check that the given parameters are a subset of those in the chain file
974
#         assert parameter_id_set.issubset(chain_file_headers_set), "Not all given parameters are free parameters for model %d (not in the chain file: %s)!" % (i,chain_file)
975

976
#         # Set the labels for the parameters in the chain file
977
#         par_labels = []
978
#         if labels is None:
979
#             labels = {}
980
#         for par_id in parameter_id_list:
981
#             if labels.get(par_id, None) is None:
982
#                 param = coolest_obj.lensing_entities.get_parameter_from_id(par_id)
983
#                 par_labels.append(param.latex_str.strip('$'))
984
#             else:
985
#                 par_labels.append(labels[par_id])
986
                    
987
#         # Read parameter values and probability weights
988
#         column_indices = [chain_file_headers.index(par_id) for par_id in parameter_id_list]
989
#         columns_to_read = sorted(column_indices) + [num_cols-1]  # add last one for probability weights
990
#         samples = pd.read_csv(chain_file, usecols=columns_to_read, delimiter=',')
991
    
992
#         # Re-order columnds to match parameter_id_list and par_labels
993
#         sample_par_values = np.array(samples[parameter_id_list])
994

995
#         # If needed, shift samples by a constant
996
#         if shift_sample_list[i] is not None:
997
#             for param_id, value in shift_sample_list[i].items():
998
#                 sample_par_values[:, parameter_id_list.index(param_id)] += value
999
#                 print(f"INFO: posterior for parameter '{param_id}' from model '{chain_names[i]}' "
1000
#                       f"has been shifted by {value}.")
1001

1002
#         # Clean-up the probability weights
1003
#         mypost = np.array(samples['probability_weights'])
1004
#         min_non_zero = np.min(mypost[np.nonzero(mypost)])
1005
#         sample_prob_weight = np.where(mypost<min_non_zero,min_non_zero,mypost)
1006
#         #sample_prob_weight = mypost
1007

1008
#         # Create MCSamples object
1009
#         mysample = MCSamples(samples=sample_par_values,names=parameter_id_list,labels=par_labels,settings=mc_samples_kwargs)
1010
#         mysample.reweightAddingLogLikes(-np.log(sample_prob_weight))
1011
#         mcsamples.append(mysample)
1012

1013

1014
        
1015
#     # Make the plot
1016
#     image = plots.getSubplotPlotter(subplot_size=subplot_size)    
1017
#     image.triangle_plot(mcsamples,
1018
#                         params=parameter_id_list,
1019
#                         legend_labels=chain_names,
1020
#                         filled=filled_contours,
1021
#                         colors=colors,
1022
#                         line_args=[{'ls':'-', 'lw': 2, 'color': c} for c in colors], 
1023
#                         contour_colors=colors)
1024

1025

1026
#     my_linestyles = ['solid','dotted','dashed','dashdot']
1027
#     my_markers    = ['s','^','o','star']
1028

1029
#     for k in range(0,len(point_estimates)):
1030
#         # Add vertical and horizontal lines
1031
#         for i in range(0,Npars):
1032
#             val = point_estimates[k][i]
1033
#             if val is not None:
1034
#                 for ax in image.subplots[i:,i]:
1035
#                     ax.axvline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1036
#                 for ax in image.subplots[i,:i]:
1037
#                     ax.axhline(val,color='black',ls=my_linestyles[k],alpha=1.0,lw=1)
1038

1039
#         # Add points
1040
#         for i in range(0,Npars):
1041
#             val_x = point_estimates[k][i]
1042
#             for j in range(i+1,Npars):
1043
#                 val_y = point_estimates[k][j]
1044
#                 if val_x is not None and val_y is not None:
1045
#                     image.subplots[j,i].scatter(val_x,val_y,s=10,facecolors='black',color='black',marker=my_markers[k])
1046
#                 else:
1047
#                     pass    
1048

1049

1050
#     # Set default ranges for angles
1051
#     if angles_range is None:
1052
#         angles_range = (-90, 90)
1053
#     for i in range(0,len(parameter_id_list)):
1054
#         dum = parameter_id_list[i].split('-')
1055
#         name = dum[-1]
1056
#         if name in ['phi','phi_ext']:
1057
#             xlim = image.subplots[i,i].get_xlim()
1058
#             #print(xlim)
1059
        
1060
#             if xlim[0] < -90:
1061
#                 for ax in image.subplots[i:,i]:
1062
#                     ax.set_xlim(left=angles_range[0])
1063
#                 for ax in image.subplots[i,:i]:
1064
#                     ax.set_ylim(bottom=angles_range[0])
1065
#             if xlim[1] > 90:
1066
#                 for ax in image.subplots[i:,i]:
1067
#                     ax.set_xlim(right=angles_range[1])
1068
#                 for ax in image.subplots[i,:i]:
1069
#                     ax.set_ylim(top=angles_range[1])
1070

1071
            
1072
#     return image
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc