• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 13679277988

05 Mar 2025 03:17PM UTC coverage: 89.223% (-5.2%) from 94.43%
13679277988

push

github

web-flow
Update PMW Tutorial (#74)

* Fix gridlines removal for cartopy artist update

* Add TC-PRIMED tutorial

* Update PMW 1C tutorial

65 of 181 new or added lines in 10 files covered. (35.91%)

909 existing lines in 41 files now uncovered.

14911 of 16712 relevant lines covered (89.22%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.81
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import cartopy.feature as cfeature
1✔
33
import matplotlib.pyplot as plt
1✔
34
import numpy as np
1✔
35
import xarray as xr
1✔
36
from pycolorbar import plot_colorbar, set_colorbar_fully_transparent
1✔
37
from pycolorbar.utils.mpl_legend import get_inset_bounds
1✔
38
from scipy.interpolate import griddata
1✔
39

40
import gpm
1✔
41
from gpm import get_plot_kwargs
1✔
42
from gpm.utils.area import get_lonlat_corners_from_centroids
1✔
43

44

45
def is_generator(obj):
1✔
46
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
47

48

49
def _call_optimize_layout(self):
1✔
50
    """Optimize the figure layout."""
51
    adapt_fig_size(ax=self.axes)
×
UNCOV
52
    self.figure.tight_layout()
×
53

54

55
def add_optimize_layout_method(p):
1✔
56
    """Add a method to optimize the figure layout using monkey patching."""
57
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
58
    return p
1✔
59

60

61
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
62
    """Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
63

64
    This function is intended to be called after all plotting has been completed.
65
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
66

67
    Assumes that the first axis in the collection of axes is representative of all others.
68
    This means that all subplots are expected to have the same aspect ratio and size.
69

70
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
71
    """
72
    # Determine the number of rows and columns of subplots in the figure.
73
    # This information is crucial for calculating the new height of the figure.
74
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
75

76
    # Access the figure object from the axis to manipulate its properties.
77
    fig = ax.get_figure()
1✔
78

79
    # Retrieve the current size of the figure in inches.
80
    width, original_height = fig.get_size_inches()
1✔
81

82
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
83
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
84
    fig.canvas.draw()
1✔
85

86
    # Extract subplot parameters to understand the figure's layout.
87
    # These parameters include the margins of the figure and the spaces between subplots.
88
    bottom = fig.subplotpars.bottom
1✔
89
    top = fig.subplotpars.top
1✔
90
    left = fig.subplotpars.left
1✔
91
    right = fig.subplotpars.right
1✔
92
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
93
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
94

95
    # Calculate the aspect ratio of the data in the subplot.
96
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
97
    aspect = ax.get_data_ratio()
1✔
98

99
    # Calculate the width of a single plot, considering the left and right margins,
100
    # the number of columns, and the space between columns.
101
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
102

103
    # Calculate the height of a single plot using its width and the data aspect ratio.
104
    hp = wp * aspect
1✔
105

106
    # Calculate the new height of the figure, taking into account the number of rows,
107
    # the space between rows, and the top and bottom margins.
108
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
109

110
    # Check if the new height is significantly reduced (more than halved).
111
    if original_height / height > 2:
1✔
112
        # Calculate the scale factor to adjust the figure size closer to the original.
UNCOV
113
        scale_factor = original_height / height / 2
×
114

115
        # Apply the scale factor to both width and height to maintain the aspect ratio.
116
        width *= scale_factor
×
UNCOV
117
        height *= scale_factor
×
118

119
    # Apply the calculated width and height to adjust the figure size.
120
    fig.set_figwidth(width)
1✔
121
    fig.set_figheight(height)
1✔
122

123

124
####--------------------------------------------------------------------------.
125

126

127
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
128
    """Infill invalid coordinates.
129

130
    Interpolate the coordinates within the convex hull of data.
131
    Use nearest neighbour outside the convex hull of data.
132
    """
133
    # Copy object
134
    xr_obj = xr_obj.copy()
1✔
135
    lon = np.asanyarray(xr_obj[x].data)
1✔
136
    lat = np.asanyarray(xr_obj[y].data)
1✔
137
    # Retrieve infilled coordinates
138
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
139
    xr_obj[x].data = lon
1✔
140
    xr_obj[y].data = lat
1✔
141
    return xr_obj
1✔
142

143

144
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
145
    """Infill invalid coordinates.
146

147
    Interpolate the coordinates within the convex hull of data.
148
    Use nearest neighbour outside the convex hull of data.
149

150
    This operation is required to plot with pcolormesh since it
151
    does not accept non-finite values in the coordinates.
152

153
    If  ``mask_data=True``, data values with invalid coordinates are masked
154
    and a numpy masked array is returned.
155
    Masked data values are not displayed in pcolormesh !
156
    If ``rgb=True``, it assumes the RGB dimension is the last data dimension.
157

158
    """
159
    # Retrieve mask of invalid coordinates
160
    x_invalid = ~np.isfinite(x)
1✔
161
    y_invalid = ~np.isfinite(y)
1✔
162
    mask = np.logical_or(x_invalid, y_invalid)
1✔
163

164
    # If no invalid coordinates, return original data
165
    if np.all(~mask):
1✔
166
        return x, y, data
1✔
167

168
    # Check at least ome valid coordinates
169
    if np.all(mask):
1✔
UNCOV
170
        raise ValueError("No valid coordinates.")
×
171

172
    # Mask the data
173
    if mask_data:
1✔
174
        if rgb:
1✔
175
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
1✔
176
            data_masked = np.ma.masked_where(data_mask, data)
1✔
177
        else:
178
            data_masked = np.ma.masked_where(mask, data)
1✔
179
    else:
UNCOV
180
        data_masked = data
×
181

182
    # Infill x and y
183
    # - Note: currently cause issue if NaN when crossing antimeridian ...
184
    # --> TODO: interpolation should be done in X,Y,Z
185
    if np.any(x_invalid):
1✔
186
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
187
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
188
    if np.any(y_invalid):
1✔
189
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
190
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
191
    return x, y, data_masked
1✔
192

193

194
def _interpolate_data(arr, method="linear"):
1✔
195
    # 1D coordinate (i.e. along_track/cross_track view)
196
    if arr.ndim == 1:
1✔
197
        return _interpolate_1d_coord(arr, method=method)
1✔
198
    # 2D coordinates (swath image)
199
    return _interpolate_2d_coord(arr, method=method)
1✔
200

201

202
def _interpolate_1d_coord(arr, method="linear"):
1✔
203
    # Find invalid locations
204
    is_invalid = ~np.isfinite(arr)
1✔
205

206
    # Find the indices of NaN values
207
    nan_indices = np.where(is_invalid)[0]
1✔
208

209
    # Return array if not NaN values
210
    if len(nan_indices) == 0:
1✔
211
        return arr
1✔
212

213
    # Find the indices of non-NaN values
214
    non_nan_indices = np.where(~is_invalid)
1✔
215

216
    # Create indices
217
    indices = np.arange(len(arr))
1✔
218

219
    # Points where we have valid data
220
    points = indices[non_nan_indices]
1✔
221

222
    # Points where data is NaN
223
    points_nan = indices[nan_indices]
1✔
224

225
    # Values at the non-NaN points
226
    values = arr[non_nan_indices]
1✔
227

228
    # Interpolate using griddata
229
    arr_new = arr.copy()
1✔
230
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
231
    return arr_new
1✔
232

233

234
def _interpolate_2d_coord(arr, method="linear"):
1✔
235
    # Find invalid locations
236
    is_invalid = ~np.isfinite(arr)
1✔
237

238
    # Find the indices of NaN values
239
    nan_indices = np.where(is_invalid)
1✔
240

241
    # Return array if not NaN values
242
    if len(nan_indices) == 0:
1✔
UNCOV
243
        return arr
×
244

245
    # Find the indices of non-NaN values
246
    non_nan_indices = np.where(~is_invalid)
1✔
247

248
    # Create a meshgrid of indices
249
    x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
250

251
    # Points (X, Y) where we have valid data
252
    points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
1✔
253

254
    # Points where data is NaN
255
    points_nan = np.array([y[nan_indices], x[nan_indices]]).T
1✔
256

257
    # Values at the non-NaN points
258
    values = arr[non_nan_indices]
1✔
259

260
    # Interpolate using griddata
261
    arr_new = arr.copy()
1✔
262
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
263
    return arr_new
1✔
264

265

266
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
267
    if np.ma.is_masked(arr):
1✔
268
        if rgb:
1✔
269
            antimeridian_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
1✔
270
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
271
        else:
272
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
273
        arr = np.ma.masked_where(combined_mask, arr)
1✔
274
    else:
275
        if rgb:
1✔
276
            antimeridian_mask = np.broadcast_to(
1✔
277
                np.expand_dims(antimeridian_mask, axis=-1),
278
                arr.shape,
279
            )
280
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
281
    return arr
1✔
282

283

284
def mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs):
1✔
285
    """Mask the array cells crossing the antimeridian.
286

287
    Here we assume not invalid lon coordinates anymore.
288
    Cartopy still bugs with several projections when data cross the antimeridian.
289
    By default, GPM-API mask data crossing the antimeridian.
290
    The GPM-API configuration default can be modified with: ``gpm.config.set({"viz_hide_antimeridian_data": False})``
291
    """
292
    antimeridian_mask = get_antimeridian_mask(lon)
1✔
293
    is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
294
    if is_crossing_antimeridian:
1✔
295
        # Sanitize cmap to avoid cartopy bug related to cmap bad color
296
        # - Cartopy requires the bad color to be fully transparent
297
        plot_kwargs = _sanitize_cartopy_plot_kwargs(plot_kwargs)
1✔
298
        # Mask data based on GPM-API config 'viz_hide_antimeridian_data'
299
        if gpm.config.get("viz_hide_antimeridian_data"):  # default is True
1✔
300
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
301
    return arr, plot_kwargs
1✔
302

303

304
def get_antimeridian_mask(lons):
1✔
305
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
306
    from scipy.ndimage import binary_dilation
1✔
307

308
    # Initialize mask
309
    n_y, n_x = lons.shape
1✔
310
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
311
    # Check vertical edges
312
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
313
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
314
    mask[row_idx, col_idx] = 1
1✔
315
    # Check horizontal edges
316
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
317
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
318
    mask[row_idx, col_idx] = 1
1✔
319
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
320
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
321
    return binary_dilation(mask)
1✔
322

323

324
####--------------------------------------------------------------------------.
325
########################
326
#### Plot utilities ####
327
########################
328

329

330
def preprocess_rgb_dataarray(da, rgb):
1✔
331
    if rgb:
1✔
332
        if rgb not in da.dims:
1✔
333
            raise ValueError(f"The specified rgb='{rgb}' must be a dimension of the DataArray.")
1✔
334
        if da[rgb].size not in [3, 4]:
1✔
UNCOV
335
            raise ValueError("The RGB dimension must have size 3 or 4.")
×
336
        da = da.transpose(..., rgb)
1✔
337
    return da
1✔
338

339

340
def check_object_format(da, plot_kwargs, check_function, **function_kwargs):
1✔
341
    """Check object format and valid dimension names."""
342
    # Preprocess RGB DataArrays
343
    da = da.squeeze()
1✔
344
    da = preprocess_rgb_dataarray(da, plot_kwargs.get("rgb", False))
1✔
345
    # Retrieve rgb or FacetGrid column/row dimensions
346
    dims_dict = {key: plot_kwargs.get(key) for key in ["rgb", "col", "row"] if plot_kwargs.get(key, None)}
1✔
347
    # Check such dimensions are available
348
    for key, dim in dims_dict.items():
1✔
349
        if dim not in da.dims:
1✔
350
            raise ValueError(f"The DataArray does not have a {key}='{dim}' dimension.")
1✔
351
    # Subset DataArray to check if complies with specific check function
352
    isel_dict = {dim: 0 for dim in dims_dict.values()}
1✔
353
    check_function(da.isel(isel_dict), **function_kwargs)
1✔
354
    return da
1✔
355

356

357
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None, is_facetgrid=False):
1✔
358
    if is_facetgrid and ax is not None:
1✔
359
        raise ValueError("When plotting with FacetGrid, do not specify the 'ax'.")
1✔
360
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
361
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
362
    if ax is not None:
1✔
363
        if len(subplot_kwargs) >= 1:
1✔
364
            raise ValueError("Provide `subplot_kwargs`only if ``ax``is None")
1✔
365
        if len(fig_kwargs) >= 1:
1✔
366
            raise ValueError("Provide `fig_kwargs` only if ``ax``is None")
1✔
367
    return fig_kwargs
1✔
368

369

370
def preprocess_subplot_kwargs(subplot_kwargs):
1✔
371
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
372
    subplot_kwargs = subplot_kwargs.copy()
1✔
373
    if "projection" not in subplot_kwargs:
1✔
374
        subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
375
    return subplot_kwargs
1✔
376

377

378
def infer_xy_labels(da, x=None, y=None, rgb=None):
1✔
379
    from xarray.plot.utils import _infer_xy_labels
1✔
380

381
    # Infer dimensions
382
    x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb)  # dummy flag for rgb
1✔
383
    return x, y
1✔
384

385

386
def infer_map_xy_coords(da, x=None, y=None):
1✔
387
    """
388
    Infer possible map x and y coordinates for the given DataArray.
389

390
    Parameters
391
    ----------
392
    da : xarray.DataArray
393
        The input DataArray.
394
    x : str, optional
395
        The name of the x (i.e. longitude) coordinate. If None, it will be inferred.
396
    y : str, optional
397
        The name of the y (i.e. latitude) coordinate. If None, it will be inferred.
398

399
    Returns
400
    -------
401
    tuple
402
        The inferred (x, y) coordinates.
403
    """
404
    possible_x_coords = ["x", "lon", "longitude"]
1✔
405
    possible_y_coords = ["y", "lat", "latitude"]
1✔
406

407
    if x is None:
1✔
408
        for coord in possible_x_coords:
1✔
409
            if coord in da.coords:
1✔
410
                x = coord
1✔
411
                break
1✔
412
        else:
UNCOV
413
            raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.")
×
414

415
    if y is None:
1✔
416
        for coord in possible_y_coords:
1✔
417
            if coord in da.coords:
1✔
418
                y = coord
1✔
419
                break
1✔
420
        else:
UNCOV
421
            raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.")
×
422

423
    return x, y
1✔
424

425

426
def initialize_cartopy_plot(
1✔
427
    ax,
428
    fig_kwargs,
429
    subplot_kwargs,
430
    add_background,
431
    add_gridlines,
432
    add_labels,
433
):
434
    """Initialize figure for cartopy plot if necessary."""
435
    # - Initialize figure
436
    if ax is None:
1✔
437
        fig_kwargs = preprocess_figure_args(
1✔
438
            ax=ax,
439
            fig_kwargs=fig_kwargs,
440
            subplot_kwargs=subplot_kwargs,
441
        )
442
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs)
1✔
443
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
444

445
    # - Add cartopy background
446
    if add_background:
1✔
447
        ax = plot_cartopy_background(ax)
1✔
448

449
    # - Add gridlines and labels
450
    if add_gridlines or add_labels:
1✔
451
        _ = plot_cartopy_gridlines_and_labels(ax, add_gridlines=add_gridlines, add_labels=add_labels)
1✔
452

453
    return ax
1✔
454

455

456
def plot_cartopy_gridlines_and_labels(ax, add_gridlines=True, add_labels=True):
1✔
457
    """Add cartopy gridlines and labels."""
458
    alpha = 0.1 if add_gridlines else 0
1✔
459
    gl = ax.gridlines(
1✔
460
        crs=ccrs.PlateCarree(),
461
        draw_labels=add_labels,
462
        linewidth=1,
463
        color="gray",
464
        alpha=alpha,
465
        linestyle="-",
466
    )
467
    gl.top_labels = False  # gl.xlabels_top = False
1✔
468
    gl.right_labels = False  # gl.ylabels_right = False
1✔
469
    gl.xlines = True
1✔
470
    gl.ylines = True
1✔
471
    return gl
1✔
472

473

474
def plot_cartopy_background(ax):
1✔
475
    """Plot cartopy background."""
476
    # - Add coastlines
477
    ax.coastlines()
1✔
478
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
479
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
480
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
481
    return ax
1✔
482

483

484
def plot_sides(sides, ax, **plot_kwargs):
1✔
485
    """Plot boundary sides.
486

487
    Expects a list of (lon, lat) tuples.
488
    """
489
    for side in sides:
1✔
490
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
491
    return p[0]
1✔
492

493

494
####--------------------------------------------------------------------------.
495
##########################
496
#### Cartopy wrappers ####
497
##########################
498

499

500
def _sanitize_cartopy_plot_kwargs(plot_kwargs):
1✔
501
    """Sanitize 'cmap' to avoid cartopy bug related to cmap bad color.
502

503
    Cartopy requires the bad color to be fully transparent.
504
    """
505
    cmap = plot_kwargs.get("cmap", None)
1✔
506
    if cmap is not None:
1✔
507
        bad = cmap.get_bad()
1✔
508
        bad[3] = 0  # enforce to 0 (transparent)
1✔
509
        cmap.set_bad(bad)
1✔
510
        plot_kwargs["cmap"] = cmap
1✔
511
    return plot_kwargs
1✔
512

513

514
def _compute_extent(x_coords, y_coords):
1✔
515
    """Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
516

517
    This function assumes that the spacing between each pixel is uniform.
518
    """
519
    # Calculate the pixel size assuming uniform spacing between pixels
520
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
1✔
521
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
1✔
522

523
    # Adjust min and max to get the corners of the outer pixels
524
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
1✔
525
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
1✔
526

527
    return [x_min, x_max, y_min, y_max]
1✔
528

529

530
def plot_cartopy_imshow(
1✔
531
    ax,
532
    da,
533
    x,
534
    y,
535
    interpolation="nearest",
536
    add_colorbar=True,
537
    plot_kwargs=None,
538
    cbar_kwargs=None,
539
):
540
    """Plot imshow with cartopy."""
541
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
542

543
    # Assume CRS of data
544
    transform = ccrs.PlateCarree()
1✔
545

546
    # Infer x and y
547
    x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
548

549
    # Align x,y, data dimensions
550
    # - Ensure image with correct dimensions orders
551
    # - It can happen that x/y coords does not have same dimension order of data array.
552
    da = da.transpose(*da[y].dims, *da[x].dims, ...)
1✔
553

554
    # - Retrieve data
555
    arr = np.asanyarray(da.data)
1✔
556

557
    # - Compute coordinates
558
    x_coords = da[x].to_numpy()
1✔
559
    y_coords = da[y].to_numpy()
1✔
560

561
    # - Derive extent
562
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
563

564
    # - Determine origin based on the orientation of da[y] values
565
    # - On the map, the y coordinates should grow from bottom to top
566
    # -->  If y coordinate is increasing, set origin="lower"
567
    # -->  If y coordinate is decreasing, set origin="upper"
568
    y_increasing = y_coords[1] > y_coords[0]
1✔
569
    origin = "lower" if y_increasing else "upper"  # OLD CODE
1✔
570

571
    # Deal with decreasing y
572
    if not y_increasing:  # decreasing y coordinates
1✔
UNCOV
573
        extent = [extent[i] for i in [0, 1, 3, 2]]
×
574

575
    # Deal with out of limits x (PlateeCarree coordinates out of bounds when  lons are defined as 0-360)
576
    set_extent = True
1✔
577

578
    # Case where coordinates are defined as 0-360 with pm=0
579
    if extent[1] > transform.x_limits[1] or extent[0] < transform.x_limits[0]:
1✔
UNCOV
580
        set_extent = False
×
581

582
    # - Add variable field with cartopy
583
    rgb = plot_kwargs.pop("rgb", False)
1✔
584
    p = ax.imshow(
1✔
585
        arr,
586
        transform=transform,
587
        extent=extent,
588
        origin=origin,
589
        interpolation=interpolation,
590
        **plot_kwargs,
591
    )
592
    # - Set the extent
593
    if set_extent:
1✔
594
        ax.set_extent(extent)
1✔
595

596
    # - Add colorbar
597
    if add_colorbar and not rgb:
1✔
598
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
599
    return p
1✔
600

601

602
def plot_cartopy_pcolormesh(
1✔
603
    ax,
604
    da,
605
    x,
606
    y,
607
    add_colorbar=True,
608
    add_swath_lines=True,
609
    plot_kwargs=None,
610
    cbar_kwargs=None,
611
):
612
    """Plot imshow with cartopy.
613

614
    x and y must represents longitude and latitudes.
615
    The function currently does not allow to zoom on regions across the antimeridian.
616
    The function mask scanning pixels which spans across the antimeridian.
617
    If the DataArray has a RGB dimension, plot_kwargs should contain the ``rgb``
618
    key with the name of the RGB dimension.
619

620
    """
621
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
622

623
    # Remove RGB from plot_kwargs
624
    rgb = plot_kwargs.pop("rgb", False)
1✔
625

626
    # Align x,y, data dimensions
627
    # - Ensure image with correct dimensions orders
628
    # - It can happen that x/y coords does not have same dimension order of data array.
629
    da = da.transpose(*da[y].dims, ...)
1✔
630

631
    # Get x, y, and array to plot
632
    da = preprocess_rgb_dataarray(da, rgb=rgb)
1✔
633
    da = da.compute()
1✔
634
    lon = da[x].data.copy()
1✔
635
    lat = da[y].data.copy()
1✔
636
    arr = da.data
1✔
637

638
    # Check if 1D coordinate (orbit nadir-view / transect / cross-section case)
639
    is_1d_case = lon.ndim == 1
1✔
640

641
    # Infill invalid value and mask data at invalid coordinates
642
    # - No invalid values after this function call
643
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb, mask_data=True)
1✔
644
    if is_1d_case:
1✔
645
        arr = np.expand_dims(arr, axis=1)
1✔
646

647
    # Ensure arguments
648
    if rgb:
1✔
649
        add_colorbar = False
1✔
650

651
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
652
    # - This enable correct masking of cells crossing the antimeridian
653
    lon, lat = get_lonlat_corners_from_centroids(lon, lat, parallel=False)
1✔
654

655
    # Mask cells crossing the antimeridian
656
    # - with gpm.config.set({"viz_hide_antimeridian_data": False}): can be used to modify the masking behaviour
657
    arr, plot_kwargs = mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs)
1✔
658

659
    # Add variable field with cartopy
660
    _ = plot_kwargs.setdefault("shading", "flat")
1✔
661
    p = ax.pcolormesh(
1✔
662
        lon,
663
        lat,
664
        arr,
665
        transform=ccrs.PlateCarree(),
666
        **plot_kwargs,
667
    )
668
    # Add swath lines
669
    # - TODO: currently assume that dimensions are (cross_track, along_track)
670
    if add_swath_lines and not is_1d_case:
1✔
671
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
672
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
673

674
    # Add colorbar
675
    if add_colorbar:
1✔
676
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
677
    return p
1✔
678

679

680
####-------------------------------------------------------------------------------.
681
#########################
682
#### Xarray wrappers ####
683
#########################
684

685

686
def _preprocess_xr_kwargs(add_colorbar, plot_kwargs, cbar_kwargs):
1✔
687
    if not add_colorbar:
1✔
688
        cbar_kwargs = None
1✔
689

690
    if "rgb" in plot_kwargs:
1✔
691
        cbar_kwargs = None
1✔
692
        add_colorbar = False
1✔
693
        args_to_keep = ["rgb", "col", "row", "origin"]  # alpha currently skipped if RGB
1✔
694
        plot_kwargs = {k: plot_kwargs[k] for k in args_to_keep if plot_kwargs.get(k, None) is not None}
1✔
695
    return add_colorbar, plot_kwargs, cbar_kwargs
1✔
696

697

698
def plot_xr_pcolormesh(
1✔
699
    ax,
700
    da,
701
    x,
702
    y,
703
    add_colorbar=True,
704
    cbar_kwargs=None,
705
    **plot_kwargs,
706
):
707
    """Plot pcolormesh with xarray."""
708
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
709
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
710
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
711
        add_colorbar=add_colorbar,
712
        plot_kwargs=plot_kwargs,
713
        cbar_kwargs=cbar_kwargs,
714
    )
715
    p = da.plot.pcolormesh(
1✔
716
        x=x,
717
        y=y,
718
        ax=ax,
719
        add_colorbar=add_colorbar,
720
        cbar_kwargs=cbar_kwargs,
721
        **plot_kwargs,
722
    )
723

724
    # Add variable name as title (if not FacetGrid)
725
    if not is_facetgrid:
1✔
726
        p.axes.set_title(da.name)
1✔
727

728
    if add_colorbar and ticklabels is not None:
1✔
UNCOV
729
        p.colorbar.ax.set_yticklabels(ticklabels)
×
730
    return p
1✔
731

732

733
def plot_xr_imshow(
1✔
734
    ax,
735
    da,
736
    x,
737
    y,
738
    interpolation="nearest",
739
    add_colorbar=True,
740
    add_labels=True,
741
    cbar_kwargs=None,
742
    visible_colorbar=True,
743
    **plot_kwargs,
744
):
745
    """Plot imshow with xarray.
746

747
    The colorbar is added with xarray to enable to display multiple colorbars
748
    when calling this function multiple times on different fields with
749
    different colorbars.
750
    """
751
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
752
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
753
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
754
        add_colorbar=add_colorbar,
755
        plot_kwargs=plot_kwargs,
756
        cbar_kwargs=cbar_kwargs,
757
    )
758
    # Allow using coords as x/y axis
759
    # BUG - Current bug in xarray
760
    if plot_kwargs.get("rgb", None) is not None:
1✔
761
        if x not in da.dims:
1✔
UNCOV
762
            da = da.swap_dims({list(da[x].dims)[0]: x})
×
763
        if y not in da.dims:
1✔
UNCOV
764
            da = da.swap_dims({list(da[y].dims)[0]: y})
×
765

766
    p = da.plot.imshow(
1✔
767
        x=x,
768
        y=y,
769
        ax=ax,
770
        interpolation=interpolation,
771
        add_colorbar=add_colorbar,
772
        add_labels=add_labels,
773
        cbar_kwargs=cbar_kwargs,
774
        **plot_kwargs,
775
    )
776

777
    # Add variable name as title (if not FacetGrid)
778
    if not is_facetgrid:
1✔
779
        p.axes.set_title(da.name)
1✔
780

781
    # Add colorbar ticklabels
782
    if add_colorbar and ticklabels is not None:
1✔
783
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
784

785
    # Make the colorbar fully transparent with a smart trick ;)
786
    # - TODO: this still cause issues when plotting 2 colorbars !
787
    if add_colorbar and not visible_colorbar:
1✔
788
        set_colorbar_fully_transparent(p)
1✔
789

790
    # Add manually the colorbar
791
    # p = da.plot.imshow(
792
    #     x=x,
793
    #     y=y,
794
    #     ax=ax,
795
    #     interpolation=interpolation,
796
    #     add_colorbar=False,
797
    #     **plot_kwargs,
798
    # )
799
    # plt.title(da.name)
800
    # if add_colorbar:
801
    #     _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
802
    return p
1✔
803

804

805
####--------------------------------------------------------------------------.
806
####################
807
#### Plot Image ####
808
####################
809

810

811
def _plot_image(
1✔
812
    da,
813
    x=None,
814
    y=None,
815
    ax=None,
816
    add_colorbar=True,
817
    add_labels=True,
818
    interpolation="nearest",
819
    fig_kwargs=None,
820
    cbar_kwargs=None,
821
    **plot_kwargs,
822
):
823
    """Plot GPM orbit granule as in image."""
824
    from gpm.checks import is_grid, is_orbit
1✔
825
    from gpm.visualization.facetgrid import sanitize_facetgrid_plot_kwargs
1✔
826

827
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)
1✔
828

829
    # - Initialize figure
830
    if ax is None:
1✔
831
        _, ax = plt.subplots(**fig_kwargs)
1✔
832

833
    # - Sanitize plot_kwargs set by by xarray FacetGrid.map_dataarray
834
    is_facetgrid = plot_kwargs.get("_is_facetgrid", False)
1✔
835
    plot_kwargs = sanitize_facetgrid_plot_kwargs(plot_kwargs)
1✔
836

837
    # - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
838
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
839
        name=da.name,
840
        user_plot_kwargs=plot_kwargs,
841
        user_cbar_kwargs=cbar_kwargs,
842
    )
843

844
    # Define x and y
845
    x, y = infer_xy_labels(da=da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
846

847
    # - Plot with xarray
848
    p = plot_xr_imshow(
1✔
849
        ax=ax,
850
        da=da,
851
        x=x,
852
        y=y,
853
        interpolation=interpolation,
854
        add_colorbar=add_colorbar,
855
        add_labels=add_labels,
856
        cbar_kwargs=cbar_kwargs,
857
        **plot_kwargs,
858
    )
859

860
    # Add custom labels
861
    default_labels = {
1✔
862
        "orbit": {"along_track": "Along-Track", "x": "Along-Track", "cross_track": "Cross-Track", "y": "Cross-Track"},
863
        "grid": {
864
            "lon": "Longitude",
865
            "longitude": "Longitude",
866
            "x": "Longitude",
867
            "lat": "Latitude",
868
            "latitude": "Latitude",
869
            "y": "Latitude",
870
        },
871
    }
872

873
    if add_labels:
1✔
874
        if is_orbit(da):
1✔
875
            ax.set_xlabel(default_labels["orbit"].get(x, x))
1✔
876
            ax.set_ylabel(default_labels["orbit"].get(y, y))
1✔
877
        elif is_grid(da):
1✔
878
            ax.set_xlabel(default_labels["grid"].get(x, x))
1✔
879
            ax.set_ylabel(default_labels["grid"].get(y, y))
1✔
880

881
    # - Monkey patch the mappable instance to add optimize_layout
882
    if not is_facetgrid:
1✔
883
        p = add_optimize_layout_method(p)
1✔
884
    # - Return mappable
885
    return p
1✔
886

887

888
def _plot_image_facetgrid(
1✔
889
    da,
890
    x=None,
891
    y=None,
892
    ax=None,
893
    add_colorbar=True,
894
    add_labels=True,
895
    interpolation="nearest",
896
    fig_kwargs=None,
897
    cbar_kwargs=None,
898
    **plot_kwargs,
899
):
900
    """Plot 2D fields with FacetGrid."""
901
    from gpm.visualization.facetgrid import ImageFacetGrid
1✔
902

903
    # Check inputs
904
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, is_facetgrid=True)
1✔
905

906
    # Retrieve GPM-API defaults cmap and cbar kwargs
907
    variable = da.name
1✔
908
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
909
        name=variable,
910
        user_plot_kwargs=plot_kwargs,
911
        user_cbar_kwargs=cbar_kwargs,
912
    )
913

914
    # Disable colorbar if rgb
915
    # - Move this to pycolorbar !
916
    # - Also remove cmap, norm, vmin and vmax in plot_kwargs
917
    if plot_kwargs.get("rgb", False):
1✔
918
        add_colorbar = False
1✔
919
        cbar_kwargs = {}
1✔
920

921
    # Create FacetGrid
922
    fc = ImageFacetGrid(
1✔
923
        data=da.compute(),
924
        col=plot_kwargs.pop("col", None),
925
        row=plot_kwargs.pop("row", None),
926
        col_wrap=plot_kwargs.pop("col_wrap", None),
927
        axes_pad=plot_kwargs.pop("axes_pad", None),
928
        fig_kwargs=fig_kwargs,
929
        cbar_kwargs=cbar_kwargs,
930
        add_colorbar=add_colorbar,
931
        aspect=plot_kwargs.pop("aspect", False),
932
        facet_height=plot_kwargs.pop("facet_height", 3),
933
        facet_aspect=plot_kwargs.pop("facet_aspect", 1),
934
    )
935

936
    # Plot the maps
937
    fc = fc.map_dataarray(
1✔
938
        _plot_image,
939
        x=x,
940
        y=y,
941
        add_colorbar=False,
942
        add_labels=add_labels,
943
        interpolation=interpolation,
944
        cbar_kwargs=cbar_kwargs,
945
        **plot_kwargs,
946
    )
947

948
    # Remove duplicated or all labels
949
    fc.remove_duplicated_axis_labels()
1✔
950

951
    if not add_labels:
1✔
UNCOV
952
        fc.remove_left_ticks_and_labels()
×
UNCOV
953
        fc.remove_bottom_ticks_and_labels()
×
954

955
    # Add colorbar
956
    if add_colorbar:
1✔
957
        fc.add_colorbar(**cbar_kwargs)
1✔
958

959
    return fc
1✔
960

961

962
def plot_image(
1✔
963
    da,
964
    x=None,
965
    y=None,
966
    ax=None,
967
    add_colorbar=True,
968
    add_labels=True,
969
    interpolation="nearest",
970
    fig_kwargs=None,
971
    cbar_kwargs=None,
972
    **plot_kwargs,
973
):
974
    """Plot data using imshow.
975

976
    Parameters
977
    ----------
978
    da : xarray.DataArray
979
        xarray DataArray.
980
    x : str, optional
981
        X dimension name.
982
        If ``None``, takes the second dimension.
983
        The default is ``None``.
984
    y : str, optional
985
        Y dimension name.
986
        If ``None``, takes the first dimension.
987
        The default is ``None``.
988
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
989
        The matplotlib axes where to plot the image.
990
        If ``None``, a figure is initialized using the
991
        specified ``fig_kwargs``.
992
        The default is ``None``.
993
    add_colorbar : bool, optional
994
        Whether to add a colorbar. The default is ``True``.
995
    add_labels : bool, optional
996
        Whether to add labels to the plot. The default is ``True``.
997
    interpolation : str, optional
998
        Argument to be passed to imshow.
999
        The default is ``"nearest"``.
1000
    fig_kwargs : dict, optional
1001
        Figure options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1002
        The default is ``None``.
1003
        Only used if ``ax`` is ``None``.
1004
    subplot_kwargs : dict, optional
1005
        Subplot options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1006
        The default is ``None``.
1007
        Only used if ```ax``` is ``None``.
1008
    cbar_kwargs : dict, optional
1009
        Colorbar options. The default is ``None``.
1010
    **plot_kwargs
1011
        Additional arguments to be passed to the plotting function.
1012
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1013
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1014
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1015

1016

1017
    """
1018
    from gpm.checks import check_is_spatial_2d, is_spatial_2d
1✔
1019

1020
    # Plot orbit
1021
    if not is_spatial_2d(da, strict=False):
1✔
1022
        raise ValueError("Can not plot. It's not a spatial 2D object.")
1✔
1023

1024
    # Check inputs
1025
    da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_is_spatial_2d, strict=True)
1✔
1026

1027
    # Plot FacetGrid with xarray imshow
1028
    if "col" in plot_kwargs or "row" in plot_kwargs:
1✔
1029
        p = _plot_image_facetgrid(
1✔
1030
            da=da,
1031
            x=x,
1032
            y=y,
1033
            ax=ax,
1034
            add_colorbar=add_colorbar,
1035
            add_labels=add_labels,
1036
            interpolation=interpolation,
1037
            fig_kwargs=fig_kwargs,
1038
            cbar_kwargs=cbar_kwargs,
1039
            **plot_kwargs,
1040
        )
1041
    # Plot with xarray imshow
1042
    else:
1043
        p = _plot_image(
1✔
1044
            da=da,
1045
            x=x,
1046
            y=y,
1047
            ax=ax,
1048
            add_colorbar=add_colorbar,
1049
            add_labels=add_labels,
1050
            interpolation=interpolation,
1051
            fig_kwargs=fig_kwargs,
1052
            cbar_kwargs=cbar_kwargs,
1053
            **plot_kwargs,
1054
        )
1055
    # Return mappable
1056
    return p
1✔
1057

1058

1059
####--------------------------------------------------------------------------.
1060
##################
1061
#### Plot map ####
1062
##################
1063

1064

1065
def plot_map(
1✔
1066
    da,
1067
    x=None,
1068
    y=None,
1069
    ax=None,
1070
    interpolation="nearest",  # used only for GPM grid objects
1071
    add_colorbar=True,
1072
    add_background=True,
1073
    add_labels=True,
1074
    add_gridlines=True,
1075
    add_swath_lines=True,  # used only for GPM orbit objects
1076
    fig_kwargs=None,
1077
    subplot_kwargs=None,
1078
    cbar_kwargs=None,
1079
    **plot_kwargs,
1080
):
1081
    """Plot data on a geographic map.
1082

1083
    Parameters
1084
    ----------
1085
    da : xarray.DataArray
1086
        xarray DataArray.
1087
    x : str, optional
1088
        Longitude coordinate name.
1089
        If ``None``, takes the second dimension.
1090
        The default is ``None``.
1091
    y : str, optional
1092
        Latitude coordinate name.
1093
        If ``None``, takes the first dimension.
1094
        The default is ``None``.
1095
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1096
        The cartopy GeoAxes where to plot the map.
1097
        If ``None``, a figure is initialized using the
1098
        specified ``fig_kwargs`` and ``subplot_kwargs``.
1099
        The default is ``None``.
1100
    add_colorbar : bool, optional
1101
        Whether to add a colorbar. The default is ``True``.
1102
    add_labels : bool, optional
1103
        Whether to add cartopy labels to the plot. The default is ``True``.
1104
    add_gridlines : bool, optional
1105
        Whether to add cartopy gridlines to the plot. The default is ``True``.
1106
    add_swath_lines : bool, optional
1107
        Whether to plot the swath sides with a dashed line. The default is ``True``.
1108
        This argument only applies for ORBIT objects.
1109
    add_background : bool, optional
1110
        Whether to add the map background. The default is ``True``.
1111
    interpolation : str, optional
1112
        Argument to be passed to :py:class:`matplotlib.axes.Axes.imshow`. Only applies for GRID objects.
1113
        The default is ``"nearest"``.
1114
    fig_kwargs : dict, optional
1115
        Figure options to be passed to `matplotlib.pyplot.subplots`.
1116
        The default is ``None``.
1117
        Only used if ``ax`` is ``None``.
1118
    subplot_kwargs : dict, optional
1119
        Dictionary of keyword arguments for :py:class:`matplotlib.pyplot.subplots`.
1120
        Must contain the Cartopy CRS ` ``projection`` key if specified.
1121
        The default is ``None``.
1122
        Only used if ``ax`` is ``None``.
1123
    cbar_kwargs : dict, optional
1124
        Colorbar options. The default is ``None``.
1125
    **plot_kwargs
1126
        Additional arguments to be passed to the plotting function.
1127
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1128
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1129
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1130

1131

1132
    """
1133
    from gpm.checks import has_spatial_dim, is_grid, is_orbit, is_spatial_2d
1✔
1134
    from gpm.visualization.grid import plot_grid_map
1✔
1135
    from gpm.visualization.orbit import plot_orbit_map
1✔
1136

1137
    # Plot orbit
1138
    # - allow vertical or other dimensions for FacetGrid
1139
    # - allow to plot a swath of size 1 (i.e. nadir-looking)
1140
    if is_orbit(da) and has_spatial_dim(da):
1✔
1141
        p = plot_orbit_map(
1✔
1142
            da=da,
1143
            x=x,
1144
            y=y,
1145
            ax=ax,
1146
            add_colorbar=add_colorbar,
1147
            add_background=add_background,
1148
            add_gridlines=add_gridlines,
1149
            add_labels=add_labels,
1150
            add_swath_lines=add_swath_lines,
1151
            fig_kwargs=fig_kwargs,
1152
            subplot_kwargs=subplot_kwargs,
1153
            cbar_kwargs=cbar_kwargs,
1154
            **plot_kwargs,
1155
        )
1156
    # Plot grid
1157
    elif is_grid(da) and is_spatial_2d(da, strict=False):
1✔
1158
        p = plot_grid_map(
1✔
1159
            da=da,
1160
            x=x,
1161
            y=y,
1162
            ax=ax,
1163
            interpolation=interpolation,
1164
            add_colorbar=add_colorbar,
1165
            add_background=add_background,
1166
            add_gridlines=add_gridlines,
1167
            add_labels=add_labels,
1168
            fig_kwargs=fig_kwargs,
1169
            subplot_kwargs=subplot_kwargs,
1170
            cbar_kwargs=cbar_kwargs,
1171
            **plot_kwargs,
1172
        )
1173
    else:
1174
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial 2D object.")
1✔
1175
    # Return mappable
1176
    return p
1✔
1177

1178

1179
def plot_map_mesh(
1✔
1180
    xr_obj,
1181
    x=None,
1182
    y=None,
1183
    ax=None,
1184
    edgecolors="k",
1185
    linewidth=0.1,
1186
    add_background=True,
1187
    add_gridlines=True,
1188
    add_labels=True,
1189
    fig_kwargs=None,
1190
    subplot_kwargs=None,
1191
    **plot_kwargs,
1192
):
1193
    from gpm.checks import is_grid, is_orbit
1✔
1194
    from gpm.visualization.grid import plot_grid_mesh
1✔
1195
    from gpm.visualization.orbit import plot_orbit_mesh
1✔
1196

1197
    # Plot orbit
1198
    if is_orbit(xr_obj):
1✔
1199
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1200
        p = plot_orbit_mesh(
1✔
1201
            da=xr_obj[y],
1202
            ax=ax,
1203
            x=x,
1204
            y=y,
1205
            edgecolors=edgecolors,
1206
            linewidth=linewidth,
1207
            add_background=add_background,
1208
            add_gridlines=add_gridlines,
1209
            add_labels=add_labels,
1210
            fig_kwargs=fig_kwargs,
1211
            subplot_kwargs=subplot_kwargs,
1212
            **plot_kwargs,
1213
        )
1214
    elif is_grid(xr_obj):
1✔
1215
        p = plot_grid_mesh(
1✔
1216
            xr_obj=xr_obj,
1217
            x=x,
1218
            y=y,
1219
            ax=ax,
1220
            edgecolors=edgecolors,
1221
            linewidth=linewidth,
1222
            add_background=add_background,
1223
            add_gridlines=add_gridlines,
1224
            add_labels=add_labels,
1225
            fig_kwargs=fig_kwargs,
1226
            subplot_kwargs=subplot_kwargs,
1227
            **plot_kwargs,
1228
        )
1229
    else:
UNCOV
1230
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial object.")
×
1231
    # Return mappable
1232
    return p
1✔
1233

1234

1235
def plot_map_mesh_centroids(
1✔
1236
    xr_obj,
1237
    x=None,
1238
    y=None,
1239
    ax=None,
1240
    c="r",
1241
    s=1,
1242
    add_background=True,
1243
    add_gridlines=True,
1244
    add_labels=True,
1245
    fig_kwargs=None,
1246
    subplot_kwargs=None,
1247
    **plot_kwargs,
1248
):
1249
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
1250
    from gpm.checks import is_grid, is_orbit
1✔
1251

1252
    # Initialize figure if necessary
1253
    ax = initialize_cartopy_plot(
1✔
1254
        ax=ax,
1255
        fig_kwargs=fig_kwargs,
1256
        subplot_kwargs=subplot_kwargs,
1257
        add_background=add_background,
1258
        add_gridlines=add_gridlines,
1259
        add_labels=add_labels,
1260
    )
1261

1262
    # Retrieve orbits lon, lat coordinates
1263
    if is_orbit(xr_obj):
1✔
1264
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1265

1266
    # Retrieve grid centroids mesh
1267
    if is_grid(xr_obj):
1✔
1268
        x, y = infer_xy_labels(xr_obj, x=x, y=y)
1✔
1269
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
1270

1271
    # Extract numpy arrays
1272
    lon = xr_obj[x].to_numpy()
1✔
1273
    lat = xr_obj[y].to_numpy()
1✔
1274

1275
    # Plot centroids
1276
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
1277

1278
    # Return mappable
1279
    return p
1✔
1280

1281

1282
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
1283
    """Create a 2D mesh coordinates DataArray.
1284

1285
    Takes as input the 1D coordinate arrays from an existing xarray.DataArray or xarray.Dataset object.
1286

1287
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
1288
    the data values to NaN.
1289

1290
    Parameters
1291
    ----------
1292
    xr_obj : xarray.DataArray or xarray.Dataset
1293
        The input xarray object containing the 1D coordinate arrays.
1294
    x : str
1295
        The name of the x-coordinate in `xr_obj`.
1296
    y : str
1297
        The name of the y-coordinate in `xr_obj`.
1298

1299
    Returns
1300
    -------
1301
    da_mesh : xarray.DataArray
1302
        A 2D xarray.DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
1303

1304
    Notes
1305
    -----
1306
    The resulting xarray.DataArray has dimensions named 'y' and 'x', corresponding to the
1307
    y and x coordinates respectively.
1308
    The coordinate values are taken directly from the input 1D coordinate arrays,
1309
    and the data values are set to NaN.
1310

1311
    """
1312
    # Extract 1D coordinate arrays
1313
    x_coords = xr_obj[x].to_numpy()
1✔
1314
    y_coords = xr_obj[y].to_numpy()
1✔
1315

1316
    # Create 2D meshgrid for x and y coordinates
1317
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
1318

1319
    # Create a 2D array of NaN values with the same shape as the meshgrid
1320
    dummy_values = np.full(X.shape, np.nan)
1✔
1321

1322
    # Create a new DataArray with 2D coordinates and NaN values
1323
    return xr.DataArray(
1✔
1324
        dummy_values,
1325
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
1326
        dims=("y", "x"),
1327
    )
1328

1329

1330
####--------------------------------------------------------------------------.
1331

1332

1333
def _plot_labels(
1✔
1334
    xr_obj,
1335
    label_name=None,
1336
    max_n_labels=50,
1337
    add_colorbar=True,
1338
    interpolation="nearest",
1339
    cmap="Paired",
1340
    fig_kwargs=None,
1341
    **plot_kwargs,
1342
):
1343
    """Plot labels.
1344

1345
    The maximum allowed number of labels to plot is 'max_n_labels'.
1346
    """
1347
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1348
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1349

1350
    from gpm.visualization.plot import plot_image
1✔
1351

1352
    if isinstance(xr_obj, xr.Dataset):
1✔
1353
        dataarray = xr_obj[label_name]
1✔
1354
    else:
1355
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1356

1357
    dataarray = dataarray.compute()
1✔
1358
    label_indices = get_label_indices(dataarray)
1✔
1359
    n_labels = len(label_indices)
1✔
1360
    if add_colorbar and n_labels > max_n_labels:
1✔
1361
        msg = f"""The array currently contains {n_labels} labels
1✔
1362
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1363
        print(msg)
1✔
1364
        add_colorbar = False
1✔
1365
    # Relabel array from 1 to ... for plotting
1366
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1367
    # Replace 0 with nan
1368
    dataarray = dataarray.where(dataarray > 0)
1✔
1369
    # Define appropriate colormap
1370
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1371
    default_plot_kwargs.update(plot_kwargs)
1✔
1372
    # Plot image
1373
    return plot_image(
1✔
1374
        dataarray,
1375
        interpolation=interpolation,
1376
        add_colorbar=add_colorbar,
1377
        cbar_kwargs=cbar_kwargs,
1378
        fig_kwargs=fig_kwargs,
1379
        **default_plot_kwargs,
1380
    )
1381

1382

1383
def plot_labels(
1✔
1384
    obj,  # Dataset, DataArray or generator
1385
    label_name=None,
1386
    max_n_labels=50,
1387
    add_colorbar=True,
1388
    interpolation="nearest",
1389
    cmap="Paired",
1390
    fig_kwargs=None,
1391
    **plot_kwargs,
1392
):
1393
    if is_generator(obj):
1✔
1394
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1395
            p = _plot_labels(
1✔
1396
                xr_obj=xr_obj,
1397
                label_name=label_name,
1398
                max_n_labels=max_n_labels,
1399
                add_colorbar=add_colorbar,
1400
                interpolation=interpolation,
1401
                cmap=cmap,
1402
                fig_kwargs=fig_kwargs,
1403
                **plot_kwargs,
1404
            )
1405
            plt.show()
1✔
1406
    else:
1407
        p = _plot_labels(
1✔
1408
            xr_obj=obj,
1409
            label_name=label_name,
1410
            max_n_labels=max_n_labels,
1411
            add_colorbar=add_colorbar,
1412
            interpolation=interpolation,
1413
            cmap=cmap,
1414
            fig_kwargs=fig_kwargs,
1415
            **plot_kwargs,
1416
        )
1417
    return p
1✔
1418

1419

1420
def plot_patches(
1✔
1421
    patch_gen,
1422
    variable=None,
1423
    add_colorbar=True,
1424
    interpolation="nearest",
1425
    fig_kwargs=None,
1426
    cbar_kwargs=None,
1427
    **plot_kwargs,
1428
):
1429
    """Plot patches."""
1430
    from gpm.visualization.plot import plot_image
1✔
1431

1432
    # Plot patches
1433
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1434
        if isinstance(xr_patch, xr.Dataset):
1✔
1435
            if variable is None:
1✔
1436
                raise ValueError("'variable' must be specified when plotting xarray.Dataset patches.")
1✔
1437
            xr_patch = xr_patch[variable]
1✔
1438
        try:
1✔
1439
            plot_image(
1✔
1440
                xr_patch,
1441
                interpolation=interpolation,
1442
                add_colorbar=add_colorbar,
1443
                fig_kwargs=fig_kwargs,
1444
                cbar_kwargs=cbar_kwargs,
1445
                **plot_kwargs,
1446
            )
1447
            plt.show()
1✔
1448
        except Exception:
1✔
1449
            pass
1✔
1450

1451

1452
####--------------------------------------------------------------------------.
1453

1454

1455
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True, border_pad=0):
1✔
1456
    """Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1457

1458
    This function creates a smaller map inset within a larger map plot to show a global view or
1459
    contextual location of the main plot's extent.
1460

1461
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1462
    within the inset to provide geographical context.
1463

1464
    Parameters
1465
    ----------
1466
    ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes
1467
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1468
    loc : str, optional
1469
        The location of the inset map within the main plot.
1470
        Options include ``'lower left'``, ``'lower right'``,
1471
        ``'upper left'``, and ``'upper right'``. The default is ``'upper left'``.
1472
    inset_height : float, optional
1473
        The size of the inset height, specified as a fraction of the figure's height.
1474
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1475
        The aspect ratio (of the map inset) will govern the ``inset_width``.
1476
    inside_figure : bool, optional
1477
        Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
1478
        the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
1479
        allowing for a half-outside placement.
1480
    projection: cartopy.crs.Projection, optional
1481
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1482

1483
    Returns
1484
    -------
1485
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1486
        The Cartopy GeoAxesSubplot object for the inset map.
1487

1488
    Notes
1489
    -----
1490
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1491
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1492
    geographical extent.
1493

1494
    Examples
1495
    --------
1496
    >>> p = da.gpm.plot_map()
1497
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1498

1499
    This example creates a main plot with a specified extent and adds an upper-left inset map
1500
    showing the global context of the main plot's extent.
1501

1502
    """
1503
    from shapely import Polygon
1✔
1504

1505
    from gpm.utils.geospatial import extend_geographic_extent
1✔
1506

1507
    # Retrieve map extent and bounds
1508
    extent = ax.get_extent()
1✔
1509
    extent = extend_geographic_extent(extent, padding=0.5)
1✔
1510
    bounds = [extent[i] for i in [0, 2, 1, 3]]
1✔
1511

1512
    # Create Cartopy Polygon
1513
    polygon = Polygon.from_bounds(*bounds)
1✔
1514

1515
    # Define Orthographic projection
1516
    if projection is None:
1✔
1517
        lon_min, lon_max, lat_min, lat_max = extent
1✔
1518
        projection = ccrs.Orthographic(
1✔
1519
            central_latitude=(lat_min + lat_max) / 2,
1520
            central_longitude=(lon_min + lon_max) / 2,
1521
        )
1522

1523
    # Define aspect ratio of the map inset
1524
    aspect_ratio = float(np.diff(projection.x_limits).item() / np.diff(projection.y_limits).item())
1✔
1525

1526
    # Define inset location relative to main plot (ax) in normalized units
1527
    # - Lower-left corner of inset Axes, and its width and height
1528
    # - [x0, y0, width, height]
1529
    inset_bounds = get_inset_bounds(
1✔
1530
        ax=ax,
1531
        loc=loc,
1532
        inset_height=inset_height,
1533
        inside_figure=inside_figure,
1534
        aspect_ratio=aspect_ratio,
1535
        border_pad=border_pad,
1536
    )
1537

1538
    ax2 = ax.inset_axes(
1✔
1539
        inset_bounds,
1540
        projection=projection,
1541
    )
1542

1543
    # Add global map
1544
    ax2.set_global()
1✔
1545
    ax2.add_feature(cfeature.LAND)
1✔
1546
    ax2.add_feature(cfeature.OCEAN)
1✔
1547

1548
    # Add extent polygon
1549
    _ = ax2.add_geometries(
1✔
1550
        [polygon],
1551
        ccrs.PlateCarree(),
1552
        facecolor="none",
1553
        edgecolor="red",
1554
        linewidth=0.3,
1555
    )
1556
    return ax2
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc