• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 22070442821

16 Feb 2026 04:27PM UTC coverage: 45.463% (+0.05%) from 45.411%
22070442821

push

github

web-flow
Merge pull request #76 from aymgal/pr-fully-defined-mass

Add support for lens light component and fully-defined pixelated lens model, various other improvements to the analysis and plottine engines

96 of 234 new or added lines in 11 files covered. (41.03%)

5 existing lines in 3 files now uncovered.

1498 of 3295 relevant lines covered (45.46%)

0.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plot_util.py
1
__author__ = 'aymgal', 'Giorgos Vernardos'
×
2

3

4
import numpy as np
×
5
import warnings
×
6
import matplotlib
×
7
import matplotlib.pyplot as plt
×
8
from matplotlib import ticker
×
9
from mpl_toolkits.axes_grid1 import make_axes_locatable
×
10
from scipy.spatial import Voronoi, voronoi_plot_2d
×
11
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
12
from matplotlib.cm import ScalarMappable
×
13
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
×
14
from coolest.api.composable_models import ComposableLensModel
×
15
import os
×
16
import tarfile
×
17
import tempfile
×
18
import zipfile
×
19
import io
×
20
from coolest.api import util
×
21
from coolest.api.analysis import Analysis
×
22

23

24
def plot_voronoi(ax, x, y, z, neg_values_as_bad=False, 
×
25
                 norm=None, cmap=None, zmin=None, zmax=None, 
26
                 edgecolor=None, zorder=1):
27

28
    if cmap is None:
×
29
        cmap = plt.get_cmap('inferno')
×
30
    if norm is None:
×
31
        if zmin is None:
×
32
            zmin = np.min(z)
×
33
        if zmax is None:
×
34
            zmax = np.max(z)
×
35
        norm = Normalize(zmin, zmax)
×
36

37
    # get voronoi regions
38
    voronoi_points = np.column_stack((x,y))
×
39
    vor = Voronoi(voronoi_points)
×
40
    new_regions, vertices = voronoi_finite_polygons_2d(vor)
×
41
    
42
    # get cell colors
43
    m = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
×
44

45
    # plot voronoi points
46
    #point_colors = [ m.to_rgba(v) for v in z ]
47
    #ax.scatter(voronoi_points[:,0],voronoi_points[:,1],c=point_colors)
48

49
    # plot voronoi cells
50
    for i, region in enumerate(new_regions):
×
51
        polygon = vertices[region]
×
52
        z_i = z[i]
×
53
        if neg_values_as_bad is True and z_i < 0.:
×
54
            cell_color = m.to_rgba(np.nan)
×
55
        else:
56
            cell_color = m.to_rgba(z_i)
×
57
        ax.fill(*zip(*polygon), facecolor=cell_color, edgecolor=edgecolor, zorder=zorder)
×
58
    return m
×
59

60

61
def voronoi_finite_polygons_2d(vor,radius=None):
×
62
    """
63
    Reconstruct infinite voronoi regions in a 2D diagram to finite
64
    regions.
65
    This function is taken from: https://gist.github.com/pv/8036995
66

67
    Parameters
68
    ----------
69
    vor : Voronoi
70
        Input diagram
71
    radius : float, optional
72
        Distance to 'points at infinity'.
73

74
    Returns
75
    -------
76
    regions : list of tuples
77
        Indices of vertices in each revised Voronoi regions.
78
    vertices : list of tuples
79
        Coordinates for revised Voronoi vertices. Same as coordinates
80
        of input vertices, with 'points at infinity' appended to the
81
        end.
82
    """
83

84
    if vor.points.shape[1] != 2:
×
85
        raise ValueError("Requires 2D input")
×
86

87
    new_regions = []
×
88
    new_vertices = vor.vertices.tolist()
×
89

90
    center = vor.points.mean(axis=0)
×
91
    if radius is None:
×
NEW
92
        radius = np.ptp(vor.points).max()
×
93

94
    # Construct a map containing all ridges for a given point
95
    all_ridges = {}
×
96
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
×
97
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
×
98
        all_ridges.setdefault(p2, []).append((p1, v1, v2))
×
99

100
    # Reconstruct infinite regions
101
    for p1, region in enumerate(vor.point_region):
×
102
        vertices = vor.regions[region]
×
103

104
        if all(v >= 0 for v in vertices):
×
105
            # finite region
106
            new_regions.append(vertices)
×
107
            continue
×
108

109
        # reconstruct a non-finite region
110
        ridges = all_ridges[p1]
×
111
        new_region = [v for v in vertices if v >= 0]
×
112

113
        for p2, v1, v2 in ridges:
×
114
            if v2 < 0:
×
115
                v1, v2 = v2, v1
×
116
            if v1 >= 0:
×
117
                # finite ridge: already in the region
118
                continue
×
119

120
            # Compute the missing endpoint of an infinite ridge
121

122
            t = vor.points[p2] - vor.points[p1] # tangent
×
123
            t /= np.linalg.norm(t)
×
124
            n = np.array([-t[1], t[0]])  # normal
×
125

126
            midpoint = vor.points[[p1, p2]].mean(axis=0)
×
127
            direction = np.sign(np.dot(midpoint - center, n)) * n
×
128
            far_point = vor.vertices[v2] + direction * radius
×
129

130
            new_region.append(len(new_vertices))
×
131
            new_vertices.append(far_point.tolist())
×
132

133
        # sort region counterclockwise
134
        vs = np.asarray([new_vertices[v] for v in new_region])
×
135
        c = vs.mean(axis=0)
×
136
        angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0])
×
137
        new_region = np.array(new_region)[np.argsort(angles)]
×
138

139
        # finish
140
        new_regions.append(new_region.tolist())
×
141
        
142
    return new_regions, np.asarray(new_vertices)
×
143

144

145
def std_colorbar(mappable, label=None, fontsize=12, label_kwargs={}, **colorbar_kwargs):
×
146
    cb = plt.colorbar(mappable, **colorbar_kwargs)
×
147
    if label is not None:
×
148
        colorbar_kwargs.pop('label', None)
×
149
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
150
    return cb
×
151

152
def std_colorbar_residuals(mappable, res_map, vmin, vmax, label=None, fontsize=12, 
×
153
                           label_kwargs={}, **colorbar_kwargs):
154
    if res_map.min() < vmin and res_map.max() > vmax:
×
155
        cb_extend = 'both'
×
156
    elif res_map.min() < vmin:
×
157
        cb_extend = 'min'
×
158
    elif res_map.max() > vmax:
×
159
        cb_extend = 'max'
×
160
    else:
161
        cb_extend = 'neither'
×
162
    colorbar_kwargs.update({'extend': cb_extend})
×
163
    return std_colorbar(mappable, label=label, fontsize=fontsize, 
×
164
                        label_kwargs=label_kwargs, **colorbar_kwargs)
165

166
def nice_colorbar(mappable, ax=None, position='right', pad=0.1, size='5%', label=None, fontsize=12, 
×
167
                  invisible=False, 
168
                  #max_nbins=None,
169
                  divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
170
    divider_kwargs.update({'position': position, 'pad': pad, 'size': size})
×
171
    if ax is None:
×
172
        ax = mappable.axes
×
173
    divider = make_axes_locatable(ax)
×
174
    cax = divider.append_axes(**divider_kwargs)
×
175
    if invisible:
×
176
        cax.axis('off')
×
177
        return None
×
178
    cb = plt.colorbar(
×
179
        mappable, cax=cax, 
180
        **colorbar_kwargs
181
    )
182
    if label is not None:
×
183
        colorbar_kwargs.pop('label', None)
×
184
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
185
    if position == 'top':
×
186
        cax.xaxis.set_ticks_position('top')
×
187
    # if max_nbins is not None: # TODO: this leads to strange results
188
    #     # cb.locator = ticker.LogLocator(subs=range(10))
189
    #     # cb.update_ticks()
190
    return cb
×
191

192
def nice_colorbar_residuals(mappable, res_map, vmin, vmax, ax=None, position='right', pad=0.1, size='5%', 
×
193
                            invisible=False, label=None, fontsize=12,
194
                            divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
195
    if res_map.min() < vmin and res_map.max() > vmax:
×
196
        cb_extend = 'both'
×
197
    elif res_map.min() < vmin:
×
198
        cb_extend = 'min'
×
199
    elif res_map.max() > vmax:
×
200
        cb_extend = 'max'
×
201
    else:
202
        cb_extend = 'neither'
×
203
    colorbar_kwargs.update({'extend': cb_extend})
×
204
    return nice_colorbar(mappable, ax=ax, position=position, pad=pad, size=size, label=label, fontsize=fontsize,
×
205
                  invisible=invisible, colorbar_kwargs=colorbar_kwargs, label_kwargs=label_kwargs,
206
                  divider_kwargs=divider_kwargs)
207

208
def cax_colorbar(fig, cax, norm=None, cmap=None, mappable=None, label=None, fontsize=12, orientation='horizontal', label_kwargs={}):
×
209
    if mappable is None:
×
210
        mappable = ScalarMappable(norm=norm, cmap=cmap)
×
211
    cb = fig.colorbar(mappable=mappable, orientation=orientation, cax=cax)
×
212
    if label is not None:
×
213
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
214

215
def scale_bar(ax, size, unit_suffix='"', loc='lower left', color='#FFFFFFBB', fontsize=12):
×
216
    if size == int(size):
×
217
        label = f"{int(size)}"
×
218
    else:
219
        label = f"{size:.1f}"
×
220
    label += unit_suffix
×
221
    artist = AnchoredSizeBar(
×
222
        ax.transData,
223
        size, label,
224
        loc=loc, label_top=True,
225
        pad=0.8, sep=5, 
226
        color=color, fontproperties=dict(size=fontsize),
227
        frameon=False, size_vertical=0,
228
    )
229
    ax.add_artist(artist)
×
230
    return artist
×
231

232
def plot_regular_grid(ax, title, image_, neg_values_as_bad=False, xylim=None, **imshow_kwargs):
×
233
    if neg_values_as_bad:
×
234
        image = np.copy(image_)
×
235
        image[image < 0] = np.nan
×
236
    else:
237
        image = image_
×
238
    im = ax.imshow(image, **imshow_kwargs)
×
239
    im.set_rasterized(True)
×
240
    set_xy_limits(ax, xylim)
×
241
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
242
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
243
    ax.set_title(title)
×
244
    return ax, im
×
245

246
def plot_irregular_grid(ax, title, points, xylim, neg_values_as_bad=False,
×
247
                            norm=None, cmap=None, plot_points=False):
248
    x, y, z = points
×
249
    im = plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
250
                      norm=norm, cmap=cmap, zorder=1)
251
    ax.set_aspect('equal', 'box')
×
252
    set_xy_limits(ax, xylim)
×
253
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
254
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
255
    if plot_points:
×
256
        ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
257
    ax.set_title(title)
×
258
    return ax, im
×
259

260
def set_xy_limits(ax, xylim):
×
261
    if xylim is None: return  # do nothing
×
262
    if isinstance(xylim, (int, float)):
×
263
        xylim_ = [-xylim, xylim, -xylim, xylim]
×
264
    elif isinstance(xylim, (list, tuple)) and len(xylim) == 2:
×
265
        xylim_ = [xylim[0], xylim[1], xylim[0], xylim[1]]
×
266
    elif not isinstance(xylim, (list, tuple)) and len(xylim) != 4:
×
267
        raise ValueError("`xylim` argument should be a single number, a 2-tuple or a 4-tuple.")
×
268
    else:
269
        xylim_ = xylim
×
270
    ax.set_xlim(xylim_[0], xylim_[1])
×
271
    ax.set_ylim(xylim_[2], xylim_[3])
×
272

273
def panel_label(ax, text, color, fontsize, alpha=0.8, loc='upper left'):
×
274
    if loc == 'upper left':
×
275
        x, y, ha, va = 0.03, 0.97, 'left', 'top'
×
276
    elif loc == 'lower left':
×
277
        x, y, ha, va = 0.03, 0.03, 'left', 'bottom'
×
278
    elif loc == 'upper right':
×
279
        x, y, ha, va = 0.97, 0.97, 'right', 'top'
×
280
    elif loc == 'lower right':
×
281
        x, y, ha, va = 0.97, 0.03, 'right', 'bottom'
×
282
    ax.text(x, y, text, color=color, fontsize=fontsize, alpha=alpha, 
×
283
            ha=ha, va=va, transform=ax.transAxes)
284

NEW
285
def normalize_across_images(
×
286
        plotter_list, 
287
        data_model_specifier, 
288
        auto_selection=False,
289
        kwargs_source=None, 
290
        kwargs_lens_mass=None,
291
        kwargs_lens_light=None,
292
        supersampling=5, 
293
        convolved=True, 
294
        super_convolution=True
295
    ):
296
    """Calculate the vmin and vmax to normalize the colormap across multiple coolest objects
297

298
    Parameters
299
    ----------
300
    plotter_list: list
301
        List of ModelPlotter objects. May be acquired from MultiModelPlotter object
302
    data_model_specifier: list
303
        List of 0s and 1s; 0 = data, 1 = model. Specifies which set of pixel values should be used
304
        when finding global minima and maxima -- data or model
305
    kwargs_source: dict
306
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
307
        "Entity_selection" contains list of lists. Selects source entities.
308
        Insert dummy None values into dictionary for data.
309
    kwargs_lens_mass: dict
310
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
311
        "Entity_selection" contains list of lists. Selects lens mass entities.
312
        Insert dummy None values into dictionary for data.
313
    kwargs_lens_light: dict
314
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
315
        "Entity_selection" contains list of lists. Selects lens light entities.
316
        Insert dummy None values into dictionary for data.
317
    supersampling: int
318
        Model image generation param
319
    convolved: bool
320
        Model image generation param
321
    super_convolution: bool
322
        Model image generation param
323
    
324
    
325
    Returns
326
    -------
327
    vmin: float
328
        global min value across all coolest objects in plotter_list for the specified data/models
329
    vmax: float
330
        global max value across all coolest objects in plotter_list for the specified data/models    
331
    """
332
    
333
    mins = []
×
334
    maxes = []
×
NEW
335
    ks_arr = kwargs_source['entity_selection'] if kwargs_source is not None else [None]*len(plotter_list)
×
NEW
336
    km_arr = kwargs_lens_mass['entity_selection'] if kwargs_lens_mass is not None else [None]*len(plotter_list)
×
NEW
337
    kl_arr = kwargs_lens_light['entity_selection'] if kwargs_lens_light is not None else [None]*len(plotter_list)
×
NEW
338
    for plotter, d_or_f, ks, km, kl in zip(plotter_list, data_model_specifier, ks_arr, km_arr, kl_arr):
×
339
        # Check if we are finding extrema for data or model
340
        if d_or_f == 0:
×
341
            image = plotter.coolest.observation.pixels.get_pixels(directory=plotter._directory)
×
342
        elif d_or_f == 1:
×
NEW
343
            lens_model = ComposableLensModel(
×
344
                plotter.coolest, plotter._directory,
345
                auto_selection=auto_selection,
346
                kwargs_selection_source=dict(entity_selection=ks),
347
                kwargs_selection_lens_mass=dict(entity_selection=km),
348
                kwargs_selection_lens_light=dict(entity_selection=kl)
349
            )
UNCOV
350
            image, _ = lens_model.model_image(supersampling, convolved, super_convolution)
×
351
        # Find min and max and append
352
        mins.append(np.min(image))
×
353
        maxes.append(np.max(image))
×
354
    vmin = min(mins)
×
355
    vmax = max(maxes)
×
356
    return vmin, vmax
×
357

358
    
NEW
359
def dmr_corner(tar_path, output_dir=None):
×
360
    """Given .tar.gz COOLEST file, plots and optionally saves DMR and corner plots for COOLEST file. Returns dictionary of important extracted information.
361

362
    Parameters
363
    ----------
364
    tar_path : string
365
        Path to .tar.gz COOLEST file
366
    output_dir : string, optional
367
        Path to automatically save DMR and corner plot to if specified, by default None
368
    
369
    Returns
370
    -------
371
    results: dictionary
372
        Contains useful information about the COOLEST objects
373
    """
374
    from coolest.api.plotting import ModelPlotter, ParametersPlotter  # placed here to avoid circular import
×
375
    
376
    results = {}
×
377
    with tempfile.TemporaryDirectory() as tmpdir:
×
378
            
379
        if tar_path[-7:] == '.tar.gz':
×
380
            # Extract tar.gz archive
381
            with tarfile.open(tar_path, "r:gz") as tar:
×
382
                tar.extractall(path=tmpdir)
×
383
        elif tar_path[-4:] == '.zip':
×
384
            # Extract zip archive
385
            with zipfile.ZipFile(tar_path, 'r') as zipf:
×
386
                zipf.extractall(tmpdir)
×
387
        else:
388
            raise ValueError("Target path must point to a .tar.gz or .zip archive.")
×
389

390
        # Get path to the extracted JSON file
391
        extracted_items = os.listdir(tmpdir)
×
392
        if '__MACOSX' in extracted_items:
×
393
            extracted_items.remove('__MACOSX')  # remove macOS artifact folder if present
×
394
        extracted_path = os.path.join(tmpdir, extracted_items[0])
×
395
        if os.path.isdir(extracted_path):
×
396
            extracted_files = os.listdir(extracted_path)
×
397
        else:
398
            extracted_files = extracted_items
×
399
            extracted_path = tmpdir  # fallback
×
400
        
401
        json_files = [f for f in extracted_files if f.endswith(".json")]
×
402
        if not json_files:
×
403
            raise ValueError("No .json file found in archive.")
×
404
        elif len(json_files) > 1:
×
405
            raise ValueError("Multiple .json files found in archive.")
×
406
        
407
        json_path = os.path.join(extracted_path, json_files[0])
×
408
        target_path = os.path.splitext(json_path)[0]
×
409

410
        # Load COOLEST object
411
        coolest_1 = util.get_coolest_object(target_path, verbose=False)
×
412

413
        # Run analysis
414
        analysis = Analysis(coolest_1, target_path, supersampling=5)
×
415

416
        coord_orig = util.get_coordinates(coolest_1)
×
417
        coord_src = coord_orig.create_new_coordinates(pixel_scale_factor=0.1, grid_shape=(1.42, 1.42))
×
418

419
        # Extract values
420
        r_eff_source = analysis.effective_radius_light(center=(0,0), coordinates=coord_src, outer_radius=1., entity_selection=[2])
×
421
        einstein_radius = analysis.effective_einstein_radius(entity_selection=[0,1])
×
422

423
        results['r_eff_source'] = r_eff_source
×
424
        results['einstein_radius'] = einstein_radius
×
425
        results['lensing_entities'] = [type(le).__name__ for le in coolest_1.lensing_entities]
×
426
        results['source_light_model'] = [type(m).__name__ for m in coolest_1.lensing_entities[2].light_model]
×
427

428
        ### DMR Plot
429
        norm = Normalize(-0.005, 0.05)
×
430
        fig, axes = plt.subplots(2, 2, constrained_layout=True)
×
431
        splotter = ModelPlotter(coolest_1, coolest_directory=os.path.dirname(target_path))
×
432

433
        splotter.plot_data_image(axes[0, 0], norm=norm)
×
434
        axes[0, 0].set_title("Observed Data")
×
435

436
        splotter.plot_model_image(
×
437
            axes[0, 1],
438
            supersampling=5, convolved=True,
439
            kwargs_source=dict(entity_selection=[2]),
440
            kwargs_lens_mass=dict(entity_selection=[0, 1]),
441
            kwargs_lens_light=dict(entity_selection=[0, 1]),
442
            norm=norm
443
        )
444
        axes[0, 1].text(0.05, 0.05, f"$\\theta_{{\\rm E}}$ = {einstein_radius:.2f}\"", color='white', fontsize=12,
×
445
                        transform=axes[0, 1].transAxes)
446
        axes[0, 1].set_title("Image Model")
×
447

448
        splotter.plot_model_residuals(axes[1, 0], supersampling=5, add_chi2_label=True, chi2_fontsize=12,
×
449
                                      kwargs_source=dict(entity_selection=[2]),
450
                                      kwargs_lens_mass=dict(entity_selection=[0, 1]))
451
        axes[1, 0].set_title("Normalized Residuals")
×
452

453
        splotter.plot_surface_brightness(axes[1, 1], kwargs_light=dict(entity_selection=[2]),
×
454
                                         norm=norm, coordinates=coord_src)
455
        axes[1, 1].text(0.05, 0.05, f"$\\theta_{{\\rm eff}}$ = {r_eff_source:.2f}\"", color='white', fontsize=12,
×
456
                        transform=axes[1, 1].transAxes)
457
        axes[1, 1].set_title("Surface Brightness")
×
458

459
        for ax in axes[1]:
×
460
            ax.set_xlabel(r"$x$ (arcsec)")
×
461
            ax.set_ylabel(r"$y$ (arcsec)")
×
462
            
463
        if output_dir is not None:
×
464
            dmr_plot_path = os.path.join(output_dir, "dmr_plot.png")
×
465
            plt.savefig(dmr_plot_path, format='png', bbox_inches='tight')
×
466
            results['dmr_plot'] = dmr_plot_path
×
467
        plt.show()
×
468
        plt.close()
×
469

470
        
471

472
        
473
        ### Corner Plot
474
        truth = coolest_1
×
475
        # Only creates corner plot if sampling method was used to create lens model
476
        # Otherwise, no chains available for corner plot!
477
        if 'chain_file_name' in truth.meta.keys():
×
478
            free_pars = truth.lensing_entities.get_parameter_ids()[:-2]
×
479
            reorder = [2, 3, 4, 5, 6, 0, 1]
×
480
            pars = [free_pars[i] for i in reorder]
×
481
            results['free_parameters'] = pars
×
482
    
483
            param_plotter = ParametersPlotter(
×
484
                pars, [truth],
485
                coolest_directories=[os.path.dirname(target_path)],
486
                coolest_names=["Smooth source"],
487
                ref_coolest_objects=[truth],
488
                colors=['#7FB6F5', '#E03424'],
489
            )
490
    
491
            settings = {
×
492
                "ignore_rows": 0.0,
493
                "fine_bins_2D": 800,
494
                "smooth_scale_2D": 0.5,
495
                "mult_bias_correction_order": 5
496
            }
497
            param_plotter.init_getdist(settings_mcsamples=settings)
×
498
            param_plotter.plot_triangle_getdist(filled_contours=True, subplot_size=3)
×
499
            if output_dir is not None:
×
500
                corner_plot_path = os.path.join(output_dir, "corner_plot.png")
×
501
                plt.savefig(corner_plot_path, format='png', bbox_inches='tight')
×
502
                results['corner_plot'] = corner_plot_path
×
503
            plt.close()
×
504
    
505
            
506
        
507
    return results
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc