• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 14223363147

02 Apr 2025 03:21PM UTC coverage: 91.17% (+0.1%) from 91.034%
14223363147

push

github

ghiggi
Fix GRID map visualization for custom projections

100 of 122 new or added lines in 12 files covered. (81.97%)

8 existing lines in 3 files now uncovered.

15683 of 17202 relevant lines covered (91.17%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.44
/gpm/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
28
import inspect
1✔
29
import warnings
1✔
30

31
import cartopy
1✔
32
import cartopy.crs as ccrs
1✔
33
import cartopy.feature as cfeature
1✔
34
import matplotlib.pyplot as plt
1✔
35
import numpy as np
1✔
36
import xarray as xr
1✔
37
from pycolorbar import plot_colorbar, set_colorbar_fully_transparent
1✔
38
from pycolorbar.utils.mpl_legend import get_inset_bounds
1✔
39
from scipy.interpolate import griddata
1✔
40

41
import gpm
1✔
42
from gpm import get_plot_kwargs
1✔
43
from gpm.dataset.crs import compute_extent
1✔
44
from gpm.utils.area import get_lonlat_corners_from_centroids
1✔
45

46

47
def is_generator(obj):
1✔
48
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
1✔
49

50

51
def _call_optimize_layout(self):
1✔
52
    """Optimize the figure layout."""
53
    adapt_fig_size(ax=self.axes)
×
54
    self.figure.tight_layout()
×
55

56

57
def add_optimize_layout_method(p):
1✔
58
    """Add a method to optimize the figure layout using monkey patching."""
59
    p.optimize_layout = _call_optimize_layout.__get__(p, type(p))
1✔
60
    return p
1✔
61

62

63
def adapt_fig_size(ax, nrow=1, ncol=1):
1✔
64
    """Adjusts the figure height of the plot based on the aspect ratio of cartopy subplots.
65

66
    This function is intended to be called after all plotting has been completed.
67
    It operates under the assumption that all subplots within the figure share the same aspect ratio.
68

69
    Assumes that the first axis in the collection of axes is representative of all others.
70
    This means that all subplots are expected to have the same aspect ratio and size.
71

72
    The implementation is inspired by Mathias Hauser's mplotutils set_map_layout function.
73
    """
74
    # Determine the number of rows and columns of subplots in the figure.
75
    # This information is crucial for calculating the new height of the figure.
76
    # nrow, ncol, __, __ = ax.get_subplotspec().get_geometry()
77

78
    # Access the figure object from the axis to manipulate its properties.
79
    fig = ax.get_figure()
1✔
80

81
    # Retrieve the current size of the figure in inches.
82
    width, original_height = fig.get_size_inches()
1✔
83

84
    # A call to draw the canvas is required to make sure the geometry of the figure is up-to-date.
85
    # This ensures that subsequent calculations for adjusting the layout are based on the latest state.
86
    fig.canvas.draw()
1✔
87

88
    # Extract subplot parameters to understand the figure's layout.
89
    # These parameters include the margins of the figure and the spaces between subplots.
90
    bottom = fig.subplotpars.bottom
1✔
91
    top = fig.subplotpars.top
1✔
92
    left = fig.subplotpars.left
1✔
93
    right = fig.subplotpars.right
1✔
94
    hspace = fig.subplotpars.hspace  # vertical space between subplots
1✔
95
    wspace = fig.subplotpars.wspace  # horizontal space between subplots
1✔
96

97
    # Calculate the aspect ratio of the data in the subplot.
98
    # This ratio is used to adjust the height of the figure to match the aspect ratio of the data.
99
    aspect = ax.get_data_ratio()
1✔
100

101
    # Calculate the width of a single plot, considering the left and right margins,
102
    # the number of columns, and the space between columns.
103
    wp = (width - width * (left + (1 - right))) / (ncol + (ncol - 1) * wspace)
1✔
104

105
    # Calculate the height of a single plot using its width and the data aspect ratio.
106
    hp = wp * aspect
1✔
107

108
    # Calculate the new height of the figure, taking into account the number of rows,
109
    # the space between rows, and the top and bottom margins.
110
    height = (hp * (nrow + ((nrow - 1) * hspace))) / (1.0 - (bottom + (1 - top)))
1✔
111

112
    # Check if the new height is significantly reduced (more than halved).
113
    if original_height / height > 2:
1✔
114
        # Calculate the scale factor to adjust the figure size closer to the original.
115
        scale_factor = original_height / height / 2
×
116

117
        # Apply the scale factor to both width and height to maintain the aspect ratio.
118
        width *= scale_factor
×
119
        height *= scale_factor
×
120

121
    # Apply the calculated width and height to adjust the figure size.
122
    fig.set_figwidth(width)
1✔
123
    fig.set_figheight(height)
1✔
124

125

126
####--------------------------------------------------------------------------.
127

128

129
def infill_invalid_coords(xr_obj, x="lon", y="lat"):
1✔
130
    """Infill invalid coordinates.
131

132
    Interpolate the coordinates within the convex hull of data.
133
    Use nearest neighbour outside the convex hull of data.
134
    """
135
    # Copy object
136
    xr_obj = xr_obj.copy()
1✔
137
    lon = np.asanyarray(xr_obj[x].data)
1✔
138
    lat = np.asanyarray(xr_obj[y].data)
1✔
139
    # Retrieve infilled coordinates
140
    lon, lat, _ = get_valid_pcolormesh_inputs(x=lon, y=lat, data=None, mask_data=False)
1✔
141
    xr_obj[x].data = lon
1✔
142
    xr_obj[y].data = lat
1✔
143
    return xr_obj
1✔
144

145

146
def get_valid_pcolormesh_inputs(x, y, data, rgb=False, mask_data=True):
1✔
147
    """Infill invalid coordinates.
148

149
    Interpolate the coordinates within the convex hull of data.
150
    Use nearest neighbour outside the convex hull of data.
151

152
    This operation is required to plot with pcolormesh since it
153
    does not accept non-finite values in the coordinates.
154

155
    If  ``mask_data=True``, data values with invalid coordinates are masked
156
    and a numpy masked array is returned.
157
    Masked data values are not displayed in pcolormesh !
158
    If ``rgb=True``, it assumes the RGB dimension is the last data dimension.
159

160
    """
161
    # Retrieve mask of invalid coordinates
162
    x_invalid = ~np.isfinite(x)
1✔
163
    y_invalid = ~np.isfinite(y)
1✔
164
    mask = np.logical_or(x_invalid, y_invalid)
1✔
165

166
    # If no invalid coordinates, return original data
167
    if np.all(~mask):
1✔
168
        return x, y, data
1✔
169

170
    # Check at least ome valid coordinates
171
    if np.all(mask):
1✔
172
        raise ValueError("No valid coordinates.")
×
173

174
    # Mask the data
175
    if mask_data:
1✔
176
        if rgb:
1✔
177
            data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
1✔
178
            data_masked = np.ma.masked_where(data_mask, data)
1✔
179
        else:
180
            data_masked = np.ma.masked_where(mask, data)
1✔
181
    else:
182
        data_masked = data
×
183

184
    # Infill x and y
185
    # - Note: currently cause issue if NaN when crossing antimeridian ...
186
    # --> TODO: interpolation should be done in X,Y,Z
187
    if np.any(x_invalid):
1✔
188
        x = _interpolate_data(x, method="linear")  # interpolation
1✔
189
        x = _interpolate_data(x, method="nearest")  # nearest neighbours outside the convex hull
1✔
190
    if np.any(y_invalid):
1✔
191
        y = _interpolate_data(y, method="linear")  # interpolation
1✔
192
        y = _interpolate_data(y, method="nearest")  # nearest neighbours outside the convex hull
1✔
193
    return x, y, data_masked
1✔
194

195

196
def _interpolate_data(arr, method="linear"):
1✔
197
    # 1D coordinate (i.e. along_track/cross_track view)
198
    if arr.ndim == 1:
1✔
199
        return _interpolate_1d_coord(arr, method=method)
1✔
200
    # 2D coordinates (swath image)
201
    return _interpolate_2d_coord(arr, method=method)
1✔
202

203

204
def _interpolate_1d_coord(arr, method="linear"):
1✔
205
    # Find invalid locations
206
    is_invalid = ~np.isfinite(arr)
1✔
207

208
    # Find the indices of NaN values
209
    nan_indices = np.where(is_invalid)[0]
1✔
210

211
    # Return array if not NaN values
212
    if len(nan_indices) == 0:
1✔
213
        return arr
1✔
214

215
    # Find the indices of non-NaN values
216
    non_nan_indices = np.where(~is_invalid)
1✔
217

218
    # Create indices
219
    indices = np.arange(len(arr))
1✔
220

221
    # Points where we have valid data
222
    points = indices[non_nan_indices]
1✔
223

224
    # Points where data is NaN
225
    points_nan = indices[nan_indices]
1✔
226

227
    # Values at the non-NaN points
228
    values = arr[non_nan_indices]
1✔
229

230
    # Interpolate using griddata
231
    arr_new = arr.copy()
1✔
232
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
233
    return arr_new
1✔
234

235

236
def _interpolate_2d_coord(arr, method="linear"):
1✔
237
    # Find invalid locations
238
    is_invalid = ~np.isfinite(arr)
1✔
239

240
    # Find the indices of NaN values
241
    nan_indices = np.where(is_invalid)
1✔
242

243
    # Return array if not NaN values
244
    if len(nan_indices) == 0:
1✔
245
        return arr
×
246

247
    # Find the indices of non-NaN values
248
    non_nan_indices = np.where(~is_invalid)
1✔
249

250
    # Create a meshgrid of indices
251
    x, y = np.meshgrid(range(arr.shape[1]), range(arr.shape[0]))
1✔
252

253
    # Points (X, Y) where we have valid data
254
    points = np.array([y[non_nan_indices], x[non_nan_indices]]).T
1✔
255

256
    # Points where data is NaN
257
    points_nan = np.array([y[nan_indices], x[nan_indices]]).T
1✔
258

259
    # Values at the non-NaN points
260
    values = arr[non_nan_indices]
1✔
261

262
    # Interpolate using griddata
263
    arr_new = arr.copy()
1✔
264
    arr_new[nan_indices] = griddata(points, values, points_nan, method=method)
1✔
265
    return arr_new
1✔
266

267

268
def _mask_antimeridian_crossing_arr(arr, antimeridian_mask, rgb):
1✔
269
    if np.ma.is_masked(arr):
1✔
270
        if rgb:
1✔
271
            antimeridian_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
1✔
272
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
273
        else:
274
            combined_mask = np.logical_or(arr.mask, antimeridian_mask)
1✔
275
        arr = np.ma.masked_where(combined_mask, arr)
1✔
276
    else:
277
        if rgb:
1✔
278
            antimeridian_mask = np.broadcast_to(
1✔
279
                np.expand_dims(antimeridian_mask, axis=-1),
280
                arr.shape,
281
            )
282
        arr = np.ma.masked_where(antimeridian_mask, arr)
1✔
283
    return arr
1✔
284

285

286
def mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs):
1✔
287
    """Mask the array cells crossing the antimeridian.
288

289
    Here we assume not invalid lon coordinates anymore.
290
    Cartopy still bugs with several projections when data cross the antimeridian.
291
    By default, GPM-API mask data crossing the antimeridian.
292
    The GPM-API configuration default can be modified with: ``gpm.config.set({"viz_hide_antimeridian_data": False})``
293
    """
294
    antimeridian_mask = get_antimeridian_mask(lon)
1✔
295
    is_crossing_antimeridian = np.any(antimeridian_mask)
1✔
296
    if is_crossing_antimeridian:
1✔
297
        # Sanitize cmap to avoid cartopy bug related to cmap bad color
298
        # - Cartopy requires the bad color to be fully transparent
299
        plot_kwargs = _sanitize_cartopy_plot_kwargs(plot_kwargs)
1✔
300
        # Mask data based on GPM-API config 'viz_hide_antimeridian_data'
301
        if gpm.config.get("viz_hide_antimeridian_data"):  # default is True
1✔
302
            arr = _mask_antimeridian_crossing_arr(arr, antimeridian_mask=antimeridian_mask, rgb=rgb)
1✔
303
    return arr, plot_kwargs
1✔
304

305

306
def get_antimeridian_mask(lons):
1✔
307
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
308
    from scipy.ndimage import binary_dilation
1✔
309

310
    # Initialize mask
311
    n_y, n_x = lons.shape
1✔
312
    mask = np.zeros((n_y - 1, n_x - 1))
1✔
313
    # Check vertical edges
314
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
1✔
315
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
1✔
316
    mask[row_idx, col_idx] = 1
1✔
317
    # Check horizontal edges
318
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
1✔
319
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
1✔
320
    mask[row_idx, col_idx] = 1
1✔
321
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
322
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
323
    return binary_dilation(mask)
1✔
324

325

326
####--------------------------------------------------------------------------.
327
########################
328
#### Plot utilities ####
329
########################
330

331

332
def preprocess_rgb_dataarray(da, rgb):
1✔
333
    if rgb:
1✔
334
        if rgb not in da.dims:
1✔
335
            raise ValueError(f"The specified rgb='{rgb}' must be a dimension of the DataArray.")
1✔
336
        if da[rgb].size not in [3, 4]:
1✔
337
            raise ValueError("The RGB dimension must have size 3 or 4.")
×
338
        da = da.transpose(..., rgb)
1✔
339
    return da
1✔
340

341

342
def check_object_format(da, plot_kwargs, check_function, **function_kwargs):
1✔
343
    """Check object format and valid dimension names."""
344
    # Preprocess RGB DataArrays
345
    da = da.squeeze()
1✔
346
    da = preprocess_rgb_dataarray(da, plot_kwargs.get("rgb", False))
1✔
347
    # Retrieve rgb or FacetGrid column/row dimensions
348
    dims_dict = {key: plot_kwargs.get(key) for key in ["rgb", "col", "row"] if plot_kwargs.get(key, None)}
1✔
349
    # Check such dimensions are available
350
    for key, dim in dims_dict.items():
1✔
351
        if dim not in da.dims:
1✔
352
            raise ValueError(f"The DataArray does not have a {key}='{dim}' dimension.")
1✔
353
    # Subset DataArray to check if complies with specific check function
354
    isel_dict = {dim: 0 for dim in dims_dict.values()}
1✔
355
    check_function(da.isel(isel_dict), **function_kwargs)
1✔
356
    return da
1✔
357

358

359
def preprocess_figure_args(ax, fig_kwargs=None, subplot_kwargs=None, is_facetgrid=False):
1✔
360
    if is_facetgrid and ax is not None:
1✔
361
        raise ValueError("When plotting with FacetGrid, do not specify the 'ax'.")
1✔
362
    fig_kwargs = {} if fig_kwargs is None else fig_kwargs
1✔
363
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
364
    if ax is not None:
1✔
365
        if len(subplot_kwargs) >= 1:
1✔
366
            raise ValueError("Provide `subplot_kwargs`only if ``ax``is None")
1✔
367
        if len(fig_kwargs) >= 1:
1✔
368
            raise ValueError("Provide `fig_kwargs` only if ``ax``is None")
1✔
369
    return fig_kwargs
1✔
370

371

372
def preprocess_subplot_kwargs(subplot_kwargs, infer_crs=False, xr_obj=None):
1✔
373
    subplot_kwargs = {} if subplot_kwargs is None else subplot_kwargs
1✔
374
    subplot_kwargs = subplot_kwargs.copy()
1✔
375
    if "projection" not in subplot_kwargs:
1✔
376
        if infer_crs:
1✔
377
            subplot_kwargs["projection"] = xr_obj.gpm.cartopy_crs
1✔
378
        else:
379
            subplot_kwargs["projection"] = ccrs.PlateCarree()
1✔
380
    return subplot_kwargs
1✔
381

382

383
def infer_xy_labels(da, x=None, y=None, rgb=None):
1✔
384
    from xarray.plot.utils import _infer_xy_labels
1✔
385

386
    # Infer dimensions
387
    x, y = _infer_xy_labels(da, x=x, y=y, imshow=True, rgb=rgb)  # dummy flag for rgb
1✔
388
    return x, y
1✔
389

390

391
def infer_map_xy_coords(da, x=None, y=None):
1✔
392
    """
393
    Infer possible map x and y coordinates for the given DataArray.
394

395
    Parameters
396
    ----------
397
    da : xarray.DataArray
398
        The input DataArray.
399
    x : str, optional
400
        The name of the x (i.e. longitude) coordinate. If None, it will be inferred.
401
    y : str, optional
402
        The name of the y (i.e. latitude) coordinate. If None, it will be inferred.
403

404
    Returns
405
    -------
406
    tuple
407
        The inferred (x, y) coordinates.
408
    """
409
    possible_x_coords = ["x", "lon", "longitude"]
1✔
410
    possible_y_coords = ["y", "lat", "latitude"]
1✔
411

412
    if x is None:
1✔
413
        for coord in possible_x_coords:
1✔
414
            if coord in da.coords:
1✔
415
                x = coord
1✔
416
                break
1✔
417
        else:
418
            raise ValueError("Cannot infer x coordinate. Please provide the x coordinate.")
×
419

420
    if y is None:
1✔
421
        for coord in possible_y_coords:
1✔
422
            if coord in da.coords:
1✔
423
                y = coord
1✔
424
                break
1✔
425
        else:
426
            raise ValueError("Cannot infer y coordinate. Please provide the y coordinate.")
×
427

428
    return x, y
1✔
429

430

431
def _get_proj_str(crs):
1✔
432
    with warnings.catch_warnings():
1✔
433
        proj_str = crs.to_dict().get("proj", "")
1✔
434
    return proj_str
1✔
435

436

437
def initialize_cartopy_plot(
1✔
438
    ax,
439
    fig_kwargs,
440
    subplot_kwargs,
441
    add_background,
442
    add_gridlines,
443
    add_labels,
444
    infer_crs=False,
445
    xr_obj=None,
446
):
447
    """Initialize figure for cartopy plot if necessary."""
448
    # - Initialize figure
449
    if ax is None:
1✔
450
        fig_kwargs = preprocess_figure_args(
1✔
451
            ax=ax,
452
            fig_kwargs=fig_kwargs,
453
            subplot_kwargs=subplot_kwargs,
454
        )
455
        subplot_kwargs = preprocess_subplot_kwargs(subplot_kwargs, infer_crs=infer_crs, xr_obj=xr_obj)
1✔
456
        _, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
1✔
457

458
    # - Add cartopy background
459
    if add_background:
1✔
460
        ax = plot_cartopy_background(ax)
1✔
461

462
    # - Add gridlines and labels
463
    if add_gridlines or add_labels:
1✔
464
        _ = plot_cartopy_gridlines_and_labels(ax, add_gridlines=add_gridlines, add_labels=add_labels)
1✔
465

466
    return ax
1✔
467

468

469
def plot_cartopy_gridlines_and_labels(ax, add_gridlines=True, add_labels=True):
1✔
470
    """Add cartopy gridlines and labels."""
471
    alpha = 0.1 if add_gridlines else 0
1✔
472
    gl = ax.gridlines(
1✔
473
        crs=ccrs.PlateCarree(),
474
        draw_labels=add_labels,
475
        linewidth=1,
476
        color="gray",
477
        alpha=alpha,
478
        linestyle="-",
479
    )
480
    gl.top_labels = False  # gl.xlabels_top = False
1✔
481
    gl.right_labels = False  # gl.ylabels_right = False
1✔
482
    gl.xlines = True
1✔
483
    gl.ylines = True
1✔
484
    return gl
1✔
485

486

487
def plot_cartopy_background(ax):
1✔
488
    """Plot cartopy background."""
489
    # - Add coastlines
490
    ax.coastlines()
1✔
491
    # - Add land and ocean
492
    # --> Raise error with some projections currently (shapely bug)
493
    # --> https://github.com/SciTools/cartopy/issues/2176
494
    if _get_proj_str(ax.projection) not in ["laea"]:
1✔
495
        ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
1✔
496
        ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
1✔
497
    # - Add borders
498
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
1✔
499
    return ax
1✔
500

501

502
def plot_sides(sides, ax, **plot_kwargs):
1✔
503
    """Plot boundary sides.
504

505
    Expects a list of (lon, lat) tuples.
506
    """
507
    for side in sides:
1✔
508
        p = ax.plot(*side, transform=ccrs.Geodetic(), **plot_kwargs)
1✔
509
    return p[0]
1✔
510

511

512
####--------------------------------------------------------------------------.
513
##########################
514
#### Cartopy wrappers ####
515
##########################
516

517

518
def _sanitize_cartopy_plot_kwargs(plot_kwargs):
1✔
519
    """Sanitize 'cmap' to avoid cartopy bug related to cmap bad color.
520

521
    Cartopy requires the bad color to be fully transparent.
522
    """
523
    cmap = plot_kwargs.get("cmap", None)
1✔
524
    if cmap is not None:
1✔
525
        bad = cmap.get_bad()
1✔
526
        bad[3] = 0  # enforce to 0 (transparent)
1✔
527
        cmap.set_bad(bad)
1✔
528
        plot_kwargs["cmap"] = cmap
1✔
529
    return plot_kwargs
1✔
530

531

532
def plot_cartopy_imshow(
1✔
533
    ax,
534
    da,
535
    x,
536
    y,
537
    interpolation="nearest",
538
    add_colorbar=True,
539
    plot_kwargs=None,
540
    cbar_kwargs=None,
541
):
542
    """Plot imshow with cartopy."""
543
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
544

545
    # Infer x and y
546
    x, y = infer_xy_labels(da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
547

548
    # Align x,y, data dimensions
549
    # - Ensure image with correct dimensions orders
550
    # - It can happen that x/y coords does not have same dimension order of data array.
551
    da = da.transpose(*da[y].dims, *da[x].dims, ...)
1✔
552

553
    # - Retrieve data
554
    arr = np.asanyarray(da.data)
1✔
555

556
    # - Compute coordinates
557
    x_coords = da[x].to_numpy()
1✔
558
    y_coords = da[y].to_numpy()
1✔
559

560
    # Compute extent
561
    extent = compute_extent(x_coords=x_coords, y_coords=y_coords)
1✔
562
    # area_extent = area_def.area_extent # [xmin, ymin, x_max, y_max]
563
    # extent = [area_extent[i] for i in [0, 2, 1, 3]] # [x_min, x_max, y_min, y_max]
564

565
    # Infer CRS of data, extent and cartopy projection
566
    try:
1✔
567
        area_def = da.gpm.pyresample_area
1✔
568
        crs = area_def.to_cartopy_crs()
1✔
NEW
569
    except Exception:
×
570
        # Assume lon/lat CRS
NEW
571
        crs = ccrs.PlateCarree()
×
572

573
    # - Determine origin based on the orientation of da[y] values
574
    # - On the map, the y coordinates should grow from bottom to top
575
    # -->  If y coordinate is increasing, set origin="lower"
576
    # -->  If y coordinate is decreasing, set origin="upper"
577
    y_increasing = y_coords[1] > y_coords[0]
1✔
578
    origin = "lower" if y_increasing else "upper"  # OLD CODE
1✔
579

580
    # Deal with decreasing y
581
    if not y_increasing:  # decreasing y coordinates
1✔
582
        extent = [extent[i] for i in [0, 1, 3, 2]]
×
583

584
    # Deal with out of limits x
585
    # - PlateeCarree coordinates out of bounds when  lons are defined as 0-360)
586
    set_extent = True
1✔
587

588
    # Case where coordinates are defined as 0-360 with pm=0
589
    if extent[1] > crs.x_limits[1] or extent[0] < crs.x_limits[0]:
1✔
590
        set_extent = False
×
591

592
    # - Add variable field with cartopy
593
    # --> TODO: specify transform argument only if data CRS different from cartopy CRS
594
    # --> GPM-API automatically create the Cartopy GeoAxes with correct CRS
595
    rgb = plot_kwargs.pop("rgb", False)
1✔
596
    p = ax.imshow(
1✔
597
        arr,
598
        # transform=crs, # if uncommented,  cuts away half of first and last row pixels
599
        extent=extent,
600
        origin=origin,
601
        interpolation=interpolation,
602
        **plot_kwargs,
603
    )
604

605
    # - Set the extent
606
    # --> If some background is globally displayed, this zoom on the actual data region
607
    if set_extent:
1✔
608
        ax.set_extent(extent, crs=crs)
1✔
609

610
    # - Add colorbar
611
    if add_colorbar and not rgb:
1✔
612
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
613
    return p
1✔
614

615

616
def plot_cartopy_pcolormesh(
1✔
617
    ax,
618
    da,
619
    x,
620
    y,
621
    add_colorbar=True,
622
    add_swath_lines=True,
623
    plot_kwargs=None,
624
    cbar_kwargs=None,
625
):
626
    """Plot imshow with cartopy.
627

628
    x and y must represents longitude and latitudes.
629
    The function currently does not allow to zoom on regions across the antimeridian.
630
    The function mask scanning pixels which spans across the antimeridian.
631
    If the DataArray has a RGB dimension, plot_kwargs should contain the ``rgb``
632
    key with the name of the RGB dimension.
633

634
    """
635
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
1✔
636

637
    # Remove RGB from plot_kwargs
638
    rgb = plot_kwargs.pop("rgb", False)
1✔
639

640
    # Align x,y, data dimensions
641
    # - Ensure image with correct dimensions orders
642
    # - It can happen that x/y coords does not have same dimension order of data array.
643
    da = da.transpose(*da[y].dims, ...)
1✔
644

645
    # Get x, y, and array to plot
646
    da = preprocess_rgb_dataarray(da, rgb=rgb)
1✔
647
    da = da.compute()
1✔
648
    lon = da[x].data.copy()
1✔
649
    lat = da[y].data.copy()
1✔
650
    arr = da.data
1✔
651

652
    # Check if 1D coordinate (orbit nadir-view / transect / cross-section case)
653
    is_1d_case = lon.ndim == 1
1✔
654

655
    # Infill invalid value and mask data at invalid coordinates
656
    # - No invalid values after this function call
657
    lon, lat, arr = get_valid_pcolormesh_inputs(lon, lat, arr, rgb=rgb, mask_data=True)
1✔
658
    if is_1d_case:
1✔
659
        arr = np.expand_dims(arr, axis=1)
1✔
660

661
    # Ensure arguments
662
    if rgb:
1✔
663
        add_colorbar = False
1✔
664

665
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
666
    # - This enable correct masking of cells crossing the antimeridian
667
    lon, lat = get_lonlat_corners_from_centroids(lon, lat, parallel=False)
1✔
668

669
    # Mask cells crossing the antimeridian
670
    # - with gpm.config.set({"viz_hide_antimeridian_data": False}): can be used to modify the masking behaviour
671
    arr, plot_kwargs = mask_antimeridian_crossing_array(arr, lon, rgb, plot_kwargs)
1✔
672

673
    # Add variable field with cartopy
674
    _ = plot_kwargs.setdefault("shading", "flat")
1✔
675
    p = ax.pcolormesh(
1✔
676
        lon,
677
        lat,
678
        arr,
679
        transform=ccrs.PlateCarree(),
680
        **plot_kwargs,
681
    )
682
    # Add swath lines
683
    # - TODO: currently assume that dimensions are (cross_track, along_track)
684
    if add_swath_lines and not is_1d_case:
1✔
685
        sides = [(lon[0, :], lat[0, :]), (lon[-1, :], lat[-1, :])]
1✔
686
        plot_sides(sides=sides, ax=ax, linestyle="--", color="black")
1✔
687

688
    # Add colorbar
689
    if add_colorbar:
1✔
690
        _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
1✔
691
    return p
1✔
692

693

694
####-------------------------------------------------------------------------------.
695
#########################
696
#### Xarray wrappers ####
697
#########################
698

699

700
def _preprocess_xr_kwargs(add_colorbar, plot_kwargs, cbar_kwargs):
1✔
701
    if not add_colorbar:
1✔
702
        cbar_kwargs = None
1✔
703

704
    if "rgb" in plot_kwargs:
1✔
705
        cbar_kwargs = None
1✔
706
        add_colorbar = False
1✔
707
        args_to_keep = ["rgb", "col", "row", "origin"]  # alpha currently skipped if RGB
1✔
708
        plot_kwargs = {k: plot_kwargs[k] for k in args_to_keep if plot_kwargs.get(k, None) is not None}
1✔
709
    return add_colorbar, plot_kwargs, cbar_kwargs
1✔
710

711

712
def plot_xr_pcolormesh(
1✔
713
    ax,
714
    da,
715
    x,
716
    y,
717
    add_colorbar=True,
718
    cbar_kwargs=None,
719
    **plot_kwargs,
720
):
721
    """Plot pcolormesh with xarray."""
722
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
723
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
724
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
725
        add_colorbar=add_colorbar,
726
        plot_kwargs=plot_kwargs,
727
        cbar_kwargs=cbar_kwargs,
728
    )
729
    p = da.plot.pcolormesh(
1✔
730
        x=x,
731
        y=y,
732
        ax=ax,
733
        add_colorbar=add_colorbar,
734
        cbar_kwargs=cbar_kwargs,
735
        **plot_kwargs,
736
    )
737

738
    # Add variable name as title (if not FacetGrid)
739
    if not is_facetgrid:
1✔
740
        p.axes.set_title(da.name)
1✔
741

742
    if add_colorbar and ticklabels is not None:
1✔
743
        p.colorbar.ax.set_yticklabels(ticklabels)
×
744
    return p
1✔
745

746

747
def plot_xr_imshow(
1✔
748
    ax,
749
    da,
750
    x,
751
    y,
752
    interpolation="nearest",
753
    add_colorbar=True,
754
    add_labels=True,
755
    cbar_kwargs=None,
756
    visible_colorbar=True,
757
    **plot_kwargs,
758
):
759
    """Plot imshow with xarray.
760

761
    The colorbar is added with xarray to enable to display multiple colorbars
762
    when calling this function multiple times on different fields with
763
    different colorbars.
764
    """
765
    is_facetgrid = bool("col" in plot_kwargs or "row" in plot_kwargs)
1✔
766
    ticklabels = cbar_kwargs.pop("ticklabels", None)
1✔
767
    add_colorbar, plot_kwargs, cbar_kwargs = _preprocess_xr_kwargs(
1✔
768
        add_colorbar=add_colorbar,
769
        plot_kwargs=plot_kwargs,
770
        cbar_kwargs=cbar_kwargs,
771
    )
772
    # Allow using coords as x/y axis
773
    # BUG - Current bug in xarray
774
    if plot_kwargs.get("rgb", None) is not None:
1✔
775
        if x not in da.dims:
1✔
776
            da = da.swap_dims({list(da[x].dims)[0]: x})
×
777
        if y not in da.dims:
1✔
778
            da = da.swap_dims({list(da[y].dims)[0]: y})
×
779

780
    p = da.plot.imshow(
1✔
781
        x=x,
782
        y=y,
783
        ax=ax,
784
        interpolation=interpolation,
785
        add_colorbar=add_colorbar,
786
        add_labels=add_labels,
787
        cbar_kwargs=cbar_kwargs,
788
        **plot_kwargs,
789
    )
790

791
    # Add variable name as title (if not FacetGrid)
792
    if not is_facetgrid:
1✔
793
        p.axes.set_title(da.name)
1✔
794

795
    # Add colorbar ticklabels
796
    if add_colorbar and ticklabels is not None:
1✔
797
        p.colorbar.ax.set_yticklabels(ticklabels)
1✔
798

799
    # Make the colorbar fully transparent with a smart trick ;)
800
    # - TODO: this still cause issues when plotting 2 colorbars !
801
    if add_colorbar and not visible_colorbar:
1✔
802
        set_colorbar_fully_transparent(p)
1✔
803

804
    # Add manually the colorbar
805
    # p = da.plot.imshow(
806
    #     x=x,
807
    #     y=y,
808
    #     ax=ax,
809
    #     interpolation=interpolation,
810
    #     add_colorbar=False,
811
    #     **plot_kwargs,
812
    # )
813
    # plt.title(da.name)
814
    # if add_colorbar:
815
    #     _ = plot_colorbar(p=p, ax=ax, **cbar_kwargs)
816
    return p
1✔
817

818

819
####--------------------------------------------------------------------------.
820
####################
821
#### Plot Image ####
822
####################
823

824

825
def _plot_image(
1✔
826
    da,
827
    x=None,
828
    y=None,
829
    ax=None,
830
    add_colorbar=True,
831
    add_labels=True,
832
    interpolation="nearest",
833
    fig_kwargs=None,
834
    cbar_kwargs=None,
835
    **plot_kwargs,
836
):
837
    """Plot GPM orbit granule as in image."""
838
    from gpm.checks import is_grid, is_orbit
1✔
839
    from gpm.visualization.facetgrid import sanitize_facetgrid_plot_kwargs
1✔
840

841
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs)
1✔
842

843
    # - Initialize figure
844
    if ax is None:
1✔
845
        _, ax = plt.subplots(**fig_kwargs)
1✔
846

847
    # - Sanitize plot_kwargs set by by xarray FacetGrid.map_dataarray
848
    is_facetgrid = plot_kwargs.get("_is_facetgrid", False)
1✔
849
    plot_kwargs = sanitize_facetgrid_plot_kwargs(plot_kwargs)
1✔
850

851
    # - If not specified, retrieve/update plot_kwargs and cbar_kwargs as function of product name
852
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
853
        name=da.name,
854
        user_plot_kwargs=plot_kwargs,
855
        user_cbar_kwargs=cbar_kwargs,
856
    )
857

858
    # Define x and y
859
    x, y = infer_xy_labels(da=da, x=x, y=y, rgb=plot_kwargs.get("rgb", None))
1✔
860

861
    # - Plot with xarray
862
    p = plot_xr_imshow(
1✔
863
        ax=ax,
864
        da=da,
865
        x=x,
866
        y=y,
867
        interpolation=interpolation,
868
        add_colorbar=add_colorbar,
869
        add_labels=add_labels,
870
        cbar_kwargs=cbar_kwargs,
871
        **plot_kwargs,
872
    )
873

874
    # Add custom labels
875
    default_labels = {
1✔
876
        "orbit": {"along_track": "Along-Track", "x": "Along-Track", "cross_track": "Cross-Track", "y": "Cross-Track"},
877
        "grid": {
878
            "lon": "Longitude",
879
            "longitude": "Longitude",
880
            "x": "Longitude",
881
            "lat": "Latitude",
882
            "latitude": "Latitude",
883
            "y": "Latitude",
884
        },
885
    }
886

887
    if add_labels:
1✔
888
        if is_orbit(da):
1✔
889
            ax.set_xlabel(default_labels["orbit"].get(x, x))
1✔
890
            ax.set_ylabel(default_labels["orbit"].get(y, y))
1✔
891
        elif is_grid(da):
1✔
892
            ax.set_xlabel(default_labels["grid"].get(x, x))
1✔
893
            ax.set_ylabel(default_labels["grid"].get(y, y))
1✔
894

895
    # - Monkey patch the mappable instance to add optimize_layout
896
    if not is_facetgrid:
1✔
897
        p = add_optimize_layout_method(p)
1✔
898
    # - Return mappable
899
    return p
1✔
900

901

902
def _plot_image_facetgrid(
1✔
903
    da,
904
    x=None,
905
    y=None,
906
    ax=None,
907
    add_colorbar=True,
908
    add_labels=True,
909
    interpolation="nearest",
910
    fig_kwargs=None,
911
    cbar_kwargs=None,
912
    **plot_kwargs,
913
):
914
    """Plot 2D fields with FacetGrid."""
915
    from gpm.visualization.facetgrid import ImageFacetGrid
1✔
916

917
    # Check inputs
918
    fig_kwargs = preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, is_facetgrid=True)
1✔
919

920
    # Retrieve GPM-API defaults cmap and cbar kwargs
921
    variable = da.name
1✔
922
    plot_kwargs, cbar_kwargs = get_plot_kwargs(
1✔
923
        name=variable,
924
        user_plot_kwargs=plot_kwargs,
925
        user_cbar_kwargs=cbar_kwargs,
926
    )
927

928
    # Disable colorbar if rgb
929
    # - Move this to pycolorbar !
930
    # - Also remove cmap, norm, vmin and vmax in plot_kwargs
931
    if plot_kwargs.get("rgb", False):
1✔
932
        add_colorbar = False
1✔
933
        cbar_kwargs = {}
1✔
934

935
    # Create FacetGrid
936
    fc = ImageFacetGrid(
1✔
937
        data=da.compute(),
938
        col=plot_kwargs.pop("col", None),
939
        row=plot_kwargs.pop("row", None),
940
        col_wrap=plot_kwargs.pop("col_wrap", None),
941
        axes_pad=plot_kwargs.pop("axes_pad", None),
942
        fig_kwargs=fig_kwargs,
943
        cbar_kwargs=cbar_kwargs,
944
        add_colorbar=add_colorbar,
945
        aspect=plot_kwargs.pop("aspect", False),
946
        facet_height=plot_kwargs.pop("facet_height", 3),
947
        facet_aspect=plot_kwargs.pop("facet_aspect", 1),
948
    )
949

950
    # Plot the maps
951
    fc = fc.map_dataarray(
1✔
952
        _plot_image,
953
        x=x,
954
        y=y,
955
        add_colorbar=False,
956
        add_labels=add_labels,
957
        interpolation=interpolation,
958
        cbar_kwargs=cbar_kwargs,
959
        **plot_kwargs,
960
    )
961

962
    # Remove duplicated or all labels
963
    fc.remove_duplicated_axis_labels()
1✔
964

965
    if not add_labels:
1✔
966
        fc.remove_left_ticks_and_labels()
×
967
        fc.remove_bottom_ticks_and_labels()
×
968

969
    # Add colorbar
970
    if add_colorbar:
1✔
971
        fc.add_colorbar(**cbar_kwargs)
1✔
972

973
    return fc
1✔
974

975

976
def plot_image(
1✔
977
    da,
978
    x=None,
979
    y=None,
980
    ax=None,
981
    add_colorbar=True,
982
    add_labels=True,
983
    interpolation="nearest",
984
    fig_kwargs=None,
985
    cbar_kwargs=None,
986
    **plot_kwargs,
987
):
988
    """Plot data using imshow.
989

990
    Parameters
991
    ----------
992
    da : xarray.DataArray
993
        xarray DataArray.
994
    x : str, optional
995
        X dimension name.
996
        If ``None``, takes the second dimension.
997
        The default is ``None``.
998
    y : str, optional
999
        Y dimension name.
1000
        If ``None``, takes the first dimension.
1001
        The default is ``None``.
1002
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1003
        The matplotlib axes where to plot the image.
1004
        If ``None``, a figure is initialized using the
1005
        specified ``fig_kwargs``.
1006
        The default is ``None``.
1007
    add_colorbar : bool, optional
1008
        Whether to add a colorbar. The default is ``True``.
1009
    add_labels : bool, optional
1010
        Whether to add labels to the plot. The default is ``True``.
1011
    interpolation : str, optional
1012
        Argument to be passed to imshow.
1013
        The default is ``"nearest"``.
1014
    fig_kwargs : dict, optional
1015
        Figure options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1016
        The default is ``None``.
1017
        Only used if ``ax`` is ``None``.
1018
    subplot_kwargs : dict, optional
1019
        Subplot options to be passed to :py:class:`matplotlib.pyplot.subplots`.
1020
        The default is ``None``.
1021
        Only used if ```ax``` is ``None``.
1022
    cbar_kwargs : dict, optional
1023
        Colorbar options. The default is ``None``.
1024
    **plot_kwargs
1025
        Additional arguments to be passed to the plotting function.
1026
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1027
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1028
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1029

1030

1031
    """
1032
    from gpm.checks import check_is_spatial_2d, is_spatial_2d
1✔
1033

1034
    # Plot orbit
1035
    if not is_spatial_2d(da, strict=False):
1✔
1036
        raise ValueError("Can not plot. It's not a spatial 2D object.")
1✔
1037

1038
    # Check inputs
1039
    da = check_object_format(da, plot_kwargs=plot_kwargs, check_function=check_is_spatial_2d, strict=True)
1✔
1040

1041
    # Plot FacetGrid with xarray imshow
1042
    if "col" in plot_kwargs or "row" in plot_kwargs:
1✔
1043
        p = _plot_image_facetgrid(
1✔
1044
            da=da,
1045
            x=x,
1046
            y=y,
1047
            ax=ax,
1048
            add_colorbar=add_colorbar,
1049
            add_labels=add_labels,
1050
            interpolation=interpolation,
1051
            fig_kwargs=fig_kwargs,
1052
            cbar_kwargs=cbar_kwargs,
1053
            **plot_kwargs,
1054
        )
1055
    # Plot with xarray imshow
1056
    else:
1057
        p = _plot_image(
1✔
1058
            da=da,
1059
            x=x,
1060
            y=y,
1061
            ax=ax,
1062
            add_colorbar=add_colorbar,
1063
            add_labels=add_labels,
1064
            interpolation=interpolation,
1065
            fig_kwargs=fig_kwargs,
1066
            cbar_kwargs=cbar_kwargs,
1067
            **plot_kwargs,
1068
        )
1069
    # Return mappable
1070
    return p
1✔
1071

1072

1073
####--------------------------------------------------------------------------.
1074
##################
1075
#### Plot map ####
1076
##################
1077

1078

1079
def plot_map(
1✔
1080
    da,
1081
    x=None,
1082
    y=None,
1083
    ax=None,
1084
    interpolation="nearest",  # used only for GPM grid objects
1085
    add_colorbar=True,
1086
    add_background=True,
1087
    add_labels=True,
1088
    add_gridlines=True,
1089
    add_swath_lines=True,  # used only for GPM orbit objects
1090
    fig_kwargs=None,
1091
    subplot_kwargs=None,
1092
    cbar_kwargs=None,
1093
    **plot_kwargs,
1094
):
1095
    """Plot data on a geographic map.
1096

1097
    Parameters
1098
    ----------
1099
    da : xarray.DataArray
1100
        xarray DataArray.
1101
    x : str, optional
1102
        Longitude coordinate name.
1103
        If ``None``, takes the second dimension.
1104
        The default is ``None``.
1105
    y : str, optional
1106
        Latitude coordinate name.
1107
        If ``None``, takes the first dimension.
1108
        The default is ``None``.
1109
    ax : cartopy.mpl.geoaxes.GeoAxes, optional
1110
        The cartopy GeoAxes where to plot the map.
1111
        If ``None``, a figure is initialized using the
1112
        specified ``fig_kwargs`` and ``subplot_kwargs``.
1113
        The default is ``None``.
1114
    add_colorbar : bool, optional
1115
        Whether to add a colorbar. The default is ``True``.
1116
    add_labels : bool, optional
1117
        Whether to add cartopy labels to the plot. The default is ``True``.
1118
    add_gridlines : bool, optional
1119
        Whether to add cartopy gridlines to the plot. The default is ``True``.
1120
    add_swath_lines : bool, optional
1121
        Whether to plot the swath sides with a dashed line. The default is ``True``.
1122
        This argument only applies for ORBIT objects.
1123
    add_background : bool, optional
1124
        Whether to add the map background. The default is ``True``.
1125
    interpolation : str, optional
1126
        Argument to be passed to :py:class:`matplotlib.axes.Axes.imshow`. Only applies for GRID objects.
1127
        The default is ``"nearest"``.
1128
    fig_kwargs : dict, optional
1129
        Figure options to be passed to `matplotlib.pyplot.subplots`.
1130
        The default is ``None``.
1131
        Only used if ``ax`` is ``None``.
1132
    subplot_kwargs : dict, optional
1133
        Dictionary of keyword arguments for :py:class:`matplotlib.pyplot.subplots`.
1134
        Must contain the Cartopy CRS ` ``projection`` key if specified.
1135
        The default is ``None``.
1136
        Only used if ``ax`` is ``None``.
1137
    cbar_kwargs : dict, optional
1138
        Colorbar options. The default is ``None``.
1139
    **plot_kwargs
1140
        Additional arguments to be passed to the plotting function.
1141
        Examples include ``cmap``, ``norm``, ``vmin``, ``vmax``, ``levels``, ...
1142
        For FacetGrid plots, specify ``row``, ``col`` and ``col_wrap``.
1143
        With ``rgb`` you can specify the name of the xarray.DataArray RGB dimension.
1144

1145

1146
    """
1147
    from gpm.checks import has_spatial_dim, is_grid, is_orbit, is_spatial_2d
1✔
1148
    from gpm.visualization.grid import plot_grid_map
1✔
1149
    from gpm.visualization.orbit import plot_orbit_map
1✔
1150

1151
    # Plot orbit
1152
    # - allow vertical or other dimensions for FacetGrid
1153
    # - allow to plot a swath of size 1 (i.e. nadir-looking)
1154
    if is_orbit(da) and has_spatial_dim(da):
1✔
1155
        p = plot_orbit_map(
1✔
1156
            da=da,
1157
            x=x,
1158
            y=y,
1159
            ax=ax,
1160
            add_colorbar=add_colorbar,
1161
            add_background=add_background,
1162
            add_gridlines=add_gridlines,
1163
            add_labels=add_labels,
1164
            add_swath_lines=add_swath_lines,
1165
            fig_kwargs=fig_kwargs,
1166
            subplot_kwargs=subplot_kwargs,
1167
            cbar_kwargs=cbar_kwargs,
1168
            **plot_kwargs,
1169
        )
1170
    # Plot grid
1171
    elif is_grid(da) and is_spatial_2d(da, strict=False):
1✔
1172
        p = plot_grid_map(
1✔
1173
            da=da,
1174
            x=x,
1175
            y=y,
1176
            ax=ax,
1177
            interpolation=interpolation,
1178
            add_colorbar=add_colorbar,
1179
            add_background=add_background,
1180
            add_gridlines=add_gridlines,
1181
            add_labels=add_labels,
1182
            fig_kwargs=fig_kwargs,
1183
            subplot_kwargs=subplot_kwargs,
1184
            cbar_kwargs=cbar_kwargs,
1185
            **plot_kwargs,
1186
        )
1187
    else:
1188
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial 2D object.")
1✔
1189
    # Return mappable
1190
    return p
1✔
1191

1192

1193
def plot_map_mesh(
1✔
1194
    xr_obj,
1195
    x=None,
1196
    y=None,
1197
    ax=None,
1198
    edgecolors="k",
1199
    linewidth=0.1,
1200
    add_background=True,
1201
    add_gridlines=True,
1202
    add_labels=True,
1203
    fig_kwargs=None,
1204
    subplot_kwargs=None,
1205
    **plot_kwargs,
1206
):
1207
    from gpm.checks import is_grid, is_orbit
1✔
1208
    from gpm.visualization.grid import plot_grid_mesh
1✔
1209
    from gpm.visualization.orbit import plot_orbit_mesh
1✔
1210

1211
    # Plot orbit
1212
    if is_orbit(xr_obj):
1✔
1213
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1214
        p = plot_orbit_mesh(
1✔
1215
            da=xr_obj[y],
1216
            ax=ax,
1217
            x=x,
1218
            y=y,
1219
            edgecolors=edgecolors,
1220
            linewidth=linewidth,
1221
            add_background=add_background,
1222
            add_gridlines=add_gridlines,
1223
            add_labels=add_labels,
1224
            fig_kwargs=fig_kwargs,
1225
            subplot_kwargs=subplot_kwargs,
1226
            **plot_kwargs,
1227
        )
1228
    elif is_grid(xr_obj):
1✔
1229
        p = plot_grid_mesh(
1✔
1230
            xr_obj=xr_obj,
1231
            x=x,
1232
            y=y,
1233
            ax=ax,
1234
            edgecolors=edgecolors,
1235
            linewidth=linewidth,
1236
            add_background=add_background,
1237
            add_gridlines=add_gridlines,
1238
            add_labels=add_labels,
1239
            fig_kwargs=fig_kwargs,
1240
            subplot_kwargs=subplot_kwargs,
1241
            **plot_kwargs,
1242
        )
1243
    else:
1244
        raise ValueError("Can not plot. It's neither a GPM GRID or GPM ORBIT spatial object.")
×
1245
    # Return mappable
1246
    return p
1✔
1247

1248

1249
def plot_map_mesh_centroids(
1✔
1250
    xr_obj,
1251
    x=None,
1252
    y=None,
1253
    ax=None,
1254
    c="r",
1255
    s=1,
1256
    add_background=True,
1257
    add_gridlines=True,
1258
    add_labels=True,
1259
    fig_kwargs=None,
1260
    subplot_kwargs=None,
1261
    **plot_kwargs,
1262
):
1263
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
1264
    from gpm.checks import is_grid, is_orbit
1✔
1265

1266
    # Initialize figure if necessary
1267
    ax = initialize_cartopy_plot(
1✔
1268
        ax=ax,
1269
        fig_kwargs=fig_kwargs,
1270
        subplot_kwargs=subplot_kwargs,
1271
        add_background=add_background,
1272
        add_gridlines=add_gridlines,
1273
        add_labels=add_labels,
1274
        infer_crs=True,
1275
        xr_obj=xr_obj,
1276
    )
1277

1278
    # Retrieve orbits lon, lat coordinates
1279
    if is_orbit(xr_obj):
1✔
1280
        x, y = infer_map_xy_coords(xr_obj, x=x, y=y)
1✔
1281

1282
    # Retrieve grid centroids mesh
1283
    if is_grid(xr_obj):
1✔
1284
        x, y = infer_xy_labels(xr_obj, x=x, y=y)
1✔
1285
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
1✔
1286

1287
    # Extract numpy arrays
1288
    lon = xr_obj[x].to_numpy()
1✔
1289
    lat = xr_obj[y].to_numpy()
1✔
1290

1291
    # Plot centroids
1292
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
1✔
1293

1294
    # Return mappable
1295
    return p
1✔
1296

1297

1298
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
1299
    """Create a 2D mesh coordinates DataArray.
1300

1301
    Takes as input the 1D coordinate arrays from an existing xarray.DataArray or xarray.Dataset object.
1302

1303
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
1304
    the data values to NaN.
1305

1306
    Parameters
1307
    ----------
1308
    xr_obj : xarray.DataArray or xarray.Dataset
1309
        The input xarray object containing the 1D coordinate arrays.
1310
    x : str
1311
        The name of the x-coordinate in `xr_obj`.
1312
    y : str
1313
        The name of the y-coordinate in `xr_obj`.
1314

1315
    Returns
1316
    -------
1317
    da_mesh : xarray.DataArray
1318
        A 2D xarray.DataArray with mesh coordinates for `x` and `y`, and NaN values for data points.
1319

1320
    Notes
1321
    -----
1322
    The resulting xarray.DataArray has dimensions named 'y' and 'x', corresponding to the
1323
    y and x coordinates respectively.
1324
    The coordinate values are taken directly from the input 1D coordinate arrays,
1325
    and the data values are set to NaN.
1326

1327
    """
1328
    # Extract 1D coordinate arrays
1329
    x_coords = xr_obj[x].to_numpy()
1✔
1330
    y_coords = xr_obj[y].to_numpy()
1✔
1331

1332
    # Create 2D meshgrid for x and y coordinates
1333
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
1✔
1334

1335
    # Create a 2D array of NaN values with the same shape as the meshgrid
1336
    dummy_values = np.full(X.shape, np.nan)
1✔
1337

1338
    # Create a new DataArray with 2D coordinates and NaN values
1339
    return xr.DataArray(
1✔
1340
        dummy_values,
1341
        coords={x: (("y", "x"), X), y: (("y", "x"), Y)},
1342
        dims=("y", "x"),
1343
    )
1344

1345

1346
####--------------------------------------------------------------------------.
1347

1348

1349
def _plot_labels(
1✔
1350
    xr_obj,
1351
    label_name=None,
1352
    max_n_labels=50,
1353
    add_colorbar=True,
1354
    interpolation="nearest",
1355
    cmap="Paired",
1356
    fig_kwargs=None,
1357
    **plot_kwargs,
1358
):
1359
    """Plot labels.
1360

1361
    The maximum allowed number of labels to plot is 'max_n_labels'.
1362
    """
1363
    from ximage.labels.labels import get_label_indices, redefine_label_array
1✔
1364
    from ximage.labels.plot_labels import get_label_colorbar_settings
1✔
1365

1366
    from gpm.visualization.plot import plot_image
1✔
1367

1368
    if isinstance(xr_obj, xr.Dataset):
1✔
1369
        dataarray = xr_obj[label_name]
1✔
1370
    else:
1371
        dataarray = xr_obj[label_name] if label_name is not None else xr_obj
1✔
1372

1373
    dataarray = dataarray.compute()
1✔
1374
    label_indices = get_label_indices(dataarray)
1✔
1375
    n_labels = len(label_indices)
1✔
1376
    if add_colorbar and n_labels > max_n_labels:
1✔
1377
        msg = f"""The array currently contains {n_labels} labels
1✔
1378
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
1379
        print(msg)
1✔
1380
        add_colorbar = False
1✔
1381
    # Relabel array from 1 to ... for plotting
1382
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
1✔
1383
    # Replace 0 with nan
1384
    dataarray = dataarray.where(dataarray > 0)
1✔
1385
    # Define appropriate colormap
1386
    default_plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap=cmap)
1✔
1387
    default_plot_kwargs.update(plot_kwargs)
1✔
1388
    # Plot image
1389
    return plot_image(
1✔
1390
        dataarray,
1391
        interpolation=interpolation,
1392
        add_colorbar=add_colorbar,
1393
        cbar_kwargs=cbar_kwargs,
1394
        fig_kwargs=fig_kwargs,
1395
        **default_plot_kwargs,
1396
    )
1397

1398

1399
def plot_labels(
1✔
1400
    obj,  # Dataset, DataArray or generator
1401
    label_name=None,
1402
    max_n_labels=50,
1403
    add_colorbar=True,
1404
    interpolation="nearest",
1405
    cmap="Paired",
1406
    fig_kwargs=None,
1407
    **plot_kwargs,
1408
):
1409
    if is_generator(obj):
1✔
1410
        for _, xr_obj in obj:  # label_id, xr_obj
1✔
1411
            p = _plot_labels(
1✔
1412
                xr_obj=xr_obj,
1413
                label_name=label_name,
1414
                max_n_labels=max_n_labels,
1415
                add_colorbar=add_colorbar,
1416
                interpolation=interpolation,
1417
                cmap=cmap,
1418
                fig_kwargs=fig_kwargs,
1419
                **plot_kwargs,
1420
            )
1421
            plt.show()
1✔
1422
    else:
1423
        p = _plot_labels(
1✔
1424
            xr_obj=obj,
1425
            label_name=label_name,
1426
            max_n_labels=max_n_labels,
1427
            add_colorbar=add_colorbar,
1428
            interpolation=interpolation,
1429
            cmap=cmap,
1430
            fig_kwargs=fig_kwargs,
1431
            **plot_kwargs,
1432
        )
1433
    return p
1✔
1434

1435

1436
def plot_patches(
1✔
1437
    patch_gen,
1438
    variable=None,
1439
    add_colorbar=True,
1440
    interpolation="nearest",
1441
    fig_kwargs=None,
1442
    cbar_kwargs=None,
1443
    **plot_kwargs,
1444
):
1445
    """Plot patches."""
1446
    from gpm.visualization.plot import plot_image
1✔
1447

1448
    # Plot patches
1449
    for _, xr_patch in patch_gen:  # label_id, xr_obj
1✔
1450
        if isinstance(xr_patch, xr.Dataset):
1✔
1451
            if variable is None:
1✔
1452
                raise ValueError("'variable' must be specified when plotting xarray.Dataset patches.")
1✔
1453
            xr_patch = xr_patch[variable]
1✔
1454
        try:
1✔
1455
            plot_image(
1✔
1456
                xr_patch,
1457
                interpolation=interpolation,
1458
                add_colorbar=add_colorbar,
1459
                fig_kwargs=fig_kwargs,
1460
                cbar_kwargs=cbar_kwargs,
1461
                **plot_kwargs,
1462
            )
1463
            plt.show()
1✔
1464
        except Exception:
1✔
1465
            pass
1✔
1466

1467

1468
####--------------------------------------------------------------------------.
1469

1470

1471
def add_map_inset(ax, loc="upper left", inset_height=0.2, projection=None, inside_figure=True, border_pad=0):
1✔
1472
    """Adds an inset map to a matplotlib axis using Cartopy, highlighting the extent of the main plot.
1473

1474
    This function creates a smaller map inset within a larger map plot to show a global view or
1475
    contextual location of the main plot's extent.
1476

1477
    It uses Cartopy for map projections and plotting, and it outlines the extent of the main plot
1478
    within the inset to provide geographical context.
1479

1480
    Parameters
1481
    ----------
1482
    ax : matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes
1483
        The main matplotlib or cartopy axis object where the geographic data is plotted.
1484
    loc : str, optional
1485
        The location of the inset map within the main plot.
1486
        Options include ``'lower left'``, ``'lower right'``,
1487
        ``'upper left'``, and ``'upper right'``. The default is ``'upper left'``.
1488
    inset_height : float, optional
1489
        The size of the inset height, specified as a fraction of the figure's height.
1490
        For example, a value of 0.2 indicates that the inset's height will be 20% of the figure's height.
1491
        The aspect ratio (of the map inset) will govern the ``inset_width``.
1492
    inside_figure : bool, optional
1493
        Determines whether the inset is constrained to be fully inside the figure bounds. If ``True`` (default),
1494
        the inset is placed fully within the figure. If ``False``, the inset can extend beyond the figure's edges,
1495
        allowing for a half-outside placement.
1496
    projection: cartopy.crs.Projection, optional
1497
        A cartopy projection. If ``None``, am Orthographic projection centered on the extent center is used.
1498

1499
    Returns
1500
    -------
1501
    ax2 : cartopy.mpl.geoaxes.GeoAxes
1502
        The Cartopy GeoAxesSubplot object for the inset map.
1503

1504
    Notes
1505
    -----
1506
    The function adjusts the extent of the inset map based on the main plot's extent, adding a
1507
    slight padding for visual clarity. It then overlays a red outline indicating the main plot's
1508
    geographical extent.
1509

1510
    Examples
1511
    --------
1512
    >>> p = da.gpm.plot_map()
1513
    >>> add_map_inset(ax=p.axes, loc="upper left", inset_height=0.15)
1514

1515
    This example creates a main plot with a specified extent and adds an upper-left inset map
1516
    showing the global context of the main plot's extent.
1517

1518
    """
1519
    from shapely import Polygon
1✔
1520

1521
    from gpm.utils.geospatial import extend_geographic_extent
1✔
1522

1523
    # Retrieve map extent and bounds
1524
    extent = ax.get_extent()
1✔
1525
    extent = extend_geographic_extent(extent, padding=0.5)
1✔
1526
    bounds = [extent[i] for i in [0, 2, 1, 3]]
1✔
1527

1528
    # Create Cartopy Polygon
1529
    polygon = Polygon.from_bounds(*bounds)
1✔
1530

1531
    # Define Orthographic projection
1532
    if projection is None:
1✔
1533
        lon_min, lon_max, lat_min, lat_max = extent
1✔
1534
        projection = ccrs.Orthographic(
1✔
1535
            central_latitude=(lat_min + lat_max) / 2,
1536
            central_longitude=(lon_min + lon_max) / 2,
1537
        )
1538

1539
    # Define aspect ratio of the map inset
1540
    aspect_ratio = float(np.diff(projection.x_limits).item() / np.diff(projection.y_limits).item())
1✔
1541

1542
    # Define inset location relative to main plot (ax) in normalized units
1543
    # - Lower-left corner of inset Axes, and its width and height
1544
    # - [x0, y0, width, height]
1545
    inset_bounds = get_inset_bounds(
1✔
1546
        ax=ax,
1547
        loc=loc,
1548
        inset_height=inset_height,
1549
        inside_figure=inside_figure,
1550
        aspect_ratio=aspect_ratio,
1551
        border_pad=border_pad,
1552
    )
1553

1554
    ax2 = ax.inset_axes(
1✔
1555
        inset_bounds,
1556
        projection=projection,
1557
    )
1558

1559
    # Add global map
1560
    ax2.set_global()
1✔
1561
    ax2.add_feature(cfeature.LAND)
1✔
1562
    ax2.add_feature(cfeature.OCEAN)
1✔
1563

1564
    # Add extent polygon
1565
    _ = ax2.add_geometries(
1✔
1566
        [polygon],
1567
        ccrs.PlateCarree(),
1568
        facecolor="none",
1569
        edgecolor="red",
1570
        linewidth=0.3,
1571
    )
1572
    return ax2
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc