• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 7001896589

27 Nov 2023 07:28AM UTC coverage: 43.96% (-12.2%) from 56.199%
7001896589

push

github

web-flow
Merge pull request #23 from EPFL-ENAC/fix-ci-coverage

Fix CI coverage

2347 of 5339 relevant lines covered (43.96%)

0.44 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.99
/gpm_api/visualization/plot.py
1
#!/usr/bin/env python3
2
"""
1✔
3
Created on Sat Dec 10 18:42:28 2022
4

5
@author: ghiggi
6
"""
7
import inspect
1✔
8

9
import cartopy
1✔
10
import cartopy.crs as ccrs
1✔
11
import matplotlib.pyplot as plt
1✔
12
import numpy as np
1✔
13
import xarray as xr
1✔
14
from matplotlib.collections import PolyCollection
1✔
15
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
16

17
### TODO: Add xarray + cartopy  (xr_carto) (xr_mpl)
18
# _plot_cartopy_xr_imshow
19
# _plot_cartopy_xr_pcolormesh
20

21

22
def is_generator(obj):
1✔
23
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
×
24

25

26
def _preprocess_figure_args(ax, fig_kwargs={}, subplot_kwargs={}):
1✔
27
    if ax is not None:
×
28
        if len(subplot_kwargs) >= 1:
×
29
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
×
30
        if len(fig_kwargs) >= 1:
×
31
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
×
32

33
    # If ax is not specified, specify the figure defaults
34
    # if ax is None:
35
    # Set default figure size and dpi
36
    # fig_kwargs['figsize'] = (12, 10)
37
    # fig_kwargs['dpi'] = 100
38

39

40
def _preprocess_subplot_kwargs(subplot_kwargs):
1✔
41
    subplot_kwargs = subplot_kwargs.copy()
×
42
    if "projection" not in subplot_kwargs:
×
43
        subplot_kwargs["projection"] = ccrs.PlateCarree()
×
44
    return subplot_kwargs
×
45

46

47
def get_extent(da, x="lon", y="lat"):
1✔
48
    # TODO: compute corners array to estimate the extent
49
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
50
    # Get the minimum and maximum longitude and latitude values
51
    lon_min, lon_max = da[x].min(), da[x].max()
×
52
    lat_min, lat_max = da[y].min(), da[y].max()
×
53
    extent = (lon_min, lon_max, lat_min, lat_max)
×
54
    return extent
×
55

56

57
def get_antimeridian_mask(lons, buffer=True):
1✔
58
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
59
    from scipy.ndimage import binary_dilation
×
60

61
    # Check vertical edges
62
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
×
63
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[::-1, :], axis=0)) > 180)
×
64
    row_idx_rev = lons.shape[0] - row_idx_rev - 1
×
65
    row_indices = np.append(row_idx, row_idx_rev)
×
66
    col_indices = np.append(col_idx, col_idx_rev)
×
67
    # Check horizontal
68
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
×
69
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[:, ::-1], axis=1)) > 180)
×
70
    col_idx_rev = lons.shape[1] - col_idx_rev - 1
×
71
    row_indices = np.append(row_indices, np.append(row_idx, row_idx_rev))
×
72
    col_indices = np.append(col_indices, np.append(col_idx, col_idx_rev))
×
73
    # Create mask
74
    mask = np.zeros(lons.shape)
×
75
    mask[row_indices, col_indices] = 1
×
76
    # Buffer by 1 in all directions to ensure edges not crossing the antimeridian
77
    mask = binary_dilation(mask)
×
78
    return mask
×
79

80

81
def get_masked_cells_polycollection(x, y, arr, mask, plot_kwargs):
1✔
82
    from scipy.ndimage import binary_dilation
×
83

84
    from gpm_api.utils.area import _from_corners_to_bounds, _get_lonlat_corners, is_vertex_clockwise
×
85

86
    # - Buffer mask by 1 to derive vertices of all masked QuadMesh
87
    mask = binary_dilation(mask)
×
88

89
    # - Get index of masked quadmesh
90
    row_mask, col_mask = np.where(mask)
×
91

92
    # - Retrieve values of masked cells
93
    array = arr[row_mask, col_mask]
×
94

95
    # - Retrieve QuadMesh corners (m+1 x n+1)
96
    x_corners, y_corners = _get_lonlat_corners(x, y)
×
97

98
    # - Retrieve QuadMesh bounds (m*n x 4)
99
    x_bounds = _from_corners_to_bounds(x_corners)
×
100
    y_bounds = _from_corners_to_bounds(y_corners)
×
101

102
    # - Retrieve vertices of masked QuadMesh (n_masked, 4, 2)
103
    x_vertices = x_bounds[row_mask, col_mask]
×
104
    y_vertices = y_bounds[row_mask, col_mask]
×
105

106
    vertices = np.stack((x_vertices, y_vertices), axis=2)
×
107

108
    # Check that are counterclockwise oriented (check first vertex)
109
    # TODO: this check should be updated to use pyresample.future.spherical
110
    if is_vertex_clockwise(vertices[0, :, :]):
×
111
        vertices = vertices[:, ::-1, :]
×
112

113
    # - Define additional kwargs for PolyCollection
114
    plot_kwargs = plot_kwargs.copy()
×
115
    if "edgecolors" not in plot_kwargs:
×
116
        plot_kwargs["edgecolors"] = "face"  # 'none'
×
117
    if "linewidth" not in plot_kwargs:
×
118
        plot_kwargs["linewidth"] = 0
×
119
    plot_kwargs["antialiaseds"] = False  # to better plotting quality
×
120

121
    # - Define PolyCollection
122
    coll = PolyCollection(
×
123
        verts=vertices,
124
        array=array,
125
        transform=ccrs.Geodetic(),
126
        **plot_kwargs,
127
    )
128
    return coll
×
129

130

131
def get_valid_pcolormesh_inputs(x, y, data, rgb=False):
1✔
132
    """
133
    Fill non-finite values with neighbour valid coordinates.
134

135
    pcolormesh does not accept non-finite values in the coordinates.
136
    This function:
137
    - Infill NaN/Inf in lat/x with closest values
138
    - Mask the corresponding pixels in the data that must not be displayed.
139

140
    If RGB=True, the RGB channels is in the last dimension
141
    """
142
    # TODO:
143
    # - Instead of np.interp, can use nearest neighbors or just 0 to speed up?
144

145
    # Retrieve mask of invalid coordinates
146
    mask = np.logical_or(~np.isfinite(x), ~np.isfinite(y))
×
147

148
    # If no invalid coordinates, return original data
149
    if np.all(~mask):
×
150
        return x, y, data
×
151

152
    # Dilate mask
153
    # mask = dilation(mask, square(2))
154

155
    # Mask the data
156
    if rgb:
×
157
        data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
158
        data_masked = np.ma.masked_where(data_mask, data)
×
159
    else:
160
        data_masked = np.ma.masked_where(mask, data)
×
161

162
    # TODO: should be done in XYZ?
163
    x_dummy = x.copy()
×
164
    x_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), x[~mask])
×
165
    y_dummy = y.copy()
×
166
    y_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), y[~mask])
×
167
    return x_dummy, y_dummy, data_masked
×
168

169

170
####--------------------------------------------------------------------------.
171
def plot_cartopy_background(ax):
1✔
172
    """Plot cartopy background."""
173
    # - Add coastlines
174
    ax.coastlines()
×
175
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
×
176
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
×
177
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
×
178
    # - Add grid lines
179
    gl = ax.gridlines(
×
180
        crs=ccrs.PlateCarree(),
181
        draw_labels=True,
182
        linewidth=1,
183
        color="gray",
184
        alpha=0.1,
185
        linestyle="-",
186
    )
187
    gl.top_labels = False  # gl.xlabels_top = False
×
188
    gl.right_labels = False  # gl.ylabels_right = False
×
189
    gl.xlines = True
×
190
    gl.ylines = True
×
191
    return ax
×
192

193

194
def plot_colorbar(p, ax, cbar_kwargs={}, size="5%", pad=0.1):
1✔
195
    """Add a colorbar to a matplotlib/cartopy plot.
196

197
    p: matplotlib.image.AxesImage
198
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot
199
    """
200
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
×
201
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
202
    divider = make_axes_locatable(ax)
×
203
    cax = divider.new_horizontal(size=size, pad=pad, axes_class=plt.Axes)
×
204

205
    p.figure.add_axes(cax)
×
206
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
×
207
    if ticklabels is not None:
×
208
        _ = cbar.ax.set_yticklabels(ticklabels)
×
209
    return cbar
×
210

211

212
####--------------------------------------------------------------------------.
213
def _plot_cartopy_imshow(
1✔
214
    ax,
215
    da,
216
    x,
217
    y,
218
    interpolation="nearest",
219
    add_colorbar=True,
220
    plot_kwargs={},
221
    cbar_kwargs={},
222
):
223
    """Plot imshow with cartopy."""
224
    # TODO: allow to plot whatever projection (based on CRS) !
225
    # TODO: allow to plot subset of PlateeCarree !
226

227
    # - Ensure image with correct dimensions orders
228
    da = da.transpose(y, x)
×
229
    arr = np.asanyarray(da.data)
×
230

231
    # - Derive extent
232
    extent = [-180, 180, -90, 90]  # TODO: Derive from data !!!!
×
233

234
    # TODO: ensure y data is increasing --> origin = "lower"
235
    # TODO: ensure y data is decreasing --> origin = "upper"
236

237
    # - Add variable field with cartopy
238
    p = ax.imshow(
×
239
        arr,
240
        transform=ccrs.PlateCarree(),
241
        extent=extent,
242
        origin="lower",
243
        interpolation=interpolation,
244
        **plot_kwargs,
245
    )
246
    # - Set the extent
247
    extent = get_extent(da, x="lon", y="lat")
×
248
    ax.set_extent(extent)
×
249

250
    # - Add colorbar
251
    if add_colorbar:
×
252
        # --> TODO: set axis proportion in a meaningful way ...
253
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
254
    return p
×
255

256

257
def _plot_rgb_pcolormesh(x, y, image, ax, **kwargs):
1✔
258
    """Plot xarray RGB DataArray with non uniform-coordinates.
259

260
    Matplotlib, cartopy and xarray pcolormesh currently does not support RGB(A) arrays.
261
    This is a temporary workaround !
262
    """
263
    if image.shape[2] not in [3, 4]:
×
264
        raise ValueError("Expecting RGB or RGB(A) arrays.")
×
265

266
    colorTuple = image.reshape((image.shape[0] * image.shape[1], image.shape[2]))
×
267
    im = ax.pcolormesh(
×
268
        x,
269
        y,
270
        image[:, :, 1],  # dummy to work ...
271
        color=colorTuple,
272
        **kwargs,
273
    )
274
    # im.set_array(None)
275
    return im
×
276

277

278
def _plot_cartopy_pcolormesh(
1✔
279
    ax,
280
    da,
281
    x,
282
    y,
283
    rgb=False,
284
    add_colorbar=True,
285
    plot_kwargs={},
286
    cbar_kwargs={},
287
):
288
    """Plot imshow with cartopy.
289

290
    The function currently does not allow to zoom on regions across the antimeridian.
291
    The function mask scanning pixels which spans across the antimeridian.
292
    If rgb=True, expect rgb dimension to be at last position.
293
    x and y must represents longitude and latitudes.
294
    """
295
    # - Get x, y, and array to plot
296
    da = da.compute()
×
297
    x = da[x].data
×
298
    y = da[y].data
×
299
    arr = da.data
×
300

301
    # - Infill invalid value and add mask if necessary
302
    x, y, arr = get_valid_pcolormesh_inputs(x, y, arr, rgb=rgb)
×
303
    # - Ensure arguments
304
    if rgb:
×
305
        add_colorbar = False
×
306

307
    # - Mask cells crossing the antimeridian
308
    # --> Here assume not invalid coordinates anymore
309
    antimeridian_mask = get_antimeridian_mask(x, buffer=True)
×
310
    is_crossing_antimeridian = np.any(antimeridian_mask)
×
311
    if is_crossing_antimeridian:
×
312
        if np.ma.is_masked(arr):
×
313
            if rgb:
×
314
                data_mask = np.broadcast_to(np.expand_dims(antimeridian_mask, axis=-1), arr.shape)
×
315
                combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
316
            else:
317
                combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
318
            arr = np.ma.masked_where(combined_mask, arr)
×
319
        else:
320
            arr = np.ma.masked_where(antimeridian_mask, arr)
×
321
        # Sanitize cmap bad color to avoid cartopy bug
322
        if "cmap" in plot_kwargs:
×
323
            cmap = plot_kwargs["cmap"]
×
324
            bad = cmap.get_bad()
×
325
            bad[3] = 0  # enforce to 0 (transparent)
×
326
            cmap.set_bad(bad)
×
327
            plot_kwargs["cmap"] = cmap
×
328

329
    # - Add variable field with cartopy
330
    if not rgb:
×
331
        p = ax.pcolormesh(
×
332
            x,
333
            y,
334
            arr,
335
            transform=ccrs.PlateCarree(),
336
            **plot_kwargs,
337
        )
338
        # - Add PolyCollection of QuadMesh cells crossing the antimeridian
339
        if is_crossing_antimeridian:
×
340
            coll = get_masked_cells_polycollection(
×
341
                x, y, arr.data, mask=antimeridian_mask, plot_kwargs=plot_kwargs
342
            )
343
            p.axes.add_collection(coll)
×
344

345
    # - Add RGB
346
    else:
347
        p = _plot_rgb_pcolormesh(x, y, arr, ax=ax, **plot_kwargs)
×
348
        if is_crossing_antimeridian:
×
349
            plot_kwargs["facecolors"] = arr.reshape((arr.shape[0] * arr.shape[1], arr.shape[2]))
×
350
            coll = get_masked_cells_polycollection(
×
351
                x, y, arr.data, mask=antimeridian_mask, plot_kwargs=plot_kwargs
352
            )
353
            p.axes.add_collection(coll)
×
354

355
    # - Set the extent
356
    # --> To be set in projection coordinates of crs !!!
357
    #     lon/lat conversion to proj required !
358
    # extent = get_extent(da, x="lon", y="lat")
359
    # ax.set_extent(extent)
360

361
    # - Add colorbar
362
    if add_colorbar:
×
363
        # --> TODO: set axis proportion in a meaningful way ...
364
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
365
    return p
×
366

367

368
def _plot_mpl_imshow(
1✔
369
    ax,
370
    da,
371
    x,
372
    y,
373
    interpolation="nearest",
374
    add_colorbar=True,
375
    plot_kwargs={},
376
    cbar_kwargs={},
377
):
378
    """Plot imshow with matplotlib."""
379
    # - Ensure image with correct dimensions orders
380
    da = da.transpose(y, x)
×
381
    arr = np.asanyarray(da.data)
×
382

383
    # - Add variable field with matplotlib
384
    p = ax.imshow(
×
385
        arr,
386
        origin="upper",
387
        interpolation=interpolation,
388
        **plot_kwargs,
389
    )
390
    # - Add colorbar
391
    if add_colorbar:
×
392
        # --> TODO: set axis proportion in a meaningful way ...
393
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
394
    # - Return mappable
395
    return p
×
396

397

398
# def _get_colorbar_inset_axes_kwargs(p):
399
#     from mpl_toolkits.axes_grid1.inset_locator import inset_axes
400

401
#     colorbar_axes = p.colorbar.ax
402

403
#     # Get the position and size of the colorbar axes in figure coordinates
404
#     bbox = colorbar_axes.get_position()
405

406
#     # Extract the width and height of the colorbar axes in figure coordinates
407
#     width = bbox.x1 - bbox.x0
408
#     height = bbox.y1 - bbox.y0
409

410
#     # Get the location of the colorbar axes ('upper', 'lower', 'center', etc.)
411
#     # This information will be used to set the 'loc' parameter of inset_axes
412
#     loc = 'upper right'  # Modify this according to your preference
413

414
#     # Get the transformation of the colorbar axes with respect to the image axes
415
#     # This information will be used to set the 'bbox_transform' parameter of inset_axes
416
#     bbox_transform = colorbar_axes.get_transform()
417

418
#     # Calculate the coordinates of the colorbar axes relative to the image axes
419
#     x0, y0 = bbox_transform.transform((bbox.x0, bbox.y0))
420
#     x1, y1 = bbox_transform.transform((bbox.x1, bbox.y1))
421
#     bbox_to_anchor = (x0, y0, x1 - x0, y1 - y0)
422

423

424
def set_colorbar_fully_transparent(p):
1✔
425
    # from mpl_toolkits.axes_grid1.inset_locator import inset_axes
426

427
    # colorbar_axes = p.colorbar.ax
428

429
    # # Get the position and size of the colorbar axes in figure coordinates
430
    # bbox = colorbar_axes.get_position()
431

432
    # # Extract the width and height of the colorbar axes in figure coordinates
433
    # width = bbox.x1 - bbox.x0
434
    # height = bbox.y1 - bbox.y0
435

436
    # # Get the location of the colorbar axes ('upper', 'lower', 'center', etc.)
437
    # # This information will be used to set the 'loc' parameter of inset_axes
438
    # loc = 'upper right'  # Modify this according to your preference
439

440
    # # Get the transformation of the colorbar axes with respect to the image axes
441
    # # This information will be used to set the 'bbox_transform' parameter of inset_axes
442
    # bbox_transform = colorbar_axes.get_transform()
443

444
    # # Calculate the coordinates of the colorbar axes relative to the image axes
445
    # x0, y0 = bbox_transform.transform((bbox.x0, bbox.y0))
446
    # x1, y1 = bbox_transform.transform((bbox.x1, bbox.y1))
447
    # bbox_to_anchor = (x0, y0, x1 - x0, y1 - y0)
448

449
    # # Create the inset axes using the retrieved parameters
450
    # inset_ax = inset_axes(p.axes,
451
    #                       width=width,
452
    #                       height=height,
453
    #                       loc=loc,
454
    #                       bbox_to_anchor=bbox_to_anchor,
455
    #                       bbox_transform=p.axes.transAxes,
456
    #                       borderpad=0)
457

458
    # Get the position of the colorbar
459
    cbar_pos = p.colorbar.ax.get_position()
×
460

461
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
462
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
463

464
    # Remove the colorbar
465
    p.colorbar.ax.set_visible(False)
×
466

467
    # Now plot an empty rectangle
468
    fig = plt.gcf()
×
469
    rect = plt.Rectangle(
×
470
        (cbar_x, cbar_y),
471
        cbar_width,
472
        cbar_height,
473
        transform=fig.transFigure,
474
        facecolor="none",
475
        edgecolor="none",
476
    )
477

478
    fig.patches.append(rect)
×
479

480

481
def _plot_xr_imshow(
1✔
482
    ax,
483
    da,
484
    x,
485
    y,
486
    interpolation="nearest",
487
    add_colorbar=True,
488
    plot_kwargs={},
489
    cbar_kwargs={},
490
    xarray_colorbar=True,
491
    visible_colorbar=True,
492
):
493
    """Plot imshow with xarray.
494

495
    The colorbar is added with xarray to enable to display multiple colorbars
496
    when calling this function multiple times on different fields with
497
    different colorbars.
498
    """
499
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
500
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
501
    if not add_colorbar:
×
502
        cbar_kwargs = {}
×
503
    p = da.plot.imshow(
×
504
        x=x,
505
        y=y,
506
        ax=ax,
507
        interpolation=interpolation,
508
        add_colorbar=add_colorbar,
509
        cbar_kwargs=cbar_kwargs,
510
        **plot_kwargs,
511
    )
512
    plt.title(da.name)
×
513
    if add_colorbar and ticklabels is not None:
×
514
        p.colorbar.ax.set_yticklabels(ticklabels)
×
515

516
    # Make the colorbar fully transparent with a smart trick ;)
517
    # - TODO: this still cause issues when plotting 2 colorbars !
518
    if add_colorbar and not visible_colorbar:
×
519
        set_colorbar_fully_transparent(p)
×
520

521
    # Add manually the colorbar
522
    # p = da.plot.imshow(
523
    #     x=x,
524
    #     y=y,
525
    #     ax=ax,
526
    #     interpolation=interpolation,
527
    #     add_colorbar=False,
528
    #     **plot_kwargs,
529
    # )
530
    # plt.title(da.name)
531
    # if add_colorbar:
532
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
533
    return p
×
534

535

536
def _plot_xr_pcolormesh(
1✔
537
    ax,
538
    da,
539
    x,
540
    y,
541
    add_colorbar=True,
542
    plot_kwargs={},
543
    cbar_kwargs={},
544
):
545
    """Plot pcolormesh with xarray."""
546
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
547
    if not add_colorbar:
×
548
        cbar_kwargs = {}
×
549
    p = da.plot.pcolormesh(
×
550
        x=x,
551
        y=y,
552
        ax=ax,
553
        add_colorbar=add_colorbar,
554
        cbar_kwargs=cbar_kwargs,
555
        **plot_kwargs,
556
    )
557
    plt.title(da.name)
×
558
    if add_colorbar and ticklabels is not None:
×
559
        p.colorbar.ax.set_yticklabels(ticklabels)
×
560
    return p
×
561

562

563
####--------------------------------------------------------------------------.
564
def plot_map(
1✔
565
    da,
566
    x="lon",
567
    y="lat",
568
    ax=None,
569
    add_colorbar=True,
570
    add_swath_lines=True,  # used only for GPM orbit objects
571
    add_background=True,
572
    rgb=False,
573
    interpolation="nearest",  # used only for GPM grid objects
574
    fig_kwargs={},
575
    subplot_kwargs={},
576
    cbar_kwargs={},
577
    **plot_kwargs,
578
):
579
    from gpm_api.checks import is_grid, is_orbit
×
580
    from gpm_api.visualization.grid import plot_grid_map
×
581
    from gpm_api.visualization.orbit import plot_orbit_map
×
582

583
    # Plot orbit
584
    if is_orbit(da):
×
585
        p = plot_orbit_map(
×
586
            da=da,
587
            x=x,
588
            y=y,
589
            ax=ax,
590
            add_colorbar=add_colorbar,
591
            add_swath_lines=add_swath_lines,
592
            add_background=add_background,
593
            rgb=rgb,
594
            fig_kwargs=fig_kwargs,
595
            subplot_kwargs=subplot_kwargs,
596
            cbar_kwargs=cbar_kwargs,
597
            **plot_kwargs,
598
        )
599
    # Plot grid
600
    elif is_grid(da):
×
601
        p = plot_grid_map(
×
602
            da=da,
603
            x=x,
604
            y=y,
605
            ax=ax,
606
            add_colorbar=add_colorbar,
607
            interpolation=interpolation,
608
            add_background=add_background,
609
            fig_kwargs=fig_kwargs,
610
            subplot_kwargs=subplot_kwargs,
611
            cbar_kwargs=cbar_kwargs,
612
            **plot_kwargs,
613
        )
614
    else:
615
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
×
616
    # Return mappable
617
    return p
×
618

619

620
def plot_image(
1✔
621
    da,
622
    x=None,
623
    y=None,
624
    ax=None,
625
    add_colorbar=True,
626
    interpolation="nearest",
627
    fig_kwargs={},
628
    cbar_kwargs={},
629
    **plot_kwargs,
630
):
631
    # figsize, dpi, subplot_kw only used if ax is None
632
    from gpm_api.checks import is_grid, is_orbit
×
633
    from gpm_api.visualization.grid import plot_grid_image
×
634
    from gpm_api.visualization.orbit import plot_orbit_image
×
635

636
    # Plot orbit
637
    if is_orbit(da):
×
638
        p = plot_orbit_image(
×
639
            da=da,
640
            x=x,
641
            y=y,
642
            ax=ax,
643
            add_colorbar=add_colorbar,
644
            interpolation=interpolation,
645
            fig_kwargs=fig_kwargs,
646
            cbar_kwargs=cbar_kwargs,
647
            **plot_kwargs,
648
        )
649
    # Plot grid
650
    elif is_grid(da):
×
651
        p = plot_grid_image(
×
652
            da=da,
653
            x=x,
654
            y=y,
655
            ax=ax,
656
            add_colorbar=add_colorbar,
657
            interpolation=interpolation,
658
            fig_kwargs=fig_kwargs,
659
            cbar_kwargs=cbar_kwargs,
660
            **plot_kwargs,
661
        )
662
    else:
663
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
×
664
    # Return mappable
665
    return p
×
666

667

668
####--------------------------------------------------------------------------.
669

670

671
def plot_map_mesh(
1✔
672
    xr_obj,
673
    x="lon",
674
    y="lat",
675
    ax=None,
676
    edgecolors="k",
677
    linewidth=0.1,
678
    add_background=True,
679
    fig_kwargs={},
680
    subplot_kwargs={},
681
    **plot_kwargs,
682
):
683
    # Interpolation only for grid objects
684
    # figsize, dpi, subplot_kw only used if ax is None
685
    from gpm_api.checks import is_orbit  # is_grid
×
686

687
    from .grid import plot_grid_mesh
×
688
    from .orbit import plot_orbit_mesh
×
689

690
    # Plot orbit
691
    if is_orbit(xr_obj):
×
692
        p = plot_orbit_mesh(
×
693
            da=xr_obj[y],
694
            ax=ax,
695
            x=x,
696
            y=y,
697
            edgecolors=edgecolors,
698
            linewidth=linewidth,
699
            add_background=add_background,
700
            fig_kwargs=fig_kwargs,
701
            subplot_kwargs=subplot_kwargs,
702
            **plot_kwargs,
703
        )
704
    else:  # Plot grid
705
        p = plot_grid_mesh(
×
706
            xr_obj=xr_obj,
707
            x=x,
708
            y=y,
709
            ax=ax,
710
            edgecolors=edgecolors,
711
            linewidth=linewidth,
712
            add_background=add_background,
713
            fig_kwargs=fig_kwargs,
714
            subplot_kwargs=subplot_kwargs,
715
            **plot_kwargs,
716
        )
717
    # Return mappable
718
    return p
×
719

720

721
def plot_map_mesh_centroids(
1✔
722
    xr_obj,
723
    x="lon",
724
    y="lat",
725
    ax=None,
726
    c="r",
727
    s=1,
728
    add_background=True,
729
    fig_kwargs={},
730
    subplot_kwargs={},
731
    **plot_kwargs,
732
):
733
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
734
    # - Check inputs
735
    _preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs)
×
736

737
    # - Initialize figure
738
    if ax is None:
×
739
        subplot_kwargs = _preprocess_subplot_kwargs(subplot_kwargs)
×
740
        fig, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
×
741

742
    # - Add cartopy background
743
    if add_background:
×
744
        ax = plot_cartopy_background(ax)
×
745

746
    # Plot centroids
747
    lon = xr_obj[x].data
×
748
    lat = xr_obj[y].data
×
749
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
×
750

751
    # - Return mappable
752
    return p
×
753

754

755
####--------------------------------------------------------------------------.
756

757

758
def _plot_labels(
1✔
759
    xr_obj,
760
    label_name=None,
761
    max_n_labels=50,
762
    add_colorbar=True,
763
    interpolation="nearest",
764
    cmap="Paired",
765
    fig_kwargs={},
766
    **plot_kwargs,
767
):
768
    """Plot labels.
769

770
    The maximum allowed number of labels to plot is 'max_n_labels'.
771
    """
772
    from ximage.labels.labels import get_label_indices, redefine_label_array
×
773
    from ximage.labels.plot_labels import get_label_colorbar_settings
×
774

775
    from gpm_api.visualization.plot import plot_image
×
776

777
    if isinstance(xr_obj, xr.Dataset):
×
778
        dataarray = xr_obj[label_name]
×
779
    else:
780
        if label_name is not None:
×
781
            dataarray = xr_obj[label_name]
×
782
        else:
783
            dataarray = xr_obj
×
784

785
    dataarray = dataarray.compute()
×
786
    label_indices = get_label_indices(dataarray)
×
787
    n_labels = len(label_indices)
×
788
    if add_colorbar and n_labels > max_n_labels:
×
789
        msg = f"""The array currently contains {n_labels} labels
×
790
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
791
        print(msg)
×
792
        add_colorbar = False
×
793
    # Relabel array from 1 to ... for plotting
794
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
×
795
    # Replace 0 with nan
796
    dataarray = dataarray.where(dataarray > 0)
×
797
    # Define appropriate colormap
798
    plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap="Paired")
×
799
    # Plot image
800
    p = plot_image(
×
801
        dataarray,
802
        interpolation=interpolation,
803
        add_colorbar=add_colorbar,
804
        cbar_kwargs=cbar_kwargs,
805
        fig_kwargs=fig_kwargs,
806
        **plot_kwargs,
807
    )
808
    return p
×
809

810

811
def plot_labels(
1✔
812
    obj,  # Dataset, DataArray or generator
813
    label_name=None,
814
    max_n_labels=50,
815
    add_colorbar=True,
816
    interpolation="nearest",
817
    cmap="Paired",
818
    fig_kwargs={},
819
    **plot_kwargs,
820
):
821
    if is_generator(obj):
×
822
        for label_id, xr_obj in obj:
×
823
            p = _plot_labels(
×
824
                xr_obj=xr_obj,
825
                label_name=label_name,
826
                max_n_labels=max_n_labels,
827
                add_colorbar=add_colorbar,
828
                interpolation=interpolation,
829
                cmap=cmap,
830
                fig_kwargs=fig_kwargs,
831
                **plot_kwargs,
832
            )
833
            plt.show()
×
834
    else:
835
        p = _plot_labels(
×
836
            xr_obj=obj,
837
            label_name=label_name,
838
            max_n_labels=max_n_labels,
839
            add_colorbar=add_colorbar,
840
            interpolation=interpolation,
841
            cmap=cmap,
842
            fig_kwargs=fig_kwargs,
843
            **plot_kwargs,
844
        )
845
    return p
×
846

847

848
def plot_patches(
1✔
849
    patch_gen,
850
    variable=None,
851
    add_colorbar=True,
852
    interpolation="nearest",
853
    fig_kwargs={},
854
    cbar_kwargs={},
855
    **plot_kwargs,
856
):
857
    """Plot patches."""
858
    from gpm_api.visualization.plot import plot_image
×
859

860
    # Plot patches
861
    for label_id, xr_patch in patch_gen:
×
862
        if isinstance(xr_patch, xr.Dataset):
×
863
            if variable is None:
×
864
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
×
865
            xr_patch = xr_patch[variable]
×
866
        try:
×
867
            plot_image(
×
868
                xr_patch,
869
                interpolation=interpolation,
870
                add_colorbar=add_colorbar,
871
                fig_kwargs=fig_kwargs,
872
                cbar_kwargs=cbar_kwargs,
873
                **plot_kwargs,
874
            )
875
            plt.show()
×
876
        except:
×
877
            pass
×
878
    return
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc