• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 8501183871

31 Mar 2024 10:04PM UTC coverage: 87.857% (+0.2%) from 87.669%
8501183871

Pull #53

github

ghiggi
Add flake8-commas rules
Pull Request #53: Refactor code style

693 of 786 new or added lines in 86 files covered. (88.17%)

4 existing lines in 4 files now uncovered.

9001 of 10245 relevant lines covered (87.86%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.45
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
1✔
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import cartopy.feature as cfeature
1✔
33
import matplotlib.pyplot as plt
1✔
34
import numpy as np
1✔
35
import xarray as xr
1✔
36
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
37
from scipy.interpolate import griddata
1✔
38

39
import gpm
1✔
40

41

42
def is_generator(obj):
1✔
43
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
44

45

46
def _call_optimize_layout(self):
1✔
47
    """Optimize the figure layout."""
48
    adapt_fig_size(ax=self.axes)
×
49
    self.figure.tight_layout()
×
50

51

52
def add_optimize_layout_method(p):
1✔
53
    """Add a method to optimize the figure layout using monkey patching."""
54
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
55
    return p
1✔
56

57

58
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
59
    """
60
    Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
61

62
    This function is intended to be called after all plotting has been completed.
63
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
64

65
    Assumes that the first axis in the collection of axes is representative of all others.
66
    This means that all subplots are expected to have the same aspect ratio and size.
67

68
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
69
    """
70
    # Determine the number of rows and columns of subplots in the figure.
71
    # This information is crucial for calculating the new height of the figure.
72
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
73

74
    # Access the figure object from the axis to manipulate its properties.
75
    fig = ax.get_figure()
1✔
76

77
    # Retrieve the current size of the figure in inches.
78
    width, original_height = fig.get_size_inches()
1✔
79

80
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
81
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
82
    fig.canvas.draw()
1✔
83

84
    # Extract subplot parameters to understand the figure's layout.
85
    # These parameters include the margins of the figure and the spaces between subplots.
86
    bottom = fig.subplotpars.bottom
1✔
87
    top = fig.subplotpars.top
1✔
88
    left = fig.subplotpars.left
1✔
89
    right = fig.subplotpars.right
1✔
90
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
91
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
92

93
    # Calculate the aspect ratio of the data in the subplot.
94
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
95
    aspect = ax.get_data_ratio()
1✔
96

97
    # Calculate the width of a single plot, considering the left and right margins,
98
    # the number of columns, and the space between columns.
99
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
100

101
    # Calculate the height of a single plot using its width and the data aspect ratio.
102
    hp = wp * aspect
1✔
103

104
    # Calculate the new height of the figure, taking into account the number of rows,
105
    # the space between rows, and the top and bottom margins.
106
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
107

108
    # Check if the new height is significantly reduced (more than halved).
109
    if original_height / height > 2:
1✔
110
        # Calculate the scale factor to adjust the figure size closer to the original.
111
        scale_factor = original_height / height / 2
×
112

113
        # Apply the scale factor to both width and height to maintain the aspect ratio.
114
        width *= scale_factor
×
115
        height *= scale_factor
×
116

117
    # Apply the calculated width and height to adjust the figure size.
118
    fig.set_figwidth(width)
1✔
119
    fig.set_figheight(height)
1✔
120

121

122
####--------------------------------------------------------------------------.
123

124

125
def get_antimeridian_mask(lons):
1✔
126
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
127
    from scipy.ndimage import binary_dilation
1✔
128

129
    # Initialize mask
130
    n_y, n_x = lons.shape
1✔
131
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
132
    # Check vertical edges
133
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
134
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
135
    mask[row_idx, col_idx] = 1
1✔
136
    # Check horizontal edges
137
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
138
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
139
    mask[row_idx, col_idx] = 1
1✔
140
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
141
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
142
    return binary_dilation(mask)
1✔
143

144

145
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
146
    """Infill invalid coordinates.
147

148
    Interpolate the coordinates within the convex hull of data.
149
    Use nearest neighbour outside the convex hull of data.
150
    """
151
    # Copy object
152
    xr_obj = xr_obj.copy()
1✔
153
    lon = np.asanyarray(xr_obj[x].data)
1✔
154
    lat = np.asanyarray(xr_obj[y].data)
1✔
155
    # Retrieve infilled coordinates
156
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
157
    xr_obj[x].data = lon
1✔
158
    xr_obj[y].data = lat
1✔
159
    return xr_obj
1✔
160

161

162
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
163
    """Infill invalid coordinates.
164

165
    Interpolate the coordinates within the convex hull of data.
166
    Use nearest neighbour outside the convex hull of data.
167

168
    This operation is required to plot with pcolormesh since it
169
    does not accept non-finite values in the coordinates.
170

171
    If  mask_data=True, data values with invalid coordinates are masked
172
    and a numpy masked array is returned.
173
    Masked data values are not displayed in pcolormesh !
174
    If rgb=True, it assumes the RGB dimension is the last data dimension.
175

176
    """
177
    # Retrieve mask of invalid coordinates
178
    x_invalid = ~np.isfinite(x)
1✔
179
    y_invalid = ~np.isfinite(y)
1✔
180
    mask = np.logical_or(x_invalid, y_invalid)
1✔
181

182
    # If no invalid coordinates, return original data
183
    if np.all(~mask):
1✔
184
        return x, y, data
1✔
185

186
    # Mask the data
187
    if mask_data:
1✔
188
        if rgb:
1✔
189
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
190
            data_masked = np.ma.masked_where(data_mask, data)
×
191
        else:
192
            data_masked = np.ma.masked_where(mask, data)
1✔
193
    else:
194
        data_masked = data
×
195

196
    # Infill x and y
197
    if np.any(x_invalid):
1✔
198
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
199
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
200
    if np.any(y_invalid):
1✔
201
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
202
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
203
    return x, y, data_masked
1✔
204

205

206
def _interpolate_data(arr, method="linear"):
1✔
207
    # Find invalid locations
208
    is_invalid = ~np.isfinite(arr)
1✔
209

210
    # Find the indices of NaN values
211
    nan_indices = np.where(is_invalid)
1✔
212

213
    # Return array if not NaN values
214
    if len(nan_indices) == 0:
1✔
215
        return arr
×
216

217
    # Find the indices of non-NaN values
218
    non_nan_indices = np.where(~is_invalid)
1✔
219

220
    # Create a meshgrid of indices
221
    X, Y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
222

223
    # Points (X, Y) where we have valid data
224
    points = np.array([Y[non_nan_indices], X[non_nan_indices]]).T
1✔
225

226
    # Points where data is NaN
227
    points_nan = np.array([Y[nan_indices], X[nan_indices]]).T
1✔
228

229
    # Values at the non-NaN points
230
    values = arr[non_nan_indices]
1✔
231

232
    # Interpolate using griddata
233
    arr_new = arr.copy()
1✔
234
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
235
    return arr_new
1✔
236

237

238
####--------------------------------------------------------------------------.
239

240

241
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None):
1✔
242
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
243
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
244
    if ax is not None:
1✔
245
        if len(subplot_kwargs) >= 1:
1✔
246
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
1✔
247
        if len(fig_kwargs) >= 1:
1✔
248
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
1✔
249
    return fig_kwargs
1✔
250

251

252
def preprocess_subplot_kwargs(subplot_kwargs):
1✔
253
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
254
    subplot_kwargs = subplot_kwargs.copy()
1✔
255
    if "projection" not in subplot_kwargs:
1✔
256
        subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
257
    return subplot_kwargs
1✔
258

259

260
def initialize_cartopy_plot(
1✔
261
    ax,
262
    fig_kwargs,
263
    subplot_kwargs,
264
    add_background,
265
):
266
    """Initialize figure for cartopy plot if necessary."""
267
    # - Initialize figure
268
    if ax is None:
1✔
269
        fig_kwargs = preprocess_figure_args(
1✔
270
            ax=ax,
271
            fig_kwargs=fig_kwargs,
272
            subplot_kwargs=subplot_kwargs,
273
        )
274
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
1✔
275
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
276

277
    # - Add cartopy background
278
    if add_background:
1✔
279
        ax = plot_cartopy_background(ax)
1✔
280
    return ax
1✔
281

282

283
####--------------------------------------------------------------------------.
284

285

286
def plot_cartopy_background(ax):
1✔
287
    """Plot cartopy background."""
288
    # - Add coastlines
289
    ax.coastlines()
1✔
290
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
291
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
292
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
293
    # - Add grid lines
294
    gl = ax.gridlines(
1✔
295
        crs=ccrs.PlateCarree(),
296
        draw_labels=True,
297
        linewidth=1,
298
        color="gray",
299
        alpha=0.1,
300
        linestyle="-",
301
    )
302
    gl.top_labels = False  # gl.xlabels_top = False
1✔
303
    gl.right_labels = False  # gl.ylabels_right = False
1✔
304
    gl.xlines = True
1✔
305
    gl.ylines = True
1✔
306
    return ax
1✔
307

308

309
def plot_sides(sides, ax, **plot_kwargs):
1✔
310
    """Plot boundary sides.
311

312
    Expects a list of (lon, lat) tuples.
313
    """
314
    for side in sides:
1✔
315
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
316
    return p[0]
1✔
317

318

319
def plot_colorbar(p, ax, cbar_kwargs=None):
1✔
320
    """Add a colorbar to a matplotlib/cartopy plot.
321

322
    cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
323
    and the padding between the plot and the colorbar.
324

325
    p: matplotlib.image.AxesImage
326
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot^
327
    """
328
    cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
1✔
329
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
1✔
330
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
331
    orientation = cbar_kwargs.get("orientation", "vertical")
1✔
332

333
    divider = make_axes_locatable(ax)
1✔
334

335
    if orientation == "vertical":
1✔
336
        size = cbar_kwargs.get("size", "5%")
1✔
337
        pad = cbar_kwargs.get("pad", 0.1)
1✔
338
        cax = divider.append_axes("right", size=size, pad=pad, axes_class=plt.Axes)
1✔
339
    elif orientation == "horizontal":
1✔
340
        size = cbar_kwargs.get("size", "5%")
1✔
341
        pad = cbar_kwargs.get("pad", 0.25)
1✔
342
        cax = divider.append_axes("bottom", size=size, pad=pad, axes_class=plt.Axes)
1✔
343
    else:
344
        raise ValueError("Invalid orientation. Choose 'vertical' or 'horizontal'.")
×
345

346
    p.figure.add_axes(cax)
1✔
347
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
1✔
348
    if ticklabels is not None:
1✔
349
        _ = (
1✔
350
            cbar.ax.set_yticklabels(ticklabels)
351
            if orientation == "vertical"
352
            else cbar.ax.set_xticklabels(ticklabels)
353
        )
354
    return cbar
1✔
355

356

357
####--------------------------------------------------------------------------.
358

359

360
def get_dataarray_extent(da, x="lon", y="lat"):
1✔
361
    # TODO: compute corners array to estimate the extent
362
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
363
    # Get the minimum and maximum longitude and latitude values
364
    lon_min, lon_max = da[x].min(), da[x].max()
1✔
365
    lat_min, lat_max = da[y].min(), da[y].max()
1✔
366
    return (lon_min, lon_max, lat_min, lat_max)
1✔
367

368

369
def _compute_extent(x_coords, y_coords):
1✔
370
    """
371
    Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
372
    This function assumes that the spacing between each pixel is uniform.
373
    """
374
    # Calculate the pixel size assuming uniform spacing between pixels
375
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
1✔
376
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
1✔
377

378
    # Adjust min and max to get the corners of the outer pixels
379
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
1✔
380
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
1✔
381

382
    return [x_min, x_max, y_min, y_max]
1✔
383

384

385
def _plot_cartopy_imshow(
1✔
386
    ax,
387
    da,
388
    x,
389
    y,
390
    interpolation="nearest",
391
    add_colorbar=True,
392
    plot_kwargs={},
393
    cbar_kwargs=None,
394
):
395
    """Plot imshow with cartopy."""
396
    # - Ensure image with correct dimensions orders
397
    da = da.transpose(y, x)
1✔
398
    arr = np.asanyarray(da.data)
1✔
399

400
    # - Compute coordinates
401
    x_coords = da[x].to_numpy()
1✔
402
    y_coords = da[y].to_numpy()
1✔
403

404
    # - Derive extent
405
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
406

407
    # - Determine origin based on the orientation of da[y] values
408
    # -->  If increasing, set origin="lower"
409
    # -->  If decreasing, set origin="upper"
410
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
1✔
411

412
    # - Add variable field with cartopy
413
    p = ax.imshow(
1✔
414
        arr,
415
        transform=ccrs.PlateCarree(),
416
        extent=extent,
417
        origin=origin,
418
        interpolation=interpolation,
419
        **plot_kwargs,
420
    )
421
    # - Set the extent
422
    extent = get_dataarray_extent(da, x="lon", y="lat")
1✔
423
    ax.set_extent(extent)
1✔
424

425
    # - Add colorbar
426
    if add_colorbar:
1✔
427
        # --> TODO: set axis proportion in a meaningful way ...
428
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
429
    return p
1✔
430

431

432
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
433
    if np.ma.is_masked(arr):
1✔
434
        if rgb:
×
435
            data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
436
            combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
437
        else:
438
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
439
        arr = np.ma.masked_where(combined_mask, arr)
×
440
    else:
441
        if rgb:
1✔
442
            antimeridian_mask = np.broadcast_to(
1✔
443
                np.expand_dims(antimeridian_mask, axis=-1),
444
                arr.shape,
445
            )
446
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
447
    return arr
1✔
448

449

450
def _plot_cartopy_pcolormesh(
1✔
451
    ax,
452
    da,
453
    x,
454
    y,
455
    rgb=False,
456
    add_colorbar=True,
457
    add_swath_lines=True,
458
    plot_kwargs={},
459
    cbar_kwargs=None,
460
):
461
    """Plot imshow with cartopy.
462

463
    The function currently does not allow to zoom on regions across the antimeridian.
464
    The function mask scanning pixels which spans across the antimeridian.
465
    If rgb=True, expect rgb dimension to be at last position.
466
    x and y must represents longitude and latitudes.
467
    """
468
    # Get x, y, and array to plot
469
    da = da.compute()
1✔
470
    lon = da[x].data
1✔
471
    lat = da[y].data
1✔
472
    arr = da.data
1✔
473

474
    # If RGB, expect last dimension to have 3 channels
475
    if rgb and arr.shape[-1] != 3 and arr.shape[-1] != 4:
1✔
476
        raise ValueError("RGB array must have 3 or 4 channels in the last dimension.")
1✔
477

478
    # Infill invalid value and add mask if necessary
479
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb)
1✔
480

481
    # Ensure arguments
482
    if rgb:
1✔
483
        add_colorbar = False
1✔
484

485
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
486
    # - This enable correct masking of cells crossing the antimeridian
487
    from gpm.utils.area import _get_lonlat_corners
1✔
488

489
    lon, lat = _get_lonlat_corners(lon, lat)
1✔
490

491
    # Mask cells crossing the antimeridian
492
    # --> Here we assume not invalid coordinates anymore
493
    # --> Cartopy still bugs with several projections when data cross the antimeridian
494
    # --> This flag can be unset with gpm.config.set({"viz_hide_antimeridian_data": False})
495
    if gpm.config.get("viz_hide_antimeridian_data"):
1✔
496
        antimeridian_mask = get_antimeridian_mask(lon)
1✔
497
        is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
498
        if is_crossing_antimeridian:
1✔
499
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
500

501
            # Sanitize cmap bad color to avoid cartopy bug
502
            # - TODO cartopy requires bad_color to be transparent ...
503
            cmap = plot_kwargs.get("cmap", None)
1✔
504
            if cmap is not None:
1✔
505
                bad = cmap.get_bad()
×
506
                bad[3] = 0  # enforce to 0 (transparent)
×
507
                cmap.set_bad(bad)
×
508
                plot_kwargs["cmap"] = cmap
×
509

510
    # Add variable field with cartopy
511
    p = ax.pcolormesh(
1✔
512
        lon,
513
        lat,
514
        arr,
515
        transform=ccrs.PlateCarree(),
516
        **plot_kwargs,
517
    )
518

519
    # Add swath lines
520
    if add_swath_lines:
1✔
521
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
522
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
523

524
    # Add colorbar
525
    # --> TODO: set axis proportion in a meaningful way ...
526
    if add_colorbar:
1✔
527
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
528
    return p
1✔
529

530

531
def _plot_mpl_imshow(
1✔
532
    ax,
533
    da,
534
    x,
535
    y,
536
    interpolation="nearest",
537
    add_colorbar=True,
538
    plot_kwargs={},
539
    cbar_kwargs=None,
540
):
541
    """Plot imshow with matplotlib."""
542
    # - Ensure image with correct dimensions orders
543
    da = da.transpose(y, x)
×
544
    arr = np.asanyarray(da.data)
×
545

546
    # - Add variable field with matplotlib
547
    p = ax.imshow(
×
548
        arr,
549
        origin="upper",
550
        interpolation=interpolation,
551
        **plot_kwargs,
552
    )
553
    # - Add colorbar
554
    if add_colorbar:
×
555
        # --> TODO: set axis proportion in a meaningful way ...
556
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
557
    # - Return mappable
558
    return p
×
559

560

561
def set_colorbar_fully_transparent(p):
1✔
562
    """Add a fully transparent colorbar.
563

564
    This is useful for animation where the colorbar should
565
    not always in all frames but the plot area must be fixed.
566
    """
567
    # Get the position of the colorbar
568
    cbar_pos = p.colorbar.ax.get_position()
×
569

570
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
571
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
572

573
    # Remove the colorbar
574
    p.colorbar.ax.set_visible(False)
×
575

576
    # Now plot an empty rectangle
577
    fig = plt.gcf()
×
578
    rect = plt.Rectangle(
×
579
        (cbar_x, cbar_y),
580
        cbar_width,
581
        cbar_height,
582
        transform=fig.transFigure,
583
        facecolor="none",
584
        edgecolor="none",
585
    )
586

587
    fig.patches.append(rect)
×
588

589

590
def _plot_xr_imshow(
1✔
591
    ax,
592
    da,
593
    x,
594
    y,
595
    interpolation="nearest",
596
    add_colorbar=True,
597
    plot_kwargs={},
598
    cbar_kwargs=None,
599
    visible_colorbar=True,
600
):
601
    """Plot imshow with xarray.
602

603
    The colorbar is added with xarray to enable to display multiple colorbars
604
    when calling this function multiple times on different fields with
605
    different colorbars.
606
    """
607
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
608
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
609
    if not add_colorbar:
1✔
610
        cbar_kwargs = None
1✔
611
    p = da.plot.imshow(
1✔
612
        x=x,
613
        y=y,
614
        ax=ax,
615
        interpolation=interpolation,
616
        add_colorbar=add_colorbar,
617
        cbar_kwargs=cbar_kwargs,
618
        **plot_kwargs,
619
    )
620
    plt.title(da.name)
1✔
621
    if add_colorbar and ticklabels is not None:
1✔
622
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
623

624
    # Make the colorbar fully transparent with a smart trick ;)
625
    # - TODO: this still cause issues when plotting 2 colorbars !
626
    if add_colorbar and not visible_colorbar:
1✔
627
        set_colorbar_fully_transparent(p)
×
628

629
    # Add manually the colorbar
630
    # p = da.plot.imshow(
631
    #     x=x,
632
    #     y=y,
633
    #     ax=ax,
634
    #     interpolation=interpolation,
635
    #     add_colorbar=False,
636
    #     **plot_kwargs,
637
    # )
638
    # plt.title(da.name)
639
    # if add_colorbar:
640
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
641
    return p
1✔
642

643

644
def _plot_xr_pcolormesh(
1✔
645
    ax,
646
    da,
647
    x,
648
    y,
649
    add_colorbar=True,
650
    plot_kwargs={},
651
    cbar_kwargs=None,
652
):
653
    """Plot pcolormesh with xarray."""
654
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
655
    if not add_colorbar:
×
656
        cbar_kwargs = None
×
657
    p = da.plot.pcolormesh(
×
658
        x=x,
659
        y=y,
660
        ax=ax,
661
        add_colorbar=add_colorbar,
662
        cbar_kwargs=cbar_kwargs,
663
        **plot_kwargs,
664
    )
665
    plt.title(da.name)
×
666
    if add_colorbar and ticklabels is not None:
×
667
        p.colorbar.ax.set_yticklabels(ticklabels)
×
668
    return p
×
669

670

671
####--------------------------------------------------------------------------.
672

673

674
def plot_map(
1✔
675
    da,
676
    x="lon",
677
    y="lat",
678
    ax=None,
679
    add_colorbar=True,
680
    add_swath_lines=True,  # used only for GPM orbit objects
681
    add_background=True,
682
    rgb=False,
683
    interpolation="nearest",  # used only for GPM grid objects
684
    fig_kwargs=None,
685
    subplot_kwargs=None,
686
    cbar_kwargs=None,
687
    **plot_kwargs,
688
):
689
    """
690
    Plot data on a geographic map.
691

692
    Parameters
693
    ----------
694
    da : xr.DataArray
695
        xarray DataArray.
696
    x : str, optional
697
        Longitude coordinate name. The default is `"lon"`.
698
    y : str, optional
699
        Latitude coordinate name. The default is `"lat"`.
700
    ax : cartopy.GeoAxes, optional
701
        The cartopy GeoAxes where to plot the map.
702
        If `None`, a figure is initialized using the
703
        specified `fig_kwargs`and `subplot_kwargs`.
704
        The default is `None`.
705
    add_colorbar : bool, optional
706
        Whether to add a colorbar. The default is `True`.
707
    add_swath_lines : bool, optional
708
        Whether to plot the swath sides with a dashed line. The default is `True`.
709
        This argument only applies for ORBIT objects.
710
    add_background : bool, optional
711
        Whether to add the map background. The default is `True`.
712
    rgb : bool, optional
713
        Whether the input DataArray has a rgb dimension. The default is `False`.
714
    interpolation : str, optional
715
        Argument to be passed to imshow. Only applies for GRID objects.
716
        The default is `"nearest"`.
717
    fig_kwargs : dict, optional
718
        Figure options to be passed to plt.subplots.
719
        The default is `None`.
720
        Only used if `ax` is `None`.
721
    subplot_kwargs : dict, optional
722
        Dictionary of keyword arguments for Matplotlib subplots.
723
        Must contain the Cartopy CRS `'projection'` key if specified.
724
        The default is `None`.
725
        Only used if `ax` is `None`.
726
    cbar_kwargs : dict, optional
727
        Colorbar options. The default is `None`.
728
    **plot_kwargs
729
        Additional arguments to be passed to the plotting function.
730
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
731
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
732
    """
733
    from gpm.checks import is_grid, is_orbit
1✔
734
    from gpm.visualization.grid import plot_grid_map
1✔
735
    from gpm.visualization.orbit import plot_orbit_map
1✔
736

737
    # Plot orbit
738
    if is_orbit(da):
1✔
739
        p = plot_orbit_map(
1✔
740
            da=da,
741
            x=x,
742
            y=y,
743
            ax=ax,
744
            add_colorbar=add_colorbar,
745
            add_swath_lines=add_swath_lines,
746
            add_background=add_background,
747
            rgb=rgb,
748
            fig_kwargs=fig_kwargs,
749
            subplot_kwargs=subplot_kwargs,
750
            cbar_kwargs=cbar_kwargs,
751
            **plot_kwargs,
752
        )
753
    # Plot grid
754
    elif is_grid(da):
1✔
755
        p = plot_grid_map(
1✔
756
            da=da,
757
            x=x,
758
            y=y,
759
            ax=ax,
760
            add_colorbar=add_colorbar,
761
            interpolation=interpolation,
762
            add_background=add_background,
763
            fig_kwargs=fig_kwargs,
764
            subplot_kwargs=subplot_kwargs,
765
            cbar_kwargs=cbar_kwargs,
766
            **plot_kwargs,
767
        )
768
    else:
769
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
1✔
770
    # Return mappable
771
    return p
1✔
772

773

774
def plot_image(
1✔
775
    da,
776
    x=None,
777
    y=None,
778
    ax=None,
779
    add_colorbar=True,
780
    interpolation="nearest",
781
    fig_kwargs=None,
782
    cbar_kwargs=None,
783
    **plot_kwargs,
784
):
785
    """
786
    Plot data using imshow.
787

788
    Parameters
789
    ----------
790
    da : xr.DataArray
791
        xarray DataArray.
792
    x : str, optional
793
        X dimension name.
794
        If ``None``, takes the second dimension.
795
        The default is `None`.
796
    y : str, optional
797
        Y dimension name.
798
        If ``None``, takes the first dimension.
799
        The default is `None`.
800
    ax : cartopy.GeoAxes, optional
801
        The matplotlib axes where to plot the image.
802
        If ``None``, a figure is initialized using the
803
        specified `fig_kwargs`.
804
        The default is `None`.
805
    add_colorbar : bool, optional
806
        Whether to add a colorbar. The default is `True`.
807
    interpolation : str, optional
808
        Argument to be passed to imshow.
809
        The default is `"nearest"`.
810
    fig_kwargs : dict, optional
811
        Figure options to be passed to `plt.subplots`.
812
        The default is ``None``.
813
        Only used if `ax` is None.
814
    subplot_kwargs : dict, optional
815
        Subplot options to be passed to `plt.subplots`.
816
        The default is `None`.
817
        Only used if `ax` is `None`.
818
    cbar_kwargs : dict, optional
819
        Colorbar options. The default is `None`.
820
    **plot_kwargs
821
        Additional arguments to be passed to the plotting function.
822
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
823
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
824
    """
825
    # figsize, dpi, subplot_kw only used if ax is None
826
    from gpm.checks import is_grid, is_orbit
1✔
827
    from gpm.visualization.grid import plot_grid_image
1✔
828
    from gpm.visualization.orbit import plot_orbit_image
1✔
829

830
    # Plot orbit
831
    if is_orbit(da):
1✔
832
        p = plot_orbit_image(
1✔
833
            da=da,
834
            x=x,
835
            y=y,
836
            ax=ax,
837
            add_colorbar=add_colorbar,
838
            interpolation=interpolation,
839
            fig_kwargs=fig_kwargs,
840
            cbar_kwargs=cbar_kwargs,
841
            **plot_kwargs,
842
        )
843
    # Plot grid
844
    elif is_grid(da):
1✔
845
        p = plot_grid_image(
1✔
846
            da=da,
847
            x=x,
848
            y=y,
849
            ax=ax,
850
            add_colorbar=add_colorbar,
851
            interpolation=interpolation,
852
            fig_kwargs=fig_kwargs,
853
            cbar_kwargs=cbar_kwargs,
854
            **plot_kwargs,
855
        )
856
    else:
857
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
1✔
858
    # Return mappable
859
    return p
1✔
860

861

862
####--------------------------------------------------------------------------.
863

864

865
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
866
    """
867
    Create a 2D xarray DataArray with mesh coordinates based on the 1D coordinate arrays
868
    from an existing xarray object (Dataset or DataArray).
869

870
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
871
    the data values to NaN.
872

873
    Parameters
874
    ----------
875
    xr_obj : xarray.DataArray or xarray.Dataset
876
        The input xarray object containing the 1D coordinate arrays.
877
    x : str
878
        The name of the x-coordinate in `xr_obj`.
879
    y : str
880
        The name of the y-coordinate in `xr_obj`.
881

882
    Returns
883
    -------
884
    da_mesh : xarray.DataArray
885
        A 2D xarray DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
886

887
    Notes
888
    -----
889
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
890
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
891
    """
892
    # Extract 1D coordinate arrays
893
    x_coords = xr_obj[x].to_numpy()
1✔
894
    y_coords = xr_obj[y].to_numpy()
1✔
895

896
    # Create 2D meshgrid for x and y coordinates
897
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
898

899
    # Create a 2D array of NaN values with the same shape as the meshgrid
900
    dummy_values = np.full(X.shape, np.nan)
1✔
901

902
    # Create a new DataArray with 2D coordinates and NaN values
903
    return xr.DataArray(
1✔
904
        dummy_values,
905
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
906
        dims=("y", "x"),
907
    )
908

909

910
def plot_map_mesh(
1✔
911
    xr_obj,
912
    x="lon",
913
    y="lat",
914
    ax=None,
915
    edgecolors="k",
916
    linewidth=0.1,
917
    add_background=True,
918
    fig_kwargs=None,
919
    subplot_kwargs=None,
920
    **plot_kwargs,
921
):
922
    from gpm.checks import is_orbit  # is_grid
1✔
923

924
    from .grid import plot_grid_mesh
1✔
925
    from .orbit import plot_orbit_mesh
1✔
926

927
    # Plot orbit
928
    if is_orbit(xr_obj):
1✔
929
        p = plot_orbit_mesh(
1✔
930
            da=xr_obj[y],
931
            ax=ax,
932
            x=x,
933
            y=y,
934
            edgecolors=edgecolors,
935
            linewidth=linewidth,
936
            add_background=add_background,
937
            fig_kwargs=fig_kwargs,
938
            subplot_kwargs=subplot_kwargs,
939
            **plot_kwargs,
940
        )
941
    else:  # Plot grid
942
        p = plot_grid_mesh(
1✔
943
            xr_obj=xr_obj,
944
            x=x,
945
            y=y,
946
            ax=ax,
947
            edgecolors=edgecolors,
948
            linewidth=linewidth,
949
            add_background=add_background,
950
            fig_kwargs=fig_kwargs,
951
            subplot_kwargs=subplot_kwargs,
952
            **plot_kwargs,
953
        )
954
    # Return mappable
955
    return p
1✔
956

957

958
def plot_map_mesh_centroids(
1✔
959
    xr_obj,
960
    x="lon",
961
    y="lat",
962
    ax=None,
963
    c="r",
964
    s=1,
965
    add_background=True,
966
    fig_kwargs=None,
967
    subplot_kwargs=None,
968
    **plot_kwargs,
969
):
970
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
971
    from gpm.checks import is_grid
1✔
972

973
    # - Initialize figure if necessary
974
    ax = initialize_cartopy_plot(
1✔
975
        ax=ax,
976
        fig_kwargs=fig_kwargs,
977
        subplot_kwargs=subplot_kwargs,
978
        add_background=add_background,
979
    )
980

981
    # - Retrieve centroids
982
    if is_grid(xr_obj):
1✔
983
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
984
    lon = xr_obj[x].to_numpy()
1✔
985
    lat = xr_obj[y].to_numpy()
1✔
986

987
    # - Plot centroids
988
    return ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
989

990
    # - Return mappable
991

992

993
####--------------------------------------------------------------------------.
994

995

996
def _plot_labels(
1✔
997
    xr_obj,
998
    label_name=None,
999
    max_n_labels=50,
1000
    add_colorbar=True,
1001
    interpolation="nearest",
1002
    cmap="Paired",
1003
    fig_kwargs=None,
1004
    **plot_kwargs,
1005
):
1006
    """Plot labels.
1007

1008
    The maximum allowed number of labels to plot is 'max_n_labels'.
1009
    """
1010
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1011
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1012

1013
    from gpm.visualization.plot import plot_image
1✔
1014

1015
    if isinstance(xr_obj, xr.Dataset):
1✔
1016
        dataarray = xr_obj[label_name]
1✔
1017
    else:
1018
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1019

1020
    dataarray = dataarray.compute()
1✔
1021
    label_indices = get_label_indices(dataarray)
1✔
1022
    n_labels = len(label_indices)
1✔
1023
    if add_colorbar and n_labels > max_n_labels:
1✔
1024
        msg = f"""The array currently contains {n_labels} labels
1✔
1025
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1026
        print(msg)
1✔
1027
        add_colorbar = False
1✔
1028
    # Relabel array from 1 to ... for plotting
1029
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1030
    # Replace 0 with nan
1031
    dataarray = dataarray.where(dataarray > 0)
1✔
1032
    # Define appropriate colormap
1033
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1034
    default_plot_kwargs.update(plot_kwargs)
1✔
1035
    # Plot image
1036
    return plot_image(
1✔
1037
        dataarray,
1038
        interpolation=interpolation,
1039
        add_colorbar=add_colorbar,
1040
        cbar_kwargs=cbar_kwargs,
1041
        fig_kwargs=fig_kwargs,
1042
        **default_plot_kwargs,
1043
    )
1044

1045

1046
def plot_labels(
1✔
1047
    obj,  # Dataset, DataArray or generator
1048
    label_name=None,
1049
    max_n_labels=50,
1050
    add_colorbar=True,
1051
    interpolation="nearest",
1052
    cmap="Paired",
1053
    fig_kwargs=None,
1054
    **plot_kwargs,
1055
):
1056
    if is_generator(obj):
1✔
1057
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1058
            p = _plot_labels(
1✔
1059
                xr_obj=xr_obj,
1060
                label_name=label_name,
1061
                max_n_labels=max_n_labels,
1062
                add_colorbar=add_colorbar,
1063
                interpolation=interpolation,
1064
                cmap=cmap,
1065
                fig_kwargs=fig_kwargs,
1066
                **plot_kwargs,
1067
            )
1068
            plt.show()
1✔
1069
    else:
1070
        p = _plot_labels(
1✔
1071
            xr_obj=obj,
1072
            label_name=label_name,
1073
            max_n_labels=max_n_labels,
1074
            add_colorbar=add_colorbar,
1075
            interpolation=interpolation,
1076
            cmap=cmap,
1077
            fig_kwargs=fig_kwargs,
1078
            **plot_kwargs,
1079
        )
1080
    return p
1✔
1081

1082

1083
def plot_patches(
1✔
1084
    patch_gen,
1085
    variable=None,
1086
    add_colorbar=True,
1087
    interpolation="nearest",
1088
    fig_kwargs=None,
1089
    cbar_kwargs=None,
1090
    **plot_kwargs,
1091
):
1092
    """Plot patches."""
1093
    from gpm.visualization.plot import plot_image
1✔
1094

1095
    # Plot patches
1096
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1097
        if isinstance(xr_patch, xr.Dataset):
1✔
1098
            if variable is None:
1✔
1099
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
1✔
1100
            xr_patch = xr_patch[variable]
1✔
1101
        try:
1✔
1102
            plot_image(
1✔
1103
                xr_patch,
1104
                interpolation=interpolation,
1105
                add_colorbar=add_colorbar,
1106
                fig_kwargs=fig_kwargs,
1107
                cbar_kwargs=cbar_kwargs,
1108
                **plot_kwargs,
1109
            )
1110
            plt.show()
1✔
1111
        except Exception:
1✔
1112
            pass
1✔
1113

1114

1115
####--------------------------------------------------------------------------.
1116

1117

1118
def get_inset_bounds(
1✔
1119
    ax,
1120
    loc="upper right",
1121
    inset_height=0.2,
1122
    inside_figure=True,
1123
    aspect_ratio=1,
1124
):
1125
    """
1126
    Calculate the bounds for an inset axes in a matplotlib figure.
1127

1128
    This function computes the normalized figure coordinates for placing an inset axes within a figure,
1129
    based on the specified location, size, and whether the inset should be fully inside the figure bounds.
1130
    It is designed to be used with matplotlib figures to facilitate the addition of insets (e.g., for maps
1131
    or zoomed plots) at predefined positions.
1132

1133
    Parameters
1134
    ----------
1135
    loc : str
1136
        The location of the inset within the figure. Valid options are 'lower left', 'lower right',
1137
        'upper left', and 'upper right'. The default is 'upper right'.
1138
    inset_height : float
1139
        The size of the inset height, specified as a fraction of the figure's height.
1140
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1141
        The aspect ratio will govern the inset_width.
1142
    inside_figure : bool, optional
1143
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1144
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1145
        allowing for a half-outside placement.
1146
    aspect_ratio : float, optional
1147
        The width-to-height ratio of the inset figure.
1148
        A value greater than 1 indicates an inset figure wider than it is tall,
1149
        and a value less than 1 indicates an inset figure taller than it is wide.
1150
        The default value is 1.0, indicating a square inset figure.
1151

1152
    Returns
1153
    -------
1154
    inset_bounds : list of float
1155
        The calculated bounds of the inset, in the format [x0, y0, width, height], where `x0` and `y0`
1156
        are the normalized figure coordinates of the lower left corner of the inset, and `width` and
1157
        `height` are the normalized width and height of the inset, respectively.
1158

1159
    """
1160
    # Get the bounding box of the parent axes in figure coordinates
1161
    bbox = ax.get_position()
×
1162
    parent_width = bbox.width
×
1163
    parent_height = bbox.height
×
1164

1165
    # Compute the inset width percentage (relative to the parent axes)
1166
    # - Take into account possible different aspect ratios
1167
    inset_height_abs = inset_height * parent_height
×
1168
    inset_width_abs = inset_height_abs * aspect_ratio
×
1169
    inset_width = inset_width_abs / parent_width
×
1170
    loc_mapping = {
×
1171
        "upper right": (1 - inset_width, 1 - inset_height),
1172
        "upper left": (0, 1 - inset_height),
1173
        "lower right": (1 - inset_width, 0),
1174
        "lower left": (0, 0),
1175
    }
1176
    inset_x, inset_y = loc_mapping[loc]
×
1177

1178
    # Adjust for insets that are allowed to be half outside of the figure
1179
    if not inside_figure:
×
1180
        inset_x += inset_width / 2 * (-1 if loc.endswith("left") else 1)
×
1181
        inset_y += inset_height / 2 * (-1 if loc.startswith("lower") else 1)
×
1182

NEW
1183
    return [inset_x, inset_y, inset_width, inset_height]
×
1184

1185

1186
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True):
1✔
1187
    """
1188
    Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1189

1190
    This function creates a smaller map inset within a larger map plot to show a global view or
1191
    contextual location of the main plot's extent.
1192

1193
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1194
    within the inset to provide geographical context.
1195

1196
    Parameters
1197
    ----------
1198
    ax : (matplotlib.axes.Axes, cartopy.mpl.geoaxes.GeoAxes)
1199
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1200
    loc : str, optional
1201
        The location of the inset map within the main plot.
1202
        Options include 'lower left', 'lower right', 'upper left', 'upper right'.
1203
        The default is 'upper left'.
1204
    inset_height : float
1205
        The size of the inset height, specified as a fraction of the figure's height.
1206
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1207
        The aspect ratio (of the map inset) will govern the inset_width.
1208
    inside_figure : bool, optional
1209
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1210
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1211
        allowing for a half-outside placement.
1212
    projection: cartopy.crs.Projection
1213
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1214

1215
    Returns
1216
    -------
1217
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1218
        The Cartopy GeoAxesSubplot object for the inset map.
1219

1220
    Notes
1221
    -----
1222
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1223
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1224
    geographical extent.
1225

1226
    Examples
1227
    --------
1228
    >>> p = da.gpm.plot_map()
1229
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1230

1231
    This example creates a main plot with a specified extent and adds an upper-left inset map
1232
    showing the global context of the main plot's extent.
1233
    """
1234
    from shapely import Polygon
×
1235

1236
    from gpm.utils.geospatial import extend_geographic_extent
×
1237

1238
    # Retrieve extent and bounds
1239
    extent = ax.get_extent()
×
1240
    extent = extend_geographic_extent(extent, padding=0.5)
×
1241
    bounds = [extent[i] for i in [0, 2, 1, 3]]
×
1242
    # Create Cartopy Polygon
1243
    polygon = Polygon.from_bounds(*bounds)
×
1244
    # Define Orthographic projection
1245
    if projection is None:
×
1246
        lon_min, lon_max, lat_min, lat_max = extent
×
1247
        projection = ccrs.Orthographic(
×
1248
            central_latitude=(lat_min + lat_max) / 2,
1249
            central_longitude=(lon_min + lon_max) / 2,
1250
        )
1251

1252
    # Define aspect ratio of the map inset
1253
    aspect_ratio = float(np.diff(projection.x_limits) / np.diff(projection.y_limits).item())
×
1254

1255
    # Define inset location relative to main plot (ax) in normalized units
1256
    # - Lower-left corner of inset Axes, and its width and height
1257
    # - [x0, y0, width, height]
1258
    inset_bounds = get_inset_bounds(
×
1259
        ax=ax,
1260
        loc=loc,
1261
        inset_height=inset_height,
1262
        inside_figure=inside_figure,
1263
        aspect_ratio=aspect_ratio,
1264
    )
1265

1266
    # ax2 = plt.axes(inset_bounds, projection=projection)
1267
    ax2 = ax.inset_axes(
×
1268
        inset_bounds,
1269
        projection=projection,
1270
    )
1271

1272
    # Add global map
1273
    ax2.set_global()
×
1274
    ax2.add_feature(cfeature.LAND)
×
1275
    ax2.add_feature(cfeature.OCEAN)
×
1276
    # Add extent polygon
1277
    _ = ax2.add_geometries(
×
1278
        [polygon],
1279
        ccrs.PlateCarree(),
1280
        facecolor="none",
1281
        edgecolor="red",
1282
        linewidth=0.3,
1283
    )
1284
    return ax2
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc