• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 8295059095

15 Mar 2024 10:52AM UTC coverage: 84.86% (+3.1%) from 81.754%
8295059095

push

github

web-flow
Add pycolorbar for colormaps and colorbar configurations (#44)

* Add pycolorbar configuration files

* Fix cmap=None bug

* Enable plot levels argument

* Fix code indent

* Add pycolorbar dependency

* Fix bug for horizontal colorbars with ticklabels

* Fix deprecation warnings

* Remove utils_cmap code

* Refactor viz internals

* Swith from {} to None for default arguments

* Update package requirements

* Update dataset decoding

* Update tests data

502 of 524 new or added lines in 30 files covered. (95.8%)

24 existing lines in 10 files now uncovered.

8514 of 10033 relevant lines covered (84.86%)

0.85 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
1✔
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import matplotlib.pyplot as plt
1✔
33
import numpy as np
1✔
34
import xarray as xr
1✔
35
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
36
from scipy.interpolate import griddata
1✔
37

38
import gpm
1✔
39

40

41
def is_generator(obj):
1✔
42
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
43

44

45
def _call_optimize_layout(self):
1✔
46
    """Optimize the figure layout."""
47
    adapt_fig_size(ax=self.axes)
×
48
    self.figure.tight_layout()
×
49

50

51
def add_optimize_layout_method(p):
1✔
52
    """Add a method to optimize the figure layout using monkey patching."""
53
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
54
    return p
1✔
55

56

57
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
58
    """
59
    Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
60

61
    This function is intended to be called after all plotting has been completed.
62
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
63

64
    Assumes that the first axis in the collection of axes is representative of all others.
65
    This means that all subplots are expected to have the same aspect ratio and size.
66

67
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
68
    """
69
    # Determine the number of rows and columns of subplots in the figure.
70
    # This information is crucial for calculating the new height of the figure.
71
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
72

73
    # Access the figure object from the axis to manipulate its properties.
74
    fig = ax.get_figure()
1✔
75

76
    # Retrieve the current size of the figure in inches.
77
    width, original_height = fig.get_size_inches()
1✔
78

79
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
80
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
81
    fig.canvas.draw()
1✔
82

83
    # Extract subplot parameters to understand the figure's layout.
84
    # These parameters include the margins of the figure and the spaces between subplots.
85
    bottom = fig.subplotpars.bottom
1✔
86
    top = fig.subplotpars.top
1✔
87
    left = fig.subplotpars.left
1✔
88
    right = fig.subplotpars.right
1✔
89
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
90
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
91

92
    # Calculate the aspect ratio of the data in the subplot.
93
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
94
    aspect = ax.get_data_ratio()
1✔
95

96
    # Calculate the width of a single plot, considering the left and right margins,
97
    # the number of columns, and the space between columns.
98
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
99

100
    # Calculate the height of a single plot using its width and the data aspect ratio.
101
    hp = wp * aspect
1✔
102

103
    # Calculate the new height of the figure, taking into account the number of rows,
104
    # the space between rows, and the top and bottom margins.
105
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
106

107
    # Check if the new height is significantly reduced (more than halved).
108
    if original_height / height > 2:
1✔
109
        # Calculate the scale factor to adjust the figure size closer to the original.
110
        scale_factor = original_height / height / 2
×
111

112
        # Apply the scale factor to both width and height to maintain the aspect ratio.
113
        width *= scale_factor
×
114
        height *= scale_factor
×
115

116
    # Apply the calculated width and height to adjust the figure size.
117
    fig.set_figwidth(width)
1✔
118
    fig.set_figheight(height)
1✔
119

120

121
####--------------------------------------------------------------------------.
122

123

124
def get_antimeridian_mask(lons, buffer=True):
1✔
125
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
126
    from scipy.ndimage import binary_dilation
1✔
127

128
    # Initialize mask
129
    n_y, n_x = lons.shape
1✔
130
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
131
    # Check vertical edges
132
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
133
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
134
    mask[row_idx, col_idx] = 1
1✔
135
    # Check horizontal edges
136
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
137
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
138
    mask[row_idx, col_idx] = 1
1✔
139
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
140
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
141
    mask = binary_dilation(mask)
1✔
142
    return mask
1✔
143

144

145
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
146
    """Infill invalid coordinates.
147

148
    Interpolate the coordinates within the convex hull of data.
149
    Use nearest neighbour outside the convex hull of data.
150
    """
151
    # Copy object
152
    xr_obj = xr_obj.copy()
1✔
153
    lon = np.asanyarray(xr_obj[x].data)
1✔
154
    lat = np.asanyarray(xr_obj[y].data)
1✔
155
    # Retrieve infilled coordinates
156
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
157
    xr_obj[x].data = lon
1✔
158
    xr_obj[y].data = lat
1✔
159
    return xr_obj
1✔
160

161

162
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
163
    """Infill invalid coordinates.
164

165
    Interpolate the coordinates within the convex hull of data.
166
    Use nearest neighbour outside the convex hull of data.
167

168
    This operation is required to plot with pcolormesh since it
169
    does not accept non-finite values in the coordinates.
170

171
    If  mask_data=True, data values with invalid coordinates are masked
172
    and a numpy masked array is returned.
173
    Masked data values are not displayed in pcolormesh !
174
    If rgb=True, it assumes the RGB dimension is the last data dimension.
175

176
    """
177
    # Retrieve mask of invalid coordinates
178
    x_invalid = ~np.isfinite(x)
1✔
179
    y_invalid = ~np.isfinite(y)
1✔
180
    mask = np.logical_or(x_invalid, y_invalid)
1✔
181

182
    # If no invalid coordinates, return original data
183
    if np.all(~mask):
1✔
184
        return x, y, data
1✔
185

186
    # Mask the data
187
    if mask_data:
1✔
188
        if rgb:
1✔
NEW
189
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
NEW
190
            data_masked = np.ma.masked_where(data_mask, data)
×
191
        else:
192
            data_masked = np.ma.masked_where(mask, data)
1✔
193
    else:
NEW
194
        data_masked = data
×
195

196
    # Infill x and y
197
    if np.any(x_invalid):
1✔
198
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
199
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
200
    if np.any(y_invalid):
1✔
201
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
202
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
203
    return x, y, data_masked
1✔
204

205

206
def _interpolate_data(arr, method="linear"):
1✔
207
    # Find invalid locations
208
    is_invalid = ~np.isfinite(arr)
1✔
209

210
    # Find the indices of NaN values
211
    nan_indices = np.where(is_invalid)
1✔
212

213
    # Return array if not NaN values
214
    if len(nan_indices) == 0:
1✔
NEW
215
        return arr
×
216

217
    # Find the indices of non-NaN values
218
    non_nan_indices = np.where(~is_invalid)
1✔
219

220
    # Create a meshgrid of indices
221
    X, Y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
222

223
    # Points (X, Y) where we have valid data
224
    points = np.array([Y[non_nan_indices], X[non_nan_indices]]).T
1✔
225

226
    # Points where data is NaN
227
    points_nan = np.array([Y[nan_indices], X[nan_indices]]).T
1✔
228

229
    # Values at the non-NaN points
230
    values = arr[non_nan_indices]
1✔
231

232
    # Interpolate using griddata
233
    arr_new = arr.copy()
1✔
234
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
235
    return arr_new
1✔
236

237

238
####--------------------------------------------------------------------------.
239

240

241
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None):
1✔
242
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
243
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
244
    if ax is not None:
1✔
245
        if len(subplot_kwargs) >= 1:
1✔
246
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
1✔
247
        if len(fig_kwargs) >= 1:
1✔
248
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
1✔
249
    return fig_kwargs
1✔
250

251

252
def preprocess_subplot_kwargs(subplot_kwargs):
1✔
253
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
254
    subplot_kwargs = subplot_kwargs.copy()
1✔
255
    if "projection" not in subplot_kwargs:
1✔
256
        subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
257
    return subplot_kwargs
1✔
258

259

260
def initialize_cartopy_plot(
1✔
261
    ax,
262
    fig_kwargs,
263
    subplot_kwargs,
264
    add_background,
265
):
266
    """Initialize figure for cartopy plot if necessary."""
267
    # - Initialize figure
268
    if ax is None:
1✔
269
        fig_kwargs = preprocess_figure_args(
1✔
270
            ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs
271
        )
272
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
1✔
273
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
274

275
    # - Add cartopy background
276
    if add_background:
1✔
277
        ax = plot_cartopy_background(ax)
1✔
278
    return ax
1✔
279

280

281
####--------------------------------------------------------------------------.
282

283

284
def plot_cartopy_background(ax):
1✔
285
    """Plot cartopy background."""
286
    # - Add coastlines
287
    ax.coastlines()
1✔
288
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
289
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
290
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
291
    # - Add grid lines
292
    gl = ax.gridlines(
1✔
293
        crs=ccrs.PlateCarree(),
294
        draw_labels=True,
295
        linewidth=1,
296
        color="gray",
297
        alpha=0.1,
298
        linestyle="-",
299
    )
300
    gl.top_labels = False  # gl.xlabels_top = False
1✔
301
    gl.right_labels = False  # gl.ylabels_right = False
1✔
302
    gl.xlines = True
1✔
303
    gl.ylines = True
1✔
304
    return ax
1✔
305

306

307
def plot_sides(sides, ax, **plot_kwargs):
1✔
308
    """Plot boundary sides.
309

310
    Expects a list of (lon, lat) tuples.
311
    """
312
    for side in sides:
1✔
313
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
314
    return p[0]
1✔
315

316

317
def plot_colorbar(p, ax, cbar_kwargs=None):
1✔
318
    """Add a colorbar to a matplotlib/cartopy plot.
319

320
    cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
321
    and the padding between the plot and the colorbar.
322

323
    p: matplotlib.image.AxesImage
324
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot^
325
    """
326
    cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
1✔
327
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
1✔
328
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
329
    orientation = cbar_kwargs.get("orientation", "vertical")
1✔
330

331
    divider = make_axes_locatable(ax)
1✔
332

333
    if orientation == "vertical":
1✔
334
        size = cbar_kwargs.get("size", "5%")
1✔
335
        pad = cbar_kwargs.get("pad", 0.1)
1✔
336
        cax = divider.append_axes("right", size=size, pad=pad, axes_class=plt.Axes)
1✔
337
    elif orientation == "horizontal":
1✔
338
        size = cbar_kwargs.get("size", "5%")
1✔
339
        pad = cbar_kwargs.get("pad", 0.25)
1✔
340
        cax = divider.append_axes("bottom", size=size, pad=pad, axes_class=plt.Axes)
1✔
341
    else:
342
        raise ValueError("Invalid orientation. Choose 'vertical' or 'horizontal'.")
×
343

344
    p.figure.add_axes(cax)
1✔
345
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
1✔
346
    if ticklabels is not None:
1✔
347
        if orientation == "vertical":
1✔
348
            _ = cbar.ax.set_yticklabels(ticklabels)
1✔
349
        else:  # horizontal
NEW
350
            _ = cbar.ax.set_xticklabels(ticklabels)
×
351
    return cbar
1✔
352

353

354
####--------------------------------------------------------------------------.
355

356

357
def get_dataarray_extent(da, x="lon", y="lat"):
1✔
358
    # TODO: compute corners array to estimate the extent
359
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
360
    # Get the minimum and maximum longitude and latitude values
361
    lon_min, lon_max = da[x].min(), da[x].max()
1✔
362
    lat_min, lat_max = da[y].min(), da[y].max()
1✔
363
    extent = (lon_min, lon_max, lat_min, lat_max)
1✔
364
    return extent
1✔
365

366

367
def _compute_extent(x_coords, y_coords):
1✔
368
    """
369
    Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
370
    This function assumes that the spacing between each pixel is uniform.
371
    """
372
    # Calculate the pixel size assuming uniform spacing between pixels
373
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
1✔
374
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
1✔
375

376
    # Adjust min and max to get the corners of the outer pixels
377
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
1✔
378
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
1✔
379

380
    return [x_min, x_max, y_min, y_max]
1✔
381

382

383
def _plot_cartopy_imshow(
1✔
384
    ax,
385
    da,
386
    x,
387
    y,
388
    interpolation="nearest",
389
    add_colorbar=True,
390
    plot_kwargs={},
391
    cbar_kwargs=None,
392
):
393
    """Plot imshow with cartopy."""
394
    # - Ensure image with correct dimensions orders
395
    da = da.transpose(y, x)
1✔
396
    arr = np.asanyarray(da.data)
1✔
397

398
    # - Compute coordinates
399
    x_coords = da[x].values
1✔
400
    y_coords = da[y].values
1✔
401

402
    # - Derive extent
403
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
404

405
    # - Determine origin based on the orientation of da[y] values
406
    # -->  If increasing, set origin="lower"
407
    # -->  If decreasing, set origin="upper"
408
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
1✔
409

410
    # - Add variable field with cartopy
411
    p = ax.imshow(
1✔
412
        arr,
413
        transform=ccrs.PlateCarree(),
414
        extent=extent,
415
        origin=origin,
416
        interpolation=interpolation,
417
        **plot_kwargs,
418
    )
419
    # - Set the extent
420
    extent = get_dataarray_extent(da, x="lon", y="lat")
1✔
421
    ax.set_extent(extent)
1✔
422

423
    # - Add colorbar
424
    if add_colorbar:
1✔
425
        # --> TODO: set axis proportion in a meaningful way ...
426
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
427
    return p
1✔
428

429

430
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
431
    if np.ma.is_masked(arr):
1✔
NEW
432
        if rgb:
×
NEW
433
            data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
NEW
434
            combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
435
        else:
NEW
436
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
NEW
437
        arr = np.ma.masked_where(combined_mask, arr)
×
438
    else:
439
        if rgb:
1✔
440
            antimeridian_mask = np.broadcast_to(
1✔
441
                np.expand_dims(antimeridian_mask, axis=-1), arr.shape
442
            )
443
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
444
    return arr
1✔
445

446

447
def _plot_cartopy_pcolormesh(
1✔
448
    ax,
449
    da,
450
    x,
451
    y,
452
    rgb=False,
453
    add_colorbar=True,
454
    add_swath_lines=True,
455
    plot_kwargs={},
456
    cbar_kwargs=None,
457
):
458
    """Plot imshow with cartopy.
459

460
    The function currently does not allow to zoom on regions across the antimeridian.
461
    The function mask scanning pixels which spans across the antimeridian.
462
    If rgb=True, expect rgb dimension to be at last position.
463
    x and y must represents longitude and latitudes.
464
    """
465
    # Get x, y, and array to plot
466
    da = da.compute()
1✔
467
    lon = da[x].data
1✔
468
    lat = da[y].data
1✔
469
    arr = da.data
1✔
470

471
    # If RGB, expect last dimension to have 3 channels
472
    if rgb:
1✔
473
        if arr.shape[-1] != 3 and arr.shape[-1] != 4:
1✔
474
            raise ValueError("RGB array must have 3 or 4 channels in the last dimension.")
1✔
475

476
    # Infill invalid value and add mask if necessary
477
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb)
1✔
478

479
    # Ensure arguments
480
    if rgb:
1✔
481
        add_colorbar = False
1✔
482

483
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
484
    # - This enable correct masking of cells crossing the antimeridian
485
    from gpm.utils.area import _get_lonlat_corners
1✔
486

487
    lon, lat = _get_lonlat_corners(lon, lat)
1✔
488

489
    # Mask cells crossing the antimeridian
490
    # --> Here we assume not invalid coordinates anymore
491
    # --> Cartopy still bugs with several projections when data cross the antimeridian
492
    # --> This flag can be unset with gpm.config.set({"viz_hide_antimeridian_data": False})
493
    if gpm.config.get("viz_hide_antimeridian_data"):
1✔
494
        antimeridian_mask = get_antimeridian_mask(lon, buffer=True)
1✔
495
        is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
496
        if is_crossing_antimeridian:
1✔
497
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
498

499
            # Sanitize cmap bad color to avoid cartopy bug
500
            # - TODO cartopy requires bad_color to be transparent ...
501
            cmap = plot_kwargs.get("cmap", None)
1✔
502
            if cmap is not None:
1✔
UNCOV
503
                bad = cmap.get_bad()
×
UNCOV
504
                bad[3] = 0  # enforce to 0 (transparent)
×
UNCOV
505
                cmap.set_bad(bad)
×
UNCOV
506
                plot_kwargs["cmap"] = cmap
×
507

508
    # Add variable field with cartopy
509
    p = ax.pcolormesh(
1✔
510
        lon,
511
        lat,
512
        arr,
513
        transform=ccrs.PlateCarree(),
514
        **plot_kwargs,
515
    )
516

517
    # Add swath lines
518
    if add_swath_lines:
1✔
519
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
520
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
521

522
    # Add colorbar
523
    # --> TODO: set axis proportion in a meaningful way ...
524
    if add_colorbar:
1✔
525
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
526
    return p
1✔
527

528

529
def _plot_mpl_imshow(
1✔
530
    ax,
531
    da,
532
    x,
533
    y,
534
    interpolation="nearest",
535
    add_colorbar=True,
536
    plot_kwargs={},
537
    cbar_kwargs=None,
538
):
539
    """Plot imshow with matplotlib."""
540
    # - Ensure image with correct dimensions orders
541
    da = da.transpose(y, x)
×
542
    arr = np.asanyarray(da.data)
×
543

544
    # - Add variable field with matplotlib
545
    p = ax.imshow(
×
546
        arr,
547
        origin="upper",
548
        interpolation=interpolation,
549
        **plot_kwargs,
550
    )
551
    # - Add colorbar
552
    if add_colorbar:
×
553
        # --> TODO: set axis proportion in a meaningful way ...
554
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
555
    # - Return mappable
556
    return p
×
557

558

559
def set_colorbar_fully_transparent(p):
1✔
560
    """Add a fully transparent colorbar.
561

562
    This is useful for animation where the colorbar should
563
    not always in all frames but the plot area must be fixed.
564
    """
565
    # Get the position of the colorbar
566
    cbar_pos = p.colorbar.ax.get_position()
×
567

568
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
569
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
570

571
    # Remove the colorbar
572
    p.colorbar.ax.set_visible(False)
×
573

574
    # Now plot an empty rectangle
575
    fig = plt.gcf()
×
576
    rect = plt.Rectangle(
×
577
        (cbar_x, cbar_y),
578
        cbar_width,
579
        cbar_height,
580
        transform=fig.transFigure,
581
        facecolor="none",
582
        edgecolor="none",
583
    )
584

585
    fig.patches.append(rect)
×
586

587

588
def _plot_xr_imshow(
1✔
589
    ax,
590
    da,
591
    x,
592
    y,
593
    interpolation="nearest",
594
    add_colorbar=True,
595
    plot_kwargs={},
596
    cbar_kwargs=None,
597
    xarray_colorbar=True,  # remove
598
    visible_colorbar=True,
599
):
600
    """Plot imshow with xarray.
601

602
    The colorbar is added with xarray to enable to display multiple colorbars
603
    when calling this function multiple times on different fields with
604
    different colorbars.
605
    """
606
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
607
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
608
    if not add_colorbar:
1✔
609
        cbar_kwargs = None
1✔
610
    p = da.plot.imshow(
1✔
611
        x=x,
612
        y=y,
613
        ax=ax,
614
        interpolation=interpolation,
615
        add_colorbar=add_colorbar,
616
        cbar_kwargs=cbar_kwargs,
617
        **plot_kwargs,
618
    )
619
    plt.title(da.name)
1✔
620
    if add_colorbar and ticklabels is not None:
1✔
621
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
622

623
    # Make the colorbar fully transparent with a smart trick ;)
624
    # - TODO: this still cause issues when plotting 2 colorbars !
625
    if add_colorbar and not visible_colorbar:
1✔
626
        set_colorbar_fully_transparent(p)
×
627

628
    # Add manually the colorbar
629
    # p = da.plot.imshow(
630
    #     x=x,
631
    #     y=y,
632
    #     ax=ax,
633
    #     interpolation=interpolation,
634
    #     add_colorbar=False,
635
    #     **plot_kwargs,
636
    # )
637
    # plt.title(da.name)
638
    # if add_colorbar:
639
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
640
    return p
1✔
641

642

643
def _plot_xr_pcolormesh(
1✔
644
    ax,
645
    da,
646
    x,
647
    y,
648
    add_colorbar=True,
649
    plot_kwargs={},
650
    cbar_kwargs=None,
651
):
652
    """Plot pcolormesh with xarray."""
653
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
654
    if not add_colorbar:
×
NEW
655
        cbar_kwargs = None
×
656
    p = da.plot.pcolormesh(
×
657
        x=x,
658
        y=y,
659
        ax=ax,
660
        add_colorbar=add_colorbar,
661
        cbar_kwargs=cbar_kwargs,
662
        **plot_kwargs,
663
    )
664
    plt.title(da.name)
×
665
    if add_colorbar and ticklabels is not None:
×
666
        p.colorbar.ax.set_yticklabels(ticklabels)
×
667
    return p
×
668

669

670
####--------------------------------------------------------------------------.
671

672

673
def plot_map(
1✔
674
    da,
675
    x="lon",
676
    y="lat",
677
    ax=None,
678
    add_colorbar=True,
679
    add_swath_lines=True,  # used only for GPM orbit objects
680
    add_background=True,
681
    rgb=False,
682
    interpolation="nearest",  # used only for GPM grid objects
683
    fig_kwargs=None,
684
    subplot_kwargs=None,
685
    cbar_kwargs=None,
686
    **plot_kwargs,
687
):
688
    """
689
    Plot data on a geographic map.
690

691
    Parameters
692
    ----------
693
    da : xr.DataArray
694
        xarray DataArray.
695
    x : str, optional
696
        Longitude coordinate name. The default is `"lon"`.
697
    y : str, optional
698
        Latitude coordinate name. The default is `"lat"`.
699
    ax : cartopy.GeoAxes, optional
700
        The cartopy GeoAxes where to plot the map.
701
        If `None`, a figure is initialized using the
702
        specified `fig_kwargs`and `subplot_kwargs`.
703
        The default is `None`.
704
    add_colorbar : bool, optional
705
        Whether to add a colorbar. The default is `True`.
706
    add_swath_lines : bool, optional
707
        Whether to plot the swath sides with a dashed line. The default is `True`.
708
        This argument only applies for ORBIT objects.
709
    add_background : bool, optional
710
        Whether to add the map background. The default is `True`.
711
    rgb : bool, optional
712
        Whether the input DataArray has a rgb dimension. The default is `False`.
713
    interpolation : str, optional
714
        Argument to be passed to imshow. Only applies for GRID objects.
715
        The default is `"nearest"`.
716
    fig_kwargs : dict, optional
717
        Figure options to be passed to plt.subplots.
718
        The default is `None`.
719
        Only used if `ax` is `None`.
720
    subplot_kwargs : dict, optional
721
        Dictionary of keyword arguments for Matplotlib subplots.
722
        Must contain the Cartopy CRS `'projection'` key if specified.
723
        The default is `None`.
724
        Only used if `ax` is `None`.
725
    cbar_kwargs : dict, optional
726
        Colorbar options. The default is `None`.
727
    **plot_kwargs
728
        Additional arguments to be passed to the plotting function.
729
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
730
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
731
    """
732
    from gpm.checks import is_grid, is_orbit
1✔
733
    from gpm.visualization.grid import plot_grid_map
1✔
734
    from gpm.visualization.orbit import plot_orbit_map
1✔
735

736
    # Plot orbit
737
    if is_orbit(da):
1✔
738
        p = plot_orbit_map(
1✔
739
            da=da,
740
            x=x,
741
            y=y,
742
            ax=ax,
743
            add_colorbar=add_colorbar,
744
            add_swath_lines=add_swath_lines,
745
            add_background=add_background,
746
            rgb=rgb,
747
            fig_kwargs=fig_kwargs,
748
            subplot_kwargs=subplot_kwargs,
749
            cbar_kwargs=cbar_kwargs,
750
            **plot_kwargs,
751
        )
752
    # Plot grid
753
    elif is_grid(da):
1✔
754
        p = plot_grid_map(
1✔
755
            da=da,
756
            x=x,
757
            y=y,
758
            ax=ax,
759
            add_colorbar=add_colorbar,
760
            interpolation=interpolation,
761
            add_background=add_background,
762
            fig_kwargs=fig_kwargs,
763
            subplot_kwargs=subplot_kwargs,
764
            cbar_kwargs=cbar_kwargs,
765
            **plot_kwargs,
766
        )
767
    else:
768
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
1✔
769
    # Return mappable
770
    return p
1✔
771

772

773
def plot_image(
1✔
774
    da,
775
    x=None,
776
    y=None,
777
    ax=None,
778
    add_colorbar=True,
779
    interpolation="nearest",
780
    fig_kwargs=None,
781
    cbar_kwargs=None,
782
    **plot_kwargs,
783
):
784
    """
785
    Plot data using imshow.
786

787
    Parameters
788
    ----------
789
    da : xr.DataArray
790
        xarray DataArray.
791
    x : str, optional
792
        X dimension name.
793
        If None, takes the second dimension.
794
        The default is `None`.
795
    y : str, optional
796
        Y dimension name.
797
        If None, takes the first dimension.
798
        The default is `None`.
799
    ax : cartopy.GeoAxes, optional
800
        The matplotlib axes where to plot the image.
801
        If None, a figure is initialized using the
802
        specified `fig_kwargs`.
803
        The default is `None`.
804
    add_colorbar : bool, optional
805
        Whether to add a colorbar. The default is `True`.
806
    interpolation : str, optional
807
        Argument to be passed to imshow.
808
        The default is `"nearest"`.
809
    fig_kwargs : dict, optional
810
        Figure options to be passed to `plt.subplots`.
811
        The default is None.
812
        Only used if `ax` is None.
813
    subplot_kwargs : dict, optional
814
        Subplot options to be passed to `plt.subplots`.
815
        The default is `None`.
816
        Only used if `ax` is `None`.
817
    cbar_kwargs : dict, optional
818
        Colorbar options. The default is `None`.
819
    **plot_kwargs
820
        Additional arguments to be passed to the plotting function.
821
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
822
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
823
    """
824
    # figsize, dpi, subplot_kw only used if ax is None
825
    from gpm.checks import is_grid, is_orbit
1✔
826
    from gpm.visualization.grid import plot_grid_image
1✔
827
    from gpm.visualization.orbit import plot_orbit_image
1✔
828

829
    # Plot orbit
830
    if is_orbit(da):
1✔
831
        p = plot_orbit_image(
1✔
832
            da=da,
833
            x=x,
834
            y=y,
835
            ax=ax,
836
            add_colorbar=add_colorbar,
837
            interpolation=interpolation,
838
            fig_kwargs=fig_kwargs,
839
            cbar_kwargs=cbar_kwargs,
840
            **plot_kwargs,
841
        )
842
    # Plot grid
843
    elif is_grid(da):
1✔
844
        p = plot_grid_image(
1✔
845
            da=da,
846
            x=x,
847
            y=y,
848
            ax=ax,
849
            add_colorbar=add_colorbar,
850
            interpolation=interpolation,
851
            fig_kwargs=fig_kwargs,
852
            cbar_kwargs=cbar_kwargs,
853
            **plot_kwargs,
854
        )
855
    else:
856
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
1✔
857
    # Return mappable
858
    return p
1✔
859

860

861
####--------------------------------------------------------------------------.
862

863

864
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
865
    """
866
    Create a 2D xarray DataArray with mesh coordinates based on the 1D coordinate arrays
867
    from an existing xarray object (Dataset or DataArray).
868

869
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
870
    the data values to NaN.
871

872
    Parameters
873
    ----------
874
    xr_obj : xarray.DataArray or xarray.Dataset
875
        The input xarray object containing the 1D coordinate arrays.
876
    x : str
877
        The name of the x-coordinate in `xr_obj`.
878
    y : str
879
        The name of the y-coordinate in `xr_obj`.
880

881
    Returns
882
    -------
883
    da_mesh : xarray.DataArray
884
        A 2D xarray DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
885

886
    Notes
887
    -----
888
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
889
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
890
    """
891
    # Extract 1D coordinate arrays
892
    x_coords = xr_obj[x].values
1✔
893
    y_coords = xr_obj[y].values
1✔
894

895
    # Create 2D meshgrid for x and y coordinates
896
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
897

898
    # Create a 2D array of NaN values with the same shape as the meshgrid
899
    dummy_values = np.full(X.shape, np.nan)
1✔
900

901
    # Create a new DataArray with 2D coordinates and NaN values
902
    da_mesh = xr.DataArray(
1✔
903
        dummy_values, coords={x: (("y", "x"), X), y: (("y", "x"), Y)}, dims=("y", "x")
904
    )
905
    return da_mesh
1✔
906

907

908
def plot_map_mesh(
1✔
909
    xr_obj,
910
    x="lon",
911
    y="lat",
912
    ax=None,
913
    edgecolors="k",
914
    linewidth=0.1,
915
    add_background=True,
916
    fig_kwargs=None,
917
    subplot_kwargs=None,
918
    **plot_kwargs,
919
):
920
    from gpm.checks import is_orbit  # is_grid
1✔
921

922
    from .grid import plot_grid_mesh
1✔
923
    from .orbit import plot_orbit_mesh
1✔
924

925
    # Plot orbit
926
    if is_orbit(xr_obj):
1✔
927
        p = plot_orbit_mesh(
1✔
928
            da=xr_obj[y],
929
            ax=ax,
930
            x=x,
931
            y=y,
932
            edgecolors=edgecolors,
933
            linewidth=linewidth,
934
            add_background=add_background,
935
            fig_kwargs=fig_kwargs,
936
            subplot_kwargs=subplot_kwargs,
937
            **plot_kwargs,
938
        )
939
    else:  # Plot grid
940
        p = plot_grid_mesh(
1✔
941
            xr_obj=xr_obj,
942
            x=x,
943
            y=y,
944
            ax=ax,
945
            edgecolors=edgecolors,
946
            linewidth=linewidth,
947
            add_background=add_background,
948
            fig_kwargs=fig_kwargs,
949
            subplot_kwargs=subplot_kwargs,
950
            **plot_kwargs,
951
        )
952
    # Return mappable
953
    return p
1✔
954

955

956
def plot_map_mesh_centroids(
1✔
957
    xr_obj,
958
    x="lon",
959
    y="lat",
960
    ax=None,
961
    c="r",
962
    s=1,
963
    add_background=True,
964
    fig_kwargs=None,
965
    subplot_kwargs=None,
966
    **plot_kwargs,
967
):
968
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
969
    from gpm.checks import is_grid
1✔
970

971
    # - Initialize figure if necessary
972
    ax = initialize_cartopy_plot(
1✔
973
        ax=ax,
974
        fig_kwargs=fig_kwargs,
975
        subplot_kwargs=subplot_kwargs,
976
        add_background=add_background,
977
    )
978

979
    # - Retrieve centroids
980
    if is_grid(xr_obj):
1✔
981
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
982
    lon = xr_obj[x].values
1✔
983
    lat = xr_obj[y].values
1✔
984

985
    # - Plot centroids
986
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
987

988
    # - Return mappable
989
    return p
1✔
990

991

992
####--------------------------------------------------------------------------.
993

994

995
def _plot_labels(
1✔
996
    xr_obj,
997
    label_name=None,
998
    max_n_labels=50,
999
    add_colorbar=True,
1000
    interpolation="nearest",
1001
    cmap="Paired",
1002
    fig_kwargs=None,
1003
    **plot_kwargs,
1004
):
1005
    """Plot labels.
1006

1007
    The maximum allowed number of labels to plot is 'max_n_labels'.
1008
    """
1009
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1010
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1011

1012
    from gpm.visualization.plot import plot_image
1✔
1013

1014
    if isinstance(xr_obj, xr.Dataset):
1✔
1015
        dataarray = xr_obj[label_name]
1✔
1016
    else:
1017
        if label_name is not None:
1✔
1018
            dataarray = xr_obj[label_name]
1✔
1019
        else:
1020
            dataarray = xr_obj
1✔
1021

1022
    dataarray = dataarray.compute()
1✔
1023
    label_indices = get_label_indices(dataarray)
1✔
1024
    n_labels = len(label_indices)
1✔
1025
    if add_colorbar and n_labels > max_n_labels:
1✔
1026
        msg = f"""The array currently contains {n_labels} labels
1✔
1027
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1028
        print(msg)
1✔
1029
        add_colorbar = False
1✔
1030
    # Relabel array from 1 to ... for plotting
1031
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1032
    # Replace 0 with nan
1033
    dataarray = dataarray.where(dataarray > 0)
1✔
1034
    # Define appropriate colormap
1035
    plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap="Paired")
1✔
1036
    # Plot image
1037
    p = plot_image(
1✔
1038
        dataarray,
1039
        interpolation=interpolation,
1040
        add_colorbar=add_colorbar,
1041
        cbar_kwargs=cbar_kwargs,
1042
        fig_kwargs=fig_kwargs,
1043
        **plot_kwargs,
1044
    )
1045
    return p
1✔
1046

1047

1048
def plot_labels(
1✔
1049
    obj,  # Dataset, DataArray or generator
1050
    label_name=None,
1051
    max_n_labels=50,
1052
    add_colorbar=True,
1053
    interpolation="nearest",
1054
    cmap="Paired",
1055
    fig_kwargs=None,
1056
    **plot_kwargs,
1057
):
1058
    if is_generator(obj):
1✔
1059
        for label_id, xr_obj in obj:
1✔
1060
            p = _plot_labels(
1✔
1061
                xr_obj=xr_obj,
1062
                label_name=label_name,
1063
                max_n_labels=max_n_labels,
1064
                add_colorbar=add_colorbar,
1065
                interpolation=interpolation,
1066
                cmap=cmap,
1067
                fig_kwargs=fig_kwargs,
1068
                **plot_kwargs,
1069
            )
1070
            plt.show()
1✔
1071
    else:
1072
        p = _plot_labels(
1✔
1073
            xr_obj=obj,
1074
            label_name=label_name,
1075
            max_n_labels=max_n_labels,
1076
            add_colorbar=add_colorbar,
1077
            interpolation=interpolation,
1078
            cmap=cmap,
1079
            fig_kwargs=fig_kwargs,
1080
            **plot_kwargs,
1081
        )
1082
    return p
1✔
1083

1084

1085
def plot_patches(
1✔
1086
    patch_gen,
1087
    variable=None,
1088
    add_colorbar=True,
1089
    interpolation="nearest",
1090
    fig_kwargs=None,
1091
    cbar_kwargs=None,
1092
    **plot_kwargs,
1093
):
1094
    """Plot patches."""
1095
    from gpm.visualization.plot import plot_image
1✔
1096

1097
    # Plot patches
1098
    for label_id, xr_patch in patch_gen:
1✔
1099
        if isinstance(xr_patch, xr.Dataset):
1✔
1100
            if variable is None:
1✔
1101
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
1✔
1102
            xr_patch = xr_patch[variable]
1✔
1103
        try:
1✔
1104
            plot_image(
1✔
1105
                xr_patch,
1106
                interpolation=interpolation,
1107
                add_colorbar=add_colorbar,
1108
                fig_kwargs=fig_kwargs,
1109
                cbar_kwargs=cbar_kwargs,
1110
                **plot_kwargs,
1111
            )
1112
            plt.show()
1✔
1113
        except:
1✔
1114
            pass
1✔
1115
    return
1✔
1116

1117

1118
####--------------------------------------------------------------------------.
1119

1120

1121
def get_inset_bounds(
1✔
1122
    ax, loc="upper right", inset_height=0.2, inside_figure=True, aspect_ratio=1, y_spacing=0.06
1123
):
1124
    """
1125
    Calculate the bounds for an inset axes in a matplotlib figure.
1126

1127
    This function computes the normalized figure coordinates for placing an inset axes within a figure,
1128
    based on the specified location, size, and whether the inset should be fully inside the figure bounds.
1129
    It is designed to be used with matplotlib figures to facilitate the addition of insets (e.g., for maps
1130
    or zoomed plots) at predefined positions.
1131

1132
    Parameters
1133
    ----------
1134
    loc : str
1135
        The location of the inset within the figure. Valid options are 'lower left', 'lower right',
1136
        'upper left', and 'upper right'. The default is 'upper right'.
1137
    inset_height : float
1138
        The size of the inset height, specified as a fraction of the figure's height.
1139
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1140
        The aspect ratio will govern the inset_width.
1141
    inside_figure : bool, optional
1142
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1143
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1144
        allowing for a half-outside placement.
1145
    aspect_ratio : float, optional
1146
        The width-to-height ratio of the inset figure.
1147
        A value greater than 1 indicates an inset figure wider than it is tall,
1148
        and a value less than 1 indicates an inset figure taller than it is wide.
1149
        The default value is 1.0, indicating a square inset figure.
1150

1151
    Returns
1152
    -------
1153
    inset_bounds : list of float
1154
        The calculated bounds of the inset, in the format [x0, y0, width, height], where `x0` and `y0`
1155
        are the normalized figure coordinates of the lower left corner of the inset, and `width` and
1156
        `height` are the normalized width and height of the inset, respectively.
1157

1158
    """
1159
    # Get the bounding box of the parent axes in figure coordinates
1160
    bbox = ax.get_position()
×
1161
    parent_width = bbox.width
×
1162
    parent_height = bbox.height
×
1163

1164
    # Compute the inset width percentage (relative to the parent axes)
1165
    # - Take into account possible different aspect ratios
1166
    inset_height_abs = inset_height * parent_height
×
1167
    inset_width_abs = inset_height_abs * aspect_ratio
×
1168
    inset_width = inset_width_abs / parent_width
×
1169
    loc_mapping = {
×
1170
        "upper right": (1 - inset_width, 1 - inset_height),
1171
        "upper left": (0, 1 - inset_height),
1172
        "lower right": (1 - inset_width, 0),
1173
        "lower left": (0, 0),
1174
    }
1175
    inset_x, inset_y = loc_mapping[loc]
×
1176

1177
    # Adjust for insets that are allowed to be half outside of the figure
1178
    if not inside_figure:
×
1179
        inset_x += inset_width / 2 * (-1 if loc.endswith("left") else 1)
×
1180
        inset_y += inset_height / 2 * (-1 if loc.startswith("lower") else 1)
×
1181

1182
    inset_bounds = [inset_x, inset_y, inset_width, inset_height]
×
1183
    return inset_bounds
×
1184

1185

1186
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True):
1✔
1187
    """
1188
    Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1189

1190
    This function creates a smaller map inset within a larger map plot to show a global view or
1191
    contextual location of the main plot's extent.
1192

1193
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1194
    within the inset to provide geographical context.
1195

1196
    Parameters
1197
    ----------
1198
    ax : (matplotlib.axes.Axes, cartopy.mpl.geoaxes.GeoAxes)
1199
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1200
    loc : str, optional
1201
        The location of the inset map within the main plot.
1202
        Options include 'lower left', 'lower right', 'upper left', 'upper right'.
1203
        The default is 'upper left'.
1204
    inset_height : float
1205
        The size of the inset height, specified as a fraction of the figure's height.
1206
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1207
        The aspect ratio (of the map inset) will govern the inset_width.
1208
    inside_figure : bool, optional
1209
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1210
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1211
        allowing for a half-outside placement.
1212
    projection: cartopy.crs.Projection
1213
        A cartopy projection. If None, am Orthographic projection centered on the extent center is used.
1214

1215
    Returns
1216
    -------
1217
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1218
        The Cartopy GeoAxesSubplot object for the inset map.
1219

1220
    Notes
1221
    -----
1222
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1223
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1224
    geographical extent.
1225

1226
    Examples
1227
    --------
1228
    >>> p = da.gpm.plot_map()
1229
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1230

1231
    This example creates a main plot with a specified extent and adds an upper-left inset map
1232
    showing the global context of the main plot's extent.
1233
    """
1234
    import cartopy.crs as ccrs
×
1235
    import cartopy.feature as cfeature
×
1236
    from shapely import Polygon
×
1237

1238
    from gpm.utils.geospatial import extend_geographic_extent
×
1239

1240
    # Retrieve extent and bounds
1241
    extent = ax.get_extent()
×
1242
    extent = extend_geographic_extent(extent, padding=0.5)
×
1243
    bounds = [extent[i] for i in [0, 2, 1, 3]]
×
1244
    # Create Cartopy Polygon
1245
    polygon = Polygon.from_bounds(*bounds)
×
1246
    # Define Orthographic projection
1247
    if projection is None:
×
1248
        lon_min, lon_max, lat_min, lat_max = extent
×
1249
        projection = ccrs.Orthographic(
×
1250
            central_latitude=(lat_min + lat_max) / 2, central_longitude=(lon_min + lon_max) / 2
1251
        )
1252

1253
    # Define aspect ratio of the map inset
1254
    aspect_ratio = float(np.diff(projection.x_limits) / np.diff(projection.y_limits).item())
×
1255

1256
    # Define inset location relative to main plot (ax) in normalized units
1257
    # - Lower-left corner of inset Axes, and its width and height
1258
    # - [x0, y0, width, height]
1259
    inset_bounds = get_inset_bounds(
×
1260
        ax=ax,
1261
        loc=loc,
1262
        inset_height=inset_height,
1263
        inside_figure=inside_figure,
1264
        aspect_ratio=aspect_ratio,
1265
    )
1266

1267
    # ax2 = plt.axes(inset_bounds, projection=projection)
1268
    ax2 = ax.inset_axes(
×
1269
        inset_bounds,
1270
        projection=projection,
1271
    )
1272

1273
    # Add global map
1274
    ax2.set_global()
×
1275
    ax2.add_feature(cfeature.LAND)
×
1276
    ax2.add_feature(cfeature.OCEAN)
×
1277
    # Add extent polygon
1278
    _ = ax2.add_geometries(
×
1279
        [polygon], ccrs.PlateCarree(), facecolor="none", edgecolor="red", linewidth=0.3
1280
    )
1281
    return ax2
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc