• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 23598694613

26 Mar 2026 02:06PM UTC coverage: 90.832%. First build
23598694613

Pull #87

github

ghiggi
Update viz tests
Pull Request #87: Add l3 readers

155 of 167 new or added lines in 7 files covered. (92.81%)

16080 of 17703 relevant lines covered (90.83%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.47
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
28
import inspect
1✔
29
import warnings
1✔
30

31
import cartopy
1✔
32
import cartopy.crs as ccrs
1✔
33
import cartopy.feature as cfeature
1✔
34
import matplotlib.pyplot as plt
1✔
35
import numpy as np
1✔
36
import xarray as xr
1✔
37
from cartopy.mpl.gridliner import Gridliner
1✔
38
from pycolorbar import plot_colorbar, set_colorbar_fully_transparent
1✔
39
from pycolorbar.utils.mpl_legend import get_inset_bounds
1✔
40
from scipy.interpolate import griddata
1✔
41

42
import gpm
1✔
43
from gpm import get_plot_kwargs
1✔
44
from gpm.dataset.crs import compute_extent
1✔
45
from gpm.utils.area import get_lonlat_corners_from_centroids
1✔
46

47

48
def is_generator(obj):
1✔
49
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
50

51

52
def _call_optimize_layout(self):
1✔
53
    """Optimize the figure layout."""
54
    adapt_fig_size(ax=self.axes)
×
55
    self.figure.tight_layout()
×
56

57

58
def add_optimize_layout_method(p):
1✔
59
    """Add a method to optimize the figure layout using monkey patching."""
60
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
61
    return p
1✔
62

63

64
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
65
    """Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
66

67
    This function is intended to be called after all plotting has been completed.
68
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
69

70
    Assumes that the first axis in the collection of axes is representative of all others.
71
    This means that all subplots are expected to have the same aspect ratio and size.
72

73
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
74
    """
75
    # Determine the number of rows and columns of subplots in the figure.
76
    # This information is crucial for calculating the new height of the figure.
77
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
78

79
    # Access the figure object from the axis to manipulate its properties.
80
    fig = ax.get_figure()
1✔
81

82
    # Retrieve the current size of the figure in inches.
83
    width, original_height = fig.get_size_inches()
1✔
84

85
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
86
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
87
    fig.canvas.draw()
1✔
88

89
    # Extract subplot parameters to understand the figure's layout.
90
    # These parameters include the margins of the figure and the spaces between subplots.
91
    bottom = fig.subplotpars.bottom
×
92
    top = fig.subplotpars.top
×
93
    left = fig.subplotpars.left
1✔
94
    right = fig.subplotpars.right
1✔
95
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
96
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
×
97

98
    # Calculate the aspect ratio of the data in the subplot.
99
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
100
    aspect = ax.get_data_ratio()
×
101

102
    # Calculate the width of a single plot, considering the left and right margins,
103
    # the number of columns, and the space between columns.
104
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
105

106
    # Calculate the height of a single plot using its width and the data aspect ratio.
107
    hp = wp * aspect
×
108

109
    # Calculate the new height of the figure, taking into account the number of rows,
110
    # the space between rows, and the top and bottom margins.
111
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
112

113
    # Check if the new height is significantly reduced (more than halved).
114
    if original_height / height > 2:
1✔
115
        # Calculate the scale factor to adjust the figure size closer to the original.
116
        scale_factor = original_height / height / 2
×
117

118
        # Apply the scale factor to both width and height to maintain the aspect ratio.
119
        width *= scale_factor
×
120
        height *= scale_factor
×
121

122
    # Apply the calculated width and height to adjust the figure size.
123
    fig.set_figwidth(width)
×
124
    fig.set_figheight(height)
×
125

126

127
####--------------------------------------------------------------------------.
128

129

130
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
131
    """Infill invalid coordinates.
132

133
    Interpolate the coordinates within the convex hull of data.
134
    Use nearest neighbour outside the convex hull of data.
135
    """
136
    # Copy object
137
    xr_obj = xr_obj.copy()
1✔
138
    lon = np.asanyarray(xr_obj[x].data)
1✔
139
    lat = np.asanyarray(xr_obj[y].data)
1✔
140
    # Retrieve infilled coordinates
141
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
142
    xr_obj[x].data = lon
1✔
143
    xr_obj[y].data = lat
1✔
144
    return xr_obj
1✔
145

146

147
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
148
    """Infill invalid coordinates.
149

150
    Interpolate the coordinates within the convex hull of data.
151
    Use nearest neighbour outside the convex hull of data.
152

153
    This operation is required to plot with pcolormesh since it
154
    does not accept non-finite values in the coordinates.
155

156
    If  ``mask_data=True``, data values with invalid coordinates are masked
157
    and a numpy masked array is returned.
158
    Masked data values are not displayed in pcolormesh !
159
    If ``rgb=True``, it assumes the RGB dimension is the last data dimension.
160

161
    """
162
    # Retrieve mask of invalid coordinates
163
    x_invalid = ~np.isfinite(x)
1✔
164
    y_invalid = ~np.isfinite(y)
1✔
165
    mask = np.logical_or(x_invalid, y_invalid)
1✔
166

167
    # If no invalid coordinates, return original data
168
    if np.all(~mask):
1✔
169
        return x, y, data
1✔
170

171
    # Check at least ome valid coordinates
172
    if np.all(mask):
1✔
173
        raise ValueError("No valid coordinates.")
×
174

175
    # Mask the data
176
    if mask_data:
1✔
177
        if rgb:
1✔
178
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
1✔
179
            data_masked = np.ma.masked_where(data_mask, data)
1✔
180
        else:
181
            data_masked = np.ma.masked_where(mask, data)
1✔
182
    else:
183
        data_masked = data
×
184

185
    # Infill x and y
186
    # - Note: currently cause issue if NaN when crossing antimeridian ...
187
    # --> TODO: interpolation should be done in X,Y,Z
188
    if np.any(x_invalid):
1✔
189
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
190
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
191
    if np.any(y_invalid):
1✔
192
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
193
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
194
    return x, y, data_masked
1✔
195

196

197
def _interpolate_data(arr, method="linear"):
1✔
198
    # 1D coordinate (i.e. along_track/cross_track view)
199
    if arr.ndim == 1:
1✔
200
        return _interpolate_1d_coord(arr, method=method)
1✔
201
    # 2D coordinates (swath image)
202
    return _interpolate_2d_coord(arr, method=method)
1✔
203

204

205
def _interpolate_1d_coord(arr, method="linear"):
1✔
206
    # Find invalid locations
207
    is_invalid = ~np.isfinite(arr)
1✔
208

209
    # Find the indices of NaN values
210
    nan_indices = np.where(is_invalid)[0]
1✔
211

212
    # Return array if not NaN values
213
    if len(nan_indices) == 0:
1✔
214
        return arr
1✔
215

216
    # Find the indices of non-NaN values
217
    non_nan_indices = np.where(~is_invalid)
1✔
218

219
    # Create indices
220
    indices = np.arange(len(arr))
1✔
221

222
    # Points where we have valid data
223
    points = indices[non_nan_indices]
1✔
224

225
    # Points where data is NaN
226
    points_nan = indices[nan_indices]
1✔
227

228
    # Values at the non-NaN points
229
    values = arr[non_nan_indices]
1✔
230

231
    # Interpolate using griddata
232
    arr_new = arr.copy()
1✔
233
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
234
    return arr_new
1✔
235

236

237
def _interpolate_2d_coord(arr, method="linear"):
1✔
238
    # Find invalid locations
239
    is_invalid = ~np.isfinite(arr)
1✔
240

241
    # Find the indices of NaN values
242
    nan_indices = np.where(is_invalid)
1✔
243

244
    # Return array if not NaN values
245
    if len(nan_indices) == 0:
1✔
246
        return arr
1✔
247

248
    # Find the indices of non-NaN values
249
    non_nan_indices = np.where(~is_invalid)
1✔
250

251
    # Create a meshgrid of indices
252
    x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
253

254
    # Points (X, Y) where we have valid data
255
    points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
1✔
256

257
    # Points where data is NaN
258
    points_nan = np.array([y[nan_indices], x[nan_indices]]).T
1✔
259

260
    # Values at the non-NaN points
261
    values = arr[non_nan_indices]
1✔
262

263
    # Interpolate using griddata
264
    arr_new = arr.copy()
1✔
265
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
266
    return arr_new
1✔
267

268

269
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
270
    if np.ma.is_masked(arr):
1✔
271
        if rgb:
1✔
272
            antimeridian_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
1✔
273
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
274
        else:
275
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
276
        arr = np.ma.masked_where(combined_mask, arr)
1✔
277
    else:
278
        if rgb:
1✔
279
            antimeridian_mask = np.broadcast_to(
1✔
280
                np.expand_dims(antimeridian_mask, axis=-1),
281
                arr.shape,
282
            )
283
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
284
    return arr
1✔
285

286

287
def mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs):
1✔
288
    """Mask the array cells crossing the antimeridian.
289

290
    Here we assume not invalid lon coordinates anymore.
291
    Cartopy still bugs with several projections when data cross the antimeridian.
292
    By default, GPM-API mask data crossing the antimeridian.
293
    The GPM-API configuration default can be modified with: ``gpm.config.set({"viz_hide_antimeridian_data": False})``
294
    """
295
    antimeridian_mask = get_antimeridian_mask(lon)
1✔
296
    is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
297
    if is_crossing_antimeridian:
1✔
298
        # Sanitize cmap to avoid cartopy bug related to cmap bad color
299
        # - Cartopy requires the bad color to be fully transparent
300
        plot_kwargs = _sanitize_cartopy_plot_kwargs(plot_kwargs)
1✔
301
        # Mask data based on GPM-API config 'viz_hide_antimeridian_data'
302
        if gpm.config.get("viz_hide_antimeridian_data"):  # default is True
1✔
303
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
304
    return arr, plot_kwargs
1✔
305

306

307
def get_antimeridian_mask(lons):
1✔
308
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
309
    from scipy.ndimage import binary_dilation
1✔
310

311
    # Initialize mask
312
    n_y, n_x = lons.shape
1✔
313
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
314
    # Check vertical edges
315
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
316
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
317
    mask[row_idx, col_idx] = 1
1✔
318
    # Check horizontal edges
319
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
320
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
321
    mask[row_idx, col_idx] = 1
1✔
322
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
323
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
324
    return binary_dilation(mask)
1✔
325

326

327
####--------------------------------------------------------------------------.
328
###########################
329
#### Cartopy utilities ####
330
###########################
331

332

333
def remove_bottom_gridlabels(ax):
1✔
NEW
334
    gridliners = [a for a in ax.artists if isinstance(a, Gridliner)]
×
NEW
335
    for gl in gridliners:
×
NEW
336
        gl.bottom_labels = False
×
337

338

339
def remove_left_gridlabels(ax):
1✔
NEW
340
    gridliners = [a for a in ax.artists if isinstance(a, Gridliner)]
×
NEW
341
    for gl in gridliners:
×
NEW
342
        gl.left_labels = False
×
343

344

345
####--------------------------------------------------------------------------.
346
########################
347
#### Plot utilities ####
348
########################
349

350

351
def preprocess_rgb_dataarray(da, rgb):
1✔
352
    if rgb:
1✔
353
        if rgb not in da.dims:
1✔
354
            raise ValueError(f"The specified rgb='{rgb}' must be a dimension of the DataArray.")
1✔
355
        if da[rgb].size not in [3, 4]:
1✔
356
            raise ValueError("The RGB dimension must have size 3 or 4.")
×
357
        da = da.transpose(..., rgb)
1✔
358
    return da
1✔
359

360

361
def check_object_format(da, plot_kwargs, check_function, **function_kwargs):
1✔
362
    """Check object format and valid dimension names."""
363
    # Preprocess RGB DataArrays
364
    da = da.squeeze()
1✔
365
    da = preprocess_rgb_dataarray(da, plot_kwargs.get("rgb", False))
1✔
366
    # Retrieve rgb or FacetGrid column/row dimensions
367
    dims_dict = {key: plot_kwargs.get(key) for key in ["rgb", "col", "row"] if plot_kwargs.get(key, None)}
1✔
368
    # Check such dimensions are available
369
    for key, dim in dims_dict.items():
1✔
370
        if dim not in da.dims:
1✔
371
            raise ValueError(f"The DataArray does not have a {key}='{dim}' dimension.")
1✔
372
    # Subset DataArray to check if complies with specific check function
373
    isel_dict = dict.fromkeys(dims_dict.values(), 0)
1✔
374
    check_function(da.isel(isel_dict), **function_kwargs)
1✔
375
    return da
1✔
376

377

378
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None, is_facetgrid=False):
1✔
379
    if is_facetgrid and ax is not None:
1✔
380
        raise ValueError("When plotting with FacetGrid, do not specify the 'ax'.")
1✔
381
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
382
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
383
    if ax is not None:
1✔
384
        if len(subplot_kwargs) >= 1:
1✔
385
            raise ValueError("Provide `subplot_kwargs`only if ``ax``is None")
1✔
386
        if len(fig_kwargs) >= 1:
1✔
387
            raise ValueError("Provide `fig_kwargs` only if ``ax``is None")
1✔
388
    return fig_kwargs
1✔
389

390

391
def preprocess_subplot_kwargs(subplot_kwargs, infer_crs=False, xr_obj=None):
1✔
392
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
393
    subplot_kwargs = subplot_kwargs.copy()
1✔
394
    if "projection" not in subplot_kwargs:
1✔
395
        if infer_crs:
1✔
396
            subplot_kwargs["projection"] = xr_obj.gpm.cartopy_crs
1✔
397
        else:
398
            subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
399
    return subplot_kwargs
1✔
400

401

402
def infer_xy_labels(da, x=None, y=None, rgb=None):
1✔
403
    from xarray.plot.utils import _infer_xy_labels
1✔
404

405
    # Infer dimensions
406
    x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb)  # dummy flag for rgb
1✔
407
    return x, y
1✔
408

409

410
def infer_map_xy_coords(da, x=None, y=None):
1✔
411
    """
412
    Infer possible map x and y coordinates for the given DataArray.
413

414
    Parameters
415
    ----------
416
    da : xarray.DataArray
417
        The input DataArray.
418
    x : str, optional
419
        The name of the x (i.e. longitude) coordinate. If None, it will be inferred.
420
    y : str, optional
421
        The name of the y (i.e. latitude) coordinate. If None, it will be inferred.
422

423
    Returns
424
    -------
425
    tuple
426
        The inferred (x, y) coordinates.
427
    """
428
    possible_x_coords = ["x", "lon", "longitude"]
1✔
429
    possible_y_coords = ["y", "lat", "latitude"]
1✔
430

431
    if x is None:
1✔
432
        for coord in possible_x_coords:
1✔
433
            if coord in da.coords:
1✔
434
                x = coord
1✔
435
                break
1✔
436
        else:
437
            raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.")
×
438

439
    if y is None:
1✔
440
        for coord in possible_y_coords:
1✔
441
            if coord in da.coords:
1✔
442
                y = coord
1✔
443
                break
1✔
444
        else:
445
            raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.")
×
446

447
    return x, y
1✔
448

449

450
def _get_proj_str(crs):
1✔
451
    with warnings.catch_warnings():
1✔
452
        warnings.simplefilter("ignore")
1✔
453
        proj_str = crs.to_dict().get("proj", "")
1✔
454
    return proj_str
1✔
455

456

457
def initialize_cartopy_plot(
1✔
458
    ax,
459
    fig_kwargs,
460
    subplot_kwargs,
461
    add_background,
462
    add_gridlines,
463
    add_labels,
464
    infer_crs=False,
465
    xr_obj=None,
466
):
467
    """Initialize figure for cartopy plot if necessary."""
468
    # - Initialize figure
469
    if ax is None:
1✔
470
        fig_kwargs = preprocess_figure_args(
1✔
471
            ax=ax,
472
            fig_kwargs=fig_kwargs,
473
            subplot_kwargs=subplot_kwargs,
474
        )
475
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs, infer_crs=infer_crs, xr_obj=xr_obj)
1✔
476
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
477

478
    # - Add cartopy background
479
    if add_background:
1✔
480
        ax = plot_cartopy_background(ax)
1✔
481

482
    # - Add gridlines and labels
483
    if add_gridlines or add_labels:
1✔
484
        _ = plot_cartopy_gridlines_and_labels(ax, add_gridlines=add_gridlines, add_labels=add_labels)
1✔
485

486
    return ax
1✔
487

488

489
def plot_cartopy_gridlines_and_labels(ax, add_gridlines=True, add_labels=True):
1✔
490
    """Add cartopy gridlines and labels."""
491
    alpha = 0.1 if add_gridlines else 0
1✔
492
    gl = ax.gridlines(
1✔
493
        crs=ccrs.PlateCarree(),
494
        draw_labels=add_labels,
495
        linewidth=1,
496
        color="gray",
497
        alpha=alpha,
498
        linestyle="-",
499
    )
500
    gl.top_labels = False  # gl.xlabels_top = False
1✔
501
    gl.right_labels = False  # gl.ylabels_right = False
1✔
502
    gl.xlines = True
1✔
503
    gl.ylines = True
1✔
504
    return gl
1✔
505

506

507
def plot_cartopy_background(ax):
1✔
508
    """Plot cartopy background."""
509
    # - Add coastlines
510
    ax.coastlines()
1✔
511
    # - Add land and ocean
512
    # --> Raise error with some projections currently (shapely bug)
513
    # --> https://github.com/SciTools/cartopy/issues/2176
514
    if _get_proj_str(ax.projection) not in ["laea"]:
1✔
515
        ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
516
        ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
517
    # - Add borders
518
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
519
    return ax
1✔
520

521

522
def plot_sides(sides, ax, **plot_kwargs):
1✔
523
    """Plot boundary sides.
524

525
    Expects a list of (lon, lat) tuples.
526
    """
527
    for side in sides:
1✔
528
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
529
    return p[0]
×
530

531

532
####--------------------------------------------------------------------------.
533
##########################
534
#### Cartopy wrappers ####
535
##########################
536

537

538
def _sanitize_cartopy_plot_kwargs(plot_kwargs):
1✔
539
    """Sanitize 'cmap' to avoid cartopy bug related to cmap bad color.
540

541
    Cartopy requires the bad color to be fully transparent.
542
    """
543
    cmap = plot_kwargs.get("cmap", None)
1✔
544
    if cmap is not None:
1✔
545
        bad = cmap.get_bad()
1✔
546
        bad[3] = 0  # enforce to 0 (transparent)
1✔
547
        cmap.set_bad(bad)
1✔
548
        plot_kwargs["cmap"] = cmap
1✔
549
    return plot_kwargs
1✔
550

551

552
def is_same_crs(crs1, crs2):
1✔
553
    """Check if same CRS."""
554
    with warnings.catch_warnings():
1✔
555
        warnings.simplefilter("ignore")
1✔
556
        crs1_dict = crs1.to_dict()
1✔
557
        crs2_dict = crs2.to_dict()
1✔
558
    keys = ["proj", "lat_0", "lon_0", "x_0", "y_0", "units", "type", "lon_wrap", "over", "pm"]
1✔
559
    dict1 = {key: crs1_dict.get(key) for key in keys}
1✔
560
    dict2 = {key: crs2_dict.get(key) for key in keys}
1✔
561
    return dict1 == dict2
1✔
562

563

564
def plot_cartopy_imshow(
1✔
565
    ax,
566
    da,
567
    x,
568
    y,
569
    interpolation="nearest",
570
    add_colorbar=True,
571
    plot_kwargs=None,
572
    cbar_kwargs=None,
573
):
574
    """Plot imshow with cartopy."""
575
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
576

577
    # Infer x and y
578
    x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
579

580
    # Align x,y, data dimensions
581
    # - Ensure image with correct dimensions orders
582
    # - It can happen that x/y coords does not have same dimension order of data array.
583
    da = da.transpose(*da[y].dims, *da[x].dims, ...)
1✔
584

585
    # - Retrieve data
586
    arr = np.asanyarray(da.data)
1✔
587

588
    # - Compute coordinates
589
    x_coords = da[x].to_numpy()
1✔
590
    y_coords = da[y].to_numpy()
1✔
591

592
    # Compute extent
593
    extent = compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
594
    # area_extent = area_def.area_extent # [xmin, ymin, x_max, y_max]
595
    # extent = [area_extent[i] for i in [0, 2, 1, 3]] # [x_min, x_max, y_min, y_max]
596

597
    # Infer CRS of data, extent and cartopy projection
598
    try:
1✔
599
        crs = da.gpm.cartopy_crs
1✔
600
    except Exception:
×
601
        # Try assuming lon/lat CRS
602
        crs = ccrs.PlateCarree()
×
603

604
    # Determine image origin based on the orientation of da[y] values
605
    # - Cartopy assume origin is lower
606
    # - If y coordinate is increasing, set origin="lower"
607
    # - If y coordinate is decreasing, set origin="upper"
608
    #   --> Means that the image array is [::-1, :] reversed within cartopy
609
    y_increasing = y_coords[1] > y_coords[0]
1✔
610
    origin = "lower" if y_increasing else "upper"  # OLD CODE
1✔
611

612
    # Deal with decreasing y
613
    # if not y_increasing:  # decreasing y coordinates
614
    # extent = [extent[i] for i in [0, 1, 3, 2]]
615

616
    # Deal with out of limits x
617
    # - PlateeCarree coordinates out of bounds when  lons are defined as 0-360)
618
    set_extent = True
1✔
619

620
    # Case where coordinates are defined as 0-360 with pm=0
621
    if extent[1] > crs.x_limits[1] or extent[0] < crs.x_limits[0]:
1✔
622
        set_extent = False
×
623

624
    # Check if specify transform
625
    # - Specify transform argument only if data CRS is different from axes CRS
626
    # - If same crs, specifying transform is slower and might cuts away half of first and last row pixels
627
    # --> GPM-API automatically create the Cartopy GeoAxes with correct CRS
628
    transform = None if is_same_crs(crs, ax.projection) else crs
1✔
629

630
    # - Add variable field with cartopy
631
    rgb = plot_kwargs.pop("rgb", False)
1✔
632
    p = ax.imshow(
1✔
633
        arr,
634
        transform=transform,
635
        extent=extent,
636
        origin=origin,
637
        interpolation=interpolation,
638
        **plot_kwargs,
639
    )
640

641
    # - Set the extent
642
    # --> If some background is globally displayed, this zoom on the actual data region
643
    if set_extent:
1✔
644
        ax.set_extent(extent, crs=crs)
1✔
645

646
    # - Add colorbar
647
    if add_colorbar and not rgb:
×
648
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
×
649
    return p
×
650

651

652
def plot_cartopy_pcolormesh(
1✔
653
    ax,
654
    da,
655
    x,
656
    y,
657
    add_colorbar=True,
658
    add_swath_lines=True,
659
    plot_kwargs=None,
660
    cbar_kwargs=None,
661
):
662
    """Plot imshow with cartopy.
663

664
    x and y must represents longitude and latitudes.
665
    The function currently does not allow to zoom on regions across the antimeridian.
666
    The function mask scanning pixels which spans across the antimeridian.
667
    If the DataArray has a RGB dimension, plot_kwargs should contain the ``rgb``
668
    key with the name of the RGB dimension.
669

670
    """
671
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
672

673
    # Remove RGB from plot_kwargs
674
    rgb = plot_kwargs.pop("rgb", False)
1✔
675

676
    # Align x,y, data dimensions
677
    # - Ensure image with correct dimensions orders
678
    # - It can happen that x/y coords does not have same dimension order of data array.
679
    da = da.transpose(*da[y].dims, ...)
1✔
680

681
    # Get x, y, and array to plot
682
    da = preprocess_rgb_dataarray(da, rgb=rgb)
1✔
683
    da = da.compute()
1✔
684
    lon = da[x].data.copy()
1✔
685
    lat = da[y].data.copy()
1✔
686
    arr = da.data
1✔
687

688
    # Check if 1D coordinate (orbit nadir-view / transect / cross-section case)
689
    is_1d_case = lon.ndim == 1
1✔
690

691
    # Infill invalid value and mask data at invalid coordinates
692
    # - No invalid values after this function call
693
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb, mask_data=True)
1✔
694
    if is_1d_case:
1✔
695
        arr = np.expand_dims(arr, axis=1)
1✔
696

697
    # Ensure arguments
698
    if rgb:
1✔
699
        add_colorbar = False
1✔
700

701
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
702
    # - This enable correct masking of cells crossing the antimeridian
703
    lon, lat = get_lonlat_corners_from_centroids(lon, lat, parallel=False)
1✔
704

705
    # Mask cells crossing the antimeridian
706
    # - with gpm.config.set({"viz_hide_antimeridian_data": False}): can be used to modify the masking behaviour
707
    arr, plot_kwargs = mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs)
1✔
708

709
    # Add variable field with cartopy
710
    _ = plot_kwargs.setdefault("shading", "flat")
1✔
711
    p = ax.pcolormesh(
1✔
712
        lon,
713
        lat,
714
        arr,
715
        transform=ccrs.PlateCarree(),
716
        **plot_kwargs,
717
    )
718
    # Add swath lines
719
    # - TODO: currently assume that dimensions are (cross_track, along_track)
720
    if add_swath_lines and not is_1d_case:
1✔
721
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
722
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
723

724
    # Add colorbar
725
    if add_colorbar:
1✔
726
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
727
    return p
1✔
728

729

730
####-------------------------------------------------------------------------------.
731
#########################
732
#### Xarray wrappers ####
733
#########################
734

735

736
def _preprocess_xr_kwargs(add_colorbar, plot_kwargs, cbar_kwargs):
1✔
737
    if not add_colorbar:
1✔
738
        cbar_kwargs = None
1✔
739

740
    if "rgb" in plot_kwargs:
1✔
741
        cbar_kwargs = None
1✔
742
        add_colorbar = False
1✔
743
        args_to_keep = ["rgb", "col", "row", "origin"]  # alpha currently skipped if RGB
1✔
744
        plot_kwargs = {k: plot_kwargs[k] for k in args_to_keep if plot_kwargs.get(k, None) is not None}
1✔
745
    return add_colorbar, plot_kwargs, cbar_kwargs
1✔
746

747

748
def plot_xr_pcolormesh(
1✔
749
    ax,
750
    da,
751
    x,
752
    y,
753
    add_colorbar=True,
754
    cbar_kwargs=None,
755
    **plot_kwargs,
756
):
757
    """Plot pcolormesh with xarray."""
758
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
759
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
760
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
761
        add_colorbar=add_colorbar,
762
        plot_kwargs=plot_kwargs,
763
        cbar_kwargs=cbar_kwargs,
764
    )
765
    p = da.plot.pcolormesh(
1✔
766
        x=x,
767
        y=y,
768
        ax=ax,
769
        add_colorbar=add_colorbar,
770
        cbar_kwargs=cbar_kwargs,
771
        **plot_kwargs,
772
    )
773

774
    # Add variable name as title (if not FacetGrid)
775
    if not is_facetgrid:
1✔
776
        p.axes.set_title(da.name)
1✔
777

778
    if add_colorbar and ticklabels is not None:
1✔
779
        p.colorbar.ax.set_yticklabels(ticklabels)
×
780
    return p
1✔
781

782

783
def plot_xr_imshow(
1✔
784
    ax,
785
    da,
786
    x,
787
    y,
788
    interpolation="nearest",
789
    add_colorbar=True,
790
    add_labels=True,
791
    cbar_kwargs=None,
792
    visible_colorbar=True,
793
    **plot_kwargs,
794
):
795
    """Plot imshow with xarray.
796

797
    The colorbar is added with xarray to enable to display multiple colorbars
798
    when calling this function multiple times on different fields with
799
    different colorbars.
800
    """
801
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
802
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
803
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
804
        add_colorbar=add_colorbar,
805
        plot_kwargs=plot_kwargs,
806
        cbar_kwargs=cbar_kwargs,
807
    )
808
    # Allow using coords as x/y axis
809
    # BUG - Current bug in xarray
810
    if plot_kwargs.get("rgb", None) is not None:
1✔
811
        if x not in da.dims:
1✔
812
            da = da.swap_dims({list(da[x].dims)[0]: x})
×
813
        if y not in da.dims:
1✔
814
            da = da.swap_dims({list(da[y].dims)[0]: y})
×
815

816
    p = da.plot.imshow(
1✔
817
        x=x,
818
        y=y,
819
        ax=ax,
820
        interpolation=interpolation,
821
        add_colorbar=add_colorbar,
822
        add_labels=add_labels,
823
        cbar_kwargs=cbar_kwargs,
824
        **plot_kwargs,
825
    )
826

827
    # Add variable name as title (if not FacetGrid)
828
    if not is_facetgrid:
1✔
829
        p.axes.set_title(da.name)
1✔
830

831
    # Add colorbar ticklabels
832
    if add_colorbar and ticklabels is not None:
1✔
833
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
834

835
    # Make the colorbar fully transparent with a smart trick ;)
836
    # - TODO: this still cause issues when plotting 2 colorbars !
837
    if add_colorbar and not visible_colorbar:
1✔
838
        set_colorbar_fully_transparent(p)
1✔
839

840
    # Add manually the colorbar
841
    # p = da.plot.imshow(
842
    #     x=x,
843
    #     y=y,
844
    #     ax=ax,
845
    #     interpolation=interpolation,
846
    #     add_colorbar=False,
847
    #     **plot_kwargs,
848
    # )
849
    # plt.title(da.name)
850
    # if add_colorbar:
851
    #     _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
852
    return p
1✔
853

854

855
####--------------------------------------------------------------------------.
856
####################
857
#### Plot Image ####
858
####################
859

860

861
def _plot_image(
1✔
862
    da,
863
    x=None,
864
    y=None,
865
    ax=None,
866
    add_colorbar=True,
867
    add_labels=True,
868
    interpolation="nearest",
869
    fig_kwargs=None,
870
    cbar_kwargs=None,
871
    **plot_kwargs,
872
):
873
    """Plot GPM orbit granule as in image."""
874
    from gpm.checks import is_grid, is_orbit
1✔
875
    from gpm.visualization.facetgrid import sanitize_facetgrid_plot_kwargs
1✔
876

877
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)
1✔
878

879
    # - Initialize figure
880
    if ax is None:
1✔
881
        _, ax = plt.subplots(**fig_kwargs)
1✔
882

883
    # - Sanitize plot_kwargs set by by xarray FacetGrid.map_dataarray
884
    is_facetgrid = plot_kwargs.get("_is_facetgrid", False)
1✔
885
    plot_kwargs = sanitize_facetgrid_plot_kwargs(plot_kwargs)
1✔
886

887
    # - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
888
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
889
        name=da.name,
890
        user_plot_kwargs=plot_kwargs,
891
        user_cbar_kwargs=cbar_kwargs,
892
    )
893

894
    # Define x and y
895
    x, y = infer_xy_labels(da=da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
896

897
    # - Plot with xarray
898
    p = plot_xr_imshow(
1✔
899
        ax=ax,
900
        da=da,
901
        x=x,
902
        y=y,
903
        interpolation=interpolation,
904
        add_colorbar=add_colorbar,
905
        add_labels=add_labels,
906
        cbar_kwargs=cbar_kwargs,
907
        **plot_kwargs,
908
    )
909

910
    # Add custom labels
911
    default_labels = {
1✔
912
        "orbit": {"along_track": "Along-Track", "x": "Along-Track", "cross_track": "Cross-Track", "y": "Cross-Track"},
913
        "grid": {
914
            "lon": "Longitude",
915
            "longitude": "Longitude",
916
            "x": "Longitude",
917
            "lat": "Latitude",
918
            "latitude": "Latitude",
919
            "y": "Latitude",
920
        },
921
    }
922

923
    if add_labels:
1✔
924
        if is_orbit(da):
1✔
925
            ax.set_xlabel(default_labels["orbit"].get(x, x))
1✔
926
            ax.set_ylabel(default_labels["orbit"].get(y, y))
1✔
927
        elif is_grid(da):
1✔
928
            ax.set_xlabel(default_labels["grid"].get(x, x))
1✔
929
            ax.set_ylabel(default_labels["grid"].get(y, y))
1✔
930

931
    # - Monkey patch the mappable instance to add optimize_layout
932
    if not is_facetgrid:
1✔
933
        p = add_optimize_layout_method(p)
1✔
934
    # - Return mappable
935
    return p
1✔
936

937

938
def _plot_image_facetgrid(
1✔
939
    da,
940
    x=None,
941
    y=None,
942
    ax=None,
943
    add_colorbar=True,
944
    add_labels=True,
945
    interpolation="nearest",
946
    fig_kwargs=None,
947
    cbar_kwargs=None,
948
    **plot_kwargs,
949
):
950
    """Plot 2D fields with FacetGrid."""
951
    from gpm.visualization.facetgrid import ImageFacetGrid
1✔
952

953
    # Check inputs
954
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, is_facetgrid=True)
1✔
955

956
    # Retrieve GPM-API defaults cmap and cbar kwargs
957
    variable = da.name
1✔
958
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
959
        name=variable,
960
        user_plot_kwargs=plot_kwargs,
961
        user_cbar_kwargs=cbar_kwargs,
962
    )
963

964
    # Disable colorbar if rgb
965
    # - Move this to pycolorbar !
966
    # - Also remove cmap, norm, vmin and vmax in plot_kwargs
967
    if plot_kwargs.get("rgb", False):
1✔
968
        add_colorbar = False
1✔
969
        cbar_kwargs = {}
1✔
970

971
    # Create FacetGrid
972
    fc = ImageFacetGrid(
1✔
973
        data=da.compute(),
974
        col=plot_kwargs.pop("col", None),
975
        row=plot_kwargs.pop("row", None),
976
        col_wrap=plot_kwargs.pop("col_wrap", None),
977
        axes_pad=plot_kwargs.pop("axes_pad", None),
978
        fig_kwargs=fig_kwargs,
979
        cbar_kwargs=cbar_kwargs,
980
        add_colorbar=add_colorbar,
981
        aspect=plot_kwargs.pop("aspect", False),
982
        facet_height=plot_kwargs.pop("facet_height", 3),
983
        facet_aspect=plot_kwargs.pop("facet_aspect", 1),
984
    )
985

986
    # Plot the maps
987
    fc = fc.map_dataarray(
1✔
988
        _plot_image,
989
        x=x,
990
        y=y,
991
        add_colorbar=False,
992
        add_labels=add_labels,
993
        interpolation=interpolation,
994
        cbar_kwargs=cbar_kwargs,
995
        **plot_kwargs,
996
    )
997

998
    # Remove duplicated or all labels
999
    fc.remove_duplicated_axis_labels()
1✔
1000

1001
    if not add_labels:
1✔
1002
        fc.remove_left_ticks_and_labels()
×
1003
        fc.remove_bottom_ticks_and_labels()
×
1004

1005
    # Add colorbar
1006
    if add_colorbar:
1✔
1007
        fc.add_colorbar(**cbar_kwargs)
1✔
1008

1009
    return fc
1✔
1010

1011

1012
def plot_image(
1✔
1013
    da,
1014
    x=None,
1015
    y=None,
1016
    ax=None,
1017
    add_colorbar=True,
1018
    add_labels=True,
1019
    interpolation="nearest",
1020
    fig_kwargs=None,
1021
    cbar_kwargs=None,
1022
    **plot_kwargs,
1023
):
1024
    """Plot data using imshow.
1025

1026
    Parameters
1027
    ----------
1028
    da : xarray.DataArray
1029
        xarray DataArray.
1030
    x : str, optional
1031
        X dimension name.
1032
        If ``None``, takes the second dimension.
1033
        The default is ``None``.
1034
    y : str, optional
1035
        Y dimension name.
1036
        If ``None``, takes the first dimension.
1037
        The default is ``None``.
1038
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1039
        The matplotlib axes where to plot the image.
1040
        If ``None``, a figure is initialized using the
1041
        specified ``fig_kwargs``.
1042
        The default is ``None``.
1043
    add_colorbar : bool, optional
1044
        Whether to add a colorbar. The default is ``True``.
1045
    add_labels : bool, optional
1046
        Whether to add labels to the plot. The default is ``True``.
1047
    interpolation : str, optional
1048
        Argument to be passed to imshow.
1049
        The default is ``"nearest"``.
1050
    fig_kwargs : dict, optional
1051
        Figure options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1052
        The default is ``None``.
1053
        Only used if ``ax`` is ``None``.
1054
    subplot_kwargs : dict, optional
1055
        Subplot options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1056
        The default is ``None``.
1057
        Only used if ```ax``` is ``None``.
1058
    cbar_kwargs : dict, optional
1059
        Colorbar options. The default is ``None``.
1060
    **plot_kwargs
1061
        Additional arguments to be passed to the plotting function.
1062
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1063
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1064
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1065

1066

1067
    """
1068
    from gpm.checks import check_is_spatial_2d, is_spatial_2d
1✔
1069

1070
    # Plot orbit
1071
    if not is_spatial_2d(da, strict=False):
1✔
1072
        raise ValueError("Can not plot. It's not a spatial 2D object.")
1✔
1073

1074
    # Check inputs
1075
    da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_is_spatial_2d, strict=True)
1✔
1076

1077
    # Plot FacetGrid with xarray imshow
1078
    if "col" in plot_kwargs or "row" in plot_kwargs:
1✔
1079
        p = _plot_image_facetgrid(
1✔
1080
            da=da,
1081
            x=x,
1082
            y=y,
1083
            ax=ax,
1084
            add_colorbar=add_colorbar,
1085
            add_labels=add_labels,
1086
            interpolation=interpolation,
1087
            fig_kwargs=fig_kwargs,
1088
            cbar_kwargs=cbar_kwargs,
1089
            **plot_kwargs,
1090
        )
1091
    # Plot with xarray imshow
1092
    else:
1093
        p = _plot_image(
1✔
1094
            da=da,
1095
            x=x,
1096
            y=y,
1097
            ax=ax,
1098
            add_colorbar=add_colorbar,
1099
            add_labels=add_labels,
1100
            interpolation=interpolation,
1101
            fig_kwargs=fig_kwargs,
1102
            cbar_kwargs=cbar_kwargs,
1103
            **plot_kwargs,
1104
        )
1105
    # Return mappable
1106
    return p
1✔
1107

1108

1109
####--------------------------------------------------------------------------.
1110
##################
1111
#### Plot map ####
1112
##################
1113

1114

1115
def plot_map(
1✔
1116
    da,
1117
    x=None,
1118
    y=None,
1119
    ax=None,
1120
    interpolation="nearest",  # used only for GPM grid objects
1121
    add_colorbar=True,
1122
    add_background=True,
1123
    add_labels=True,
1124
    add_gridlines=True,
1125
    add_swath_lines=True,  # used only for GPM orbit objects
1126
    fig_kwargs=None,
1127
    subplot_kwargs=None,
1128
    cbar_kwargs=None,
1129
    **plot_kwargs,
1130
):
1131
    """Plot data on a geographic map.
1132

1133
    Parameters
1134
    ----------
1135
    da : xarray.DataArray
1136
        xarray DataArray.
1137
    x : str, optional
1138
        Longitude coordinate name.
1139
        If ``None``, takes the second dimension.
1140
        The default is ``None``.
1141
    y : str, optional
1142
        Latitude coordinate name.
1143
        If ``None``, takes the first dimension.
1144
        The default is ``None``.
1145
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1146
        The cartopy GeoAxes where to plot the map.
1147
        If ``None``, a figure is initialized using the
1148
        specified ``fig_kwargs`` and ``subplot_kwargs``.
1149
        The default is ``None``.
1150
    add_colorbar : bool, optional
1151
        Whether to add a colorbar. The default is ``True``.
1152
    add_labels : bool, optional
1153
        Whether to add cartopy labels to the plot. The default is ``True``.
1154
    add_gridlines : bool, optional
1155
        Whether to add cartopy gridlines to the plot. The default is ``True``.
1156
    add_swath_lines : bool, optional
1157
        Whether to plot the swath sides with a dashed line. The default is ``True``.
1158
        This argument only applies for ORBIT objects.
1159
    add_background : bool, optional
1160
        Whether to add the map background. The default is ``True``.
1161
    interpolation : str, optional
1162
        Argument to be passed to :py:class:`matplotlib.axes.Axes.imshow`. Only applies for GRID objects.
1163
        The default is ``"nearest"``.
1164
    fig_kwargs : dict, optional
1165
        Figure options to be passed to `matplotlib.pyplot.subplots`.
1166
        The default is ``None``.
1167
        Only used if ``ax`` is ``None``.
1168
    subplot_kwargs : dict, optional
1169
        Dictionary of keyword arguments for :py:class:`matplotlib.pyplot.subplots`.
1170
        Must contain the Cartopy CRS ` ``projection`` key if specified.
1171
        The default is ``None``.
1172
        Only used if ``ax`` is ``None``.
1173
    cbar_kwargs : dict, optional
1174
        Colorbar options. The default is ``None``.
1175
    **plot_kwargs
1176
        Additional arguments to be passed to the plotting function.
1177
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1178
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1179
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1180

1181

1182
    """
1183
    from gpm.checks import has_spatial_dim, is_grid, is_orbit, is_spatial_2d
1✔
1184
    from gpm.visualization.grid import plot_grid_map
1✔
1185
    from gpm.visualization.orbit import plot_orbit_map
1✔
1186

1187
    # Plot orbit
1188
    # - allow vertical or other dimensions for FacetGrid
1189
    # - allow to plot a swath of size 1 (i.e. nadir-looking)
1190
    if is_orbit(da) and has_spatial_dim(da):
1✔
1191
        p = plot_orbit_map(
1✔
1192
            da=da,
1193
            x=x,
1194
            y=y,
1195
            ax=ax,
1196
            add_colorbar=add_colorbar,
1197
            add_background=add_background,
1198
            add_gridlines=add_gridlines,
1199
            add_labels=add_labels,
1200
            add_swath_lines=add_swath_lines,
1201
            fig_kwargs=fig_kwargs,
1202
            subplot_kwargs=subplot_kwargs,
1203
            cbar_kwargs=cbar_kwargs,
1204
            **plot_kwargs,
1205
        )
1206
    # Plot grid
1207
    elif is_grid(da) and is_spatial_2d(da, strict=False):
1✔
1208
        p = plot_grid_map(
1✔
1209
            da=da,
1210
            x=x,
1211
            y=y,
1212
            ax=ax,
1213
            interpolation=interpolation,
1214
            add_colorbar=add_colorbar,
1215
            add_background=add_background,
1216
            add_gridlines=add_gridlines,
1217
            add_labels=add_labels,
1218
            fig_kwargs=fig_kwargs,
1219
            subplot_kwargs=subplot_kwargs,
1220
            cbar_kwargs=cbar_kwargs,
1221
            **plot_kwargs,
1222
        )
1223
    else:
1224
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial 2D object.")
1✔
1225
    # Return mappable
1226
    return p
1✔
1227

1228

1229
def plot_map_mesh(
1✔
1230
    xr_obj,
1231
    x=None,
1232
    y=None,
1233
    ax=None,
1234
    edgecolors="k",
1235
    linewidth=0.1,
1236
    add_background=True,
1237
    add_gridlines=True,
1238
    add_labels=True,
1239
    fig_kwargs=None,
1240
    subplot_kwargs=None,
1241
    **plot_kwargs,
1242
):
1243
    from gpm.checks import is_grid, is_orbit
1✔
1244
    from gpm.visualization.grid import plot_grid_mesh
1✔
1245
    from gpm.visualization.orbit import plot_orbit_mesh
1✔
1246

1247
    # Plot orbit
1248
    if is_orbit(xr_obj):
1✔
1249
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1250
        p = plot_orbit_mesh(
1✔
1251
            da=xr_obj[y],
1252
            ax=ax,
1253
            x=x,
1254
            y=y,
1255
            edgecolors=edgecolors,
1256
            linewidth=linewidth,
1257
            add_background=add_background,
1258
            add_gridlines=add_gridlines,
1259
            add_labels=add_labels,
1260
            fig_kwargs=fig_kwargs,
1261
            subplot_kwargs=subplot_kwargs,
1262
            **plot_kwargs,
1263
        )
1264
    elif is_grid(xr_obj):
1✔
1265
        p = plot_grid_mesh(
1✔
1266
            xr_obj=xr_obj,
1267
            x=x,
1268
            y=y,
1269
            ax=ax,
1270
            edgecolors=edgecolors,
1271
            linewidth=linewidth,
1272
            add_background=add_background,
1273
            add_gridlines=add_gridlines,
1274
            add_labels=add_labels,
1275
            fig_kwargs=fig_kwargs,
1276
            subplot_kwargs=subplot_kwargs,
1277
            **plot_kwargs,
1278
        )
1279
    else:
1280
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial object.")
×
1281
    # Return mappable
1282
    return p
1✔
1283

1284

1285
def plot_map_mesh_centroids(
1✔
1286
    xr_obj,
1287
    x=None,
1288
    y=None,
1289
    ax=None,
1290
    c="r",
1291
    s=1,
1292
    add_background=True,
1293
    add_gridlines=True,
1294
    add_labels=True,
1295
    fig_kwargs=None,
1296
    subplot_kwargs=None,
1297
    **plot_kwargs,
1298
):
1299
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
1300
    from gpm.checks import is_grid, is_orbit
1✔
1301

1302
    # Initialize figure if necessary
1303
    ax = initialize_cartopy_plot(
1✔
1304
        ax=ax,
1305
        fig_kwargs=fig_kwargs,
1306
        subplot_kwargs=subplot_kwargs,
1307
        add_background=add_background,
1308
        add_gridlines=add_gridlines,
1309
        add_labels=add_labels,
1310
        infer_crs=True,
1311
        xr_obj=xr_obj,
1312
    )
1313

1314
    # Retrieve orbits lon, lat coordinates
1315
    if is_orbit(xr_obj):
1✔
1316
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1317

1318
    # Retrieve grid centroids mesh
1319
    if is_grid(xr_obj):
1✔
1320
        x, y = infer_xy_labels(xr_obj, x=x, y=y)
1✔
1321
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
1322

1323
    # Extract numpy arrays
1324
    lon = xr_obj[x].to_numpy()
1✔
1325
    lat = xr_obj[y].to_numpy()
1✔
1326

1327
    # Plot centroids
1328
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
1329

1330
    # Return mappable
1331
    return p
1✔
1332

1333

1334
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
1335
    """Create a 2D mesh coordinates DataArray.
1336

1337
    Takes as input the 1D coordinate arrays from an existing xarray.DataArray or xarray.Dataset object.
1338

1339
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
1340
    the data values to NaN.
1341

1342
    Parameters
1343
    ----------
1344
    xr_obj : xarray.DataArray or xarray.Dataset
1345
        The input xarray object containing the 1D coordinate arrays.
1346
    x : str
1347
        The name of the x-coordinate in `xr_obj`.
1348
    y : str
1349
        The name of the y-coordinate in `xr_obj`.
1350

1351
    Returns
1352
    -------
1353
    da_mesh : xarray.DataArray
1354
        A 2D xarray.DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
1355

1356
    Notes
1357
    -----
1358
    The resulting xarray.DataArray has dimensions named 'y' and 'x', corresponding to the
1359
    y and x coordinates respectively.
1360
    The coordinate values are taken directly from the input 1D coordinate arrays,
1361
    and the data values are set to NaN.
1362

1363
    """
1364
    # Extract 1D coordinate arrays
1365
    x_coords = xr_obj[x].to_numpy()
1✔
1366
    y_coords = xr_obj[y].to_numpy()
1✔
1367

1368
    # Create 2D meshgrid for x and y coordinates
1369
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
1370

1371
    # Create a 2D array of NaN values with the same shape as the meshgrid
1372
    dummy_values = np.full(X.shape, np.nan)
1✔
1373

1374
    # Create a new DataArray with 2D coordinates and NaN values
1375
    return xr.DataArray(
1✔
1376
        dummy_values,
1377
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
1378
        dims=("y", "x"),
1379
    )
1380

1381

1382
####--------------------------------------------------------------------------.
1383

1384

1385
def _plot_labels(
1✔
1386
    xr_obj,
1387
    label_name=None,
1388
    max_n_labels=50,
1389
    add_colorbar=True,
1390
    interpolation="nearest",
1391
    cmap="Paired",
1392
    fig_kwargs=None,
1393
    **plot_kwargs,
1394
):
1395
    """Plot labels.
1396

1397
    The maximum allowed number of labels to plot is 'max_n_labels'.
1398
    """
1399
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1400
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1401

1402
    from gpm.visualization.plot import plot_image
1✔
1403

1404
    if isinstance(xr_obj, xr.Dataset):
1✔
1405
        dataarray = xr_obj[label_name]
1✔
1406
    else:
1407
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1408

1409
    dataarray = dataarray.compute()
1✔
1410
    label_indices = get_label_indices(dataarray)
1✔
1411
    n_labels = len(label_indices)
1✔
1412
    if add_colorbar and n_labels > max_n_labels:
1✔
1413
        msg = f"""The array currently contains {n_labels} labels
1✔
1414
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1415
        print(msg)
1✔
1416
        add_colorbar = False
1✔
1417
    # Relabel array from 1 to ... for plotting
1418
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1419
    # Replace 0 with nan
1420
    dataarray = dataarray.where(dataarray > 0)
1✔
1421
    # Define appropriate colormap
1422
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1423
    default_plot_kwargs.update(plot_kwargs)
1✔
1424
    # Plot image
1425
    return plot_image(
1✔
1426
        dataarray,
1427
        interpolation=interpolation,
1428
        add_colorbar=add_colorbar,
1429
        cbar_kwargs=cbar_kwargs,
1430
        fig_kwargs=fig_kwargs,
1431
        **default_plot_kwargs,
1432
    )
1433

1434

1435
def plot_labels(
1✔
1436
    obj,  # Dataset, DataArray or generator
1437
    label_name=None,
1438
    max_n_labels=50,
1439
    add_colorbar=True,
1440
    interpolation="nearest",
1441
    cmap="Paired",
1442
    fig_kwargs=None,
1443
    **plot_kwargs,
1444
):
1445
    if is_generator(obj):
1✔
1446
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1447
            p = _plot_labels(
1✔
1448
                xr_obj=xr_obj,
1449
                label_name=label_name,
1450
                max_n_labels=max_n_labels,
1451
                add_colorbar=add_colorbar,
1452
                interpolation=interpolation,
1453
                cmap=cmap,
1454
                fig_kwargs=fig_kwargs,
1455
                **plot_kwargs,
1456
            )
1457
            plt.show()
1✔
1458
    else:
1459
        p = _plot_labels(
1✔
1460
            xr_obj=obj,
1461
            label_name=label_name,
1462
            max_n_labels=max_n_labels,
1463
            add_colorbar=add_colorbar,
1464
            interpolation=interpolation,
1465
            cmap=cmap,
1466
            fig_kwargs=fig_kwargs,
1467
            **plot_kwargs,
1468
        )
1469
    return p
1✔
1470

1471

1472
def plot_patches(
1✔
1473
    patch_gen,
1474
    variable=None,
1475
    add_colorbar=True,
1476
    interpolation="nearest",
1477
    fig_kwargs=None,
1478
    cbar_kwargs=None,
1479
    **plot_kwargs,
1480
):
1481
    """Plot patches."""
1482
    from gpm.visualization.plot import plot_image
1✔
1483

1484
    # Plot patches
1485
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1486
        if isinstance(xr_patch, xr.Dataset):
1✔
1487
            if variable is None:
1✔
1488
                raise ValueError("'variable' must be specified when plotting xarray.Dataset patches.")
1✔
1489
            xr_patch = xr_patch[variable]
1✔
1490
        try:
1✔
1491
            plot_image(
1✔
1492
                xr_patch,
1493
                interpolation=interpolation,
1494
                add_colorbar=add_colorbar,
1495
                fig_kwargs=fig_kwargs,
1496
                cbar_kwargs=cbar_kwargs,
1497
                **plot_kwargs,
1498
            )
1499
            plt.show()
1✔
1500
        except Exception:
1✔
1501
            pass
1✔
1502

1503

1504
####--------------------------------------------------------------------------.
1505

1506

1507
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True, border_pad=0):
1✔
1508
    """Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1509

1510
    This function creates a smaller map inset within a larger map plot to show a global view or
1511
    contextual location of the main plot's extent.
1512

1513
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1514
    within the inset to provide geographical context.
1515

1516
    Parameters
1517
    ----------
1518
    ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes
1519
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1520
    loc : str, optional
1521
        The location of the inset map within the main plot.
1522
        Options include ``'lower left'``, ``'lower right'``,
1523
        ``'upper left'``, and ``'upper right'``. The default is ``'upper left'``.
1524
    inset_height : float, optional
1525
        The size of the inset height, specified as a fraction of the figure's height.
1526
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1527
        The aspect ratio (of the map inset) will govern the ``inset_width``.
1528
    inside_figure : bool, optional
1529
        Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
1530
        the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
1531
        allowing for a half-outside placement.
1532
    projection: cartopy.crs.Projection, optional
1533
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1534

1535
    Returns
1536
    -------
1537
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1538
        The Cartopy GeoAxesSubplot object for the inset map.
1539

1540
    Notes
1541
    -----
1542
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1543
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1544
    geographical extent.
1545

1546
    Examples
1547
    --------
1548
    >>> p = da.gpm.plot_map()
1549
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1550

1551
    This example creates a main plot with a specified extent and adds an upper-left inset map
1552
    showing the global context of the main plot's extent.
1553

1554
    """
1555
    from shapely import Polygon
1✔
1556

1557
    from gpm.utils.geospatial import extend_geographic_extent
1✔
1558

1559
    # Retrieve map extent and bounds
1560
    extent = ax.get_extent()
1✔
1561
    extent = extend_geographic_extent(extent, padding=0.5)
1✔
1562
    bounds = [extent[i] for i in [0, 2, 1, 3]]
1✔
1563

1564
    # Create Cartopy Polygon
1565
    polygon = Polygon.from_bounds(*bounds)
1✔
1566

1567
    # Define Orthographic projection
1568
    if projection is None:
1✔
1569
        lon_min, lon_max, lat_min, lat_max = extent
1✔
1570
        projection = ccrs.Orthographic(
1✔
1571
            central_latitude=(lat_min + lat_max) / 2,
1572
            central_longitude=(lon_min + lon_max) / 2,
1573
        )
1574

1575
    # Define aspect ratio of the map inset
1576
    aspect_ratio = float(np.diff(projection.x_limits).item() / np.diff(projection.y_limits).item())
1✔
1577

1578
    # Define inset location relative to main plot (ax) in normalized units
1579
    # - Lower-left corner of inset Axes, and its width and height
1580
    # - [x0, y0, width, height]
1581
    inset_bounds = get_inset_bounds(
1✔
1582
        ax=ax,
1583
        loc=loc,
1584
        inset_height=inset_height,
1585
        inside_figure=inside_figure,
1586
        aspect_ratio=aspect_ratio,
1587
        border_pad=border_pad,
1588
    )
1589

1590
    ax2 = ax.inset_axes(
1✔
1591
        inset_bounds,
1592
        projection=projection,
1593
    )
1594

1595
    # Add global map
1596
    ax2.set_global()
1✔
1597
    ax2.add_feature(cfeature.LAND)
1✔
1598
    ax2.add_feature(cfeature.OCEAN)
1✔
1599

1600
    # Add extent polygon
1601
    _ = ax2.add_geometries(
1✔
1602
        [polygon],
1603
        ccrs.PlateCarree(),
1604
        facecolor="none",
1605
        edgecolor="red",
1606
        linewidth=0.3,
1607
    )
1608
    return ax2
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc