• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 14245507561

03 Apr 2025 02:26PM UTC coverage: 91.18% (+0.01%) from 91.17%
14245507561

push

github

ghiggi
Fix GRID reprojection on-the-fly

12 of 12 new or added lines in 1 file covered. (100.0%)

1 existing line in 1 file now uncovered.

15692 of 17210 relevant lines covered (91.18%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.74
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
28
import inspect
1✔
29
import warnings
1✔
30

31
import cartopy
1✔
32
import cartopy.crs as ccrs
1✔
33
import cartopy.feature as cfeature
1✔
34
import matplotlib.pyplot as plt
1✔
35
import numpy as np
1✔
36
import xarray as xr
1✔
37
from pycolorbar import plot_colorbar, set_colorbar_fully_transparent
1✔
38
from pycolorbar.utils.mpl_legend import get_inset_bounds
1✔
39
from scipy.interpolate import griddata
1✔
40

41
import gpm
1✔
42
from gpm import get_plot_kwargs
1✔
43
from gpm.dataset.crs import compute_extent
1✔
44
from gpm.utils.area import get_lonlat_corners_from_centroids
1✔
45

46

47
def is_generator(obj):
1✔
48
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
49

50

51
def _call_optimize_layout(self):
1✔
52
    """Optimize the figure layout."""
53
    adapt_fig_size(ax=self.axes)
×
54
    self.figure.tight_layout()
×
55

56

57
def add_optimize_layout_method(p):
1✔
58
    """Add a method to optimize the figure layout using monkey patching."""
59
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
60
    return p
1✔
61

62

63
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
64
    """Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
65

66
    This function is intended to be called after all plotting has been completed.
67
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
68

69
    Assumes that the first axis in the collection of axes is representative of all others.
70
    This means that all subplots are expected to have the same aspect ratio and size.
71

72
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
73
    """
74
    # Determine the number of rows and columns of subplots in the figure.
75
    # This information is crucial for calculating the new height of the figure.
76
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
77

78
    # Access the figure object from the axis to manipulate its properties.
79
    fig = ax.get_figure()
1✔
80

81
    # Retrieve the current size of the figure in inches.
82
    width, original_height = fig.get_size_inches()
1✔
83

84
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
85
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
86
    fig.canvas.draw()
1✔
87

88
    # Extract subplot parameters to understand the figure's layout.
89
    # These parameters include the margins of the figure and the spaces between subplots.
90
    bottom = fig.subplotpars.bottom
1✔
91
    top = fig.subplotpars.top
1✔
92
    left = fig.subplotpars.left
1✔
93
    right = fig.subplotpars.right
1✔
94
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
95
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
96

97
    # Calculate the aspect ratio of the data in the subplot.
98
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
99
    aspect = ax.get_data_ratio()
1✔
100

101
    # Calculate the width of a single plot, considering the left and right margins,
102
    # the number of columns, and the space between columns.
103
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
104

105
    # Calculate the height of a single plot using its width and the data aspect ratio.
106
    hp = wp * aspect
1✔
107

108
    # Calculate the new height of the figure, taking into account the number of rows,
109
    # the space between rows, and the top and bottom margins.
110
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
111

112
    # Check if the new height is significantly reduced (more than halved).
113
    if original_height / height > 2:
1✔
114
        # Calculate the scale factor to adjust the figure size closer to the original.
115
        scale_factor = original_height / height / 2
×
116

117
        # Apply the scale factor to both width and height to maintain the aspect ratio.
118
        width *= scale_factor
×
119
        height *= scale_factor
×
120

121
    # Apply the calculated width and height to adjust the figure size.
122
    fig.set_figwidth(width)
1✔
123
    fig.set_figheight(height)
1✔
124

125

126
####--------------------------------------------------------------------------.
127

128

129
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
130
    """Infill invalid coordinates.
131

132
    Interpolate the coordinates within the convex hull of data.
133
    Use nearest neighbour outside the convex hull of data.
134
    """
135
    # Copy object
136
    xr_obj = xr_obj.copy()
1✔
137
    lon = np.asanyarray(xr_obj[x].data)
1✔
138
    lat = np.asanyarray(xr_obj[y].data)
1✔
139
    # Retrieve infilled coordinates
140
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
141
    xr_obj[x].data = lon
1✔
142
    xr_obj[y].data = lat
1✔
143
    return xr_obj
1✔
144

145

146
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
147
    """Infill invalid coordinates.
148

149
    Interpolate the coordinates within the convex hull of data.
150
    Use nearest neighbour outside the convex hull of data.
151

152
    This operation is required to plot with pcolormesh since it
153
    does not accept non-finite values in the coordinates.
154

155
    If  ``mask_data=True``, data values with invalid coordinates are masked
156
    and a numpy masked array is returned.
157
    Masked data values are not displayed in pcolormesh !
158
    If ``rgb=True``, it assumes the RGB dimension is the last data dimension.
159

160
    """
161
    # Retrieve mask of invalid coordinates
162
    x_invalid = ~np.isfinite(x)
1✔
163
    y_invalid = ~np.isfinite(y)
1✔
164
    mask = np.logical_or(x_invalid, y_invalid)
1✔
165

166
    # If no invalid coordinates, return original data
167
    if np.all(~mask):
1✔
168
        return x, y, data
1✔
169

170
    # Check at least ome valid coordinates
171
    if np.all(mask):
1✔
172
        raise ValueError("No valid coordinates.")
×
173

174
    # Mask the data
175
    if mask_data:
1✔
176
        if rgb:
1✔
177
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
1✔
178
            data_masked = np.ma.masked_where(data_mask, data)
1✔
179
        else:
180
            data_masked = np.ma.masked_where(mask, data)
1✔
181
    else:
182
        data_masked = data
×
183

184
    # Infill x and y
185
    # - Note: currently cause issue if NaN when crossing antimeridian ...
186
    # --> TODO: interpolation should be done in X,Y,Z
187
    if np.any(x_invalid):
1✔
188
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
189
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
190
    if np.any(y_invalid):
1✔
191
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
192
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
193
    return x, y, data_masked
1✔
194

195

196
def _interpolate_data(arr, method="linear"):
1✔
197
    # 1D coordinate (i.e. along_track/cross_track view)
198
    if arr.ndim == 1:
1✔
199
        return _interpolate_1d_coord(arr, method=method)
1✔
200
    # 2D coordinates (swath image)
201
    return _interpolate_2d_coord(arr, method=method)
1✔
202

203

204
def _interpolate_1d_coord(arr, method="linear"):
1✔
205
    # Find invalid locations
206
    is_invalid = ~np.isfinite(arr)
1✔
207

208
    # Find the indices of NaN values
209
    nan_indices = np.where(is_invalid)[0]
1✔
210

211
    # Return array if not NaN values
212
    if len(nan_indices) == 0:
1✔
213
        return arr
1✔
214

215
    # Find the indices of non-NaN values
216
    non_nan_indices = np.where(~is_invalid)
1✔
217

218
    # Create indices
219
    indices = np.arange(len(arr))
1✔
220

221
    # Points where we have valid data
222
    points = indices[non_nan_indices]
1✔
223

224
    # Points where data is NaN
225
    points_nan = indices[nan_indices]
1✔
226

227
    # Values at the non-NaN points
228
    values = arr[non_nan_indices]
1✔
229

230
    # Interpolate using griddata
231
    arr_new = arr.copy()
1✔
232
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
233
    return arr_new
1✔
234

235

236
def _interpolate_2d_coord(arr, method="linear"):
1✔
237
    # Find invalid locations
238
    is_invalid = ~np.isfinite(arr)
1✔
239

240
    # Find the indices of NaN values
241
    nan_indices = np.where(is_invalid)
1✔
242

243
    # Return array if not NaN values
244
    if len(nan_indices) == 0:
1✔
245
        return arr
×
246

247
    # Find the indices of non-NaN values
248
    non_nan_indices = np.where(~is_invalid)
1✔
249

250
    # Create a meshgrid of indices
251
    x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
252

253
    # Points (X, Y) where we have valid data
254
    points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
1✔
255

256
    # Points where data is NaN
257
    points_nan = np.array([y[nan_indices], x[nan_indices]]).T
1✔
258

259
    # Values at the non-NaN points
260
    values = arr[non_nan_indices]
1✔
261

262
    # Interpolate using griddata
263
    arr_new = arr.copy()
1✔
264
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
265
    return arr_new
1✔
266

267

268
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
269
    if np.ma.is_masked(arr):
1✔
270
        if rgb:
1✔
271
            antimeridian_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
1✔
272
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
273
        else:
274
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
275
        arr = np.ma.masked_where(combined_mask, arr)
1✔
276
    else:
277
        if rgb:
1✔
278
            antimeridian_mask = np.broadcast_to(
1✔
279
                np.expand_dims(antimeridian_mask, axis=-1),
280
                arr.shape,
281
            )
282
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
283
    return arr
1✔
284

285

286
def mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs):
1✔
287
    """Mask the array cells crossing the antimeridian.
288

289
    Here we assume not invalid lon coordinates anymore.
290
    Cartopy still bugs with several projections when data cross the antimeridian.
291
    By default, GPM-API mask data crossing the antimeridian.
292
    The GPM-API configuration default can be modified with: ``gpm.config.set({"viz_hide_antimeridian_data": False})``
293
    """
294
    antimeridian_mask = get_antimeridian_mask(lon)
1✔
295
    is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
296
    if is_crossing_antimeridian:
1✔
297
        # Sanitize cmap to avoid cartopy bug related to cmap bad color
298
        # - Cartopy requires the bad color to be fully transparent
299
        plot_kwargs = _sanitize_cartopy_plot_kwargs(plot_kwargs)
1✔
300
        # Mask data based on GPM-API config 'viz_hide_antimeridian_data'
301
        if gpm.config.get("viz_hide_antimeridian_data"):  # default is True
1✔
302
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
303
    return arr, plot_kwargs
1✔
304

305

306
def get_antimeridian_mask(lons):
1✔
307
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
308
    from scipy.ndimage import binary_dilation
1✔
309

310
    # Initialize mask
311
    n_y, n_x = lons.shape
1✔
312
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
313
    # Check vertical edges
314
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
315
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
316
    mask[row_idx, col_idx] = 1
1✔
317
    # Check horizontal edges
318
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
319
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
320
    mask[row_idx, col_idx] = 1
1✔
321
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
322
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
323
    return binary_dilation(mask)
1✔
324

325

326
####--------------------------------------------------------------------------.
327
########################
328
#### Plot utilities ####
329
########################
330

331

332
def preprocess_rgb_dataarray(da, rgb):
1✔
333
    if rgb:
1✔
334
        if rgb not in da.dims:
1✔
335
            raise ValueError(f"The specified rgb='{rgb}' must be a dimension of the DataArray.")
1✔
336
        if da[rgb].size not in [3, 4]:
1✔
337
            raise ValueError("The RGB dimension must have size 3 or 4.")
×
338
        da = da.transpose(..., rgb)
1✔
339
    return da
1✔
340

341

342
def check_object_format(da, plot_kwargs, check_function, **function_kwargs):
1✔
343
    """Check object format and valid dimension names."""
344
    # Preprocess RGB DataArrays
345
    da = da.squeeze()
1✔
346
    da = preprocess_rgb_dataarray(da, plot_kwargs.get("rgb", False))
1✔
347
    # Retrieve rgb or FacetGrid column/row dimensions
348
    dims_dict = {key: plot_kwargs.get(key) for key in ["rgb", "col", "row"] if plot_kwargs.get(key, None)}
1✔
349
    # Check such dimensions are available
350
    for key, dim in dims_dict.items():
1✔
351
        if dim not in da.dims:
1✔
352
            raise ValueError(f"The DataArray does not have a {key}='{dim}' dimension.")
1✔
353
    # Subset DataArray to check if complies with specific check function
354
    isel_dict = {dim: 0 for dim in dims_dict.values()}
1✔
355
    check_function(da.isel(isel_dict), **function_kwargs)
1✔
356
    return da
1✔
357

358

359
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None, is_facetgrid=False):
1✔
360
    if is_facetgrid and ax is not None:
1✔
361
        raise ValueError("When plotting with FacetGrid, do not specify the 'ax'.")
1✔
362
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
363
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
364
    if ax is not None:
1✔
365
        if len(subplot_kwargs) >= 1:
1✔
366
            raise ValueError("Provide `subplot_kwargs`only if ``ax``is None")
1✔
367
        if len(fig_kwargs) >= 1:
1✔
368
            raise ValueError("Provide `fig_kwargs` only if ``ax``is None")
1✔
369
    return fig_kwargs
1✔
370

371

372
def preprocess_subplot_kwargs(subplot_kwargs, infer_crs=False, xr_obj=None):
1✔
373
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
374
    subplot_kwargs = subplot_kwargs.copy()
1✔
375
    if "projection" not in subplot_kwargs:
1✔
376
        if infer_crs:
1✔
377
            subplot_kwargs["projection"] = xr_obj.gpm.cartopy_crs
1✔
378
        else:
379
            subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
380
    return subplot_kwargs
1✔
381

382

383
def infer_xy_labels(da, x=None, y=None, rgb=None):
1✔
384
    from xarray.plot.utils import _infer_xy_labels
1✔
385

386
    # Infer dimensions
387
    x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb)  # dummy flag for rgb
1✔
388
    return x, y
1✔
389

390

391
def infer_map_xy_coords(da, x=None, y=None):
1✔
392
    """
393
    Infer possible map x and y coordinates for the given DataArray.
394

395
    Parameters
396
    ----------
397
    da : xarray.DataArray
398
        The input DataArray.
399
    x : str, optional
400
        The name of the x (i.e. longitude) coordinate. If None, it will be inferred.
401
    y : str, optional
402
        The name of the y (i.e. latitude) coordinate. If None, it will be inferred.
403

404
    Returns
405
    -------
406
    tuple
407
        The inferred (x, y) coordinates.
408
    """
409
    possible_x_coords = ["x", "lon", "longitude"]
1✔
410
    possible_y_coords = ["y", "lat", "latitude"]
1✔
411

412
    if x is None:
1✔
413
        for coord in possible_x_coords:
1✔
414
            if coord in da.coords:
1✔
415
                x = coord
1✔
416
                break
1✔
417
        else:
418
            raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.")
×
419

420
    if y is None:
1✔
421
        for coord in possible_y_coords:
1✔
422
            if coord in da.coords:
1✔
423
                y = coord
1✔
424
                break
1✔
425
        else:
426
            raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.")
×
427

428
    return x, y
1✔
429

430

431
def _get_proj_str(crs):
1✔
432
    with warnings.catch_warnings():
1✔
433
        warnings.simplefilter("ignore")
1✔
434
        proj_str = crs.to_dict().get("proj", "")
1✔
435
    return proj_str
1✔
436

437

438
def initialize_cartopy_plot(
1✔
439
    ax,
440
    fig_kwargs,
441
    subplot_kwargs,
442
    add_background,
443
    add_gridlines,
444
    add_labels,
445
    infer_crs=False,
446
    xr_obj=None,
447
):
448
    """Initialize figure for cartopy plot if necessary."""
449
    # - Initialize figure
450
    if ax is None:
1✔
451
        fig_kwargs = preprocess_figure_args(
1✔
452
            ax=ax,
453
            fig_kwargs=fig_kwargs,
454
            subplot_kwargs=subplot_kwargs,
455
        )
456
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs, infer_crs=infer_crs, xr_obj=xr_obj)
1✔
457
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
458

459
    # - Add cartopy background
460
    if add_background:
1✔
461
        ax = plot_cartopy_background(ax)
1✔
462

463
    # - Add gridlines and labels
464
    if add_gridlines or add_labels:
1✔
465
        _ = plot_cartopy_gridlines_and_labels(ax, add_gridlines=add_gridlines, add_labels=add_labels)
1✔
466

467
    return ax
1✔
468

469

470
def plot_cartopy_gridlines_and_labels(ax, add_gridlines=True, add_labels=True):
1✔
471
    """Add cartopy gridlines and labels."""
472
    alpha = 0.1 if add_gridlines else 0
1✔
473
    gl = ax.gridlines(
1✔
474
        crs=ccrs.PlateCarree(),
475
        draw_labels=add_labels,
476
        linewidth=1,
477
        color="gray",
478
        alpha=alpha,
479
        linestyle="-",
480
    )
481
    gl.top_labels = False  # gl.xlabels_top = False
1✔
482
    gl.right_labels = False  # gl.ylabels_right = False
1✔
483
    gl.xlines = True
1✔
484
    gl.ylines = True
1✔
485
    return gl
1✔
486

487

488
def plot_cartopy_background(ax):
1✔
489
    """Plot cartopy background."""
490
    # - Add coastlines
491
    ax.coastlines()
1✔
492
    # - Add land and ocean
493
    # --> Raise error with some projections currently (shapely bug)
494
    # --> https://github.com/SciTools/cartopy/issues/2176
495
    if _get_proj_str(ax.projection) not in ["laea"]:
1✔
496
        ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
497
        ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
498
    # - Add borders
499
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
500
    return ax
1✔
501

502

503
def plot_sides(sides, ax, **plot_kwargs):
1✔
504
    """Plot boundary sides.
505

506
    Expects a list of (lon, lat) tuples.
507
    """
508
    for side in sides:
1✔
509
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
510
    return p[0]
1✔
511

512

513
####--------------------------------------------------------------------------.
514
##########################
515
#### Cartopy wrappers ####
516
##########################
517

518

519
def _sanitize_cartopy_plot_kwargs(plot_kwargs):
1✔
520
    """Sanitize 'cmap' to avoid cartopy bug related to cmap bad color.
521

522
    Cartopy requires the bad color to be fully transparent.
523
    """
524
    cmap = plot_kwargs.get("cmap", None)
1✔
525
    if cmap is not None:
1✔
526
        bad = cmap.get_bad()
1✔
527
        bad[3] = 0  # enforce to 0 (transparent)
1✔
528
        cmap.set_bad(bad)
1✔
529
        plot_kwargs["cmap"] = cmap
1✔
530
    return plot_kwargs
1✔
531

532

533
def is_same_crs(crs1, crs2):
1✔
534
    """Check if same CRS."""
535
    with warnings.catch_warnings():
1✔
536
        warnings.simplefilter("ignore")
1✔
537
        crs1_dict = crs1.to_dict()
1✔
538
        crs2_dict = crs2.to_dict()
1✔
539
    keys = ["proj", "lat_0", "lon_0", "x_0", "y_0", "units", "type", "lon_wrap", "over", "pm"]
1✔
540
    dict1 = {key: crs1_dict.get(key) for key in keys}
1✔
541
    dict2 = {key: crs2_dict.get(key) for key in keys}
1✔
542
    return dict1 == dict2
1✔
543

544

545
def plot_cartopy_imshow(
1✔
546
    ax,
547
    da,
548
    x,
549
    y,
550
    interpolation="nearest",
551
    add_colorbar=True,
552
    plot_kwargs=None,
553
    cbar_kwargs=None,
554
):
555
    """Plot imshow with cartopy."""
556
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
557

558
    # Infer x and y
559
    x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
560

561
    # Align x,y, data dimensions
562
    # - Ensure image with correct dimensions orders
563
    # - It can happen that x/y coords does not have same dimension order of data array.
564
    da = da.transpose(*da[y].dims, *da[x].dims, ...)
1✔
565

566
    # - Retrieve data
567
    arr = np.asanyarray(da.data)
1✔
568

569
    # - Compute coordinates
570
    x_coords = da[x].to_numpy()
1✔
571
    y_coords = da[y].to_numpy()
1✔
572

573
    # Compute extent
574
    extent = compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
575
    # area_extent = area_def.area_extent # [xmin, ymin, x_max, y_max]
576
    # extent = [area_extent[i] for i in [0, 2, 1, 3]] # [x_min, x_max, y_min, y_max]
577

578
    # Infer CRS of data, extent and cartopy projection
579
    try:
1✔
580
        crs = da.gpm.cartopy_crs
1✔
UNCOV
581
    except Exception:
×
582
        # Try assuming lon/lat CRS
583
        crs = ccrs.PlateCarree()
×
584

585
    # Determine image origin based on the orientation of da[y] values
586
    # - Cartopy assume origin is lower
587
    # - If y coordinate is increasing, set origin="lower"
588
    # - If y coordinate is decreasing, set origin="upper"
589
    #   --> Means that the image array is [::-1, :] reversed within cartopy
590
    y_increasing = y_coords[1] > y_coords[0]
1✔
591
    origin = "lower" if y_increasing else "upper"  # OLD CODE
1✔
592

593
    # Deal with decreasing y
594
    # if not y_increasing:  # decreasing y coordinates
595
    # extent = [extent[i] for i in [0, 1, 3, 2]]
596

597
    # Deal with out of limits x
598
    # - PlateeCarree coordinates out of bounds when  lons are defined as 0-360)
599
    set_extent = True
1✔
600

601
    # Case where coordinates are defined as 0-360 with pm=0
602
    if extent[1] > crs.x_limits[1] or extent[0] < crs.x_limits[0]:
1✔
603
        set_extent = False
×
604

605
    # Check if specify transform
606
    # - Specify transform argument only if data CRS is different from axes CRS
607
    # - If same crs, specifying transform is slower and might cuts away half of first and last row pixels
608
    # --> GPM-API automatically create the Cartopy GeoAxes with correct CRS
609
    transform = None if is_same_crs(crs, ax.projection) else crs
1✔
610

611
    # - Add variable field with cartopy
612
    rgb = plot_kwargs.pop("rgb", False)
1✔
613
    p = ax.imshow(
1✔
614
        arr,
615
        transform=transform,
616
        extent=extent,
617
        origin=origin,
618
        interpolation=interpolation,
619
        **plot_kwargs,
620
    )
621

622
    # - Set the extent
623
    # --> If some background is globally displayed, this zoom on the actual data region
624
    if set_extent:
1✔
625
        ax.set_extent(extent, crs=crs)
1✔
626

627
    # - Add colorbar
628
    if add_colorbar and not rgb:
1✔
629
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
630
    return p
1✔
631

632

633
def plot_cartopy_pcolormesh(
1✔
634
    ax,
635
    da,
636
    x,
637
    y,
638
    add_colorbar=True,
639
    add_swath_lines=True,
640
    plot_kwargs=None,
641
    cbar_kwargs=None,
642
):
643
    """Plot imshow with cartopy.
644

645
    x and y must represents longitude and latitudes.
646
    The function currently does not allow to zoom on regions across the antimeridian.
647
    The function mask scanning pixels which spans across the antimeridian.
648
    If the DataArray has a RGB dimension, plot_kwargs should contain the ``rgb``
649
    key with the name of the RGB dimension.
650

651
    """
652
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
653

654
    # Remove RGB from plot_kwargs
655
    rgb = plot_kwargs.pop("rgb", False)
1✔
656

657
    # Align x,y, data dimensions
658
    # - Ensure image with correct dimensions orders
659
    # - It can happen that x/y coords does not have same dimension order of data array.
660
    da = da.transpose(*da[y].dims, ...)
1✔
661

662
    # Get x, y, and array to plot
663
    da = preprocess_rgb_dataarray(da, rgb=rgb)
1✔
664
    da = da.compute()
1✔
665
    lon = da[x].data.copy()
1✔
666
    lat = da[y].data.copy()
1✔
667
    arr = da.data
1✔
668

669
    # Check if 1D coordinate (orbit nadir-view / transect / cross-section case)
670
    is_1d_case = lon.ndim == 1
1✔
671

672
    # Infill invalid value and mask data at invalid coordinates
673
    # - No invalid values after this function call
674
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb, mask_data=True)
1✔
675
    if is_1d_case:
1✔
676
        arr = np.expand_dims(arr, axis=1)
1✔
677

678
    # Ensure arguments
679
    if rgb:
1✔
680
        add_colorbar = False
1✔
681

682
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
683
    # - This enable correct masking of cells crossing the antimeridian
684
    lon, lat = get_lonlat_corners_from_centroids(lon, lat, parallel=False)
1✔
685

686
    # Mask cells crossing the antimeridian
687
    # - with gpm.config.set({"viz_hide_antimeridian_data": False}): can be used to modify the masking behaviour
688
    arr, plot_kwargs = mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs)
1✔
689

690
    # Add variable field with cartopy
691
    _ = plot_kwargs.setdefault("shading", "flat")
1✔
692
    p = ax.pcolormesh(
1✔
693
        lon,
694
        lat,
695
        arr,
696
        transform=ccrs.PlateCarree(),
697
        **plot_kwargs,
698
    )
699
    # Add swath lines
700
    # - TODO: currently assume that dimensions are (cross_track, along_track)
701
    if add_swath_lines and not is_1d_case:
1✔
702
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
703
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
704

705
    # Add colorbar
706
    if add_colorbar:
1✔
707
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
708
    return p
1✔
709

710

711
####-------------------------------------------------------------------------------.
712
#########################
713
#### Xarray wrappers ####
714
#########################
715

716

717
def _preprocess_xr_kwargs(add_colorbar, plot_kwargs, cbar_kwargs):
1✔
718
    if not add_colorbar:
1✔
719
        cbar_kwargs = None
1✔
720

721
    if "rgb" in plot_kwargs:
1✔
722
        cbar_kwargs = None
1✔
723
        add_colorbar = False
1✔
724
        args_to_keep = ["rgb", "col", "row", "origin"]  # alpha currently skipped if RGB
1✔
725
        plot_kwargs = {k: plot_kwargs[k] for k in args_to_keep if plot_kwargs.get(k, None) is not None}
1✔
726
    return add_colorbar, plot_kwargs, cbar_kwargs
1✔
727

728

729
def plot_xr_pcolormesh(
1✔
730
    ax,
731
    da,
732
    x,
733
    y,
734
    add_colorbar=True,
735
    cbar_kwargs=None,
736
    **plot_kwargs,
737
):
738
    """Plot pcolormesh with xarray."""
739
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
740
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
741
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
742
        add_colorbar=add_colorbar,
743
        plot_kwargs=plot_kwargs,
744
        cbar_kwargs=cbar_kwargs,
745
    )
746
    p = da.plot.pcolormesh(
1✔
747
        x=x,
748
        y=y,
749
        ax=ax,
750
        add_colorbar=add_colorbar,
751
        cbar_kwargs=cbar_kwargs,
752
        **plot_kwargs,
753
    )
754

755
    # Add variable name as title (if not FacetGrid)
756
    if not is_facetgrid:
1✔
757
        p.axes.set_title(da.name)
1✔
758

759
    if add_colorbar and ticklabels is not None:
1✔
760
        p.colorbar.ax.set_yticklabels(ticklabels)
×
761
    return p
1✔
762

763

764
def plot_xr_imshow(
1✔
765
    ax,
766
    da,
767
    x,
768
    y,
769
    interpolation="nearest",
770
    add_colorbar=True,
771
    add_labels=True,
772
    cbar_kwargs=None,
773
    visible_colorbar=True,
774
    **plot_kwargs,
775
):
776
    """Plot imshow with xarray.
777

778
    The colorbar is added with xarray to enable to display multiple colorbars
779
    when calling this function multiple times on different fields with
780
    different colorbars.
781
    """
782
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
783
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
784
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
785
        add_colorbar=add_colorbar,
786
        plot_kwargs=plot_kwargs,
787
        cbar_kwargs=cbar_kwargs,
788
    )
789
    # Allow using coords as x/y axis
790
    # BUG - Current bug in xarray
791
    if plot_kwargs.get("rgb", None) is not None:
1✔
792
        if x not in da.dims:
1✔
793
            da = da.swap_dims({list(da[x].dims)[0]: x})
×
794
        if y not in da.dims:
1✔
795
            da = da.swap_dims({list(da[y].dims)[0]: y})
×
796

797
    p = da.plot.imshow(
1✔
798
        x=x,
799
        y=y,
800
        ax=ax,
801
        interpolation=interpolation,
802
        add_colorbar=add_colorbar,
803
        add_labels=add_labels,
804
        cbar_kwargs=cbar_kwargs,
805
        **plot_kwargs,
806
    )
807

808
    # Add variable name as title (if not FacetGrid)
809
    if not is_facetgrid:
1✔
810
        p.axes.set_title(da.name)
1✔
811

812
    # Add colorbar ticklabels
813
    if add_colorbar and ticklabels is not None:
1✔
814
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
815

816
    # Make the colorbar fully transparent with a smart trick ;)
817
    # - TODO: this still cause issues when plotting 2 colorbars !
818
    if add_colorbar and not visible_colorbar:
1✔
819
        set_colorbar_fully_transparent(p)
1✔
820

821
    # Add manually the colorbar
822
    # p = da.plot.imshow(
823
    #     x=x,
824
    #     y=y,
825
    #     ax=ax,
826
    #     interpolation=interpolation,
827
    #     add_colorbar=False,
828
    #     **plot_kwargs,
829
    # )
830
    # plt.title(da.name)
831
    # if add_colorbar:
832
    #     _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
833
    return p
1✔
834

835

836
####--------------------------------------------------------------------------.
837
####################
838
#### Plot Image ####
839
####################
840

841

842
def _plot_image(
1✔
843
    da,
844
    x=None,
845
    y=None,
846
    ax=None,
847
    add_colorbar=True,
848
    add_labels=True,
849
    interpolation="nearest",
850
    fig_kwargs=None,
851
    cbar_kwargs=None,
852
    **plot_kwargs,
853
):
854
    """Plot GPM orbit granule as in image."""
855
    from gpm.checks import is_grid, is_orbit
1✔
856
    from gpm.visualization.facetgrid import sanitize_facetgrid_plot_kwargs
1✔
857

858
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)
1✔
859

860
    # - Initialize figure
861
    if ax is None:
1✔
862
        _, ax = plt.subplots(**fig_kwargs)
1✔
863

864
    # - Sanitize plot_kwargs set by by xarray FacetGrid.map_dataarray
865
    is_facetgrid = plot_kwargs.get("_is_facetgrid", False)
1✔
866
    plot_kwargs = sanitize_facetgrid_plot_kwargs(plot_kwargs)
1✔
867

868
    # - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
869
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
870
        name=da.name,
871
        user_plot_kwargs=plot_kwargs,
872
        user_cbar_kwargs=cbar_kwargs,
873
    )
874

875
    # Define x and y
876
    x, y = infer_xy_labels(da=da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
877

878
    # - Plot with xarray
879
    p = plot_xr_imshow(
1✔
880
        ax=ax,
881
        da=da,
882
        x=x,
883
        y=y,
884
        interpolation=interpolation,
885
        add_colorbar=add_colorbar,
886
        add_labels=add_labels,
887
        cbar_kwargs=cbar_kwargs,
888
        **plot_kwargs,
889
    )
890

891
    # Add custom labels
892
    default_labels = {
1✔
893
        "orbit": {"along_track": "Along-Track", "x": "Along-Track", "cross_track": "Cross-Track", "y": "Cross-Track"},
894
        "grid": {
895
            "lon": "Longitude",
896
            "longitude": "Longitude",
897
            "x": "Longitude",
898
            "lat": "Latitude",
899
            "latitude": "Latitude",
900
            "y": "Latitude",
901
        },
902
    }
903

904
    if add_labels:
1✔
905
        if is_orbit(da):
1✔
906
            ax.set_xlabel(default_labels["orbit"].get(x, x))
1✔
907
            ax.set_ylabel(default_labels["orbit"].get(y, y))
1✔
908
        elif is_grid(da):
1✔
909
            ax.set_xlabel(default_labels["grid"].get(x, x))
1✔
910
            ax.set_ylabel(default_labels["grid"].get(y, y))
1✔
911

912
    # - Monkey patch the mappable instance to add optimize_layout
913
    if not is_facetgrid:
1✔
914
        p = add_optimize_layout_method(p)
1✔
915
    # - Return mappable
916
    return p
1✔
917

918

919
def _plot_image_facetgrid(
1✔
920
    da,
921
    x=None,
922
    y=None,
923
    ax=None,
924
    add_colorbar=True,
925
    add_labels=True,
926
    interpolation="nearest",
927
    fig_kwargs=None,
928
    cbar_kwargs=None,
929
    **plot_kwargs,
930
):
931
    """Plot 2D fields with FacetGrid."""
932
    from gpm.visualization.facetgrid import ImageFacetGrid
1✔
933

934
    # Check inputs
935
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, is_facetgrid=True)
1✔
936

937
    # Retrieve GPM-API defaults cmap and cbar kwargs
938
    variable = da.name
1✔
939
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
940
        name=variable,
941
        user_plot_kwargs=plot_kwargs,
942
        user_cbar_kwargs=cbar_kwargs,
943
    )
944

945
    # Disable colorbar if rgb
946
    # - Move this to pycolorbar !
947
    # - Also remove cmap, norm, vmin and vmax in plot_kwargs
948
    if plot_kwargs.get("rgb", False):
1✔
949
        add_colorbar = False
1✔
950
        cbar_kwargs = {}
1✔
951

952
    # Create FacetGrid
953
    fc = ImageFacetGrid(
1✔
954
        data=da.compute(),
955
        col=plot_kwargs.pop("col", None),
956
        row=plot_kwargs.pop("row", None),
957
        col_wrap=plot_kwargs.pop("col_wrap", None),
958
        axes_pad=plot_kwargs.pop("axes_pad", None),
959
        fig_kwargs=fig_kwargs,
960
        cbar_kwargs=cbar_kwargs,
961
        add_colorbar=add_colorbar,
962
        aspect=plot_kwargs.pop("aspect", False),
963
        facet_height=plot_kwargs.pop("facet_height", 3),
964
        facet_aspect=plot_kwargs.pop("facet_aspect", 1),
965
    )
966

967
    # Plot the maps
968
    fc = fc.map_dataarray(
1✔
969
        _plot_image,
970
        x=x,
971
        y=y,
972
        add_colorbar=False,
973
        add_labels=add_labels,
974
        interpolation=interpolation,
975
        cbar_kwargs=cbar_kwargs,
976
        **plot_kwargs,
977
    )
978

979
    # Remove duplicated or all labels
980
    fc.remove_duplicated_axis_labels()
1✔
981

982
    if not add_labels:
1✔
983
        fc.remove_left_ticks_and_labels()
×
984
        fc.remove_bottom_ticks_and_labels()
×
985

986
    # Add colorbar
987
    if add_colorbar:
1✔
988
        fc.add_colorbar(**cbar_kwargs)
1✔
989

990
    return fc
1✔
991

992

993
def plot_image(
1✔
994
    da,
995
    x=None,
996
    y=None,
997
    ax=None,
998
    add_colorbar=True,
999
    add_labels=True,
1000
    interpolation="nearest",
1001
    fig_kwargs=None,
1002
    cbar_kwargs=None,
1003
    **plot_kwargs,
1004
):
1005
    """Plot data using imshow.
1006

1007
    Parameters
1008
    ----------
1009
    da : xarray.DataArray
1010
        xarray DataArray.
1011
    x : str, optional
1012
        X dimension name.
1013
        If ``None``, takes the second dimension.
1014
        The default is ``None``.
1015
    y : str, optional
1016
        Y dimension name.
1017
        If ``None``, takes the first dimension.
1018
        The default is ``None``.
1019
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1020
        The matplotlib axes where to plot the image.
1021
        If ``None``, a figure is initialized using the
1022
        specified ``fig_kwargs``.
1023
        The default is ``None``.
1024
    add_colorbar : bool, optional
1025
        Whether to add a colorbar. The default is ``True``.
1026
    add_labels : bool, optional
1027
        Whether to add labels to the plot. The default is ``True``.
1028
    interpolation : str, optional
1029
        Argument to be passed to imshow.
1030
        The default is ``"nearest"``.
1031
    fig_kwargs : dict, optional
1032
        Figure options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1033
        The default is ``None``.
1034
        Only used if ``ax`` is ``None``.
1035
    subplot_kwargs : dict, optional
1036
        Subplot options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1037
        The default is ``None``.
1038
        Only used if ```ax``` is ``None``.
1039
    cbar_kwargs : dict, optional
1040
        Colorbar options. The default is ``None``.
1041
    **plot_kwargs
1042
        Additional arguments to be passed to the plotting function.
1043
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1044
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1045
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1046

1047

1048
    """
1049
    from gpm.checks import check_is_spatial_2d, is_spatial_2d
1✔
1050

1051
    # Plot orbit
1052
    if not is_spatial_2d(da, strict=False):
1✔
1053
        raise ValueError("Can not plot. It's not a spatial 2D object.")
1✔
1054

1055
    # Check inputs
1056
    da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_is_spatial_2d, strict=True)
1✔
1057

1058
    # Plot FacetGrid with xarray imshow
1059
    if "col" in plot_kwargs or "row" in plot_kwargs:
1✔
1060
        p = _plot_image_facetgrid(
1✔
1061
            da=da,
1062
            x=x,
1063
            y=y,
1064
            ax=ax,
1065
            add_colorbar=add_colorbar,
1066
            add_labels=add_labels,
1067
            interpolation=interpolation,
1068
            fig_kwargs=fig_kwargs,
1069
            cbar_kwargs=cbar_kwargs,
1070
            **plot_kwargs,
1071
        )
1072
    # Plot with xarray imshow
1073
    else:
1074
        p = _plot_image(
1✔
1075
            da=da,
1076
            x=x,
1077
            y=y,
1078
            ax=ax,
1079
            add_colorbar=add_colorbar,
1080
            add_labels=add_labels,
1081
            interpolation=interpolation,
1082
            fig_kwargs=fig_kwargs,
1083
            cbar_kwargs=cbar_kwargs,
1084
            **plot_kwargs,
1085
        )
1086
    # Return mappable
1087
    return p
1✔
1088

1089

1090
####--------------------------------------------------------------------------.
1091
##################
1092
#### Plot map ####
1093
##################
1094

1095

1096
def plot_map(
1✔
1097
    da,
1098
    x=None,
1099
    y=None,
1100
    ax=None,
1101
    interpolation="nearest",  # used only for GPM grid objects
1102
    add_colorbar=True,
1103
    add_background=True,
1104
    add_labels=True,
1105
    add_gridlines=True,
1106
    add_swath_lines=True,  # used only for GPM orbit objects
1107
    fig_kwargs=None,
1108
    subplot_kwargs=None,
1109
    cbar_kwargs=None,
1110
    **plot_kwargs,
1111
):
1112
    """Plot data on a geographic map.
1113

1114
    Parameters
1115
    ----------
1116
    da : xarray.DataArray
1117
        xarray DataArray.
1118
    x : str, optional
1119
        Longitude coordinate name.
1120
        If ``None``, takes the second dimension.
1121
        The default is ``None``.
1122
    y : str, optional
1123
        Latitude coordinate name.
1124
        If ``None``, takes the first dimension.
1125
        The default is ``None``.
1126
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1127
        The cartopy GeoAxes where to plot the map.
1128
        If ``None``, a figure is initialized using the
1129
        specified ``fig_kwargs`` and ``subplot_kwargs``.
1130
        The default is ``None``.
1131
    add_colorbar : bool, optional
1132
        Whether to add a colorbar. The default is ``True``.
1133
    add_labels : bool, optional
1134
        Whether to add cartopy labels to the plot. The default is ``True``.
1135
    add_gridlines : bool, optional
1136
        Whether to add cartopy gridlines to the plot. The default is ``True``.
1137
    add_swath_lines : bool, optional
1138
        Whether to plot the swath sides with a dashed line. The default is ``True``.
1139
        This argument only applies for ORBIT objects.
1140
    add_background : bool, optional
1141
        Whether to add the map background. The default is ``True``.
1142
    interpolation : str, optional
1143
        Argument to be passed to :py:class:`matplotlib.axes.Axes.imshow`. Only applies for GRID objects.
1144
        The default is ``"nearest"``.
1145
    fig_kwargs : dict, optional
1146
        Figure options to be passed to `matplotlib.pyplot.subplots`.
1147
        The default is ``None``.
1148
        Only used if ``ax`` is ``None``.
1149
    subplot_kwargs : dict, optional
1150
        Dictionary of keyword arguments for :py:class:`matplotlib.pyplot.subplots`.
1151
        Must contain the Cartopy CRS ` ``projection`` key if specified.
1152
        The default is ``None``.
1153
        Only used if ``ax`` is ``None``.
1154
    cbar_kwargs : dict, optional
1155
        Colorbar options. The default is ``None``.
1156
    **plot_kwargs
1157
        Additional arguments to be passed to the plotting function.
1158
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1159
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1160
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1161

1162

1163
    """
1164
    from gpm.checks import has_spatial_dim, is_grid, is_orbit, is_spatial_2d
1✔
1165
    from gpm.visualization.grid import plot_grid_map
1✔
1166
    from gpm.visualization.orbit import plot_orbit_map
1✔
1167

1168
    # Plot orbit
1169
    # - allow vertical or other dimensions for FacetGrid
1170
    # - allow to plot a swath of size 1 (i.e. nadir-looking)
1171
    if is_orbit(da) and has_spatial_dim(da):
1✔
1172
        p = plot_orbit_map(
1✔
1173
            da=da,
1174
            x=x,
1175
            y=y,
1176
            ax=ax,
1177
            add_colorbar=add_colorbar,
1178
            add_background=add_background,
1179
            add_gridlines=add_gridlines,
1180
            add_labels=add_labels,
1181
            add_swath_lines=add_swath_lines,
1182
            fig_kwargs=fig_kwargs,
1183
            subplot_kwargs=subplot_kwargs,
1184
            cbar_kwargs=cbar_kwargs,
1185
            **plot_kwargs,
1186
        )
1187
    # Plot grid
1188
    elif is_grid(da) and is_spatial_2d(da, strict=False):
1✔
1189
        p = plot_grid_map(
1✔
1190
            da=da,
1191
            x=x,
1192
            y=y,
1193
            ax=ax,
1194
            interpolation=interpolation,
1195
            add_colorbar=add_colorbar,
1196
            add_background=add_background,
1197
            add_gridlines=add_gridlines,
1198
            add_labels=add_labels,
1199
            fig_kwargs=fig_kwargs,
1200
            subplot_kwargs=subplot_kwargs,
1201
            cbar_kwargs=cbar_kwargs,
1202
            **plot_kwargs,
1203
        )
1204
    else:
1205
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial 2D object.")
1✔
1206
    # Return mappable
1207
    return p
1✔
1208

1209

1210
def plot_map_mesh(
1✔
1211
    xr_obj,
1212
    x=None,
1213
    y=None,
1214
    ax=None,
1215
    edgecolors="k",
1216
    linewidth=0.1,
1217
    add_background=True,
1218
    add_gridlines=True,
1219
    add_labels=True,
1220
    fig_kwargs=None,
1221
    subplot_kwargs=None,
1222
    **plot_kwargs,
1223
):
1224
    from gpm.checks import is_grid, is_orbit
1✔
1225
    from gpm.visualization.grid import plot_grid_mesh
1✔
1226
    from gpm.visualization.orbit import plot_orbit_mesh
1✔
1227

1228
    # Plot orbit
1229
    if is_orbit(xr_obj):
1✔
1230
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1231
        p = plot_orbit_mesh(
1✔
1232
            da=xr_obj[y],
1233
            ax=ax,
1234
            x=x,
1235
            y=y,
1236
            edgecolors=edgecolors,
1237
            linewidth=linewidth,
1238
            add_background=add_background,
1239
            add_gridlines=add_gridlines,
1240
            add_labels=add_labels,
1241
            fig_kwargs=fig_kwargs,
1242
            subplot_kwargs=subplot_kwargs,
1243
            **plot_kwargs,
1244
        )
1245
    elif is_grid(xr_obj):
1✔
1246
        p = plot_grid_mesh(
1✔
1247
            xr_obj=xr_obj,
1248
            x=x,
1249
            y=y,
1250
            ax=ax,
1251
            edgecolors=edgecolors,
1252
            linewidth=linewidth,
1253
            add_background=add_background,
1254
            add_gridlines=add_gridlines,
1255
            add_labels=add_labels,
1256
            fig_kwargs=fig_kwargs,
1257
            subplot_kwargs=subplot_kwargs,
1258
            **plot_kwargs,
1259
        )
1260
    else:
1261
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial object.")
×
1262
    # Return mappable
1263
    return p
1✔
1264

1265

1266
def plot_map_mesh_centroids(
1✔
1267
    xr_obj,
1268
    x=None,
1269
    y=None,
1270
    ax=None,
1271
    c="r",
1272
    s=1,
1273
    add_background=True,
1274
    add_gridlines=True,
1275
    add_labels=True,
1276
    fig_kwargs=None,
1277
    subplot_kwargs=None,
1278
    **plot_kwargs,
1279
):
1280
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
1281
    from gpm.checks import is_grid, is_orbit
1✔
1282

1283
    # Initialize figure if necessary
1284
    ax = initialize_cartopy_plot(
1✔
1285
        ax=ax,
1286
        fig_kwargs=fig_kwargs,
1287
        subplot_kwargs=subplot_kwargs,
1288
        add_background=add_background,
1289
        add_gridlines=add_gridlines,
1290
        add_labels=add_labels,
1291
        infer_crs=True,
1292
        xr_obj=xr_obj,
1293
    )
1294

1295
    # Retrieve orbits lon, lat coordinates
1296
    if is_orbit(xr_obj):
1✔
1297
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1298

1299
    # Retrieve grid centroids mesh
1300
    if is_grid(xr_obj):
1✔
1301
        x, y = infer_xy_labels(xr_obj, x=x, y=y)
1✔
1302
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
1303

1304
    # Extract numpy arrays
1305
    lon = xr_obj[x].to_numpy()
1✔
1306
    lat = xr_obj[y].to_numpy()
1✔
1307

1308
    # Plot centroids
1309
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
1310

1311
    # Return mappable
1312
    return p
1✔
1313

1314

1315
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
1316
    """Create a 2D mesh coordinates DataArray.
1317

1318
    Takes as input the 1D coordinate arrays from an existing xarray.DataArray or xarray.Dataset object.
1319

1320
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
1321
    the data values to NaN.
1322

1323
    Parameters
1324
    ----------
1325
    xr_obj : xarray.DataArray or xarray.Dataset
1326
        The input xarray object containing the 1D coordinate arrays.
1327
    x : str
1328
        The name of the x-coordinate in `xr_obj`.
1329
    y : str
1330
        The name of the y-coordinate in `xr_obj`.
1331

1332
    Returns
1333
    -------
1334
    da_mesh : xarray.DataArray
1335
        A 2D xarray.DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
1336

1337
    Notes
1338
    -----
1339
    The resulting xarray.DataArray has dimensions named 'y' and 'x', corresponding to the
1340
    y and x coordinates respectively.
1341
    The coordinate values are taken directly from the input 1D coordinate arrays,
1342
    and the data values are set to NaN.
1343

1344
    """
1345
    # Extract 1D coordinate arrays
1346
    x_coords = xr_obj[x].to_numpy()
1✔
1347
    y_coords = xr_obj[y].to_numpy()
1✔
1348

1349
    # Create 2D meshgrid for x and y coordinates
1350
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
1351

1352
    # Create a 2D array of NaN values with the same shape as the meshgrid
1353
    dummy_values = np.full(X.shape, np.nan)
1✔
1354

1355
    # Create a new DataArray with 2D coordinates and NaN values
1356
    return xr.DataArray(
1✔
1357
        dummy_values,
1358
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
1359
        dims=("y", "x"),
1360
    )
1361

1362

1363
####--------------------------------------------------------------------------.
1364

1365

1366
def _plot_labels(
1✔
1367
    xr_obj,
1368
    label_name=None,
1369
    max_n_labels=50,
1370
    add_colorbar=True,
1371
    interpolation="nearest",
1372
    cmap="Paired",
1373
    fig_kwargs=None,
1374
    **plot_kwargs,
1375
):
1376
    """Plot labels.
1377

1378
    The maximum allowed number of labels to plot is 'max_n_labels'.
1379
    """
1380
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1381
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1382

1383
    from gpm.visualization.plot import plot_image
1✔
1384

1385
    if isinstance(xr_obj, xr.Dataset):
1✔
1386
        dataarray = xr_obj[label_name]
1✔
1387
    else:
1388
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1389

1390
    dataarray = dataarray.compute()
1✔
1391
    label_indices = get_label_indices(dataarray)
1✔
1392
    n_labels = len(label_indices)
1✔
1393
    if add_colorbar and n_labels > max_n_labels:
1✔
1394
        msg = f"""The array currently contains {n_labels} labels
1✔
1395
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1396
        print(msg)
1✔
1397
        add_colorbar = False
1✔
1398
    # Relabel array from 1 to ... for plotting
1399
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1400
    # Replace 0 with nan
1401
    dataarray = dataarray.where(dataarray > 0)
1✔
1402
    # Define appropriate colormap
1403
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1404
    default_plot_kwargs.update(plot_kwargs)
1✔
1405
    # Plot image
1406
    return plot_image(
1✔
1407
        dataarray,
1408
        interpolation=interpolation,
1409
        add_colorbar=add_colorbar,
1410
        cbar_kwargs=cbar_kwargs,
1411
        fig_kwargs=fig_kwargs,
1412
        **default_plot_kwargs,
1413
    )
1414

1415

1416
def plot_labels(
1✔
1417
    obj,  # Dataset, DataArray or generator
1418
    label_name=None,
1419
    max_n_labels=50,
1420
    add_colorbar=True,
1421
    interpolation="nearest",
1422
    cmap="Paired",
1423
    fig_kwargs=None,
1424
    **plot_kwargs,
1425
):
1426
    if is_generator(obj):
1✔
1427
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1428
            p = _plot_labels(
1✔
1429
                xr_obj=xr_obj,
1430
                label_name=label_name,
1431
                max_n_labels=max_n_labels,
1432
                add_colorbar=add_colorbar,
1433
                interpolation=interpolation,
1434
                cmap=cmap,
1435
                fig_kwargs=fig_kwargs,
1436
                **plot_kwargs,
1437
            )
1438
            plt.show()
1✔
1439
    else:
1440
        p = _plot_labels(
1✔
1441
            xr_obj=obj,
1442
            label_name=label_name,
1443
            max_n_labels=max_n_labels,
1444
            add_colorbar=add_colorbar,
1445
            interpolation=interpolation,
1446
            cmap=cmap,
1447
            fig_kwargs=fig_kwargs,
1448
            **plot_kwargs,
1449
        )
1450
    return p
1✔
1451

1452

1453
def plot_patches(
1✔
1454
    patch_gen,
1455
    variable=None,
1456
    add_colorbar=True,
1457
    interpolation="nearest",
1458
    fig_kwargs=None,
1459
    cbar_kwargs=None,
1460
    **plot_kwargs,
1461
):
1462
    """Plot patches."""
1463
    from gpm.visualization.plot import plot_image
1✔
1464

1465
    # Plot patches
1466
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1467
        if isinstance(xr_patch, xr.Dataset):
1✔
1468
            if variable is None:
1✔
1469
                raise ValueError("'variable' must be specified when plotting xarray.Dataset patches.")
1✔
1470
            xr_patch = xr_patch[variable]
1✔
1471
        try:
1✔
1472
            plot_image(
1✔
1473
                xr_patch,
1474
                interpolation=interpolation,
1475
                add_colorbar=add_colorbar,
1476
                fig_kwargs=fig_kwargs,
1477
                cbar_kwargs=cbar_kwargs,
1478
                **plot_kwargs,
1479
            )
1480
            plt.show()
1✔
1481
        except Exception:
1✔
1482
            pass
1✔
1483

1484

1485
####--------------------------------------------------------------------------.
1486

1487

1488
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True, border_pad=0):
1✔
1489
    """Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1490

1491
    This function creates a smaller map inset within a larger map plot to show a global view or
1492
    contextual location of the main plot's extent.
1493

1494
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1495
    within the inset to provide geographical context.
1496

1497
    Parameters
1498
    ----------
1499
    ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes
1500
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1501
    loc : str, optional
1502
        The location of the inset map within the main plot.
1503
        Options include ``'lower left'``, ``'lower right'``,
1504
        ``'upper left'``, and ``'upper right'``. The default is ``'upper left'``.
1505
    inset_height : float, optional
1506
        The size of the inset height, specified as a fraction of the figure's height.
1507
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1508
        The aspect ratio (of the map inset) will govern the ``inset_width``.
1509
    inside_figure : bool, optional
1510
        Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
1511
        the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
1512
        allowing for a half-outside placement.
1513
    projection: cartopy.crs.Projection, optional
1514
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1515

1516
    Returns
1517
    -------
1518
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1519
        The Cartopy GeoAxesSubplot object for the inset map.
1520

1521
    Notes
1522
    -----
1523
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1524
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1525
    geographical extent.
1526

1527
    Examples
1528
    --------
1529
    >>> p = da.gpm.plot_map()
1530
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1531

1532
    This example creates a main plot with a specified extent and adds an upper-left inset map
1533
    showing the global context of the main plot's extent.
1534

1535
    """
1536
    from shapely import Polygon
1✔
1537

1538
    from gpm.utils.geospatial import extend_geographic_extent
1✔
1539

1540
    # Retrieve map extent and bounds
1541
    extent = ax.get_extent()
1✔
1542
    extent = extend_geographic_extent(extent, padding=0.5)
1✔
1543
    bounds = [extent[i] for i in [0, 2, 1, 3]]
1✔
1544

1545
    # Create Cartopy Polygon
1546
    polygon = Polygon.from_bounds(*bounds)
1✔
1547

1548
    # Define Orthographic projection
1549
    if projection is None:
1✔
1550
        lon_min, lon_max, lat_min, lat_max = extent
1✔
1551
        projection = ccrs.Orthographic(
1✔
1552
            central_latitude=(lat_min + lat_max) / 2,
1553
            central_longitude=(lon_min + lon_max) / 2,
1554
        )
1555

1556
    # Define aspect ratio of the map inset
1557
    aspect_ratio = float(np.diff(projection.x_limits).item() / np.diff(projection.y_limits).item())
1✔
1558

1559
    # Define inset location relative to main plot (ax) in normalized units
1560
    # - Lower-left corner of inset Axes, and its width and height
1561
    # - [x0, y0, width, height]
1562
    inset_bounds = get_inset_bounds(
1✔
1563
        ax=ax,
1564
        loc=loc,
1565
        inset_height=inset_height,
1566
        inside_figure=inside_figure,
1567
        aspect_ratio=aspect_ratio,
1568
        border_pad=border_pad,
1569
    )
1570

1571
    ax2 = ax.inset_axes(
1✔
1572
        inset_bounds,
1573
        projection=projection,
1574
    )
1575

1576
    # Add global map
1577
    ax2.set_global()
1✔
1578
    ax2.add_feature(cfeature.LAND)
1✔
1579
    ax2.add_feature(cfeature.OCEAN)
1✔
1580

1581
    # Add extent polygon
1582
    _ = ax2.add_geometries(
1✔
1583
        [polygon],
1584
        ccrs.PlateCarree(),
1585
        facecolor="none",
1586
        edgecolor="red",
1587
        linewidth=0.3,
1588
    )
1589
    return ax2
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc