• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 8501078316

31 Mar 2024 09:37PM UTC coverage: 87.854% (+0.2%) from 87.669%
8501078316

Pull #53

github

ghiggi
Add pandas-vet rules
Pull Request #53: Refactor code style

649 of 737 new or added lines in 86 files covered. (88.06%)

4 existing lines in 4 files now uncovered.

9005 of 10250 relevant lines covered (87.85%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.5
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
1✔
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import cartopy.feature as cfeature
1✔
33
import matplotlib.pyplot as plt
1✔
34
import numpy as np
1✔
35
import xarray as xr
1✔
36
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
37
from scipy.interpolate import griddata
1✔
38

39
import gpm
1✔
40

41

42
def is_generator(obj):
1✔
43
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
44

45

46
def _call_optimize_layout(self):
1✔
47
    """Optimize the figure layout."""
48
    adapt_fig_size(ax=self.axes)
×
49
    self.figure.tight_layout()
×
50

51

52
def add_optimize_layout_method(p):
1✔
53
    """Add a method to optimize the figure layout using monkey patching."""
54
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
55
    return p
1✔
56

57

58
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
59
    """
60
    Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
61

62
    This function is intended to be called after all plotting has been completed.
63
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
64

65
    Assumes that the first axis in the collection of axes is representative of all others.
66
    This means that all subplots are expected to have the same aspect ratio and size.
67

68
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
69
    """
70
    # Determine the number of rows and columns of subplots in the figure.
71
    # This information is crucial for calculating the new height of the figure.
72
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
73

74
    # Access the figure object from the axis to manipulate its properties.
75
    fig = ax.get_figure()
1✔
76

77
    # Retrieve the current size of the figure in inches.
78
    width, original_height = fig.get_size_inches()
1✔
79

80
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
81
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
82
    fig.canvas.draw()
1✔
83

84
    # Extract subplot parameters to understand the figure's layout.
85
    # These parameters include the margins of the figure and the spaces between subplots.
86
    bottom = fig.subplotpars.bottom
1✔
87
    top = fig.subplotpars.top
1✔
88
    left = fig.subplotpars.left
1✔
89
    right = fig.subplotpars.right
1✔
90
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
91
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
92

93
    # Calculate the aspect ratio of the data in the subplot.
94
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
95
    aspect = ax.get_data_ratio()
1✔
96

97
    # Calculate the width of a single plot, considering the left and right margins,
98
    # the number of columns, and the space between columns.
99
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
100

101
    # Calculate the height of a single plot using its width and the data aspect ratio.
102
    hp = wp * aspect
1✔
103

104
    # Calculate the new height of the figure, taking into account the number of rows,
105
    # the space between rows, and the top and bottom margins.
106
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
107

108
    # Check if the new height is significantly reduced (more than halved).
109
    if original_height / height > 2:
1✔
110
        # Calculate the scale factor to adjust the figure size closer to the original.
111
        scale_factor = original_height / height / 2
×
112

113
        # Apply the scale factor to both width and height to maintain the aspect ratio.
114
        width *= scale_factor
×
115
        height *= scale_factor
×
116

117
    # Apply the calculated width and height to adjust the figure size.
118
    fig.set_figwidth(width)
1✔
119
    fig.set_figheight(height)
1✔
120

121

122
####--------------------------------------------------------------------------.
123

124

125
def get_antimeridian_mask(lons):
1✔
126
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
127
    from scipy.ndimage import binary_dilation
1✔
128

129
    # Initialize mask
130
    n_y, n_x = lons.shape
1✔
131
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
132
    # Check vertical edges
133
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
134
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
135
    mask[row_idx, col_idx] = 1
1✔
136
    # Check horizontal edges
137
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
138
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
139
    mask[row_idx, col_idx] = 1
1✔
140
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
141
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
142
    return binary_dilation(mask)
1✔
143

144

145
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
146
    """Infill invalid coordinates.
147

148
    Interpolate the coordinates within the convex hull of data.
149
    Use nearest neighbour outside the convex hull of data.
150
    """
151
    # Copy object
152
    xr_obj = xr_obj.copy()
1✔
153
    lon = np.asanyarray(xr_obj[x].data)
1✔
154
    lat = np.asanyarray(xr_obj[y].data)
1✔
155
    # Retrieve infilled coordinates
156
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
157
    xr_obj[x].data = lon
1✔
158
    xr_obj[y].data = lat
1✔
159
    return xr_obj
1✔
160

161

162
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
163
    """Infill invalid coordinates.
164

165
    Interpolate the coordinates within the convex hull of data.
166
    Use nearest neighbour outside the convex hull of data.
167

168
    This operation is required to plot with pcolormesh since it
169
    does not accept non-finite values in the coordinates.
170

171
    If  mask_data=True, data values with invalid coordinates are masked
172
    and a numpy masked array is returned.
173
    Masked data values are not displayed in pcolormesh !
174
    If rgb=True, it assumes the RGB dimension is the last data dimension.
175

176
    """
177
    # Retrieve mask of invalid coordinates
178
    x_invalid = ~np.isfinite(x)
1✔
179
    y_invalid = ~np.isfinite(y)
1✔
180
    mask = np.logical_or(x_invalid, y_invalid)
1✔
181

182
    # If no invalid coordinates, return original data
183
    if np.all(~mask):
1✔
184
        return x, y, data
1✔
185

186
    # Mask the data
187
    if mask_data:
1✔
188
        if rgb:
1✔
189
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
190
            data_masked = np.ma.masked_where(data_mask, data)
×
191
        else:
192
            data_masked = np.ma.masked_where(mask, data)
1✔
193
    else:
194
        data_masked = data
×
195

196
    # Infill x and y
197
    if np.any(x_invalid):
1✔
198
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
199
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
200
    if np.any(y_invalid):
1✔
201
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
202
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
203
    return x, y, data_masked
1✔
204

205

206
def _interpolate_data(arr, method="linear"):
1✔
207
    # Find invalid locations
208
    is_invalid = ~np.isfinite(arr)
1✔
209

210
    # Find the indices of NaN values
211
    nan_indices = np.where(is_invalid)
1✔
212

213
    # Return array if not NaN values
214
    if len(nan_indices) == 0:
1✔
215
        return arr
×
216

217
    # Find the indices of non-NaN values
218
    non_nan_indices = np.where(~is_invalid)
1✔
219

220
    # Create a meshgrid of indices
221
    X, Y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
222

223
    # Points (X, Y) where we have valid data
224
    points = np.array([Y[non_nan_indices], X[non_nan_indices]]).T
1✔
225

226
    # Points where data is NaN
227
    points_nan = np.array([Y[nan_indices], X[nan_indices]]).T
1✔
228

229
    # Values at the non-NaN points
230
    values = arr[non_nan_indices]
1✔
231

232
    # Interpolate using griddata
233
    arr_new = arr.copy()
1✔
234
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
235
    return arr_new
1✔
236

237

238
####--------------------------------------------------------------------------.
239

240

241
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None):
1✔
242
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
243
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
244
    if ax is not None:
1✔
245
        if len(subplot_kwargs) >= 1:
1✔
246
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
1✔
247
        if len(fig_kwargs) >= 1:
1✔
248
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
1✔
249
    return fig_kwargs
1✔
250

251

252
def preprocess_subplot_kwargs(subplot_kwargs):
1✔
253
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
254
    subplot_kwargs = subplot_kwargs.copy()
1✔
255
    if "projection" not in subplot_kwargs:
1✔
256
        subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
257
    return subplot_kwargs
1✔
258

259

260
def initialize_cartopy_plot(
1✔
261
    ax,
262
    fig_kwargs,
263
    subplot_kwargs,
264
    add_background,
265
):
266
    """Initialize figure for cartopy plot if necessary."""
267
    # - Initialize figure
268
    if ax is None:
1✔
269
        fig_kwargs = preprocess_figure_args(
1✔
270
            ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs
271
        )
272
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
1✔
273
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
274

275
    # - Add cartopy background
276
    if add_background:
1✔
277
        ax = plot_cartopy_background(ax)
1✔
278
    return ax
1✔
279

280

281
####--------------------------------------------------------------------------.
282

283

284
def plot_cartopy_background(ax):
1✔
285
    """Plot cartopy background."""
286
    # - Add coastlines
287
    ax.coastlines()
1✔
288
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
289
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
290
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
291
    # - Add grid lines
292
    gl = ax.gridlines(
1✔
293
        crs=ccrs.PlateCarree(),
294
        draw_labels=True,
295
        linewidth=1,
296
        color="gray",
297
        alpha=0.1,
298
        linestyle="-",
299
    )
300
    gl.top_labels = False  # gl.xlabels_top = False
1✔
301
    gl.right_labels = False  # gl.ylabels_right = False
1✔
302
    gl.xlines = True
1✔
303
    gl.ylines = True
1✔
304
    return ax
1✔
305

306

307
def plot_sides(sides, ax, **plot_kwargs):
1✔
308
    """Plot boundary sides.
309

310
    Expects a list of (lon, lat) tuples.
311
    """
312
    for side in sides:
1✔
313
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
314
    return p[0]
1✔
315

316

317
def plot_colorbar(p, ax, cbar_kwargs=None):
1✔
318
    """Add a colorbar to a matplotlib/cartopy plot.
319

320
    cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
321
    and the padding between the plot and the colorbar.
322

323
    p: matplotlib.image.AxesImage
324
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot^
325
    """
326
    cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
1✔
327
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
1✔
328
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
329
    orientation = cbar_kwargs.get("orientation", "vertical")
1✔
330

331
    divider = make_axes_locatable(ax)
1✔
332

333
    if orientation == "vertical":
1✔
334
        size = cbar_kwargs.get("size", "5%")
1✔
335
        pad = cbar_kwargs.get("pad", 0.1)
1✔
336
        cax = divider.append_axes("right", size=size, pad=pad, axes_class=plt.Axes)
1✔
337
    elif orientation == "horizontal":
1✔
338
        size = cbar_kwargs.get("size", "5%")
1✔
339
        pad = cbar_kwargs.get("pad", 0.25)
1✔
340
        cax = divider.append_axes("bottom", size=size, pad=pad, axes_class=plt.Axes)
1✔
341
    else:
342
        raise ValueError("Invalid orientation. Choose 'vertical' or 'horizontal'.")
×
343

344
    p.figure.add_axes(cax)
1✔
345
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
1✔
346
    if ticklabels is not None:
1✔
347
        _ = (
1✔
348
            cbar.ax.set_yticklabels(ticklabels)
349
            if orientation == "vertical"
350
            else cbar.ax.set_xticklabels(ticklabels)
351
        )
352
    return cbar
1✔
353

354

355
####--------------------------------------------------------------------------.
356

357

358
def get_dataarray_extent(da, x="lon", y="lat"):
1✔
359
    # TODO: compute corners array to estimate the extent
360
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
361
    # Get the minimum and maximum longitude and latitude values
362
    lon_min, lon_max = da[x].min(), da[x].max()
1✔
363
    lat_min, lat_max = da[y].min(), da[y].max()
1✔
364
    return (lon_min, lon_max, lat_min, lat_max)
1✔
365

366

367
def _compute_extent(x_coords, y_coords):
1✔
368
    """
369
    Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
370
    This function assumes that the spacing between each pixel is uniform.
371
    """
372
    # Calculate the pixel size assuming uniform spacing between pixels
373
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
1✔
374
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
1✔
375

376
    # Adjust min and max to get the corners of the outer pixels
377
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
1✔
378
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
1✔
379

380
    return [x_min, x_max, y_min, y_max]
1✔
381

382

383
def _plot_cartopy_imshow(
1✔
384
    ax,
385
    da,
386
    x,
387
    y,
388
    interpolation="nearest",
389
    add_colorbar=True,
390
    plot_kwargs={},
391
    cbar_kwargs=None,
392
):
393
    """Plot imshow with cartopy."""
394
    # - Ensure image with correct dimensions orders
395
    da = da.transpose(y, x)
1✔
396
    arr = np.asanyarray(da.data)
1✔
397

398
    # - Compute coordinates
399
    x_coords = da[x].to_numpy()
1✔
400
    y_coords = da[y].to_numpy()
1✔
401

402
    # - Derive extent
403
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
404

405
    # - Determine origin based on the orientation of da[y] values
406
    # -->  If increasing, set origin="lower"
407
    # -->  If decreasing, set origin="upper"
408
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
1✔
409

410
    # - Add variable field with cartopy
411
    p = ax.imshow(
1✔
412
        arr,
413
        transform=ccrs.PlateCarree(),
414
        extent=extent,
415
        origin=origin,
416
        interpolation=interpolation,
417
        **plot_kwargs,
418
    )
419
    # - Set the extent
420
    extent = get_dataarray_extent(da, x="lon", y="lat")
1✔
421
    ax.set_extent(extent)
1✔
422

423
    # - Add colorbar
424
    if add_colorbar:
1✔
425
        # --> TODO: set axis proportion in a meaningful way ...
426
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
427
    return p
1✔
428

429

430
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
431
    if np.ma.is_masked(arr):
1✔
432
        if rgb:
×
433
            data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
434
            combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
435
        else:
436
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
437
        arr = np.ma.masked_where(combined_mask, arr)
×
438
    else:
439
        if rgb:
1✔
440
            antimeridian_mask = np.broadcast_to(
1✔
441
                np.expand_dims(antimeridian_mask, axis=-1), arr.shape
442
            )
443
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
444
    return arr
1✔
445

446

447
def _plot_cartopy_pcolormesh(
1✔
448
    ax,
449
    da,
450
    x,
451
    y,
452
    rgb=False,
453
    add_colorbar=True,
454
    add_swath_lines=True,
455
    plot_kwargs={},
456
    cbar_kwargs=None,
457
):
458
    """Plot imshow with cartopy.
459

460
    The function currently does not allow to zoom on regions across the antimeridian.
461
    The function mask scanning pixels which spans across the antimeridian.
462
    If rgb=True, expect rgb dimension to be at last position.
463
    x and y must represents longitude and latitudes.
464
    """
465
    # Get x, y, and array to plot
466
    da = da.compute()
1✔
467
    lon = da[x].data
1✔
468
    lat = da[y].data
1✔
469
    arr = da.data
1✔
470

471
    # If RGB, expect last dimension to have 3 channels
472
    if rgb and arr.shape[-1] != 3 and arr.shape[-1] != 4:
1✔
473
        raise ValueError("RGB array must have 3 or 4 channels in the last dimension.")
1✔
474

475
    # Infill invalid value and add mask if necessary
476
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb)
1✔
477

478
    # Ensure arguments
479
    if rgb:
1✔
480
        add_colorbar = False
1✔
481

482
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
483
    # - This enable correct masking of cells crossing the antimeridian
484
    from gpm.utils.area import _get_lonlat_corners
1✔
485

486
    lon, lat = _get_lonlat_corners(lon, lat)
1✔
487

488
    # Mask cells crossing the antimeridian
489
    # --> Here we assume not invalid coordinates anymore
490
    # --> Cartopy still bugs with several projections when data cross the antimeridian
491
    # --> This flag can be unset with gpm.config.set({"viz_hide_antimeridian_data": False})
492
    if gpm.config.get("viz_hide_antimeridian_data"):
1✔
493
        antimeridian_mask = get_antimeridian_mask(lon)
1✔
494
        is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
495
        if is_crossing_antimeridian:
1✔
496
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
497

498
            # Sanitize cmap bad color to avoid cartopy bug
499
            # - TODO cartopy requires bad_color to be transparent ...
500
            cmap = plot_kwargs.get("cmap", None)
1✔
501
            if cmap is not None:
1✔
502
                bad = cmap.get_bad()
×
503
                bad[3] = 0  # enforce to 0 (transparent)
×
504
                cmap.set_bad(bad)
×
505
                plot_kwargs["cmap"] = cmap
×
506

507
    # Add variable field with cartopy
508
    p = ax.pcolormesh(
1✔
509
        lon,
510
        lat,
511
        arr,
512
        transform=ccrs.PlateCarree(),
513
        **plot_kwargs,
514
    )
515

516
    # Add swath lines
517
    if add_swath_lines:
1✔
518
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
519
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
520

521
    # Add colorbar
522
    # --> TODO: set axis proportion in a meaningful way ...
523
    if add_colorbar:
1✔
524
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
525
    return p
1✔
526

527

528
def _plot_mpl_imshow(
1✔
529
    ax,
530
    da,
531
    x,
532
    y,
533
    interpolation="nearest",
534
    add_colorbar=True,
535
    plot_kwargs={},
536
    cbar_kwargs=None,
537
):
538
    """Plot imshow with matplotlib."""
539
    # - Ensure image with correct dimensions orders
540
    da = da.transpose(y, x)
×
541
    arr = np.asanyarray(da.data)
×
542

543
    # - Add variable field with matplotlib
544
    p = ax.imshow(
×
545
        arr,
546
        origin="upper",
547
        interpolation=interpolation,
548
        **plot_kwargs,
549
    )
550
    # - Add colorbar
551
    if add_colorbar:
×
552
        # --> TODO: set axis proportion in a meaningful way ...
553
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
554
    # - Return mappable
555
    return p
×
556

557

558
def set_colorbar_fully_transparent(p):
1✔
559
    """Add a fully transparent colorbar.
560

561
    This is useful for animation where the colorbar should
562
    not always in all frames but the plot area must be fixed.
563
    """
564
    # Get the position of the colorbar
565
    cbar_pos = p.colorbar.ax.get_position()
×
566

567
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
568
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
569

570
    # Remove the colorbar
571
    p.colorbar.ax.set_visible(False)
×
572

573
    # Now plot an empty rectangle
574
    fig = plt.gcf()
×
575
    rect = plt.Rectangle(
×
576
        (cbar_x, cbar_y),
577
        cbar_width,
578
        cbar_height,
579
        transform=fig.transFigure,
580
        facecolor="none",
581
        edgecolor="none",
582
    )
583

584
    fig.patches.append(rect)
×
585

586

587
def _plot_xr_imshow(
1✔
588
    ax,
589
    da,
590
    x,
591
    y,
592
    interpolation="nearest",
593
    add_colorbar=True,
594
    plot_kwargs={},
595
    cbar_kwargs=None,
596
    visible_colorbar=True,
597
):
598
    """Plot imshow with xarray.
599

600
    The colorbar is added with xarray to enable to display multiple colorbars
601
    when calling this function multiple times on different fields with
602
    different colorbars.
603
    """
604
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
605
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
606
    if not add_colorbar:
1✔
607
        cbar_kwargs = None
1✔
608
    p = da.plot.imshow(
1✔
609
        x=x,
610
        y=y,
611
        ax=ax,
612
        interpolation=interpolation,
613
        add_colorbar=add_colorbar,
614
        cbar_kwargs=cbar_kwargs,
615
        **plot_kwargs,
616
    )
617
    plt.title(da.name)
1✔
618
    if add_colorbar and ticklabels is not None:
1✔
619
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
620

621
    # Make the colorbar fully transparent with a smart trick ;)
622
    # - TODO: this still cause issues when plotting 2 colorbars !
623
    if add_colorbar and not visible_colorbar:
1✔
624
        set_colorbar_fully_transparent(p)
×
625

626
    # Add manually the colorbar
627
    # p = da.plot.imshow(
628
    #     x=x,
629
    #     y=y,
630
    #     ax=ax,
631
    #     interpolation=interpolation,
632
    #     add_colorbar=False,
633
    #     **plot_kwargs,
634
    # )
635
    # plt.title(da.name)
636
    # if add_colorbar:
637
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
638
    return p
1✔
639

640

641
def _plot_xr_pcolormesh(
1✔
642
    ax,
643
    da,
644
    x,
645
    y,
646
    add_colorbar=True,
647
    plot_kwargs={},
648
    cbar_kwargs=None,
649
):
650
    """Plot pcolormesh with xarray."""
651
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
652
    if not add_colorbar:
×
653
        cbar_kwargs = None
×
654
    p = da.plot.pcolormesh(
×
655
        x=x,
656
        y=y,
657
        ax=ax,
658
        add_colorbar=add_colorbar,
659
        cbar_kwargs=cbar_kwargs,
660
        **plot_kwargs,
661
    )
662
    plt.title(da.name)
×
663
    if add_colorbar and ticklabels is not None:
×
664
        p.colorbar.ax.set_yticklabels(ticklabels)
×
665
    return p
×
666

667

668
####--------------------------------------------------------------------------.
669

670

671
def plot_map(
1✔
672
    da,
673
    x="lon",
674
    y="lat",
675
    ax=None,
676
    add_colorbar=True,
677
    add_swath_lines=True,  # used only for GPM orbit objects
678
    add_background=True,
679
    rgb=False,
680
    interpolation="nearest",  # used only for GPM grid objects
681
    fig_kwargs=None,
682
    subplot_kwargs=None,
683
    cbar_kwargs=None,
684
    **plot_kwargs,
685
):
686
    """
687
    Plot data on a geographic map.
688

689
    Parameters
690
    ----------
691
    da : xr.DataArray
692
        xarray DataArray.
693
    x : str, optional
694
        Longitude coordinate name. The default is `"lon"`.
695
    y : str, optional
696
        Latitude coordinate name. The default is `"lat"`.
697
    ax : cartopy.GeoAxes, optional
698
        The cartopy GeoAxes where to plot the map.
699
        If `None`, a figure is initialized using the
700
        specified `fig_kwargs`and `subplot_kwargs`.
701
        The default is `None`.
702
    add_colorbar : bool, optional
703
        Whether to add a colorbar. The default is `True`.
704
    add_swath_lines : bool, optional
705
        Whether to plot the swath sides with a dashed line. The default is `True`.
706
        This argument only applies for ORBIT objects.
707
    add_background : bool, optional
708
        Whether to add the map background. The default is `True`.
709
    rgb : bool, optional
710
        Whether the input DataArray has a rgb dimension. The default is `False`.
711
    interpolation : str, optional
712
        Argument to be passed to imshow. Only applies for GRID objects.
713
        The default is `"nearest"`.
714
    fig_kwargs : dict, optional
715
        Figure options to be passed to plt.subplots.
716
        The default is `None`.
717
        Only used if `ax` is `None`.
718
    subplot_kwargs : dict, optional
719
        Dictionary of keyword arguments for Matplotlib subplots.
720
        Must contain the Cartopy CRS `'projection'` key if specified.
721
        The default is `None`.
722
        Only used if `ax` is `None`.
723
    cbar_kwargs : dict, optional
724
        Colorbar options. The default is `None`.
725
    **plot_kwargs
726
        Additional arguments to be passed to the plotting function.
727
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
728
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
729
    """
730
    from gpm.checks import is_grid, is_orbit
1✔
731
    from gpm.visualization.grid import plot_grid_map
1✔
732
    from gpm.visualization.orbit import plot_orbit_map
1✔
733

734
    # Plot orbit
735
    if is_orbit(da):
1✔
736
        p = plot_orbit_map(
1✔
737
            da=da,
738
            x=x,
739
            y=y,
740
            ax=ax,
741
            add_colorbar=add_colorbar,
742
            add_swath_lines=add_swath_lines,
743
            add_background=add_background,
744
            rgb=rgb,
745
            fig_kwargs=fig_kwargs,
746
            subplot_kwargs=subplot_kwargs,
747
            cbar_kwargs=cbar_kwargs,
748
            **plot_kwargs,
749
        )
750
    # Plot grid
751
    elif is_grid(da):
1✔
752
        p = plot_grid_map(
1✔
753
            da=da,
754
            x=x,
755
            y=y,
756
            ax=ax,
757
            add_colorbar=add_colorbar,
758
            interpolation=interpolation,
759
            add_background=add_background,
760
            fig_kwargs=fig_kwargs,
761
            subplot_kwargs=subplot_kwargs,
762
            cbar_kwargs=cbar_kwargs,
763
            **plot_kwargs,
764
        )
765
    else:
766
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
1✔
767
    # Return mappable
768
    return p
1✔
769

770

771
def plot_image(
1✔
772
    da,
773
    x=None,
774
    y=None,
775
    ax=None,
776
    add_colorbar=True,
777
    interpolation="nearest",
778
    fig_kwargs=None,
779
    cbar_kwargs=None,
780
    **plot_kwargs,
781
):
782
    """
783
    Plot data using imshow.
784

785
    Parameters
786
    ----------
787
    da : xr.DataArray
788
        xarray DataArray.
789
    x : str, optional
790
        X dimension name.
791
        If ``None``, takes the second dimension.
792
        The default is `None`.
793
    y : str, optional
794
        Y dimension name.
795
        If ``None``, takes the first dimension.
796
        The default is `None`.
797
    ax : cartopy.GeoAxes, optional
798
        The matplotlib axes where to plot the image.
799
        If ``None``, a figure is initialized using the
800
        specified `fig_kwargs`.
801
        The default is `None`.
802
    add_colorbar : bool, optional
803
        Whether to add a colorbar. The default is `True`.
804
    interpolation : str, optional
805
        Argument to be passed to imshow.
806
        The default is `"nearest"`.
807
    fig_kwargs : dict, optional
808
        Figure options to be passed to `plt.subplots`.
809
        The default is ``None``.
810
        Only used if `ax` is None.
811
    subplot_kwargs : dict, optional
812
        Subplot options to be passed to `plt.subplots`.
813
        The default is `None`.
814
        Only used if `ax` is `None`.
815
    cbar_kwargs : dict, optional
816
        Colorbar options. The default is `None`.
817
    **plot_kwargs
818
        Additional arguments to be passed to the plotting function.
819
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
820
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
821
    """
822
    # figsize, dpi, subplot_kw only used if ax is None
823
    from gpm.checks import is_grid, is_orbit
1✔
824
    from gpm.visualization.grid import plot_grid_image
1✔
825
    from gpm.visualization.orbit import plot_orbit_image
1✔
826

827
    # Plot orbit
828
    if is_orbit(da):
1✔
829
        p = plot_orbit_image(
1✔
830
            da=da,
831
            x=x,
832
            y=y,
833
            ax=ax,
834
            add_colorbar=add_colorbar,
835
            interpolation=interpolation,
836
            fig_kwargs=fig_kwargs,
837
            cbar_kwargs=cbar_kwargs,
838
            **plot_kwargs,
839
        )
840
    # Plot grid
841
    elif is_grid(da):
1✔
842
        p = plot_grid_image(
1✔
843
            da=da,
844
            x=x,
845
            y=y,
846
            ax=ax,
847
            add_colorbar=add_colorbar,
848
            interpolation=interpolation,
849
            fig_kwargs=fig_kwargs,
850
            cbar_kwargs=cbar_kwargs,
851
            **plot_kwargs,
852
        )
853
    else:
854
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
1✔
855
    # Return mappable
856
    return p
1✔
857

858

859
####--------------------------------------------------------------------------.
860

861

862
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
863
    """
864
    Create a 2D xarray DataArray with mesh coordinates based on the 1D coordinate arrays
865
    from an existing xarray object (Dataset or DataArray).
866

867
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
868
    the data values to NaN.
869

870
    Parameters
871
    ----------
872
    xr_obj : xarray.DataArray or xarray.Dataset
873
        The input xarray object containing the 1D coordinate arrays.
874
    x : str
875
        The name of the x-coordinate in `xr_obj`.
876
    y : str
877
        The name of the y-coordinate in `xr_obj`.
878

879
    Returns
880
    -------
881
    da_mesh : xarray.DataArray
882
        A 2D xarray DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
883

884
    Notes
885
    -----
886
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
887
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
888
    """
889
    # Extract 1D coordinate arrays
890
    x_coords = xr_obj[x].to_numpy()
1✔
891
    y_coords = xr_obj[y].to_numpy()
1✔
892

893
    # Create 2D meshgrid for x and y coordinates
894
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
895

896
    # Create a 2D array of NaN values with the same shape as the meshgrid
897
    dummy_values = np.full(X.shape, np.nan)
1✔
898

899
    # Create a new DataArray with 2D coordinates and NaN values
900
    return xr.DataArray(
1✔
901
        dummy_values, coords={x: (("y", "x"), X), y: (("y", "x"), Y)}, dims=("y", "x")
902
    )
903

904

905
def plot_map_mesh(
1✔
906
    xr_obj,
907
    x="lon",
908
    y="lat",
909
    ax=None,
910
    edgecolors="k",
911
    linewidth=0.1,
912
    add_background=True,
913
    fig_kwargs=None,
914
    subplot_kwargs=None,
915
    **plot_kwargs,
916
):
917
    from gpm.checks import is_orbit  # is_grid
1✔
918

919
    from .grid import plot_grid_mesh
1✔
920
    from .orbit import plot_orbit_mesh
1✔
921

922
    # Plot orbit
923
    if is_orbit(xr_obj):
1✔
924
        p = plot_orbit_mesh(
1✔
925
            da=xr_obj[y],
926
            ax=ax,
927
            x=x,
928
            y=y,
929
            edgecolors=edgecolors,
930
            linewidth=linewidth,
931
            add_background=add_background,
932
            fig_kwargs=fig_kwargs,
933
            subplot_kwargs=subplot_kwargs,
934
            **plot_kwargs,
935
        )
936
    else:  # Plot grid
937
        p = plot_grid_mesh(
1✔
938
            xr_obj=xr_obj,
939
            x=x,
940
            y=y,
941
            ax=ax,
942
            edgecolors=edgecolors,
943
            linewidth=linewidth,
944
            add_background=add_background,
945
            fig_kwargs=fig_kwargs,
946
            subplot_kwargs=subplot_kwargs,
947
            **plot_kwargs,
948
        )
949
    # Return mappable
950
    return p
1✔
951

952

953
def plot_map_mesh_centroids(
1✔
954
    xr_obj,
955
    x="lon",
956
    y="lat",
957
    ax=None,
958
    c="r",
959
    s=1,
960
    add_background=True,
961
    fig_kwargs=None,
962
    subplot_kwargs=None,
963
    **plot_kwargs,
964
):
965
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
966
    from gpm.checks import is_grid
1✔
967

968
    # - Initialize figure if necessary
969
    ax = initialize_cartopy_plot(
1✔
970
        ax=ax,
971
        fig_kwargs=fig_kwargs,
972
        subplot_kwargs=subplot_kwargs,
973
        add_background=add_background,
974
    )
975

976
    # - Retrieve centroids
977
    if is_grid(xr_obj):
1✔
978
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
979
    lon = xr_obj[x].to_numpy()
1✔
980
    lat = xr_obj[y].to_numpy()
1✔
981

982
    # - Plot centroids
983
    return ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
984

985
    # - Return mappable
986

987

988
####--------------------------------------------------------------------------.
989

990

991
def _plot_labels(
1✔
992
    xr_obj,
993
    label_name=None,
994
    max_n_labels=50,
995
    add_colorbar=True,
996
    interpolation="nearest",
997
    cmap="Paired",
998
    fig_kwargs=None,
999
    **plot_kwargs,
1000
):
1001
    """Plot labels.
1002

1003
    The maximum allowed number of labels to plot is 'max_n_labels'.
1004
    """
1005
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1006
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1007

1008
    from gpm.visualization.plot import plot_image
1✔
1009

1010
    if isinstance(xr_obj, xr.Dataset):
1✔
1011
        dataarray = xr_obj[label_name]
1✔
1012
    else:
1013
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1014

1015
    dataarray = dataarray.compute()
1✔
1016
    label_indices = get_label_indices(dataarray)
1✔
1017
    n_labels = len(label_indices)
1✔
1018
    if add_colorbar and n_labels > max_n_labels:
1✔
1019
        msg = f"""The array currently contains {n_labels} labels
1✔
1020
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1021
        print(msg)
1✔
1022
        add_colorbar = False
1✔
1023
    # Relabel array from 1 to ... for plotting
1024
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1025
    # Replace 0 with nan
1026
    dataarray = dataarray.where(dataarray > 0)
1✔
1027
    # Define appropriate colormap
1028
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1029
    default_plot_kwargs.update(plot_kwargs)
1✔
1030
    # Plot image
1031
    return plot_image(
1✔
1032
        dataarray,
1033
        interpolation=interpolation,
1034
        add_colorbar=add_colorbar,
1035
        cbar_kwargs=cbar_kwargs,
1036
        fig_kwargs=fig_kwargs,
1037
        **default_plot_kwargs,
1038
    )
1039

1040

1041
def plot_labels(
1✔
1042
    obj,  # Dataset, DataArray or generator
1043
    label_name=None,
1044
    max_n_labels=50,
1045
    add_colorbar=True,
1046
    interpolation="nearest",
1047
    cmap="Paired",
1048
    fig_kwargs=None,
1049
    **plot_kwargs,
1050
):
1051
    if is_generator(obj):
1✔
1052
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1053
            p = _plot_labels(
1✔
1054
                xr_obj=xr_obj,
1055
                label_name=label_name,
1056
                max_n_labels=max_n_labels,
1057
                add_colorbar=add_colorbar,
1058
                interpolation=interpolation,
1059
                cmap=cmap,
1060
                fig_kwargs=fig_kwargs,
1061
                **plot_kwargs,
1062
            )
1063
            plt.show()
1✔
1064
    else:
1065
        p = _plot_labels(
1✔
1066
            xr_obj=obj,
1067
            label_name=label_name,
1068
            max_n_labels=max_n_labels,
1069
            add_colorbar=add_colorbar,
1070
            interpolation=interpolation,
1071
            cmap=cmap,
1072
            fig_kwargs=fig_kwargs,
1073
            **plot_kwargs,
1074
        )
1075
    return p
1✔
1076

1077

1078
def plot_patches(
1✔
1079
    patch_gen,
1080
    variable=None,
1081
    add_colorbar=True,
1082
    interpolation="nearest",
1083
    fig_kwargs=None,
1084
    cbar_kwargs=None,
1085
    **plot_kwargs,
1086
):
1087
    """Plot patches."""
1088
    from gpm.visualization.plot import plot_image
1✔
1089

1090
    # Plot patches
1091
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1092
        if isinstance(xr_patch, xr.Dataset):
1✔
1093
            if variable is None:
1✔
1094
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
1✔
1095
            xr_patch = xr_patch[variable]
1✔
1096
        try:
1✔
1097
            plot_image(
1✔
1098
                xr_patch,
1099
                interpolation=interpolation,
1100
                add_colorbar=add_colorbar,
1101
                fig_kwargs=fig_kwargs,
1102
                cbar_kwargs=cbar_kwargs,
1103
                **plot_kwargs,
1104
            )
1105
            plt.show()
1✔
1106
        except Exception:
1✔
1107
            pass
1✔
1108
    return
1✔
1109

1110

1111
####--------------------------------------------------------------------------.
1112

1113

1114
def get_inset_bounds(
1✔
1115
    ax,
1116
    loc="upper right",
1117
    inset_height=0.2,
1118
    inside_figure=True,
1119
    aspect_ratio=1,
1120
):
1121
    """
1122
    Calculate the bounds for an inset axes in a matplotlib figure.
1123

1124
    This function computes the normalized figure coordinates for placing an inset axes within a figure,
1125
    based on the specified location, size, and whether the inset should be fully inside the figure bounds.
1126
    It is designed to be used with matplotlib figures to facilitate the addition of insets (e.g., for maps
1127
    or zoomed plots) at predefined positions.
1128

1129
    Parameters
1130
    ----------
1131
    loc : str
1132
        The location of the inset within the figure. Valid options are 'lower left', 'lower right',
1133
        'upper left', and 'upper right'. The default is 'upper right'.
1134
    inset_height : float
1135
        The size of the inset height, specified as a fraction of the figure's height.
1136
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1137
        The aspect ratio will govern the inset_width.
1138
    inside_figure : bool, optional
1139
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1140
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1141
        allowing for a half-outside placement.
1142
    aspect_ratio : float, optional
1143
        The width-to-height ratio of the inset figure.
1144
        A value greater than 1 indicates an inset figure wider than it is tall,
1145
        and a value less than 1 indicates an inset figure taller than it is wide.
1146
        The default value is 1.0, indicating a square inset figure.
1147

1148
    Returns
1149
    -------
1150
    inset_bounds : list of float
1151
        The calculated bounds of the inset, in the format [x0, y0, width, height], where `x0` and `y0`
1152
        are the normalized figure coordinates of the lower left corner of the inset, and `width` and
1153
        `height` are the normalized width and height of the inset, respectively.
1154

1155
    """
1156
    # Get the bounding box of the parent axes in figure coordinates
1157
    bbox = ax.get_position()
×
1158
    parent_width = bbox.width
×
1159
    parent_height = bbox.height
×
1160

1161
    # Compute the inset width percentage (relative to the parent axes)
1162
    # - Take into account possible different aspect ratios
1163
    inset_height_abs = inset_height * parent_height
×
1164
    inset_width_abs = inset_height_abs * aspect_ratio
×
1165
    inset_width = inset_width_abs / parent_width
×
1166
    loc_mapping = {
×
1167
        "upper right": (1 - inset_width, 1 - inset_height),
1168
        "upper left": (0, 1 - inset_height),
1169
        "lower right": (1 - inset_width, 0),
1170
        "lower left": (0, 0),
1171
    }
1172
    inset_x, inset_y = loc_mapping[loc]
×
1173

1174
    # Adjust for insets that are allowed to be half outside of the figure
1175
    if not inside_figure:
×
1176
        inset_x += inset_width / 2 * (-1 if loc.endswith("left") else 1)
×
1177
        inset_y += inset_height / 2 * (-1 if loc.startswith("lower") else 1)
×
1178

NEW
1179
    return [inset_x, inset_y, inset_width, inset_height]
×
1180

1181

1182
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True):
1✔
1183
    """
1184
    Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1185

1186
    This function creates a smaller map inset within a larger map plot to show a global view or
1187
    contextual location of the main plot's extent.
1188

1189
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1190
    within the inset to provide geographical context.
1191

1192
    Parameters
1193
    ----------
1194
    ax : (matplotlib.axes.Axes, cartopy.mpl.geoaxes.GeoAxes)
1195
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1196
    loc : str, optional
1197
        The location of the inset map within the main plot.
1198
        Options include 'lower left', 'lower right', 'upper left', 'upper right'.
1199
        The default is 'upper left'.
1200
    inset_height : float
1201
        The size of the inset height, specified as a fraction of the figure's height.
1202
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1203
        The aspect ratio (of the map inset) will govern the inset_width.
1204
    inside_figure : bool, optional
1205
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1206
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1207
        allowing for a half-outside placement.
1208
    projection: cartopy.crs.Projection
1209
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1210

1211
    Returns
1212
    -------
1213
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1214
        The Cartopy GeoAxesSubplot object for the inset map.
1215

1216
    Notes
1217
    -----
1218
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1219
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1220
    geographical extent.
1221

1222
    Examples
1223
    --------
1224
    >>> p = da.gpm.plot_map()
1225
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1226

1227
    This example creates a main plot with a specified extent and adds an upper-left inset map
1228
    showing the global context of the main plot's extent.
1229
    """
1230
    from shapely import Polygon
×
1231

1232
    from gpm.utils.geospatial import extend_geographic_extent
×
1233

1234
    # Retrieve extent and bounds
1235
    extent = ax.get_extent()
×
1236
    extent = extend_geographic_extent(extent, padding=0.5)
×
1237
    bounds = [extent[i] for i in [0, 2, 1, 3]]
×
1238
    # Create Cartopy Polygon
1239
    polygon = Polygon.from_bounds(*bounds)
×
1240
    # Define Orthographic projection
1241
    if projection is None:
×
1242
        lon_min, lon_max, lat_min, lat_max = extent
×
1243
        projection = ccrs.Orthographic(
×
1244
            central_latitude=(lat_min + lat_max) / 2, central_longitude=(lon_min + lon_max) / 2
1245
        )
1246

1247
    # Define aspect ratio of the map inset
1248
    aspect_ratio = float(np.diff(projection.x_limits) / np.diff(projection.y_limits).item())
×
1249

1250
    # Define inset location relative to main plot (ax) in normalized units
1251
    # - Lower-left corner of inset Axes, and its width and height
1252
    # - [x0, y0, width, height]
1253
    inset_bounds = get_inset_bounds(
×
1254
        ax=ax,
1255
        loc=loc,
1256
        inset_height=inset_height,
1257
        inside_figure=inside_figure,
1258
        aspect_ratio=aspect_ratio,
1259
    )
1260

1261
    # ax2 = plt.axes(inset_bounds, projection=projection)
1262
    ax2 = ax.inset_axes(
×
1263
        inset_bounds,
1264
        projection=projection,
1265
    )
1266

1267
    # Add global map
1268
    ax2.set_global()
×
1269
    ax2.add_feature(cfeature.LAND)
×
1270
    ax2.add_feature(cfeature.OCEAN)
×
1271
    # Add extent polygon
1272
    _ = ax2.add_geometries(
×
1273
        [polygon], ccrs.PlateCarree(), facecolor="none", edgecolor="red", linewidth=0.3
1274
    )
1275
    return ax2
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc