• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 7798551251

06 Feb 2024 10:55AM UTC coverage: 59.56%. Remained the same
7798551251

push

github

ghiggi
Fix documentation

3305 of 5549 relevant lines covered (59.56%)

0.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.89
/gpm_api/visualization/plot.py
1
#!/usr/bin/env python3
2
"""
1✔
3
Created on Sat Dec 10 18:42:28 2022
4

5
@author: ghiggi
6
"""
7
import inspect
1✔
8

9
import cartopy
1✔
10
import cartopy.crs as ccrs
1✔
11
import matplotlib.pyplot as plt
1✔
12
import numpy as np
1✔
13
import xarray as xr
1✔
14
from matplotlib.collections import PolyCollection
1✔
15
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
16

17
### TODO: Add xarray + cartopy  (xr_carto) (xr_mpl)
18
# _plot_cartopy_xr_imshow
19
# _plot_cartopy_xr_pcolormesh
20

21

22
def is_generator(obj):
1✔
23
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
×
24

25

26
def _preprocess_figure_args(ax, fig_kwargs={}, subplot_kwargs={}):
1✔
27
    if ax is not None:
×
28
        if len(subplot_kwargs) >= 1:
×
29
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
×
30
        if len(fig_kwargs) >= 1:
×
31
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
×
32

33
    # If ax is not specified, specify the figure defaults
34
    # if ax is None:
35
    # Set default figure size and dpi
36
    # fig_kwargs['figsize'] = (12, 10)
37
    # fig_kwargs['dpi'] = 100
38

39

40
def _preprocess_subplot_kwargs(subplot_kwargs):
1✔
41
    subplot_kwargs = subplot_kwargs.copy()
×
42
    if "projection" not in subplot_kwargs:
×
43
        subplot_kwargs["projection"] = ccrs.PlateCarree()
×
44
    return subplot_kwargs
×
45

46

47
def get_extent(da, x="lon", y="lat"):
1✔
48
    # TODO: compute corners array to estimate the extent
49
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
50
    # Get the minimum and maximum longitude and latitude values
51
    lon_min, lon_max = da[x].min(), da[x].max()
×
52
    lat_min, lat_max = da[y].min(), da[y].max()
×
53
    extent = (lon_min, lon_max, lat_min, lat_max)
×
54
    return extent
×
55

56

57
def get_antimeridian_mask(lons, buffer=True):
1✔
58
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
59
    from scipy.ndimage import binary_dilation
×
60

61
    # Check vertical edges
62
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
×
63
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[::-1, :], axis=0)) > 180)
×
64
    row_idx_rev = lons.shape[0] - row_idx_rev - 1
×
65
    row_indices = np.append(row_idx, row_idx_rev)
×
66
    col_indices = np.append(col_idx, col_idx_rev)
×
67
    # Check horizontal
68
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
×
69
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[:, ::-1], axis=1)) > 180)
×
70
    col_idx_rev = lons.shape[1] - col_idx_rev - 1
×
71
    row_indices = np.append(row_indices, np.append(row_idx, row_idx_rev))
×
72
    col_indices = np.append(col_indices, np.append(col_idx, col_idx_rev))
×
73
    # Create mask
74
    mask = np.zeros(lons.shape)
×
75
    mask[row_indices, col_indices] = 1
×
76
    # Buffer by 1 in all directions to ensure edges not crossing the antimeridian
77
    mask = binary_dilation(mask)
×
78
    return mask
×
79

80

81
def get_masked_cells_polycollection(x, y, arr, mask, plot_kwargs):
1✔
82
    from scipy.ndimage import binary_dilation
×
83

84
    from gpm_api.utils.area import _from_corners_to_bounds, _get_lonlat_corners, is_vertex_clockwise
×
85

86
    # - Buffer mask by 1 to derive vertices of all masked QuadMesh
87
    mask = binary_dilation(mask)
×
88

89
    # - Get index of masked quadmesh
90
    row_mask, col_mask = np.where(mask)
×
91

92
    # - Retrieve values of masked cells
93
    array = arr[row_mask, col_mask]
×
94

95
    # - Retrieve QuadMesh corners (m+1 x n+1)
96
    x_corners, y_corners = _get_lonlat_corners(x, y)
×
97

98
    # - Retrieve QuadMesh bounds (m*n x 4)
99
    x_bounds = _from_corners_to_bounds(x_corners)
×
100
    y_bounds = _from_corners_to_bounds(y_corners)
×
101

102
    # - Retrieve vertices of masked QuadMesh (n_masked, 4, 2)
103
    x_vertices = x_bounds[row_mask, col_mask]
×
104
    y_vertices = y_bounds[row_mask, col_mask]
×
105

106
    vertices = np.stack((x_vertices, y_vertices), axis=2)
×
107

108
    # Check that are counterclockwise oriented (check first vertex)
109
    # TODO: this check should be updated to use pyresample.future.spherical
110
    if is_vertex_clockwise(vertices[0, :, :]):
×
111
        vertices = vertices[:, ::-1, :]
×
112

113
    # - Define additional kwargs for PolyCollection
114
    plot_kwargs = plot_kwargs.copy()
×
115
    if "edgecolors" not in plot_kwargs:
×
116
        plot_kwargs["edgecolors"] = "face"  # 'none'
×
117
    if "linewidth" not in plot_kwargs:
×
118
        plot_kwargs["linewidth"] = 0
×
119
    plot_kwargs["antialiaseds"] = False  # to better plotting quality
×
120

121
    # - Define PolyCollection
122
    coll = PolyCollection(
×
123
        verts=vertices,
124
        array=array,
125
        transform=ccrs.Geodetic(),
126
        **plot_kwargs,
127
    )
128
    return coll
×
129

130

131
def get_valid_pcolormesh_inputs(x, y, data, rgb=False):
1✔
132
    """
133
    Fill non-finite values with neighbour valid coordinates.
134

135
    pcolormesh does not accept non-finite values in the coordinates.
136
    This function:
137
    - Infill NaN/Inf in lat/x with closest values
138
    - Mask the corresponding pixels in the data that must not be displayed.
139

140
    If RGB=True, the RGB channels is in the last dimension
141
    """
142
    # TODO:
143
    # - Instead of np.interp, can use nearest neighbors or just 0 to speed up?
144

145
    # Retrieve mask of invalid coordinates
146
    mask = np.logical_or(~np.isfinite(x), ~np.isfinite(y))
×
147

148
    # If no invalid coordinates, return original data
149
    if np.all(~mask):
×
150
        return x, y, data
×
151

152
    # Dilate mask
153
    # mask = dilation(mask, square(2))
154

155
    # Mask the data
156
    if rgb:
×
157
        data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
158
        data_masked = np.ma.masked_where(data_mask, data)
×
159
    else:
160
        data_masked = np.ma.masked_where(mask, data)
×
161

162
    # TODO: should be done in XYZ?
163
    x_dummy = x.copy()
×
164
    x_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), x[~mask])
×
165
    y_dummy = y.copy()
×
166
    y_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), y[~mask])
×
167
    return x_dummy, y_dummy, data_masked
×
168

169

170
####--------------------------------------------------------------------------.
171
def plot_cartopy_background(ax):
1✔
172
    """Plot cartopy background."""
173
    # - Add coastlines
174
    ax.coastlines()
×
175
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
×
176
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
×
177
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
×
178
    # - Add grid lines
179
    gl = ax.gridlines(
×
180
        crs=ccrs.PlateCarree(),
181
        draw_labels=True,
182
        linewidth=1,
183
        color="gray",
184
        alpha=0.1,
185
        linestyle="-",
186
    )
187
    gl.top_labels = False  # gl.xlabels_top = False
×
188
    gl.right_labels = False  # gl.ylabels_right = False
×
189
    gl.xlines = True
×
190
    gl.ylines = True
×
191
    return ax
×
192

193

194
def plot_colorbar(p, ax, cbar_kwargs={}, size="5%", pad=0.1):
1✔
195
    """Add a colorbar to a matplotlib/cartopy plot.
196

197
    p: matplotlib.image.AxesImage
198
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot
199
    """
200
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
×
201
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
202
    divider = make_axes_locatable(ax)
×
203
    cax = divider.new_horizontal(size=size, pad=pad, axes_class=plt.Axes)
×
204

205
    p.figure.add_axes(cax)
×
206
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
×
207
    if ticklabels is not None:
×
208
        _ = cbar.ax.set_yticklabels(ticklabels)
×
209
    return cbar
×
210

211

212
####--------------------------------------------------------------------------.
213

214

215
def _compute_extent(x_coords, y_coords):
1✔
216
    """
217
    Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
218
    This function assumes that the spacing between each pixel is uniform.
219
    """
220
    # Calculate the pixel size assuming uniform spacing between pixels
221
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
×
222
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
×
223

224
    # Adjust min and max to get the corners of the outer pixels
225
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
×
226
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
×
227

228
    return [x_min, x_max, y_min, y_max]
×
229

230

231
def _plot_cartopy_imshow(
1✔
232
    ax,
233
    da,
234
    x,
235
    y,
236
    interpolation="nearest",
237
    add_colorbar=True,
238
    plot_kwargs={},
239
    cbar_kwargs={},
240
):
241
    """Plot imshow with cartopy."""
242
    # - Ensure image with correct dimensions orders
243
    da = da.transpose(y, x)
×
244
    arr = np.asanyarray(da.data)
×
245

246
    # - Compute coordinates
247
    x_coords = da[x].values
×
248
    y_coords = da[y].values
×
249

250
    # - Derive extent
251
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
×
252

253
    # - Determine origin based on the orientation of da[y] values
254
    # -->  If increasing, set origin="lower"
255
    # -->  If decreasing, set origin="upper"
256
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
×
257

258
    # - Add variable field with cartopy
259
    p = ax.imshow(
×
260
        arr,
261
        transform=ccrs.PlateCarree(),
262
        extent=extent,
263
        origin=origin,
264
        interpolation=interpolation,
265
        **plot_kwargs,
266
    )
267
    # - Set the extent
268
    extent = get_extent(da, x="lon", y="lat")
×
269
    ax.set_extent(extent)
×
270

271
    # - Add colorbar
272
    if add_colorbar:
×
273
        # --> TODO: set axis proportion in a meaningful way ...
274
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
275
    return p
×
276

277

278
def _plot_rgb_pcolormesh(x, y, image, ax, **kwargs):
1✔
279
    """Plot xarray RGB DataArray with non uniform-coordinates.
280

281
    Matplotlib, cartopy and xarray pcolormesh currently does not support RGB(A) arrays.
282
    This is a temporary workaround !
283
    """
284
    if image.shape[2] not in [3, 4]:
×
285
        raise ValueError("Expecting RGB or RGB(A) arrays.")
×
286

287
    colorTuple = image.reshape((image.shape[0] * image.shape[1], image.shape[2]))
×
288
    im = ax.pcolormesh(
×
289
        x,
290
        y,
291
        image[:, :, 1],  # dummy to work ...
292
        color=colorTuple,
293
        **kwargs,
294
    )
295
    # im.set_array(None)
296
    return im
×
297

298

299
def _plot_cartopy_pcolormesh(
1✔
300
    ax,
301
    da,
302
    x,
303
    y,
304
    rgb=False,
305
    add_colorbar=True,
306
    plot_kwargs={},
307
    cbar_kwargs={},
308
):
309
    """Plot imshow with cartopy.
310

311
    The function currently does not allow to zoom on regions across the antimeridian.
312
    The function mask scanning pixels which spans across the antimeridian.
313
    If rgb=True, expect rgb dimension to be at last position.
314
    x and y must represents longitude and latitudes.
315
    """
316
    # - Get x, y, and array to plot
317
    da = da.compute()
×
318
    x = da[x].data
×
319
    y = da[y].data
×
320
    arr = da.data
×
321

322
    # - Infill invalid value and add mask if necessary
323
    x, y, arr = get_valid_pcolormesh_inputs(x, y, arr, rgb=rgb)
×
324

325
    # - Ensure arguments
326
    if rgb:
×
327
        add_colorbar = False
×
328

329
    # - Mask cells crossing the antimeridian
330
    # --> Here assume not invalid coordinates anymore
331
    antimeridian_mask = get_antimeridian_mask(x, buffer=True)
×
332
    is_crossing_antimeridian = np.any(antimeridian_mask)
×
333
    if is_crossing_antimeridian:
×
334
        if np.ma.is_masked(arr):
×
335
            if rgb:
×
336
                data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
337
                combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
338
            else:
339
                combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
340
            arr = np.ma.masked_where(combined_mask, arr)
×
341
        else:
342
            arr = np.ma.masked_where(antimeridian_mask, arr)
×
343
        # Sanitize cmap bad color to avoid cartopy bug
344
        if "cmap" in plot_kwargs:
×
345
            cmap = plot_kwargs["cmap"]
×
346
            bad = cmap.get_bad()
×
347
            bad[3] = 0  # enforce to 0 (transparent)
×
348
            cmap.set_bad(bad)
×
349
            plot_kwargs["cmap"] = cmap
×
350

351
    # - Add variable field with cartopy
352
    if not rgb:
×
353
        p = ax.pcolormesh(
×
354
            x,
355
            y,
356
            arr,
357
            transform=ccrs.PlateCarree(),
358
            **plot_kwargs,
359
        )
360
        # - Add PolyCollection of QuadMesh cells crossing the antimeridian
361
        if is_crossing_antimeridian:
×
362
            coll = get_masked_cells_polycollection(
×
363
                x, y, arr.data, mask=antimeridian_mask, plot_kwargs=plot_kwargs
364
            )
365
            p.axes.add_collection(coll)
×
366

367
    # - Add RGB
368
    else:
369
        p = _plot_rgb_pcolormesh(x, y, arr, ax=ax, **plot_kwargs)
×
370
        if is_crossing_antimeridian:
×
371
            plot_kwargs["facecolors"] = arr.reshape((arr.shape[0] * arr.shape[1], arr.shape[2]))
×
372
            coll = get_masked_cells_polycollection(
×
373
                x, y, arr.data, mask=antimeridian_mask, plot_kwargs=plot_kwargs
374
            )
375
            p.axes.add_collection(coll)
×
376

377
    # - Set the extent
378
    # --> To be set in projection coordinates of crs !!!
379
    #     lon/lat conversion to proj required !
380
    # extent = get_extent(da, x="lon", y="lat")
381
    # ax.set_extent(extent)
382

383
    # - Add colorbar
384
    if add_colorbar:
×
385
        # --> TODO: set axis proportion in a meaningful way ...
386
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
387
    return p
×
388

389

390
def _plot_mpl_imshow(
1✔
391
    ax,
392
    da,
393
    x,
394
    y,
395
    interpolation="nearest",
396
    add_colorbar=True,
397
    plot_kwargs={},
398
    cbar_kwargs={},
399
):
400
    """Plot imshow with matplotlib."""
401
    # - Ensure image with correct dimensions orders
402
    da = da.transpose(y, x)
×
403
    arr = np.asanyarray(da.data)
×
404

405
    # - Add variable field with matplotlib
406
    p = ax.imshow(
×
407
        arr,
408
        origin="upper",
409
        interpolation=interpolation,
410
        **plot_kwargs,
411
    )
412
    # - Add colorbar
413
    if add_colorbar:
×
414
        # --> TODO: set axis proportion in a meaningful way ...
415
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
416
    # - Return mappable
417
    return p
×
418

419

420
# def _get_colorbar_inset_axes_kwargs(p):
421
#     from mpl_toolkits.axes_grid1.inset_locator import inset_axes
422

423
#     colorbar_axes = p.colorbar.ax
424

425
#     # Get the position and size of the colorbar axes in figure coordinates
426
#     bbox = colorbar_axes.get_position()
427

428
#     # Extract the width and height of the colorbar axes in figure coordinates
429
#     width = bbox.x1 - bbox.x0
430
#     height = bbox.y1 - bbox.y0
431

432
#     # Get the location of the colorbar axes ('upper', 'lower', 'center', etc.)
433
#     # This information will be used to set the 'loc' parameter of inset_axes
434
#     loc = 'upper right'  # Modify this according to your preference
435

436
#     # Get the transformation of the colorbar axes with respect to the image axes
437
#     # This information will be used to set the 'bbox_transform' parameter of inset_axes
438
#     bbox_transform = colorbar_axes.get_transform()
439

440
#     # Calculate the coordinates of the colorbar axes relative to the image axes
441
#     x0, y0 = bbox_transform.transform((bbox.x0, bbox.y0))
442
#     x1, y1 = bbox_transform.transform((bbox.x1, bbox.y1))
443
#     bbox_to_anchor = (x0, y0, x1 - x0, y1 - y0)
444

445

446
def set_colorbar_fully_transparent(p):
1✔
447
    # from mpl_toolkits.axes_grid1.inset_locator import inset_axes
448

449
    # colorbar_axes = p.colorbar.ax
450

451
    # # Get the position and size of the colorbar axes in figure coordinates
452
    # bbox = colorbar_axes.get_position()
453

454
    # # Extract the width and height of the colorbar axes in figure coordinates
455
    # width = bbox.x1 - bbox.x0
456
    # height = bbox.y1 - bbox.y0
457

458
    # # Get the location of the colorbar axes ('upper', 'lower', 'center', etc.)
459
    # # This information will be used to set the 'loc' parameter of inset_axes
460
    # loc = 'upper right'  # Modify this according to your preference
461

462
    # # Get the transformation of the colorbar axes with respect to the image axes
463
    # # This information will be used to set the 'bbox_transform' parameter of inset_axes
464
    # bbox_transform = colorbar_axes.get_transform()
465

466
    # # Calculate the coordinates of the colorbar axes relative to the image axes
467
    # x0, y0 = bbox_transform.transform((bbox.x0, bbox.y0))
468
    # x1, y1 = bbox_transform.transform((bbox.x1, bbox.y1))
469
    # bbox_to_anchor = (x0, y0, x1 - x0, y1 - y0)
470

471
    # # Create the inset axes using the retrieved parameters
472
    # inset_ax = inset_axes(p.axes,
473
    #                       width=width,
474
    #                       height=height,
475
    #                       loc=loc,
476
    #                       bbox_to_anchor=bbox_to_anchor,
477
    #                       bbox_transform=p.axes.transAxes,
478
    #                       borderpad=0)
479

480
    # Get the position of the colorbar
481
    cbar_pos = p.colorbar.ax.get_position()
×
482

483
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
484
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
485

486
    # Remove the colorbar
487
    p.colorbar.ax.set_visible(False)
×
488

489
    # Now plot an empty rectangle
490
    fig = plt.gcf()
×
491
    rect = plt.Rectangle(
×
492
        (cbar_x, cbar_y),
493
        cbar_width,
494
        cbar_height,
495
        transform=fig.transFigure,
496
        facecolor="none",
497
        edgecolor="none",
498
    )
499

500
    fig.patches.append(rect)
×
501

502

503
def _plot_xr_imshow(
1✔
504
    ax,
505
    da,
506
    x,
507
    y,
508
    interpolation="nearest",
509
    add_colorbar=True,
510
    plot_kwargs={},
511
    cbar_kwargs={},
512
    xarray_colorbar=True,
513
    visible_colorbar=True,
514
):
515
    """Plot imshow with xarray.
516

517
    The colorbar is added with xarray to enable to display multiple colorbars
518
    when calling this function multiple times on different fields with
519
    different colorbars.
520
    """
521
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
522
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
523
    if not add_colorbar:
×
524
        cbar_kwargs = {}
×
525
    p = da.plot.imshow(
×
526
        x=x,
527
        y=y,
528
        ax=ax,
529
        interpolation=interpolation,
530
        add_colorbar=add_colorbar,
531
        cbar_kwargs=cbar_kwargs,
532
        **plot_kwargs,
533
    )
534
    plt.title(da.name)
×
535
    if add_colorbar and ticklabels is not None:
×
536
        p.colorbar.ax.set_yticklabels(ticklabels)
×
537

538
    # Make the colorbar fully transparent with a smart trick ;)
539
    # - TODO: this still cause issues when plotting 2 colorbars !
540
    if add_colorbar and not visible_colorbar:
×
541
        set_colorbar_fully_transparent(p)
×
542

543
    # Add manually the colorbar
544
    # p = da.plot.imshow(
545
    #     x=x,
546
    #     y=y,
547
    #     ax=ax,
548
    #     interpolation=interpolation,
549
    #     add_colorbar=False,
550
    #     **plot_kwargs,
551
    # )
552
    # plt.title(da.name)
553
    # if add_colorbar:
554
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
555
    return p
×
556

557

558
def _plot_xr_pcolormesh(
1✔
559
    ax,
560
    da,
561
    x,
562
    y,
563
    add_colorbar=True,
564
    plot_kwargs={},
565
    cbar_kwargs={},
566
):
567
    """Plot pcolormesh with xarray."""
568
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
569
    if not add_colorbar:
×
570
        cbar_kwargs = {}
×
571
    p = da.plot.pcolormesh(
×
572
        x=x,
573
        y=y,
574
        ax=ax,
575
        add_colorbar=add_colorbar,
576
        cbar_kwargs=cbar_kwargs,
577
        **plot_kwargs,
578
    )
579
    plt.title(da.name)
×
580
    if add_colorbar and ticklabels is not None:
×
581
        p.colorbar.ax.set_yticklabels(ticklabels)
×
582
    return p
×
583

584

585
####--------------------------------------------------------------------------.
586
#### TODO: doc
587
# figsize, dpi, subplot_kw only used if ax is None
588

589

590
def plot_map(
1✔
591
    da,
592
    x="lon",
593
    y="lat",
594
    ax=None,
595
    add_colorbar=True,
596
    add_swath_lines=True,  # used only for GPM orbit objects
597
    add_background=True,
598
    rgb=False,
599
    interpolation="nearest",  # used only for GPM grid objects
600
    fig_kwargs={},
601
    subplot_kwargs={},
602
    cbar_kwargs={},
603
    **plot_kwargs,
604
):
605
    from gpm_api.checks import is_grid, is_orbit
×
606
    from gpm_api.visualization.grid import plot_grid_map
×
607
    from gpm_api.visualization.orbit import plot_orbit_map
×
608

609
    # Plot orbit
610
    if is_orbit(da):
×
611
        p = plot_orbit_map(
×
612
            da=da,
613
            x=x,
614
            y=y,
615
            ax=ax,
616
            add_colorbar=add_colorbar,
617
            add_swath_lines=add_swath_lines,
618
            add_background=add_background,
619
            rgb=rgb,
620
            fig_kwargs=fig_kwargs,
621
            subplot_kwargs=subplot_kwargs,
622
            cbar_kwargs=cbar_kwargs,
623
            **plot_kwargs,
624
        )
625
    # Plot grid
626
    elif is_grid(da):
×
627
        p = plot_grid_map(
×
628
            da=da,
629
            x=x,
630
            y=y,
631
            ax=ax,
632
            add_colorbar=add_colorbar,
633
            interpolation=interpolation,
634
            add_background=add_background,
635
            fig_kwargs=fig_kwargs,
636
            subplot_kwargs=subplot_kwargs,
637
            cbar_kwargs=cbar_kwargs,
638
            **plot_kwargs,
639
        )
640
    else:
641
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
×
642
    # Return mappable
643
    return p
×
644

645

646
def plot_image(
1✔
647
    da,
648
    x=None,
649
    y=None,
650
    ax=None,
651
    add_colorbar=True,
652
    interpolation="nearest",
653
    fig_kwargs={},
654
    cbar_kwargs={},
655
    **plot_kwargs,
656
):
657
    # figsize, dpi, subplot_kw only used if ax is None
658
    from gpm_api.checks import is_grid, is_orbit
×
659
    from gpm_api.visualization.grid import plot_grid_image
×
660
    from gpm_api.visualization.orbit import plot_orbit_image
×
661

662
    # Plot orbit
663
    if is_orbit(da):
×
664
        p = plot_orbit_image(
×
665
            da=da,
666
            x=x,
667
            y=y,
668
            ax=ax,
669
            add_colorbar=add_colorbar,
670
            interpolation=interpolation,
671
            fig_kwargs=fig_kwargs,
672
            cbar_kwargs=cbar_kwargs,
673
            **plot_kwargs,
674
        )
675
    # Plot grid
676
    elif is_grid(da):
×
677
        p = plot_grid_image(
×
678
            da=da,
679
            x=x,
680
            y=y,
681
            ax=ax,
682
            add_colorbar=add_colorbar,
683
            interpolation=interpolation,
684
            fig_kwargs=fig_kwargs,
685
            cbar_kwargs=cbar_kwargs,
686
            **plot_kwargs,
687
        )
688
    else:
689
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
×
690
    # Return mappable
691
    return p
×
692

693

694
####--------------------------------------------------------------------------.
695

696

697
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
698
    """
699
    Create a 2D xarray DataArray with mesh coordinates based on the 1D coordinate arrays
700
    from an existing xarray object (Dataset or DataArray).
701

702
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
703
    the data values to NaN.
704

705
    Parameters
706
    ----------
707
    xr_obj : xarray.DataArray or xarray.Dataset
708
        The input xarray object containing the 1D coordinate arrays.
709
    x : str
710
        The name of the x-coordinate in xr_obj.
711
    y : str
712
        The name of the y-coordinate in xr_obj.
713

714
    Returns
715
    -------
716
    da_mesh : xarray.DataArray
717
        A 2D xarray DataArray with mesh coordinates for x and y, and NaN values for data points.
718

719
    Notes
720
    -----
721
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
722
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
723
    """
724
    # Extract 1D coordinate arrays
725
    x_coords = xr_obj[x].values
×
726
    y_coords = xr_obj[y].values
×
727

728
    # Create 2D meshgrid for x and y coordinates
729
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
×
730

731
    # Create a 2D array of NaN values with the same shape as the meshgrid
732
    dummy_values = np.full(X.shape, np.nan)
×
733

734
    # Create a new DataArray with 2D coordinates and NaN values
735
    da_mesh = xr.DataArray(
×
736
        dummy_values, coords={x: (("y", "x"), X), y: (("y", "x"), Y)}, dims=("y", "x")
737
    )
738
    return da_mesh
×
739

740

741
def plot_map_mesh(
1✔
742
    xr_obj,
743
    x="lon",
744
    y="lat",
745
    ax=None,
746
    edgecolors="k",
747
    linewidth=0.1,
748
    add_background=True,
749
    fig_kwargs={},
750
    subplot_kwargs={},
751
    **plot_kwargs,
752
):
753
    from gpm_api.checks import is_orbit  # is_grid
×
754

755
    from .grid import plot_grid_mesh
×
756
    from .orbit import plot_orbit_mesh
×
757

758
    # Plot orbit
759
    if is_orbit(xr_obj):
×
760
        p = plot_orbit_mesh(
×
761
            da=xr_obj[y],
762
            ax=ax,
763
            x=x,
764
            y=y,
765
            edgecolors=edgecolors,
766
            linewidth=linewidth,
767
            add_background=add_background,
768
            fig_kwargs=fig_kwargs,
769
            subplot_kwargs=subplot_kwargs,
770
            **plot_kwargs,
771
        )
772
    else:  # Plot grid
773
        p = plot_grid_mesh(
×
774
            xr_obj=xr_obj,
775
            x=x,
776
            y=y,
777
            ax=ax,
778
            edgecolors=edgecolors,
779
            linewidth=linewidth,
780
            add_background=add_background,
781
            fig_kwargs=fig_kwargs,
782
            subplot_kwargs=subplot_kwargs,
783
            **plot_kwargs,
784
        )
785
    # Return mappable
786
    return p
×
787

788

789
def plot_map_mesh_centroids(
1✔
790
    xr_obj,
791
    x="lon",
792
    y="lat",
793
    ax=None,
794
    c="r",
795
    s=1,
796
    add_background=True,
797
    fig_kwargs={},
798
    subplot_kwargs={},
799
    **plot_kwargs,
800
):
801
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
802
    from gpm_api.checks import is_grid
×
803

804
    # - Check inputs
805
    _preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs)
×
806

807
    # - Initialize figure
808
    if ax is None:
×
809
        subplot_kwargs = _preprocess_subplot_kwargs(subplot_kwargs)
×
810
        fig, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
×
811

812
    # - Add cartopy background
813
    if add_background:
×
814
        ax = plot_cartopy_background(ax)
×
815

816
    # - Retrieve centroids
817
    if is_grid(xr_obj):
×
818
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
×
819
    lon = xr_obj[x].values
×
820
    lat = xr_obj[y].values
×
821

822
    # - Plot centroids
823
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
×
824

825
    # - Return mappable
826
    return p
×
827

828

829
####--------------------------------------------------------------------------.
830

831

832
def _plot_labels(
1✔
833
    xr_obj,
834
    label_name=None,
835
    max_n_labels=50,
836
    add_colorbar=True,
837
    interpolation="nearest",
838
    cmap="Paired",
839
    fig_kwargs={},
840
    **plot_kwargs,
841
):
842
    """Plot labels.
843

844
    The maximum allowed number of labels to plot is 'max_n_labels'.
845
    """
846
    from ximage.labels.labels import get_label_indices, redefine_label_array
×
847
    from ximage.labels.plot_labels import get_label_colorbar_settings
×
848

849
    from gpm_api.visualization.plot import plot_image
×
850

851
    if isinstance(xr_obj, xr.Dataset):
×
852
        dataarray = xr_obj[label_name]
×
853
    else:
854
        if label_name is not None:
×
855
            dataarray = xr_obj[label_name]
×
856
        else:
857
            dataarray = xr_obj
×
858

859
    dataarray = dataarray.compute()
×
860
    label_indices = get_label_indices(dataarray)
×
861
    n_labels = len(label_indices)
×
862
    if add_colorbar and n_labels > max_n_labels:
×
863
        msg = f"""The array currently contains {n_labels} labels
×
864
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
865
        print(msg)
×
866
        add_colorbar = False
×
867
    # Relabel array from 1 to ... for plotting
868
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
×
869
    # Replace 0 with nan
870
    dataarray = dataarray.where(dataarray > 0)
×
871
    # Define appropriate colormap
872
    plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap="Paired")
×
873
    # Plot image
874
    p = plot_image(
×
875
        dataarray,
876
        interpolation=interpolation,
877
        add_colorbar=add_colorbar,
878
        cbar_kwargs=cbar_kwargs,
879
        fig_kwargs=fig_kwargs,
880
        **plot_kwargs,
881
    )
882
    return p
×
883

884

885
def plot_labels(
1✔
886
    obj,  # Dataset, DataArray or generator
887
    label_name=None,
888
    max_n_labels=50,
889
    add_colorbar=True,
890
    interpolation="nearest",
891
    cmap="Paired",
892
    fig_kwargs={},
893
    **plot_kwargs,
894
):
895
    if is_generator(obj):
×
896
        for label_id, xr_obj in obj:
×
897
            p = _plot_labels(
×
898
                xr_obj=xr_obj,
899
                label_name=label_name,
900
                max_n_labels=max_n_labels,
901
                add_colorbar=add_colorbar,
902
                interpolation=interpolation,
903
                cmap=cmap,
904
                fig_kwargs=fig_kwargs,
905
                **plot_kwargs,
906
            )
907
            plt.show()
×
908
    else:
909
        p = _plot_labels(
×
910
            xr_obj=obj,
911
            label_name=label_name,
912
            max_n_labels=max_n_labels,
913
            add_colorbar=add_colorbar,
914
            interpolation=interpolation,
915
            cmap=cmap,
916
            fig_kwargs=fig_kwargs,
917
            **plot_kwargs,
918
        )
919
    return p
×
920

921

922
def plot_patches(
1✔
923
    patch_gen,
924
    variable=None,
925
    add_colorbar=True,
926
    interpolation="nearest",
927
    fig_kwargs={},
928
    cbar_kwargs={},
929
    **plot_kwargs,
930
):
931
    """Plot patches."""
932
    from gpm_api.visualization.plot import plot_image
×
933

934
    # Plot patches
935
    for label_id, xr_patch in patch_gen:
×
936
        if isinstance(xr_patch, xr.Dataset):
×
937
            if variable is None:
×
938
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
×
939
            xr_patch = xr_patch[variable]
×
940
        try:
×
941
            plot_image(
×
942
                xr_patch,
943
                interpolation=interpolation,
944
                add_colorbar=add_colorbar,
945
                fig_kwargs=fig_kwargs,
946
                cbar_kwargs=cbar_kwargs,
947
                **plot_kwargs,
948
            )
949
            plt.show()
×
950
        except:
×
951
            pass
×
952
    return
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc