• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 8518249396

02 Apr 2024 06:11AM UTC coverage: 87.861% (+0.2%) from 87.669%
8518249396

push

github

ghiggi
Add and fix pydocstyle rules

28 of 29 new or added lines in 15 files covered. (96.55%)

534 existing lines in 45 files now uncovered.

9004 of 10248 relevant lines covered (87.86%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.45
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
1✔
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import cartopy.feature as cfeature
1✔
33
import matplotlib.pyplot as plt
1✔
34
import numpy as np
1✔
35
import xarray as xr
1✔
36
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
37
from scipy.interpolate import griddata
1✔
38

39
import gpm
1✔
40

41

42
def is_generator(obj):
1✔
43
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
44

45

46
def _call_optimize_layout(self):
1✔
47
    """Optimize the figure layout."""
48
    adapt_fig_size(ax=self.axes)
×
UNCOV
49
    self.figure.tight_layout()
×
50

51

52
def add_optimize_layout_method(p):
1✔
53
    """Add a method to optimize the figure layout using monkey patching."""
54
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
55
    return p
1✔
56

57

58
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
59
    """Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
60

61
    This function is intended to be called after all plotting has been completed.
62
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
63

64
    Assumes that the first axis in the collection of axes is representative of all others.
65
    This means that all subplots are expected to have the same aspect ratio and size.
66

67
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
68
    """
69
    # Determine the number of rows and columns of subplots in the figure.
70
    # This information is crucial for calculating the new height of the figure.
71
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
72

73
    # Access the figure object from the axis to manipulate its properties.
74
    fig = ax.get_figure()
1✔
75

76
    # Retrieve the current size of the figure in inches.
77
    width, original_height = fig.get_size_inches()
1✔
78

79
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
80
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
81
    fig.canvas.draw()
1✔
82

83
    # Extract subplot parameters to understand the figure's layout.
84
    # These parameters include the margins of the figure and the spaces between subplots.
85
    bottom = fig.subplotpars.bottom
1✔
86
    top = fig.subplotpars.top
1✔
87
    left = fig.subplotpars.left
1✔
88
    right = fig.subplotpars.right
1✔
89
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
90
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
91

92
    # Calculate the aspect ratio of the data in the subplot.
93
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
94
    aspect = ax.get_data_ratio()
1✔
95

96
    # Calculate the width of a single plot, considering the left and right margins,
97
    # the number of columns, and the space between columns.
98
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
99

100
    # Calculate the height of a single plot using its width and the data aspect ratio.
101
    hp = wp * aspect
1✔
102

103
    # Calculate the new height of the figure, taking into account the number of rows,
104
    # the space between rows, and the top and bottom margins.
105
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
106

107
    # Check if the new height is significantly reduced (more than halved).
108
    if original_height / height > 2:
1✔
109
        # Calculate the scale factor to adjust the figure size closer to the original.
110
        scale_factor = original_height / height / 2
×
111

112
        # Apply the scale factor to both width and height to maintain the aspect ratio.
113
        width *= scale_factor
×
114
        height *= scale_factor
×
115

116
    # Apply the calculated width and height to adjust the figure size.
117
    fig.set_figwidth(width)
1✔
118
    fig.set_figheight(height)
1✔
119

120

121
####--------------------------------------------------------------------------.
122

123

124
def get_antimeridian_mask(lons):
1✔
125
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
126
    from scipy.ndimage import binary_dilation
1✔
127

128
    # Initialize mask
129
    n_y, n_x = lons.shape
1✔
130
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
131
    # Check vertical edges
132
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
133
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
134
    mask[row_idx, col_idx] = 1
1✔
135
    # Check horizontal edges
136
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
137
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
138
    mask[row_idx, col_idx] = 1
1✔
139
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
140
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
141
    return binary_dilation(mask)
1✔
142

143

144
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
145
    """Infill invalid coordinates.
146

147
    Interpolate the coordinates within the convex hull of data.
148
    Use nearest neighbour outside the convex hull of data.
149
    """
150
    # Copy object
151
    xr_obj = xr_obj.copy()
1✔
152
    lon = np.asanyarray(xr_obj[x].data)
1✔
153
    lat = np.asanyarray(xr_obj[y].data)
1✔
154
    # Retrieve infilled coordinates
155
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
156
    xr_obj[x].data = lon
1✔
157
    xr_obj[y].data = lat
1✔
158
    return xr_obj
1✔
159

160

161
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
162
    """Infill invalid coordinates.
163

164
    Interpolate the coordinates within the convex hull of data.
165
    Use nearest neighbour outside the convex hull of data.
166

167
    This operation is required to plot with pcolormesh since it
168
    does not accept non-finite values in the coordinates.
169

170
    If  mask_data=True, data values with invalid coordinates are masked
171
    and a numpy masked array is returned.
172
    Masked data values are not displayed in pcolormesh !
173
    If rgb=True, it assumes the RGB dimension is the last data dimension.
174

175
    """
176
    # Retrieve mask of invalid coordinates
177
    x_invalid = ~np.isfinite(x)
1✔
178
    y_invalid = ~np.isfinite(y)
1✔
179
    mask = np.logical_or(x_invalid, y_invalid)
1✔
180

181
    # If no invalid coordinates, return original data
182
    if np.all(~mask):
1✔
183
        return x, y, data
1✔
184

185
    # Mask the data
186
    if mask_data:
1✔
187
        if rgb:
1✔
UNCOV
188
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
189
            data_masked = np.ma.masked_where(data_mask, data)
×
190
        else:
191
            data_masked = np.ma.masked_where(mask, data)
1✔
192
    else:
UNCOV
193
        data_masked = data
×
194

195
    # Infill x and y
196
    if np.any(x_invalid):
1✔
197
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
198
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
199
    if np.any(y_invalid):
1✔
200
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
201
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
202
    return x, y, data_masked
1✔
203

204

205
def _interpolate_data(arr, method="linear"):
1✔
206
    # Find invalid locations
207
    is_invalid = ~np.isfinite(arr)
1✔
208

209
    # Find the indices of NaN values
210
    nan_indices = np.where(is_invalid)
1✔
211

212
    # Return array if not NaN values
213
    if len(nan_indices) == 0:
1✔
UNCOV
214
        return arr
×
215

216
    # Find the indices of non-NaN values
217
    non_nan_indices = np.where(~is_invalid)
1✔
218

219
    # Create a meshgrid of indices
220
    x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
221

222
    # Points (X, Y) where we have valid data
223
    points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
1✔
224

225
    # Points where data is NaN
226
    points_nan = np.array([y[nan_indices], x[nan_indices]]).T
1✔
227

228
    # Values at the non-NaN points
229
    values = arr[non_nan_indices]
1✔
230

231
    # Interpolate using griddata
232
    arr_new = arr.copy()
1✔
233
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
234
    return arr_new
1✔
235

236

237
####--------------------------------------------------------------------------.
238

239

240
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None):
1✔
241
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
242
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
243
    if ax is not None:
1✔
244
        if len(subplot_kwargs) >= 1:
1✔
245
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
1✔
246
        if len(fig_kwargs) >= 1:
1✔
247
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
1✔
248
    return fig_kwargs
1✔
249

250

251
def preprocess_subplot_kwargs(subplot_kwargs):
1✔
252
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
253
    subplot_kwargs = subplot_kwargs.copy()
1✔
254
    if "projection" not in subplot_kwargs:
1✔
255
        subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
256
    return subplot_kwargs
1✔
257

258

259
def initialize_cartopy_plot(
1✔
260
    ax,
261
    fig_kwargs,
262
    subplot_kwargs,
263
    add_background,
264
):
265
    """Initialize figure for cartopy plot if necessary."""
266
    # - Initialize figure
267
    if ax is None:
1✔
268
        fig_kwargs = preprocess_figure_args(
1✔
269
            ax=ax,
270
            fig_kwargs=fig_kwargs,
271
            subplot_kwargs=subplot_kwargs,
272
        )
273
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
1✔
274
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
275

276
    # - Add cartopy background
277
    if add_background:
1✔
278
        ax = plot_cartopy_background(ax)
1✔
279
    return ax
1✔
280

281

282
####--------------------------------------------------------------------------.
283

284

285
def plot_cartopy_background(ax):
1✔
286
    """Plot cartopy background."""
287
    # - Add coastlines
288
    ax.coastlines()
1✔
289
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
290
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
291
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
292
    # - Add grid lines
293
    gl = ax.gridlines(
1✔
294
        crs=ccrs.PlateCarree(),
295
        draw_labels=True,
296
        linewidth=1,
297
        color="gray",
298
        alpha=0.1,
299
        linestyle="-",
300
    )
301
    gl.top_labels = False  # gl.xlabels_top = False
1✔
302
    gl.right_labels = False  # gl.ylabels_right = False
1✔
303
    gl.xlines = True
1✔
304
    gl.ylines = True
1✔
305
    return ax
1✔
306

307

308
def plot_sides(sides, ax, **plot_kwargs):
1✔
309
    """Plot boundary sides.
310

311
    Expects a list of (lon, lat) tuples.
312
    """
313
    for side in sides:
1✔
314
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
315
    return p[0]
1✔
316

317

318
def plot_colorbar(p, ax, cbar_kwargs=None):
1✔
319
    """Add a colorbar to a matplotlib/cartopy plot.
320

321
    cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
322
    and the padding between the plot and the colorbar.
323

324
    p: matplotlib.image.AxesImage
325
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot^
326
    """
327
    cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
1✔
328
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
1✔
329
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
330
    orientation = cbar_kwargs.get("orientation", "vertical")
1✔
331

332
    divider = make_axes_locatable(ax)
1✔
333

334
    if orientation == "vertical":
1✔
335
        size = cbar_kwargs.get("size", "5%")
1✔
336
        pad = cbar_kwargs.get("pad", 0.1)
1✔
337
        cax = divider.append_axes("right", size=size, pad=pad, axes_class=plt.Axes)
1✔
338
    elif orientation == "horizontal":
1✔
339
        size = cbar_kwargs.get("size", "5%")
1✔
340
        pad = cbar_kwargs.get("pad", 0.25)
1✔
341
        cax = divider.append_axes("bottom", size=size, pad=pad, axes_class=plt.Axes)
1✔
342
    else:
UNCOV
343
        raise ValueError("Invalid orientation. Choose 'vertical' or 'horizontal'.")
×
344

345
    p.figure.add_axes(cax)
1✔
346
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
1✔
347
    if ticklabels is not None:
1✔
348
        _ = (
1✔
349
            cbar.ax.set_yticklabels(ticklabels)
350
            if orientation == "vertical"
351
            else cbar.ax.set_xticklabels(ticklabels)
352
        )
353
    return cbar
1✔
354

355

356
####--------------------------------------------------------------------------.
357

358

359
def get_dataarray_extent(da, x="lon", y="lat"):
1✔
360
    # TODO: compute corners array to estimate the extent
361
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
362
    # Get the minimum and maximum longitude and latitude values
363
    lon_min, lon_max = da[x].min(), da[x].max()
1✔
364
    lat_min, lat_max = da[y].min(), da[y].max()
1✔
365
    return (lon_min, lon_max, lat_min, lat_max)
1✔
366

367

368
def _compute_extent(x_coords, y_coords):
1✔
369
    """Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
370

371
    This function assumes that the spacing between each pixel is uniform.
372
    """
373
    # Calculate the pixel size assuming uniform spacing between pixels
374
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
1✔
375
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
1✔
376

377
    # Adjust min and max to get the corners of the outer pixels
378
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
1✔
379
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
1✔
380

381
    return [x_min, x_max, y_min, y_max]
1✔
382

383

384
def _plot_cartopy_imshow(
1✔
385
    ax,
386
    da,
387
    x,
388
    y,
389
    interpolation="nearest",
390
    add_colorbar=True,
391
    plot_kwargs={},
392
    cbar_kwargs=None,
393
):
394
    """Plot imshow with cartopy."""
395
    # - Ensure image with correct dimensions orders
396
    da = da.transpose(y, x)
1✔
397
    arr = np.asanyarray(da.data)
1✔
398

399
    # - Compute coordinates
400
    x_coords = da[x].to_numpy()
1✔
401
    y_coords = da[y].to_numpy()
1✔
402

403
    # - Derive extent
404
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
405

406
    # - Determine origin based on the orientation of da[y] values
407
    # -->  If increasing, set origin="lower"
408
    # -->  If decreasing, set origin="upper"
409
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
1✔
410

411
    # - Add variable field with cartopy
412
    p = ax.imshow(
1✔
413
        arr,
414
        transform=ccrs.PlateCarree(),
415
        extent=extent,
416
        origin=origin,
417
        interpolation=interpolation,
418
        **plot_kwargs,
419
    )
420
    # - Set the extent
421
    extent = get_dataarray_extent(da, x="lon", y="lat")
1✔
422
    ax.set_extent(extent)
1✔
423

424
    # - Add colorbar
425
    if add_colorbar:
1✔
426
        # --> TODO: set axis proportion in a meaningful way ...
427
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
428
    return p
1✔
429

430

431
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
432
    if np.ma.is_masked(arr):
1✔
433
        if rgb:
×
434
            data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
435
            combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
436
        else:
437
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
438
        arr = np.ma.masked_where(combined_mask, arr)
×
439
    else:
440
        if rgb:
1✔
441
            antimeridian_mask = np.broadcast_to(
1✔
442
                np.expand_dims(antimeridian_mask, axis=-1),
443
                arr.shape,
444
            )
445
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
446
    return arr
1✔
447

448

449
def _plot_cartopy_pcolormesh(
1✔
450
    ax,
451
    da,
452
    x,
453
    y,
454
    rgb=False,
455
    add_colorbar=True,
456
    add_swath_lines=True,
457
    plot_kwargs={},
458
    cbar_kwargs=None,
459
):
460
    """Plot imshow with cartopy.
461

462
    The function currently does not allow to zoom on regions across the antimeridian.
463
    The function mask scanning pixels which spans across the antimeridian.
464
    If rgb=True, expect rgb dimension to be at last position.
465
    x and y must represents longitude and latitudes.
466
    """
467
    # Get x, y, and array to plot
468
    da = da.compute()
1✔
469
    lon = da[x].data
1✔
470
    lat = da[y].data
1✔
471
    arr = da.data
1✔
472

473
    # If RGB, expect last dimension to have 3 channels
474
    if rgb and arr.shape[-1] != 3 and arr.shape[-1] != 4:
1✔
475
        raise ValueError("RGB array must have 3 or 4 channels in the last dimension.")
1✔
476

477
    # Infill invalid value and add mask if necessary
478
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb)
1✔
479

480
    # Ensure arguments
481
    if rgb:
1✔
482
        add_colorbar = False
1✔
483

484
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
485
    # - This enable correct masking of cells crossing the antimeridian
486
    from gpm.utils.area import _get_lonlat_corners
1✔
487

488
    lon, lat = _get_lonlat_corners(lon, lat)
1✔
489

490
    # Mask cells crossing the antimeridian
491
    # --> Here we assume not invalid coordinates anymore
492
    # --> Cartopy still bugs with several projections when data cross the antimeridian
493
    # --> This flag can be unset with gpm.config.set({"viz_hide_antimeridian_data": False})
494
    if gpm.config.get("viz_hide_antimeridian_data"):
1✔
495
        antimeridian_mask = get_antimeridian_mask(lon)
1✔
496
        is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
497
        if is_crossing_antimeridian:
1✔
498
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
499

500
            # Sanitize cmap bad color to avoid cartopy bug
501
            # - TODO cartopy requires bad_color to be transparent ...
502
            cmap = plot_kwargs.get("cmap", None)
1✔
503
            if cmap is not None:
1✔
504
                bad = cmap.get_bad()
×
505
                bad[3] = 0  # enforce to 0 (transparent)
×
506
                cmap.set_bad(bad)
×
507
                plot_kwargs["cmap"] = cmap
×
508

509
    # Add variable field with cartopy
510
    p = ax.pcolormesh(
1✔
511
        lon,
512
        lat,
513
        arr,
514
        transform=ccrs.PlateCarree(),
515
        **plot_kwargs,
516
    )
517

518
    # Add swath lines
519
    if add_swath_lines:
1✔
520
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
521
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
522

523
    # Add colorbar
524
    # --> TODO: set axis proportion in a meaningful way ...
525
    if add_colorbar:
1✔
526
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
1✔
527
    return p
1✔
528

529

530
def _plot_mpl_imshow(
1✔
531
    ax,
532
    da,
533
    x,
534
    y,
535
    interpolation="nearest",
536
    add_colorbar=True,
537
    plot_kwargs={},
538
    cbar_kwargs=None,
539
):
540
    """Plot imshow with matplotlib."""
541
    # - Ensure image with correct dimensions orders
542
    da = da.transpose(y, x)
×
543
    arr = np.asanyarray(da.data)
×
544

545
    # - Add variable field with matplotlib
546
    p = ax.imshow(
×
547
        arr,
548
        origin="upper",
549
        interpolation=interpolation,
550
        **plot_kwargs,
551
    )
552
    # - Add colorbar
553
    if add_colorbar:
×
554
        # --> TODO: set axis proportion in a meaningful way ...
555
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
556
    # - Return mappable
557
    return p
×
558

559

560
def set_colorbar_fully_transparent(p):
1✔
561
    """Add a fully transparent colorbar.
562

563
    This is useful for animation where the colorbar should
564
    not always in all frames but the plot area must be fixed.
565
    """
566
    # Get the position of the colorbar
567
    cbar_pos = p.colorbar.ax.get_position()
×
568

569
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
570
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
571

572
    # Remove the colorbar
573
    p.colorbar.ax.set_visible(False)
×
574

575
    # Now plot an empty rectangle
576
    fig = plt.gcf()
×
577
    rect = plt.Rectangle(
×
578
        (cbar_x, cbar_y),
579
        cbar_width,
580
        cbar_height,
581
        transform=fig.transFigure,
582
        facecolor="none",
583
        edgecolor="none",
584
    )
585

586
    fig.patches.append(rect)
×
587

588

589
def _plot_xr_imshow(
1✔
590
    ax,
591
    da,
592
    x,
593
    y,
594
    interpolation="nearest",
595
    add_colorbar=True,
596
    plot_kwargs={},
597
    cbar_kwargs=None,
598
    visible_colorbar=True,
599
):
600
    """Plot imshow with xarray.
601

602
    The colorbar is added with xarray to enable to display multiple colorbars
603
    when calling this function multiple times on different fields with
604
    different colorbars.
605
    """
606
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
607
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
608
    if not add_colorbar:
1✔
609
        cbar_kwargs = None
1✔
610
    p = da.plot.imshow(
1✔
611
        x=x,
612
        y=y,
613
        ax=ax,
614
        interpolation=interpolation,
615
        add_colorbar=add_colorbar,
616
        cbar_kwargs=cbar_kwargs,
617
        **plot_kwargs,
618
    )
619
    plt.title(da.name)
1✔
620
    if add_colorbar and ticklabels is not None:
1✔
621
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
622

623
    # Make the colorbar fully transparent with a smart trick ;)
624
    # - TODO: this still cause issues when plotting 2 colorbars !
625
    if add_colorbar and not visible_colorbar:
1✔
UNCOV
626
        set_colorbar_fully_transparent(p)
×
627

628
    # Add manually the colorbar
629
    # p = da.plot.imshow(
630
    #     x=x,
631
    #     y=y,
632
    #     ax=ax,
633
    #     interpolation=interpolation,
634
    #     add_colorbar=False,
635
    #     **plot_kwargs,
636
    # )
637
    # plt.title(da.name)
638
    # if add_colorbar:
639
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
640
    return p
1✔
641

642

643
def _plot_xr_pcolormesh(
1✔
644
    ax,
645
    da,
646
    x,
647
    y,
648
    add_colorbar=True,
649
    plot_kwargs={},
650
    cbar_kwargs=None,
651
):
652
    """Plot pcolormesh with xarray."""
UNCOV
653
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
654
    if not add_colorbar:
×
655
        cbar_kwargs = None
×
656
    p = da.plot.pcolormesh(
×
657
        x=x,
658
        y=y,
659
        ax=ax,
660
        add_colorbar=add_colorbar,
661
        cbar_kwargs=cbar_kwargs,
662
        **plot_kwargs,
663
    )
UNCOV
664
    plt.title(da.name)
×
665
    if add_colorbar and ticklabels is not None:
×
666
        p.colorbar.ax.set_yticklabels(ticklabels)
×
667
    return p
×
668

669

670
####--------------------------------------------------------------------------.
671

672

673
def plot_map(
1✔
674
    da,
675
    x="lon",
676
    y="lat",
677
    ax=None,
678
    add_colorbar=True,
679
    add_swath_lines=True,  # used only for GPM orbit objects
680
    add_background=True,
681
    rgb=False,
682
    interpolation="nearest",  # used only for GPM grid objects
683
    fig_kwargs=None,
684
    subplot_kwargs=None,
685
    cbar_kwargs=None,
686
    **plot_kwargs,
687
):
688
    """Plot data on a geographic map.
689

690
    Parameters
691
    ----------
692
    da : xr.DataArray
693
        xarray DataArray.
694
    x : str, optional
695
        Longitude coordinate name. The default is `"lon"`.
696
    y : str, optional
697
        Latitude coordinate name. The default is `"lat"`.
698
    ax : cartopy.GeoAxes, optional
699
        The cartopy GeoAxes where to plot the map.
700
        If `None`, a figure is initialized using the
701
        specified `fig_kwargs`and `subplot_kwargs`.
702
        The default is `None`.
703
    add_colorbar : bool, optional
704
        Whether to add a colorbar. The default is `True`.
705
    add_swath_lines : bool, optional
706
        Whether to plot the swath sides with a dashed line. The default is `True`.
707
        This argument only applies for ORBIT objects.
708
    add_background : bool, optional
709
        Whether to add the map background. The default is `True`.
710
    rgb : bool, optional
711
        Whether the input DataArray has a rgb dimension. The default is `False`.
712
    interpolation : str, optional
713
        Argument to be passed to imshow. Only applies for GRID objects.
714
        The default is `"nearest"`.
715
    fig_kwargs : dict, optional
716
        Figure options to be passed to plt.subplots.
717
        The default is `None`.
718
        Only used if `ax` is `None`.
719
    subplot_kwargs : dict, optional
720
        Dictionary of keyword arguments for Matplotlib subplots.
721
        Must contain the Cartopy CRS `'projection'` key if specified.
722
        The default is `None`.
723
        Only used if `ax` is `None`.
724
    cbar_kwargs : dict, optional
725
        Colorbar options. The default is `None`.
726
    **plot_kwargs
727
        Additional arguments to be passed to the plotting function.
728
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
729
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
730

731
    """
732
    from gpm.checks import is_grid, is_orbit
1✔
733
    from gpm.visualization.grid import plot_grid_map
1✔
734
    from gpm.visualization.orbit import plot_orbit_map
1✔
735

736
    # Plot orbit
737
    if is_orbit(da):
1✔
738
        p = plot_orbit_map(
1✔
739
            da=da,
740
            x=x,
741
            y=y,
742
            ax=ax,
743
            add_colorbar=add_colorbar,
744
            add_swath_lines=add_swath_lines,
745
            add_background=add_background,
746
            rgb=rgb,
747
            fig_kwargs=fig_kwargs,
748
            subplot_kwargs=subplot_kwargs,
749
            cbar_kwargs=cbar_kwargs,
750
            **plot_kwargs,
751
        )
752
    # Plot grid
753
    elif is_grid(da):
1✔
754
        p = plot_grid_map(
1✔
755
            da=da,
756
            x=x,
757
            y=y,
758
            ax=ax,
759
            add_colorbar=add_colorbar,
760
            interpolation=interpolation,
761
            add_background=add_background,
762
            fig_kwargs=fig_kwargs,
763
            subplot_kwargs=subplot_kwargs,
764
            cbar_kwargs=cbar_kwargs,
765
            **plot_kwargs,
766
        )
767
    else:
768
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
1✔
769
    # Return mappable
770
    return p
1✔
771

772

773
def plot_image(
1✔
774
    da,
775
    x=None,
776
    y=None,
777
    ax=None,
778
    add_colorbar=True,
779
    interpolation="nearest",
780
    fig_kwargs=None,
781
    cbar_kwargs=None,
782
    **plot_kwargs,
783
):
784
    """Plot data using imshow.
785

786
    Parameters
787
    ----------
788
    da : xr.DataArray
789
        xarray DataArray.
790
    x : str, optional
791
        X dimension name.
792
        If ``None``, takes the second dimension.
793
        The default is `None`.
794
    y : str, optional
795
        Y dimension name.
796
        If ``None``, takes the first dimension.
797
        The default is `None`.
798
    ax : cartopy.GeoAxes, optional
799
        The matplotlib axes where to plot the image.
800
        If ``None``, a figure is initialized using the
801
        specified `fig_kwargs`.
802
        The default is `None`.
803
    add_colorbar : bool, optional
804
        Whether to add a colorbar. The default is `True`.
805
    interpolation : str, optional
806
        Argument to be passed to imshow.
807
        The default is `"nearest"`.
808
    fig_kwargs : dict, optional
809
        Figure options to be passed to `plt.subplots`.
810
        The default is ``None``.
811
        Only used if `ax` is None.
812
    subplot_kwargs : dict, optional
813
        Subplot options to be passed to `plt.subplots`.
814
        The default is `None`.
815
        Only used if `ax` is `None`.
816
    cbar_kwargs : dict, optional
817
        Colorbar options. The default is `None`.
818
    **plot_kwargs
819
        Additional arguments to be passed to the plotting function.
820
        Examples include `cmap`, `norm`, `vmin`, `vmax`, `levels`, ...
821
        For FacetGrid plots, specify `row`, `col` and `col_wrap`.
822

823
    """
824
    # figsize, dpi, subplot_kw only used if ax is None
825
    from gpm.checks import is_grid, is_orbit
1✔
826
    from gpm.visualization.grid import plot_grid_image
1✔
827
    from gpm.visualization.orbit import plot_orbit_image
1✔
828

829
    # Plot orbit
830
    if is_orbit(da):
1✔
831
        p = plot_orbit_image(
1✔
832
            da=da,
833
            x=x,
834
            y=y,
835
            ax=ax,
836
            add_colorbar=add_colorbar,
837
            interpolation=interpolation,
838
            fig_kwargs=fig_kwargs,
839
            cbar_kwargs=cbar_kwargs,
840
            **plot_kwargs,
841
        )
842
    # Plot grid
843
    elif is_grid(da):
1✔
844
        p = plot_grid_image(
1✔
845
            da=da,
846
            x=x,
847
            y=y,
848
            ax=ax,
849
            add_colorbar=add_colorbar,
850
            interpolation=interpolation,
851
            fig_kwargs=fig_kwargs,
852
            cbar_kwargs=cbar_kwargs,
853
            **plot_kwargs,
854
        )
855
    else:
856
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
1✔
857
    # Return mappable
858
    return p
1✔
859

860

861
####--------------------------------------------------------------------------.
862

863

864
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
865
    """Create a 2D mesh coordinates DataArray.
866

867
    Takes as input the 1D coordinate arrays from an existing xarray object (Dataset or DataArray).
868

869
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
870
    the data values to NaN.
871

872
    Parameters
873
    ----------
874
    xr_obj : xarray.DataArray or xarray.Dataset
875
        The input xarray object containing the 1D coordinate arrays.
876
    x : str
877
        The name of the x-coordinate in `xr_obj`.
878
    y : str
879
        The name of the y-coordinate in `xr_obj`.
880

881
    Returns
882
    -------
883
    da_mesh : xarray.DataArray
884
        A 2D xarray DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
885

886
    Notes
887
    -----
888
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
889
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
890

891
    """
892
    # Extract 1D coordinate arrays
893
    x_coords = xr_obj[x].to_numpy()
1✔
894
    y_coords = xr_obj[y].to_numpy()
1✔
895

896
    # Create 2D meshgrid for x and y coordinates
897
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
898

899
    # Create a 2D array of NaN values with the same shape as the meshgrid
900
    dummy_values = np.full(X.shape, np.nan)
1✔
901

902
    # Create a new DataArray with 2D coordinates and NaN values
903
    return xr.DataArray(
1✔
904
        dummy_values,
905
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
906
        dims=("y", "x"),
907
    )
908

909

910
def plot_map_mesh(
1✔
911
    xr_obj,
912
    x="lon",
913
    y="lat",
914
    ax=None,
915
    edgecolors="k",
916
    linewidth=0.1,
917
    add_background=True,
918
    fig_kwargs=None,
919
    subplot_kwargs=None,
920
    **plot_kwargs,
921
):
922
    from gpm.checks import is_orbit  # is_grid
1✔
923

924
    from .grid import plot_grid_mesh
1✔
925
    from .orbit import plot_orbit_mesh
1✔
926

927
    # Plot orbit
928
    if is_orbit(xr_obj):
1✔
929
        p = plot_orbit_mesh(
1✔
930
            da=xr_obj[y],
931
            ax=ax,
932
            x=x,
933
            y=y,
934
            edgecolors=edgecolors,
935
            linewidth=linewidth,
936
            add_background=add_background,
937
            fig_kwargs=fig_kwargs,
938
            subplot_kwargs=subplot_kwargs,
939
            **plot_kwargs,
940
        )
941
    else:  # Plot grid
942
        p = plot_grid_mesh(
1✔
943
            xr_obj=xr_obj,
944
            x=x,
945
            y=y,
946
            ax=ax,
947
            edgecolors=edgecolors,
948
            linewidth=linewidth,
949
            add_background=add_background,
950
            fig_kwargs=fig_kwargs,
951
            subplot_kwargs=subplot_kwargs,
952
            **plot_kwargs,
953
        )
954
    # Return mappable
955
    return p
1✔
956

957

958
def plot_map_mesh_centroids(
1✔
959
    xr_obj,
960
    x="lon",
961
    y="lat",
962
    ax=None,
963
    c="r",
964
    s=1,
965
    add_background=True,
966
    fig_kwargs=None,
967
    subplot_kwargs=None,
968
    **plot_kwargs,
969
):
970
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
971
    from gpm.checks import is_grid
1✔
972

973
    # - Initialize figure if necessary
974
    ax = initialize_cartopy_plot(
1✔
975
        ax=ax,
976
        fig_kwargs=fig_kwargs,
977
        subplot_kwargs=subplot_kwargs,
978
        add_background=add_background,
979
    )
980

981
    # - Retrieve centroids
982
    if is_grid(xr_obj):
1✔
983
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
984
    lon = xr_obj[x].to_numpy()
1✔
985
    lat = xr_obj[y].to_numpy()
1✔
986

987
    # - Plot centroids
988
    return ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
989

990
    # - Return mappable
991

992

993
####--------------------------------------------------------------------------.
994

995

996
def _plot_labels(
1✔
997
    xr_obj,
998
    label_name=None,
999
    max_n_labels=50,
1000
    add_colorbar=True,
1001
    interpolation="nearest",
1002
    cmap="Paired",
1003
    fig_kwargs=None,
1004
    **plot_kwargs,
1005
):
1006
    """Plot labels.
1007

1008
    The maximum allowed number of labels to plot is 'max_n_labels'.
1009
    """
1010
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1011
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1012

1013
    from gpm.visualization.plot import plot_image
1✔
1014

1015
    if isinstance(xr_obj, xr.Dataset):
1✔
1016
        dataarray = xr_obj[label_name]
1✔
1017
    else:
1018
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1019

1020
    dataarray = dataarray.compute()
1✔
1021
    label_indices = get_label_indices(dataarray)
1✔
1022
    n_labels = len(label_indices)
1✔
1023
    if add_colorbar and n_labels > max_n_labels:
1✔
1024
        msg = f"""The array currently contains {n_labels} labels
1✔
1025
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1026
        print(msg)
1✔
1027
        add_colorbar = False
1✔
1028
    # Relabel array from 1 to ... for plotting
1029
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1030
    # Replace 0 with nan
1031
    dataarray = dataarray.where(dataarray > 0)
1✔
1032
    # Define appropriate colormap
1033
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1034
    default_plot_kwargs.update(plot_kwargs)
1✔
1035
    # Plot image
1036
    return plot_image(
1✔
1037
        dataarray,
1038
        interpolation=interpolation,
1039
        add_colorbar=add_colorbar,
1040
        cbar_kwargs=cbar_kwargs,
1041
        fig_kwargs=fig_kwargs,
1042
        **default_plot_kwargs,
1043
    )
1044

1045

1046
def plot_labels(
1✔
1047
    obj,  # Dataset, DataArray or generator
1048
    label_name=None,
1049
    max_n_labels=50,
1050
    add_colorbar=True,
1051
    interpolation="nearest",
1052
    cmap="Paired",
1053
    fig_kwargs=None,
1054
    **plot_kwargs,
1055
):
1056
    if is_generator(obj):
1✔
1057
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1058
            p = _plot_labels(
1✔
1059
                xr_obj=xr_obj,
1060
                label_name=label_name,
1061
                max_n_labels=max_n_labels,
1062
                add_colorbar=add_colorbar,
1063
                interpolation=interpolation,
1064
                cmap=cmap,
1065
                fig_kwargs=fig_kwargs,
1066
                **plot_kwargs,
1067
            )
1068
            plt.show()
1✔
1069
    else:
1070
        p = _plot_labels(
1✔
1071
            xr_obj=obj,
1072
            label_name=label_name,
1073
            max_n_labels=max_n_labels,
1074
            add_colorbar=add_colorbar,
1075
            interpolation=interpolation,
1076
            cmap=cmap,
1077
            fig_kwargs=fig_kwargs,
1078
            **plot_kwargs,
1079
        )
1080
    return p
1✔
1081

1082

1083
def plot_patches(
1✔
1084
    patch_gen,
1085
    variable=None,
1086
    add_colorbar=True,
1087
    interpolation="nearest",
1088
    fig_kwargs=None,
1089
    cbar_kwargs=None,
1090
    **plot_kwargs,
1091
):
1092
    """Plot patches."""
1093
    from gpm.visualization.plot import plot_image
1✔
1094

1095
    # Plot patches
1096
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1097
        if isinstance(xr_patch, xr.Dataset):
1✔
1098
            if variable is None:
1✔
1099
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
1✔
1100
            xr_patch = xr_patch[variable]
1✔
1101
        try:
1✔
1102
            plot_image(
1✔
1103
                xr_patch,
1104
                interpolation=interpolation,
1105
                add_colorbar=add_colorbar,
1106
                fig_kwargs=fig_kwargs,
1107
                cbar_kwargs=cbar_kwargs,
1108
                **plot_kwargs,
1109
            )
1110
            plt.show()
1✔
1111
        except Exception:
1✔
1112
            pass
1✔
1113

1114

1115
####--------------------------------------------------------------------------.
1116

1117

1118
def get_inset_bounds(
1✔
1119
    ax,
1120
    loc="upper right",
1121
    inset_height=0.2,
1122
    inside_figure=True,
1123
    aspect_ratio=1,
1124
):
1125
    """Calculate the bounds for an inset axes in a matplotlib figure.
1126

1127
    This function computes the normalized figure coordinates for placing an inset axes within a figure,
1128
    based on the specified location, size, and whether the inset should be fully inside the figure bounds.
1129
    It is designed to be used with matplotlib figures to facilitate the addition of insets (e.g., for maps
1130
    or zoomed plots) at predefined positions.
1131

1132
    Parameters
1133
    ----------
1134
    loc : str
1135
        The location of the inset within the figure. Valid options are 'lower left', 'lower right',
1136
        'upper left', and 'upper right'. The default is 'upper right'.
1137
    inset_height : float
1138
        The size of the inset height, specified as a fraction of the figure's height.
1139
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1140
        The aspect ratio will govern the inset_width.
1141
    inside_figure : bool, optional
1142
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1143
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1144
        allowing for a half-outside placement.
1145
    aspect_ratio : float, optional
1146
        The width-to-height ratio of the inset figure.
1147
        A value greater than 1 indicates an inset figure wider than it is tall,
1148
        and a value less than 1 indicates an inset figure taller than it is wide.
1149
        The default value is 1.0, indicating a square inset figure.
1150

1151
    Returns
1152
    -------
1153
    inset_bounds : list of float
1154
        The calculated bounds of the inset, in the format [x0, y0, width, height], where `x0` and `y0`
1155
        are the normalized figure coordinates of the lower left corner of the inset, and `width` and
1156
        `height` are the normalized width and height of the inset, respectively.
1157

1158
    """
1159
    # Get the bounding box of the parent axes in figure coordinates
UNCOV
1160
    bbox = ax.get_position()
×
UNCOV
1161
    parent_width = bbox.width
×
1162
    parent_height = bbox.height
×
1163

1164
    # Compute the inset width percentage (relative to the parent axes)
1165
    # - Take into account possible different aspect ratios
UNCOV
1166
    inset_height_abs = inset_height * parent_height
×
UNCOV
1167
    inset_width_abs = inset_height_abs * aspect_ratio
×
1168
    inset_width = inset_width_abs / parent_width
×
1169
    loc_mapping = {
×
1170
        "upper right": (1 - inset_width, 1 - inset_height),
1171
        "upper left": (0, 1 - inset_height),
1172
        "lower right": (1 - inset_width, 0),
1173
        "lower left": (0, 0),
1174
    }
UNCOV
1175
    inset_x, inset_y = loc_mapping[loc]
×
1176

1177
    # Adjust for insets that are allowed to be half outside of the figure
UNCOV
1178
    if not inside_figure:
×
UNCOV
1179
        inset_x += inset_width / 2 * (-1 if loc.endswith("left") else 1)
×
1180
        inset_y += inset_height / 2 * (-1 if loc.startswith("lower") else 1)
×
1181

1182
    return [inset_x, inset_y, inset_width, inset_height]
×
1183

1184

1185
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True):
1✔
1186
    """Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1187

1188
    This function creates a smaller map inset within a larger map plot to show a global view or
1189
    contextual location of the main plot's extent.
1190

1191
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1192
    within the inset to provide geographical context.
1193

1194
    Parameters
1195
    ----------
1196
    ax : (matplotlib.axes.Axes, cartopy.mpl.geoaxes.GeoAxes)
1197
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1198
    loc : str, optional
1199
        The location of the inset map within the main plot.
1200
        Options include 'lower left', 'lower right', 'upper left', 'upper right'.
1201
        The default is 'upper left'.
1202
    inset_height : float
1203
        The size of the inset height, specified as a fraction of the figure's height.
1204
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1205
        The aspect ratio (of the map inset) will govern the inset_width.
1206
    inside_figure : bool, optional
1207
        Determines whether the inset is constrained to be fully inside the figure bounds. If `True` (default),
1208
        the inset is placed fully within the figure. If `False`, the inset can extend beyond the figure's edges,
1209
        allowing for a half-outside placement.
1210
    projection: cartopy.crs.Projection
1211
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1212

1213
    Returns
1214
    -------
1215
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1216
        The Cartopy GeoAxesSubplot object for the inset map.
1217

1218
    Notes
1219
    -----
1220
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1221
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1222
    geographical extent.
1223

1224
    Examples
1225
    --------
1226
    >>> p = da.gpm.plot_map()
1227
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1228

1229
    This example creates a main plot with a specified extent and adds an upper-left inset map
1230
    showing the global context of the main plot's extent.
1231

1232
    """
UNCOV
1233
    from shapely import Polygon
×
1234

UNCOV
1235
    from gpm.utils.geospatial import extend_geographic_extent
×
1236

1237
    # Retrieve extent and bounds
1238
    extent = ax.get_extent()
×
UNCOV
1239
    extent = extend_geographic_extent(extent, padding=0.5)
×
1240
    bounds = [extent[i] for i in [0, 2, 1, 3]]
×
1241
    # Create Cartopy Polygon
UNCOV
1242
    polygon = Polygon.from_bounds(*bounds)
×
1243
    # Define Orthographic projection
1244
    if projection is None:
×
1245
        lon_min, lon_max, lat_min, lat_max = extent
×
UNCOV
1246
        projection = ccrs.Orthographic(
×
1247
            central_latitude=(lat_min + lat_max) / 2,
1248
            central_longitude=(lon_min + lon_max) / 2,
1249
        )
1250

1251
    # Define aspect ratio of the map inset
UNCOV
1252
    aspect_ratio = float(np.diff(projection.x_limits) / np.diff(projection.y_limits).item())
×
1253

1254
    # Define inset location relative to main plot (ax) in normalized units
1255
    # - Lower-left corner of inset Axes, and its width and height
1256
    # - [x0, y0, width, height]
UNCOV
1257
    inset_bounds = get_inset_bounds(
×
1258
        ax=ax,
1259
        loc=loc,
1260
        inset_height=inset_height,
1261
        inside_figure=inside_figure,
1262
        aspect_ratio=aspect_ratio,
1263
    )
1264

1265
    # ax2 = plt.axes(inset_bounds, projection=projection)
UNCOV
1266
    ax2 = ax.inset_axes(
×
1267
        inset_bounds,
1268
        projection=projection,
1269
    )
1270

1271
    # Add global map
UNCOV
1272
    ax2.set_global()
×
UNCOV
1273
    ax2.add_feature(cfeature.LAND)
×
UNCOV
1274
    ax2.add_feature(cfeature.OCEAN)
×
1275
    # Add extent polygon
1276
    _ = ax2.add_geometries(
×
1277
        [polygon],
1278
        ccrs.PlateCarree(),
1279
        facecolor="none",
1280
        edgecolor="red",
1281
        linewidth=0.3,
1282
    )
1283
    return ax2
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc