• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 27645366736

16 Jun 2026 08:16PM UTC coverage: 43.458% (-2.0%) from 45.463%
27645366736

Pull #75

github

web-flow
Merge 79a70e2d3 into 9ef83563c
Pull Request #75: Adding multi-source plane lens functionality

31 of 242 new or added lines in 7 files covered. (12.81%)

6 existing lines in 4 files now uncovered.

1528 of 3516 relevant lines covered (43.46%)

0.43 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plot_util.py
1
__author__ = 'aymgal', 'Giorgos Vernardos'
×
2

3

4
import numpy as np
×
5
import warnings
×
6
import matplotlib
×
7
import matplotlib.pyplot as plt
×
8
from matplotlib import ticker
×
9
from mpl_toolkits.axes_grid1 import make_axes_locatable
×
10
from scipy.spatial import Voronoi, voronoi_plot_2d
×
11
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
12
from matplotlib.cm import ScalarMappable
×
13
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
×
14
from coolest.api.composable_models import ComposableLensModel
×
15
import os
×
16
import tarfile
×
17
import tempfile
×
18
import zipfile
×
19
import io
×
20
from coolest.api import util
×
21
from coolest.api.analysis import Analysis
×
22

23

24
def plot_voronoi(ax, x, y, z, neg_values_as_bad=False, 
×
25
                 norm=None, cmap=None, zmin=None, zmax=None, 
26
                 edgecolor=None, zorder=1):
27

28
    if cmap is None:
×
29
        cmap = plt.get_cmap('inferno')
×
30
    if norm is None:
×
31
        if zmin is None:
×
32
            zmin = np.min(z)
×
33
        if zmax is None:
×
34
            zmax = np.max(z)
×
35
        norm = Normalize(zmin, zmax)
×
36

37
    # get voronoi regions
38
    voronoi_points = np.column_stack((x,y))
×
39
    vor = Voronoi(voronoi_points)
×
40
    new_regions, vertices = voronoi_finite_polygons_2d(vor)
×
41
    
42
    # get cell colors
43
    m = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
×
44

45
    # plot voronoi points
46
    #point_colors = [ m.to_rgba(v) for v in z ]
47
    #ax.scatter(voronoi_points[:,0],voronoi_points[:,1],c=point_colors)
48

49
    # plot voronoi cells
50
    for i, region in enumerate(new_regions):
×
51
        polygon = vertices[region]
×
52
        z_i = z[i]
×
53
        if neg_values_as_bad is True and z_i < 0.:
×
54
            cell_color = m.to_rgba(np.nan)
×
55
        else:
56
            cell_color = m.to_rgba(z_i)
×
57
        ax.fill(*zip(*polygon), facecolor=cell_color, edgecolor=edgecolor, zorder=zorder)
×
58
    return m
×
59

60

61
def voronoi_finite_polygons_2d(vor,radius=None):
×
62
    """
63
    Reconstruct infinite voronoi regions in a 2D diagram to finite
64
    regions.
65
    This function is taken from: https://gist.github.com/pv/8036995
66

67
    Parameters
68
    ----------
69
    vor : Voronoi
70
        Input diagram
71
    radius : float, optional
72
        Distance to 'points at infinity'.
73

74
    Returns
75
    -------
76
    regions : list of tuples
77
        Indices of vertices in each revised Voronoi regions.
78
    vertices : list of tuples
79
        Coordinates for revised Voronoi vertices. Same as coordinates
80
        of input vertices, with 'points at infinity' appended to the
81
        end.
82
    """
83

84
    if vor.points.shape[1] != 2:
×
85
        raise ValueError("Requires 2D input")
×
86

87
    new_regions = []
×
88
    new_vertices = vor.vertices.tolist()
×
89

90
    center = vor.points.mean(axis=0)
×
91
    if radius is None:
×
92
        radius = np.ptp(vor.points).max()
×
93

94
    # Construct a map containing all ridges for a given point
95
    all_ridges = {}
×
96
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
×
97
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
×
98
        all_ridges.setdefault(p2, []).append((p1, v1, v2))
×
99

100
    # Reconstruct infinite regions
101
    for p1, region in enumerate(vor.point_region):
×
102
        vertices = vor.regions[region]
×
103

104
        if all(v >= 0 for v in vertices):
×
105
            # finite region
106
            new_regions.append(vertices)
×
107
            continue
×
108

109
        # reconstruct a non-finite region
110
        ridges = all_ridges[p1]
×
111
        new_region = [v for v in vertices if v >= 0]
×
112

113
        for p2, v1, v2 in ridges:
×
114
            if v2 < 0:
×
115
                v1, v2 = v2, v1
×
116
            if v1 >= 0:
×
117
                # finite ridge: already in the region
118
                continue
×
119

120
            # Compute the missing endpoint of an infinite ridge
121

122
            t = vor.points[p2] - vor.points[p1] # tangent
×
123
            t /= np.linalg.norm(t)
×
124
            n = np.array([-t[1], t[0]])  # normal
×
125

126
            midpoint = vor.points[[p1, p2]].mean(axis=0)
×
127
            direction = np.sign(np.dot(midpoint - center, n)) * n
×
128
            far_point = vor.vertices[v2] + direction * radius
×
129

130
            new_region.append(len(new_vertices))
×
131
            new_vertices.append(far_point.tolist())
×
132

133
        # sort region counterclockwise
134
        vs = np.asarray([new_vertices[v] for v in new_region])
×
135
        c = vs.mean(axis=0)
×
136
        angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0])
×
137
        new_region = np.array(new_region)[np.argsort(angles)]
×
138

139
        # finish
140
        new_regions.append(new_region.tolist())
×
141
        
142
    return new_regions, np.asarray(new_vertices)
×
143

144

145
def std_colorbar(mappable, label=None, fontsize=12, label_kwargs={}, **colorbar_kwargs):
×
146
    cb = plt.colorbar(mappable, **colorbar_kwargs)
×
147
    if label is not None:
×
148
        colorbar_kwargs.pop('label', None)
×
149
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
150
    return cb
×
151

152
def std_colorbar_residuals(mappable, res_map, vmin, vmax, label=None, fontsize=12, 
×
153
                           label_kwargs={}, **colorbar_kwargs):
154
    if res_map.min() < vmin and res_map.max() > vmax:
×
155
        cb_extend = 'both'
×
156
    elif res_map.min() < vmin:
×
157
        cb_extend = 'min'
×
158
    elif res_map.max() > vmax:
×
159
        cb_extend = 'max'
×
160
    else:
161
        cb_extend = 'neither'
×
162
    colorbar_kwargs.update({'extend': cb_extend})
×
163
    return std_colorbar(mappable, label=label, fontsize=fontsize, 
×
164
                        label_kwargs=label_kwargs, **colorbar_kwargs)
165

166
def nice_colorbar(mappable, ax=None, position='right', pad=0.1, size='5%', label=None, fontsize=12, 
×
167
                  invisible=False, 
168
                  #max_nbins=None,
169
                  divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
170
    divider_kwargs.update({'position': position, 'pad': pad, 'size': size})
×
171
    if ax is None:
×
172
        ax = mappable.axes
×
173
    divider = make_axes_locatable(ax)
×
174
    cax = divider.append_axes(**divider_kwargs)
×
175
    if invisible:
×
176
        cax.axis('off')
×
177
        return None
×
178
    cb = plt.colorbar(
×
179
        mappable, cax=cax, 
180
        **colorbar_kwargs
181
    )
182
    if label is not None:
×
183
        colorbar_kwargs.pop('label', None)
×
184
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
185
    if position == 'top':
×
186
        cax.xaxis.set_ticks_position('top')
×
187
    # if max_nbins is not None: # TODO: this leads to strange results
188
    #     # cb.locator = ticker.LogLocator(subs=range(10))
189
    #     # cb.update_ticks()
190
    return cb
×
191

192
def nice_colorbar_residuals(mappable, res_map, vmin, vmax, ax=None, position='right', pad=0.1, size='5%', 
×
193
                            invisible=False, label=None, fontsize=12,
194
                            divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
195
    if res_map.min() < vmin and res_map.max() > vmax:
×
196
        cb_extend = 'both'
×
197
    elif res_map.min() < vmin:
×
198
        cb_extend = 'min'
×
199
    elif res_map.max() > vmax:
×
200
        cb_extend = 'max'
×
201
    else:
202
        cb_extend = 'neither'
×
203
    colorbar_kwargs.update({'extend': cb_extend})
×
204
    return nice_colorbar(mappable, ax=ax, position=position, pad=pad, size=size, label=label, fontsize=fontsize,
×
205
                  invisible=invisible, colorbar_kwargs=colorbar_kwargs, label_kwargs=label_kwargs,
206
                  divider_kwargs=divider_kwargs)
207

208
def cax_colorbar(fig, cax, norm=None, cmap=None, mappable=None, label=None, fontsize=12, orientation='horizontal', label_kwargs={}):
×
209
    if mappable is None:
×
210
        mappable = ScalarMappable(norm=norm, cmap=cmap)
×
211
    cb = fig.colorbar(mappable=mappable, orientation=orientation, cax=cax)
×
212
    if label is not None:
×
213
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
214

215
def scale_bar(ax, size, unit_suffix='"', loc='lower left', color='#FFFFFFBB', fontsize=12):
×
216
    if size == int(size):
×
217
        label = f"{int(size)}"
×
218
    else:
219
        label = f"{size:.1f}"
×
220
    label += unit_suffix
×
221
    artist = AnchoredSizeBar(
×
222
        ax.transData,
223
        size, label,
224
        loc=loc, label_top=True,
225
        pad=0.8, sep=5, 
226
        color=color, fontproperties=dict(size=fontsize),
227
        frameon=False, size_vertical=0,
228
    )
229
    ax.add_artist(artist)
×
230
    return artist
×
231

232
def plot_regular_grid(ax, title, image_, neg_values_as_bad=False, xylim=None, **imshow_kwargs):
×
233
    if neg_values_as_bad:
×
234
        image = np.copy(image_)
×
235
        image[image < 0] = np.nan
×
236
    else:
237
        image = image_
×
238
    im = ax.imshow(image, **imshow_kwargs)
×
239
    im.set_rasterized(True)
×
240
    set_xy_limits(ax, xylim)
×
241
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
242
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
243
    ax.set_title(title)
×
244
    return ax, im
×
245

246
def plot_irregular_grid(ax, title, points, xylim, neg_values_as_bad=False,
×
247
                            norm=None, cmap=None, plot_points=False):
248
    x, y, z = points
×
249
    im = plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
250
                      norm=norm, cmap=cmap, zorder=1)
251
    ax.set_aspect('equal', 'box')
×
252
    set_xy_limits(ax, xylim)
×
253
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
254
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
255
    if plot_points:
×
256
        ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
257
    ax.set_title(title)
×
258
    return ax, im
×
259

260
def set_xy_limits(ax, xylim):
×
261
    if xylim is None: return  # do nothing
×
262
    if isinstance(xylim, (int, float)):
×
263
        xylim_ = [-xylim, xylim, -xylim, xylim]
×
264
    elif isinstance(xylim, (list, tuple)) and len(xylim) == 2:
×
265
        xylim_ = [xylim[0], xylim[1], xylim[0], xylim[1]]
×
266
    elif not isinstance(xylim, (list, tuple)) and len(xylim) != 4:
×
267
        raise ValueError("`xylim` argument should be a single number, a 2-tuple or a 4-tuple.")
×
268
    else:
269
        xylim_ = xylim
×
270
    ax.set_xlim(xylim_[0], xylim_[1])
×
271
    ax.set_ylim(xylim_[2], xylim_[3])
×
272

273
def panel_label(ax, text, color, fontsize, alpha=0.8, loc='upper left'):
×
274
    if loc == 'upper left':
×
275
        x, y, ha, va = 0.03, 0.97, 'left', 'top'
×
276
    elif loc == 'lower left':
×
277
        x, y, ha, va = 0.03, 0.03, 'left', 'bottom'
×
278
    elif loc == 'upper right':
×
279
        x, y, ha, va = 0.97, 0.97, 'right', 'top'
×
280
    elif loc == 'lower right':
×
281
        x, y, ha, va = 0.97, 0.03, 'right', 'bottom'
×
282
    ax.text(x, y, text, color=color, fontsize=fontsize, alpha=alpha, 
×
283
            ha=ha, va=va, transform=ax.transAxes)
284

285
def normalize_across_images(
×
286
        plotter_list, 
287
        data_model_specifier, 
288
        auto_selection=False,
289
        kwargs_source=None, 
290
        kwargs_lens_mass=None,
291
        kwargs_lens_light=None,
292
        supersampling=5, 
293
        convolved=True, 
294
        super_convolution=True
295
    ):
296
    """Calculate the vmin and vmax to normalize the colormap across multiple coolest objects
297

298
    Parameters
299
    ----------
300
    plotter_list: list
301
        List of ModelPlotter objects. May be acquired from MultiModelPlotter object
302
    data_model_specifier: list
303
        List of 0s and 1s; 0 = data, 1 = model. Specifies which set of pixel values should be used
304
        when finding global minima and maxima -- data or model
305
    kwargs_source: dict
306
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
307
        "Entity_selection" contains list of lists. Selects source entities.
308
        Insert dummy None values into dictionary for data.
309
    kwargs_lens_mass: dict
310
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
311
        "Entity_selection" contains list of lists. Selects lens mass entities.
312
        Insert dummy None values into dictionary for data.
313
    kwargs_lens_light: dict
314
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
315
        "Entity_selection" contains list of lists. Selects lens light entities.
316
        Insert dummy None values into dictionary for data.
317
    supersampling: int
318
        Model image generation param
319
    convolved: bool
320
        Model image generation param
321
    super_convolution: bool
322
        Model image generation param
323
    
324
    
325
    Returns
326
    -------
327
    vmin: float
328
        global min value across all coolest objects in plotter_list for the specified data/models
329
    vmax: float
330
        global max value across all coolest objects in plotter_list for the specified data/models    
331
    """
332
    
333
    mins = []
×
334
    maxes = []
×
335
    ks_arr = kwargs_source['entity_selection'] if kwargs_source is not None else [None]*len(plotter_list)
×
336
    km_arr = kwargs_lens_mass['entity_selection'] if kwargs_lens_mass is not None else [None]*len(plotter_list)
×
337
    kl_arr = kwargs_lens_light['entity_selection'] if kwargs_lens_light is not None else [None]*len(plotter_list)
×
338
    for plotter, d_or_f, ks, km, kl in zip(plotter_list, data_model_specifier, ks_arr, km_arr, kl_arr):
×
339
        # Check if we are finding extrema for data or model
340
        if d_or_f == 0:
×
341
            image = plotter.coolest.observation.pixels.get_pixels(directory=plotter._directory)
×
342
        elif d_or_f == 1:
×
343
            lens_model = ComposableLensModel(
×
344
                plotter.coolest, plotter._directory,
345
                auto_selection=auto_selection,
346
                kwargs_selection_source=dict(entity_selection=ks),
347
                kwargs_selection_lens_mass=dict(entity_selection=km),
348
                kwargs_selection_lens_light=dict(entity_selection=kl)
349
            )
350
            image, _ = lens_model.model_image(supersampling, convolved, super_convolution)
×
351
        # Find min and max and append
352
        mins.append(np.min(image))
×
353
        maxes.append(np.max(image))
×
354
    vmin = min(mins)
×
355
    vmax = max(maxes)
×
356
    return vmin, vmax
×
357

358
    
NEW
359
def dmr_corner(tar_path, output_dir=None, reorder = None):
×
360
    """Given .tar.gz COOLEST file, plots and optionally saves DMR and corner plots for COOLEST file. Returns dictionary of important extracted information.
361

362
    Parameters
363
    ----------
364
    tar_path : string
365
        Path to .tar.gz COOLEST file
366
    output_dir : string, optional
367
        Path to automatically save DMR and corner plot to if specified, by default None
368
    reorder: list, optional
369
        List of integers specifying ordering of parameters in plot. If None, defult ordering is used.
370
    
371
    Returns
372
    -------
373
    results: dictionary
374
        Contains useful information about the COOLEST objects
375
    """
376
    from coolest.api.plotting import ModelPlotter, ParametersPlotter  # placed here to avoid circular import
×
377
    
378
    results = {}
×
379
    with tempfile.TemporaryDirectory() as tmpdir:
×
380
            
381
        if tar_path[-7:] == '.tar.gz':
×
382
            # Extract tar.gz archive
383
            with tarfile.open(tar_path, "r:gz") as tar:
×
384
                tar.extractall(path=tmpdir)
×
385
        elif tar_path[-4:] == '.zip':
×
386
            # Extract zip archive
387
            with zipfile.ZipFile(tar_path, 'r') as zipf:
×
388
                zipf.extractall(tmpdir)
×
389
        else:
390
            raise ValueError("Target path must point to a .tar.gz or .zip archive.")
×
391

392
        # Get path to the extracted JSON file
393
        extracted_items = os.listdir(tmpdir)
×
394
        if '__MACOSX' in extracted_items:
×
395
            extracted_items.remove('__MACOSX')  # remove macOS artifact folder if present
×
396
        extracted_path = os.path.join(tmpdir, extracted_items[0])
×
397
        if os.path.isdir(extracted_path):
×
398
            extracted_files = os.listdir(extracted_path)
×
399
        else:
400
            extracted_files = extracted_items
×
401
            extracted_path = tmpdir  # fallback
×
402
        
403
        json_files = [f for f in extracted_files if f.endswith(".json")]
×
404
        if not json_files:
×
405
            raise ValueError("No .json file found in archive.")
×
406
        elif len(json_files) > 1:
×
407
            raise ValueError("Multiple .json files found in archive.")
×
408
        
409
        json_path = os.path.join(extracted_path, json_files[0])
×
410
        target_path = os.path.splitext(json_path)[0]
×
411

412
        # Load COOLEST object
413
        coolest_1 = util.get_coolest_object(target_path, verbose=False)
×
414

415
        # Run analysis
416
        analysis = Analysis(coolest_1, target_path, supersampling=5)
×
417

418
        coord_orig = util.get_coordinates(coolest_1)
×
419
        coord_src = coord_orig.create_new_coordinates(pixel_scale_factor=0.1, grid_shape=(1.42, 1.42))
×
420

421
        # Extract values
422
        r_eff_source = analysis.effective_radius_light(center=(0,0), coordinates=coord_src, outer_radius=1., entity_selection=[2])
×
423
        einstein_radius = analysis.effective_einstein_radius(entity_selection=[0,1])
×
424

425
        results['r_eff_source'] = r_eff_source
×
426
        results['einstein_radius'] = einstein_radius
×
427
        results['lensing_entities'] = [type(le).__name__ for le in coolest_1.lensing_entities]
×
428
        results['source_light_model'] = [type(m).__name__ for m in coolest_1.lensing_entities[2].light_model]
×
429

430
        ### DMR Plot
431
        norm = Normalize(-0.005, 0.05)
×
432
        fig, axes = plt.subplots(2, 2, constrained_layout=True)
×
433
        splotter = ModelPlotter(coolest_1, coolest_directory=os.path.dirname(target_path))
×
434

435
        splotter.plot_data_image(axes[0, 0], norm=norm)
×
436
        axes[0, 0].set_title("Observed Data")
×
437

438
        splotter.plot_model_image(
×
439
            axes[0, 1],
440
            supersampling=5, convolved=True,
441
            norm=norm,
442
            auto_selection = True
443
        )
444
        axes[0, 1].text(0.05, 0.05, f"$\\theta_{{\\rm E}}$ = {einstein_radius:.2f}\"", color='white', fontsize=12,
×
445
                        transform=axes[0, 1].transAxes)
446
        axes[0, 1].set_title("Image Model")
×
447

NEW
448
        splotter.plot_model_residuals(axes[1, 0], supersampling=5, add_chi2_label=True, chi2_fontsize=12)
×
UNCOV
449
        axes[1, 0].set_title("Normalized Residuals")
×
450

NEW
451
        splotter.plot_surface_brightness(axes[1, 1],
×
452
                                         norm=norm, coordinates=coord_src)
453
        axes[1, 1].text(0.05, 0.05, f"$\\theta_{{\\rm eff}}$ = {r_eff_source:.2f}\"", color='white', fontsize=12,
×
454
                        transform=axes[1, 1].transAxes)
455
        axes[1, 1].set_title("Surface Brightness")
×
456

457
        for ax in axes[1]:
×
458
            ax.set_xlabel(r"$x$ (arcsec)")
×
459
            ax.set_ylabel(r"$y$ (arcsec)")
×
460
            
461
        if output_dir is not None:
×
462
            dmr_plot_path = os.path.join(output_dir, "dmr_plot.png")
×
463
            plt.savefig(dmr_plot_path, format='png', bbox_inches='tight')
×
464
            results['dmr_plot'] = dmr_plot_path
×
465
        plt.show()
×
466
        plt.close()
×
467

468
        
469

470
        
471
        ### Corner Plot
472
        truth = coolest_1
×
473
        # Only creates corner plot if sampling method was used to create lens model
474
        # Otherwise, no chains available for corner plot!
475
        if 'chain_file_name' in truth.meta.keys():
×
NEW
476
            print('Creating corner plot')
×
NEW
477
            free_pars = truth.lensing_entities.get_parameter_ids()[:-2] 
×
NEW
478
            if len(getattr(truth, 'multiplane_betas', [])) > 0:
×
NEW
479
                free_pars += truth.multiplane_betas.get_all_beta_ids()
×
NEW
480
            pars = free_pars
×
NEW
481
            if reorder is not None:
×
NEW
482
                pars = [free_pars[i] for i in reorder]
×
UNCOV
483
            results['free_parameters'] = pars
×
484
    
485
            param_plotter = ParametersPlotter(
×
486
                pars, [truth],
487
                coolest_directories=[os.path.dirname(target_path)],
488
                coolest_names=["Smooth source"],
489
                ref_coolest_objects=[truth],
490
                colors=['#7FB6F5', '#E03424'],
491
            )
492
    
493
            settings = {
×
494
                "ignore_rows": 0.0,
495
                "fine_bins_2D": 800,
496
                "smooth_scale_2D": 0.5,
497
                "mult_bias_correction_order": 5
498
            }
499
            param_plotter.init_getdist(settings_mcsamples=settings)
×
500
            param_plotter.plot_triangle_getdist(filled_contours=True, subplot_size=3)
×
501
            if output_dir is not None:
×
502
                corner_plot_path = os.path.join(output_dir, "corner_plot.png")
×
503
                plt.savefig(corner_plot_path, format='png', bbox_inches='tight')
×
504
                results['corner_plot'] = corner_plot_path
×
NEW
505
            plt.show()
×
UNCOV
506
            plt.close()
×
507
    
508
            
509
        
510
    return results
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc