• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

aymgal / COOLEST / 18935731027

30 Oct 2025 09:18AM UTC coverage: 45.411% (-0.1%) from 45.557%
18935731027

push

github

aymgal
Ready for quick release 0.1.11

1 of 1 new or added line in 1 file covered. (100.0%)

40 existing lines in 1 file now uncovered.

1420 of 3127 relevant lines covered (45.41%)

0.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/coolest/api/plot_util.py
1
__author__ = 'aymgal', 'Giorgos Vernardos'
×
2

3

4
import numpy as np
×
5
import warnings
×
6
import matplotlib
×
7
import matplotlib.pyplot as plt
×
8
from matplotlib import ticker
×
9
from mpl_toolkits.axes_grid1 import make_axes_locatable
×
10
from scipy.spatial import Voronoi, voronoi_plot_2d
×
11
from matplotlib.colors import Normalize, LogNorm, TwoSlopeNorm
×
12
from matplotlib.cm import ScalarMappable
×
13
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
×
14
from coolest.api.composable_models import ComposableLensModel
×
15
import os
×
16
import tarfile
×
17
import tempfile
×
18
import zipfile
×
19
import io
×
20
from coolest.api import util
×
21
from coolest.api.analysis import Analysis
×
22

23

24
def plot_voronoi(ax, x, y, z, neg_values_as_bad=False, 
×
25
                 norm=None, cmap=None, zmin=None, zmax=None, 
26
                 edgecolor=None, zorder=1):
27

28
    if cmap is None:
×
29
        cmap = plt.get_cmap('inferno')
×
30
    if norm is None:
×
31
        if zmin is None:
×
32
            zmin = np.min(z)
×
33
        if zmax is None:
×
34
            zmax = np.max(z)
×
35
        norm = Normalize(zmin, zmax)
×
36

37
    # get voronoi regions
38
    voronoi_points = np.column_stack((x,y))
×
39
    vor = Voronoi(voronoi_points)
×
40
    new_regions, vertices = voronoi_finite_polygons_2d(vor)
×
41
    
42
    # get cell colors
43
    m = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)
×
44

45
    # plot voronoi points
46
    #point_colors = [ m.to_rgba(v) for v in z ]
47
    #ax.scatter(voronoi_points[:,0],voronoi_points[:,1],c=point_colors)
48

49
    # plot voronoi cells
50
    for i, region in enumerate(new_regions):
×
51
        polygon = vertices[region]
×
52
        z_i = z[i]
×
53
        if neg_values_as_bad is True and z_i < 0.:
×
54
            cell_color = m.to_rgba(np.nan)
×
55
        else:
56
            cell_color = m.to_rgba(z_i)
×
57
        ax.fill(*zip(*polygon), facecolor=cell_color, edgecolor=edgecolor, zorder=zorder)
×
58
    return m
×
59

60

61
def voronoi_finite_polygons_2d(vor,radius=None):
×
62
    """
63
    Reconstruct infinite voronoi regions in a 2D diagram to finite
64
    regions.
65
    This function is taken from: https://gist.github.com/pv/8036995
66

67
    Parameters
68
    ----------
69
    vor : Voronoi
70
        Input diagram
71
    radius : float, optional
72
        Distance to 'points at infinity'.
73

74
    Returns
75
    -------
76
    regions : list of tuples
77
        Indices of vertices in each revised Voronoi regions.
78
    vertices : list of tuples
79
        Coordinates for revised Voronoi vertices. Same as coordinates
80
        of input vertices, with 'points at infinity' appended to the
81
        end.
82
    """
83

84
    if vor.points.shape[1] != 2:
×
85
        raise ValueError("Requires 2D input")
×
86

87
    new_regions = []
×
88
    new_vertices = vor.vertices.tolist()
×
89

90
    center = vor.points.mean(axis=0)
×
91
    if radius is None:
×
92
        radius = vor.points.ptp().max()
×
93

94
    # Construct a map containing all ridges for a given point
95
    all_ridges = {}
×
96
    for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
×
97
        all_ridges.setdefault(p1, []).append((p2, v1, v2))
×
98
        all_ridges.setdefault(p2, []).append((p1, v1, v2))
×
99

100
    # Reconstruct infinite regions
101
    for p1, region in enumerate(vor.point_region):
×
102
        vertices = vor.regions[region]
×
103

104
        if all(v >= 0 for v in vertices):
×
105
            # finite region
106
            new_regions.append(vertices)
×
107
            continue
×
108

109
        # reconstruct a non-finite region
110
        ridges = all_ridges[p1]
×
111
        new_region = [v for v in vertices if v >= 0]
×
112

113
        for p2, v1, v2 in ridges:
×
114
            if v2 < 0:
×
115
                v1, v2 = v2, v1
×
116
            if v1 >= 0:
×
117
                # finite ridge: already in the region
118
                continue
×
119

120
            # Compute the missing endpoint of an infinite ridge
121

122
            t = vor.points[p2] - vor.points[p1] # tangent
×
123
            t /= np.linalg.norm(t)
×
124
            n = np.array([-t[1], t[0]])  # normal
×
125

126
            midpoint = vor.points[[p1, p2]].mean(axis=0)
×
127
            direction = np.sign(np.dot(midpoint - center, n)) * n
×
128
            far_point = vor.vertices[v2] + direction * radius
×
129

130
            new_region.append(len(new_vertices))
×
131
            new_vertices.append(far_point.tolist())
×
132

133
        # sort region counterclockwise
134
        vs = np.asarray([new_vertices[v] for v in new_region])
×
135
        c = vs.mean(axis=0)
×
136
        angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0])
×
137
        new_region = np.array(new_region)[np.argsort(angles)]
×
138

139
        # finish
140
        new_regions.append(new_region.tolist())
×
141
        
142
    return new_regions, np.asarray(new_vertices)
×
143

144

145
def std_colorbar(mappable, label=None, fontsize=12, label_kwargs={}, **colorbar_kwargs):
×
146
    cb = plt.colorbar(mappable, **colorbar_kwargs)
×
147
    if label is not None:
×
148
        colorbar_kwargs.pop('label', None)
×
149
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
150
    return cb
×
151

152
def std_colorbar_residuals(mappable, res_map, vmin, vmax, label=None, fontsize=12, 
×
153
                           label_kwargs={}, **colorbar_kwargs):
154
    if res_map.min() < vmin and res_map.max() > vmax:
×
155
        cb_extend = 'both'
×
156
    elif res_map.min() < vmin:
×
157
        cb_extend = 'min'
×
158
    elif res_map.max() > vmax:
×
159
        cb_extend = 'max'
×
160
    else:
161
        cb_extend = 'neither'
×
162
    colorbar_kwargs.update({'extend': cb_extend})
×
163
    return std_colorbar(mappable, label=label, fontsize=fontsize, 
×
164
                        label_kwargs=label_kwargs, **colorbar_kwargs)
165

166
def nice_colorbar(mappable, ax=None, position='right', pad=0.1, size='5%', label=None, fontsize=12, 
×
167
                  invisible=False, 
168
                  #max_nbins=None,
169
                  divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
170
    divider_kwargs.update({'position': position, 'pad': pad, 'size': size})
×
171
    if ax is None:
×
172
        ax = mappable.axes
×
173
    divider = make_axes_locatable(ax)
×
174
    cax = divider.append_axes(**divider_kwargs)
×
175
    if invisible:
×
176
        cax.axis('off')
×
177
        return None
×
178
    cb = plt.colorbar(
×
179
        mappable, cax=cax, 
180
        **colorbar_kwargs
181
    )
182
    if label is not None:
×
183
        colorbar_kwargs.pop('label', None)
×
184
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
185
    if position == 'top':
×
186
        cax.xaxis.set_ticks_position('top')
×
187
    # if max_nbins is not None: # TODO: this leads to strange results
188
    #     # cb.locator = ticker.LogLocator(subs=range(10))
189
    #     # cb.update_ticks()
190
    return cb
×
191

192
def nice_colorbar_residuals(mappable, res_map, vmin, vmax, ax=None, position='right', pad=0.1, size='5%', 
×
193
                            invisible=False, label=None, fontsize=12,
194
                            divider_kwargs={}, colorbar_kwargs={}, label_kwargs={}):
195
    if res_map.min() < vmin and res_map.max() > vmax:
×
196
        cb_extend = 'both'
×
197
    elif res_map.min() < vmin:
×
198
        cb_extend = 'min'
×
199
    elif res_map.max() > vmax:
×
200
        cb_extend = 'max'
×
201
    else:
202
        cb_extend = 'neither'
×
203
    colorbar_kwargs.update({'extend': cb_extend})
×
204
    return nice_colorbar(mappable, ax=ax, position=position, pad=pad, size=size, label=label, fontsize=fontsize,
×
205
                  invisible=invisible, colorbar_kwargs=colorbar_kwargs, label_kwargs=label_kwargs,
206
                  divider_kwargs=divider_kwargs)
207

208
def cax_colorbar(fig, cax, norm=None, cmap=None, mappable=None, label=None, fontsize=12, orientation='horizontal', label_kwargs={}):
×
209
    if mappable is None:
×
210
        mappable = ScalarMappable(norm=norm, cmap=cmap)
×
211
    cb = fig.colorbar(mappable=mappable, orientation=orientation, cax=cax)
×
212
    if label is not None:
×
213
        cb.set_label(label, fontsize=fontsize, **label_kwargs)
×
214

215
def scale_bar(ax, size, unit_suffix='"', loc='lower left', color='#FFFFFFBB', fontsize=12):
×
216
    if size == int(size):
×
217
        label = f"{int(size)}"
×
218
    else:
219
        label = f"{size:.1f}"
×
220
    label += unit_suffix
×
221
    artist = AnchoredSizeBar(
×
222
        ax.transData,
223
        size, label,
224
        loc=loc, label_top=True,
225
        pad=0.8, sep=5, 
226
        color=color, fontproperties=dict(size=fontsize),
227
        frameon=False, size_vertical=0,
228
    )
229
    ax.add_artist(artist)
×
230
    return artist
×
231

232
def plot_regular_grid(ax, title, image_, neg_values_as_bad=False, xylim=None, **imshow_kwargs):
×
233
    if neg_values_as_bad:
×
234
        image = np.copy(image_)
×
235
        image[image < 0] = np.nan
×
236
    else:
237
        image = image_
×
238
    im = ax.imshow(image, **imshow_kwargs)
×
239
    im.set_rasterized(True)
×
240
    set_xy_limits(ax, xylim)
×
241
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
242
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
243
    ax.set_title(title)
×
244
    return ax, im
×
245

246
def plot_irregular_grid(ax, title, points, xylim, neg_values_as_bad=False,
×
247
                            norm=None, cmap=None, plot_points=False):
248
    x, y, z = points
×
249
    im = plot_voronoi(ax, x, y, z, neg_values_as_bad=neg_values_as_bad, 
×
250
                      norm=norm, cmap=cmap, zorder=1)
251
    ax.set_aspect('equal', 'box')
×
252
    set_xy_limits(ax, xylim)
×
253
    ax.xaxis.set_major_locator(plt.MaxNLocator(3))
×
254
    ax.yaxis.set_major_locator(plt.MaxNLocator(3))
×
255
    if plot_points:
×
256
        ax.scatter(x, y, s=5, c='white', marker='.', alpha=0.4, zorder=2)
×
257
    ax.set_title(title)
×
258
    return ax, im
×
259

260
def set_xy_limits(ax, xylim):
×
261
    if xylim is None: return  # do nothing
×
262
    if isinstance(xylim, (int, float)):
×
263
        xylim_ = [-xylim, xylim, -xylim, xylim]
×
264
    elif isinstance(xylim, (list, tuple)) and len(xylim) == 2:
×
265
        xylim_ = [xylim[0], xylim[1], xylim[0], xylim[1]]
×
266
    elif not isinstance(xylim, (list, tuple)) and len(xylim) != 4:
×
267
        raise ValueError("`xylim` argument should be a single number, a 2-tuple or a 4-tuple.")
×
268
    else:
269
        xylim_ = xylim
×
270
    ax.set_xlim(xylim_[0], xylim_[1])
×
271
    ax.set_ylim(xylim_[2], xylim_[3])
×
272

273
def panel_label(ax, text, color, fontsize, alpha=0.8, loc='upper left'):
×
274
    if loc == 'upper left':
×
275
        x, y, ha, va = 0.03, 0.97, 'left', 'top'
×
276
    elif loc == 'lower left':
×
277
        x, y, ha, va = 0.03, 0.03, 'left', 'bottom'
×
278
    elif loc == 'upper right':
×
279
        x, y, ha, va = 0.97, 0.97, 'right', 'top'
×
280
    elif loc == 'lower right':
×
281
        x, y, ha, va = 0.97, 0.03, 'right', 'bottom'
×
282
    ax.text(x, y, text, color=color, fontsize=fontsize, alpha=alpha, 
×
283
            ha=ha, va=va, transform=ax.transAxes)
284

285
def normalize_across_images(plotter_list, data_model_specifier, kwargs_source = None, kwargs_lens_mass = None,
×
286
                            supersampling=5, convolved=True, super_convolution=True):
287
    """Calculate the vmin and vmax to normalize the colormap across multiple coolest objects
288

289
    Parameters
290
    ----------
291
    plotter_list: list
292
        List of ModelPlotter objects. May be acquired from MultiModelPlotter object
293
    data_model_specifier: list
294
        List of 0s and 1s; 0 = data, 1 = model. Specifies which set of pixel values should be used
295
        when finding global minima and maxima -- data or model
296
    kwargs_source: dict
297
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
298
        "Entity_selection" contains list of lists. Selects source entities.
299
        Insert dummy None values into dictionary for data.
300
    kwargs_lens_mass: dict
301
        Dictionary with "entity_selection" key, same as used in MultiModelPlotters.
302
        "Entity_selection" contains list of lists. Selects lens mass entities.
303
        Insert dummy None values into dictionary for data.
304
    supersampling: int
305
        Model image generation param
306
    convolved: bool
307
        Model image generation param
308
    super_convolution: bool
309
        Model image generation param
310
    
311
    
312
    Returns
313
    -------
314
    vmin: float
315
        global min value across all coolest objects in plotter_list for the specified data/models
316
    vmax: float
317
        global max value across all coolest objects in plotter_list for the specified data/models    
318
    """
319
    
320
    mins = []
×
321
    maxes = []
×
322
    ks_arr = kwargs_source['entity_selection']
×
323
    km_arr = kwargs_lens_mass['entity_selection']
×
324
    for plotter, d_or_f, ks, km in zip(plotter_list, data_model_specifier, ks_arr, km_arr):
×
325
        # Check if we are finding extrema for data or model
326
        if d_or_f == 0:
×
327
            image = plotter.coolest.observation.pixels.get_pixels(directory=plotter._directory)
×
328
        elif d_or_f == 1:
×
329
            lens_model = ComposableLensModel(plotter.coolest, plotter._directory,
×
330
                                         kwargs_selection_source=dict(entity_selection=ks),
331
                                         kwargs_selection_lens_mass=dict(entity_selection=km))
332
            image, _ = lens_model.model_image(supersampling, convolved, super_convolution)
×
333
        # Find min and max and append
334
        mins.append(np.min(image))
×
335
        maxes.append(np.max(image))
×
336
    vmin = min(mins)
×
337
    vmax = max(maxes)
×
338
    return vmin, vmax
×
339

340
    
341
def dmr_corner(tar_path, output_dir = None):
×
342
    """Given .tar.gz COOLEST file, plots and optionally saves DMR and corner plots for COOLEST file. Returns dictionary of important extracted information.
343

344
    Parameters
345
    ----------
346
    tar_path : string
347
        Path to .tar.gz COOLEST file
348
    output_dir : string, optional
349
        Path to automatically save DMR and corner plot to if specified, by default None
350
    
351
    Returns
352
    -------
353
    results: dictionary
354
        Contains useful information about the COOLEST objects
355
    """
UNCOV
356
    from coolest.api.plotting import ModelPlotter, ParametersPlotter  # placed here to avoid circular import
×
357
    
358
    results = {}
×
UNCOV
359
    with tempfile.TemporaryDirectory() as tmpdir:
×
360
            
361
        if tar_path[-7:] == '.tar.gz':
×
362
            # Extract tar.gz archive
UNCOV
363
            with tarfile.open(tar_path, "r:gz") as tar:
×
UNCOV
364
                tar.extractall(path=tmpdir)
×
365
        elif tar_path[-4:] == '.zip':
×
366
            # Extract zip archive
367
            with zipfile.ZipFile(tar_path, 'r') as zipf:
×
368
                zipf.extractall(tmpdir)
×
369
        else:
370
            raise ValueError("Target path must point to a .tar.gz or .zip archive.")
×
371

372
        # Get path to the extracted JSON file
373
        extracted_items = os.listdir(tmpdir)
×
374
        if '__MACOSX' in extracted_items:
×
375
            extracted_items.remove('__MACOSX')  # remove macOS artifact folder if present
×
UNCOV
376
        extracted_path = os.path.join(tmpdir, extracted_items[0])
×
377
        if os.path.isdir(extracted_path):
×
378
            extracted_files = os.listdir(extracted_path)
×
379
        else:
UNCOV
380
            extracted_files = extracted_items
×
381
            extracted_path = tmpdir  # fallback
×
382
        
UNCOV
383
        json_files = [f for f in extracted_files if f.endswith(".json")]
×
384
        if not json_files:
×
UNCOV
385
            raise ValueError("No .json file found in archive.")
×
386
        elif len(json_files) > 1:
×
387
            raise ValueError("Multiple .json files found in archive.")
×
388
        
UNCOV
389
        json_path = os.path.join(extracted_path, json_files[0])
×
390
        target_path = os.path.splitext(json_path)[0]
×
391

392
        # Load COOLEST object
393
        coolest_1 = util.get_coolest_object(target_path, verbose=False)
×
394

395
        # Run analysis
396
        analysis = Analysis(coolest_1, target_path, supersampling=5)
×
397

UNCOV
398
        coord_orig = util.get_coordinates(coolest_1)
×
399
        coord_src = coord_orig.create_new_coordinates(pixel_scale_factor=0.1, grid_shape=(1.42, 1.42))
×
400

401
        # Extract values
UNCOV
402
        r_eff_source = analysis.effective_radius_light(center=(0,0), coordinates=coord_src, outer_radius=1., entity_selection=[2])
×
403
        einstein_radius = analysis.effective_einstein_radius(entity_selection=[0,1])
×
404

UNCOV
405
        results['r_eff_source'] = r_eff_source
×
406
        results['einstein_radius'] = einstein_radius
×
UNCOV
407
        results['lensing_entities'] = [type(le).__name__ for le in coolest_1.lensing_entities]
×
UNCOV
408
        results['source_light_model'] = [type(m).__name__ for m in coolest_1.lensing_entities[2].light_model]
×
409

410
        ### DMR Plot
UNCOV
411
        norm = Normalize(-0.005, 0.05)
×
UNCOV
412
        fig, axes = plt.subplots(2, 2, constrained_layout=True)
×
413
        splotter = ModelPlotter(coolest_1, coolest_directory=os.path.dirname(target_path))
×
414

415
        splotter.plot_data_image(axes[0, 0], norm=norm)
×
UNCOV
416
        axes[0, 0].set_title("Observed Data")
×
417

UNCOV
418
        splotter.plot_model_image(
×
419
            axes[0, 1],
420
            supersampling=5, convolved=True,
421
            kwargs_source=dict(entity_selection=[2]),
422
            kwargs_lens_mass=dict(entity_selection=[0, 1]),
423
            norm=norm
424
        )
UNCOV
425
        axes[0, 1].text(0.05, 0.05, f"$\\theta_{{\\rm E}}$ = {einstein_radius:.2f}\"", color='white', fontsize=12,
×
426
                        transform=axes[0, 1].transAxes)
UNCOV
427
        axes[0, 1].set_title("Image Model")
×
428

429
        splotter.plot_model_residuals(axes[1, 0], supersampling=5, add_chi2_label=True, chi2_fontsize=12,
×
430
                                      kwargs_source=dict(entity_selection=[2]),
431
                                      kwargs_lens_mass=dict(entity_selection=[0, 1]))
432
        axes[1, 0].set_title("Normalized Residuals")
×
433

434
        splotter.plot_surface_brightness(axes[1, 1], kwargs_light=dict(entity_selection=[2]),
×
435
                                         norm=norm, coordinates=coord_src)
436
        axes[1, 1].text(0.05, 0.05, f"$\\theta_{{\\rm eff}}$ = {r_eff_source:.2f}\"", color='white', fontsize=12,
×
437
                        transform=axes[1, 1].transAxes)
UNCOV
438
        axes[1, 1].set_title("Surface Brightness")
×
439

UNCOV
440
        for ax in axes[1]:
×
UNCOV
441
            ax.set_xlabel(r"$x$ (arcsec)")
×
UNCOV
442
            ax.set_ylabel(r"$y$ (arcsec)")
×
443
            
UNCOV
444
        if output_dir is not None:
×
UNCOV
445
            dmr_plot_path = os.path.join(output_dir, "dmr_plot.png")
×
446
            plt.savefig(dmr_plot_path, format='png', bbox_inches='tight')
×
447
            results['dmr_plot'] = dmr_plot_path
×
448
        plt.show()
×
449
        plt.close()
×
450

451
        
452

453
        
454
        ### Corner Plot
UNCOV
455
        truth = coolest_1
×
456
        # Only creates corner plot if sampling method was used to create lens model
457
        # Otherwise, no chains available for corner plot!
UNCOV
458
        if 'chain_file_name' in truth.meta.keys():
×
UNCOV
459
            free_pars = truth.lensing_entities.get_parameter_ids()[:-2]
×
460
            reorder = [2, 3, 4, 5, 6, 0, 1]
×
UNCOV
461
            pars = [free_pars[i] for i in reorder]
×
UNCOV
462
            results['free_parameters'] = pars
×
463
    
UNCOV
464
            param_plotter = ParametersPlotter(
×
465
                pars, [truth],
466
                coolest_directories=[os.path.dirname(target_path)],
467
                coolest_names=["Smooth source"],
468
                ref_coolest_objects=[truth],
469
                colors=['#7FB6F5', '#E03424'],
470
            )
471
    
472
            settings = {
×
473
                "ignore_rows": 0.0,
474
                "fine_bins_2D": 800,
475
                "smooth_scale_2D": 0.5,
476
                "mult_bias_correction_order": 5
477
            }
UNCOV
478
            param_plotter.init_getdist(settings_mcsamples=settings)
×
UNCOV
479
            param_plotter.plot_triangle_getdist(filled_contours=True, subplot_size=3)
×
UNCOV
480
            if output_dir is not None:
×
UNCOV
481
                corner_plot_path = os.path.join(output_dir, "corner_plot.png")
×
UNCOV
482
                plt.savefig(corner_plot_path, format='png', bbox_inches='tight')
×
UNCOV
483
                results['corner_plot'] = corner_plot_path
×
UNCOV
484
            plt.close()
×
485
    
486
            
487
        
UNCOV
488
    return results
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc