• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 8053365388

26 Feb 2024 05:51PM UTC coverage: 62.974% (-0.3%) from 63.244%
8053365388

push

github

ghiggi
Drop support for python 3.8

3723 of 5912 relevant lines covered (62.97%)

0.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

12.01
/gpm_api/visualization/plot.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains basic functions for GPM-API data visualization."""
1✔
28
import inspect
1✔
29

30
import cartopy
1✔
31
import cartopy.crs as ccrs
1✔
32
import matplotlib.pyplot as plt
1✔
33
import numpy as np
1✔
34
import xarray as xr
1✔
35
from mpl_toolkits.axes_grid1 import make_axes_locatable
1✔
36

37
import gpm_api
1✔
38

39
### TODO: Add xarray + cartopy  (xr_carto) (xr_mpl)
40
# _plot_cartopy_xr_imshow
41
# _plot_cartopy_xr_pcolormesh
42

43

44
def is_generator(obj):
1✔
45
    return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)
×
46

47

48
def _preprocess_figure_args(ax, fig_kwargs={}, subplot_kwargs={}):
1✔
49
    if ax is not None:
×
50
        if len(subplot_kwargs) >= 1:
×
51
            raise ValueError("Provide `subplot_kwargs`only if `ax`is None")
×
52
        if len(fig_kwargs) >= 1:
×
53
            raise ValueError("Provide `fig_kwargs` only if `ax`is None")
×
54

55
    # If ax is not specified, specify the figure defaults
56
    # if ax is None:
57
    # Set default figure size and dpi
58
    # fig_kwargs['figsize'] = (12, 10)
59
    # fig_kwargs['dpi'] = 100
60

61

62
def _preprocess_subplot_kwargs(subplot_kwargs):
1✔
63
    subplot_kwargs = subplot_kwargs.copy()
×
64
    if "projection" not in subplot_kwargs:
×
65
        subplot_kwargs["projection"] = ccrs.PlateCarree()
×
66
    return subplot_kwargs
×
67

68

69
def get_extent(da, x="lon", y="lat"):
1✔
70
    # TODO: compute corners array to estimate the extent
71
    # - OR increase by 1° in everydirection and then wrap between -180, 180,90,90
72
    # Get the minimum and maximum longitude and latitude values
73
    lon_min, lon_max = da[x].min(), da[x].max()
×
74
    lat_min, lat_max = da[y].min(), da[y].max()
×
75
    extent = (lon_min, lon_max, lat_min, lat_max)
×
76
    return extent
×
77

78

79
def get_antimeridian_mask(lons, buffer=True):
1✔
80
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
81
    from scipy.ndimage import binary_dilation
×
82

83
    # Initialize mask
84
    n_y, n_x = lons.shape
×
85
    mask = np.zeros((n_y - 1, n_x - 1))
×
86
    # Check vertical edges
87
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
×
88
    col_idx = np.clip(col_idx - 1, 0, n_x - 1)
×
89
    mask[row_idx, col_idx] = 1
×
90
    # Check horizontal edges
91
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
×
92
    row_idx = np.clip(row_idx - 1, 0, n_y - 1)
×
93
    mask[row_idx, col_idx] = 1
×
94
    # Buffer by 1 in all directions to avoid plotting cells neighbour to those crossing the antimeridian
95
    # --> This should not be needed, but it's needed to avoid cartopy bugs !
96
    mask = binary_dilation(mask)
×
97
    return mask
×
98

99

100
def get_antimeridian_mask_old(lons, buffer=True):
1✔
101
    """Get mask of longitude coordinates neighbors crossing the antimeridian."""
102
    from scipy.ndimage import binary_dilation
×
103

104
    # Check vertical edges
105
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=0)) > 180)
×
106
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[::-1, :], axis=0)) > 180)
×
107
    row_idx_rev = lons.shape[0] - row_idx_rev - 1
×
108
    row_indices = np.append(row_idx, row_idx_rev)
×
109
    col_indices = np.append(col_idx, col_idx_rev)
×
110
    # Check horizontal
111
    row_idx, col_idx = np.where(np.abs(np.diff(lons, axis=1)) > 180)
×
112
    row_idx_rev, col_idx_rev = np.where(np.abs(np.diff(lons[:, ::-1], axis=1)) > 180)
×
113
    col_idx_rev = lons.shape[1] - col_idx_rev - 1
×
114
    row_indices = np.append(row_indices, np.append(row_idx, row_idx_rev))
×
115
    col_indices = np.append(col_indices, np.append(col_idx, col_idx_rev))
×
116
    # Create mask
117
    mask = np.zeros(lons.shape)
×
118
    mask[row_indices, col_indices] = 1
×
119
    # Buffer by 1 in all directions to ensure edges not crossing the antimeridian
120
    mask = binary_dilation(mask)
×
121
    return mask
×
122

123

124
def get_valid_pcolormesh_inputs(x, y, data, rgb=False):
1✔
125
    """
126
    Fill non-finite values with neighbour valid coordinates.
127

128
    pcolormesh does not accept non-finite values in the coordinates.
129
    This function:
130
    - Infill NaN/Inf in lat/x with closest values
131
    - Mask the corresponding pixels in the data that must not be displayed.
132

133
    If RGB=True, the RGB channels is in the last dimension
134
    """
135
    # TODO:
136
    # - Instead of np.interp, can use nearest neighbors or just 0 to speed up?
137

138
    # Retrieve mask of invalid coordinates
139
    mask = np.logical_or(~np.isfinite(x), ~np.isfinite(y))
×
140

141
    # If no invalid coordinates, return original data
142
    if np.all(~mask):
×
143
        return x, y, data
×
144

145
    # Dilate mask
146
    # mask = dilation(mask, square(2))
147

148
    # Mask the data
149
    if rgb:
×
150
        data_mask = np.broadcast_to(np.expand_dims(mask, axis=-1), data.shape)
×
151
        data_masked = np.ma.masked_where(data_mask, data)
×
152
    else:
153
        data_masked = np.ma.masked_where(mask, data)
×
154

155
    # TODO: should be done in XYZ?
156
    x_dummy = x.copy()
×
157
    x_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), x[~mask])
×
158
    y_dummy = y.copy()
×
159
    y_dummy[mask] = np.interp(np.flatnonzero(mask), np.flatnonzero(~mask), y[~mask])
×
160
    return x_dummy, y_dummy, data_masked
×
161

162

163
####--------------------------------------------------------------------------.
164
def plot_cartopy_background(ax):
1✔
165
    """Plot cartopy background."""
166
    # - Add coastlines
167
    ax.coastlines()
×
168
    ax.add_feature(cartopy.feature.LAND, facecolor=[0.9, 0.9, 0.9])
×
169
    ax.add_feature(cartopy.feature.OCEAN, alpha=0.6)
×
170
    ax.add_feature(cartopy.feature.BORDERS)  # BORDERS also draws provinces, ...
×
171
    # - Add grid lines
172
    gl = ax.gridlines(
×
173
        crs=ccrs.PlateCarree(),
174
        draw_labels=True,
175
        linewidth=1,
176
        color="gray",
177
        alpha=0.1,
178
        linestyle="-",
179
    )
180
    gl.top_labels = False  # gl.xlabels_top = False
×
181
    gl.right_labels = False  # gl.ylabels_right = False
×
182
    gl.xlines = True
×
183
    gl.ylines = True
×
184
    return ax
×
185

186

187
def plot_colorbar(p, ax, cbar_kwargs={}):
1✔
188
    """Add a colorbar to a matplotlib/cartopy plot.
189

190
    cbar_kwargs 'size' and 'pad' controls the size of the colorbar.
191
    and the padding between the plot and the colorbar.
192

193
    p: matplotlib.image.AxesImage
194
    ax:  cartopy.mpl.geoaxes.GeoAxesSubplot^
195
    """
196
    cbar_kwargs = cbar_kwargs.copy()  # otherwise pop ticklabels outside the function
×
197
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
198
    orientation = cbar_kwargs.get("orientation", "vertical")
×
199

200
    divider = make_axes_locatable(ax)
×
201

202
    if orientation == "vertical":
×
203
        size = cbar_kwargs.get("size", "5%")
×
204
        pad = cbar_kwargs.get("pad", 0.1)
×
205
        cax = divider.append_axes("right", size=size, pad=pad, axes_class=plt.Axes)
×
206
    elif orientation == "horizontal":
×
207
        size = cbar_kwargs.get("size", "5%")
×
208
        pad = cbar_kwargs.get("pad", 0.25)
×
209
        cax = divider.append_axes("bottom", size=size, pad=pad, axes_class=plt.Axes)
×
210
    else:
211
        raise ValueError("Invalid orientation. Choose 'vertical' or 'horizontal'.")
×
212

213
    p.figure.add_axes(cax)
×
214
    cbar = plt.colorbar(p, cax=cax, ax=ax, **cbar_kwargs)
×
215
    if ticklabels is not None:
×
216
        if orientation == "vertical":
×
217
            _ = cbar.ax.set_yticklabels(ticklabels)
×
218
        else:  # horizontal
219
            _ = cbar.ax.set_yticklabels(ticklabels)
×
220
    return cbar
×
221

222

223
####--------------------------------------------------------------------------.
224

225

226
def _compute_extent(x_coords, y_coords):
1✔
227
    """
228
    Compute the extent (x_min, x_max, y_min, y_max) from the pixel centroids in x and y coordinates.
229
    This function assumes that the spacing between each pixel is uniform.
230
    """
231
    # Calculate the pixel size assuming uniform spacing between pixels
232
    pixel_size_x = (x_coords[-1] - x_coords[0]) / (len(x_coords) - 1)
×
233
    pixel_size_y = (y_coords[-1] - y_coords[0]) / (len(y_coords) - 1)
×
234

235
    # Adjust min and max to get the corners of the outer pixels
236
    x_min, x_max = x_coords[0] - pixel_size_x / 2, x_coords[-1] + pixel_size_x / 2
×
237
    y_min, y_max = y_coords[0] - pixel_size_y / 2, y_coords[-1] + pixel_size_y / 2
×
238

239
    return [x_min, x_max, y_min, y_max]
×
240

241

242
def _plot_cartopy_imshow(
1✔
243
    ax,
244
    da,
245
    x,
246
    y,
247
    interpolation="nearest",
248
    add_colorbar=True,
249
    plot_kwargs={},
250
    cbar_kwargs={},
251
):
252
    """Plot imshow with cartopy."""
253
    # - Ensure image with correct dimensions orders
254
    da = da.transpose(y, x)
×
255
    arr = np.asanyarray(da.data)
×
256

257
    # - Compute coordinates
258
    x_coords = da[x].values
×
259
    y_coords = da[y].values
×
260

261
    # - Derive extent
262
    extent = _compute_extent(x_coords=x_coords, y_coords=y_coords)
×
263

264
    # - Determine origin based on the orientation of da[y] values
265
    # -->  If increasing, set origin="lower"
266
    # -->  If decreasing, set origin="upper"
267
    origin = "lower" if y_coords[1] > y_coords[0] else "upper"
×
268

269
    # - Add variable field with cartopy
270
    p = ax.imshow(
×
271
        arr,
272
        transform=ccrs.PlateCarree(),
273
        extent=extent,
274
        origin=origin,
275
        interpolation=interpolation,
276
        **plot_kwargs,
277
    )
278
    # - Set the extent
279
    extent = get_extent(da, x="lon", y="lat")
×
280
    ax.set_extent(extent)
×
281

282
    # - Add colorbar
283
    if add_colorbar:
×
284
        # --> TODO: set axis proportion in a meaningful way ...
285
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
286
    return p
×
287

288

289
def _plot_rgb_pcolormesh(x, y, image, ax, **kwargs):
1✔
290
    """Plot xarray RGB DataArray with non uniform-coordinates.
291

292
    Matplotlib, cartopy and xarray pcolormesh currently does not support RGB(A) arrays.
293
    This is a temporary workaround !
294
    """
295
    if image.shape[2] not in [3, 4]:
×
296
        raise ValueError("Expecting RGB or RGB(A) arrays.")
×
297

298
    colorTuple = image.reshape((image.shape[0] * image.shape[1], image.shape[2]))
×
299
    im = ax.pcolormesh(
×
300
        x,
301
        y,
302
        image[:, :, 1],  # dummy to work ...
303
        color=colorTuple,
304
        **kwargs,
305
    )
306
    # im.set_array(None)
307
    return im
×
308

309

310
def _plot_cartopy_pcolormesh(
1✔
311
    ax,
312
    da,
313
    x,
314
    y,
315
    rgb=False,
316
    add_colorbar=True,
317
    plot_kwargs={},
318
    cbar_kwargs={},
319
):
320
    """Plot imshow with cartopy.
321

322
    The function currently does not allow to zoom on regions across the antimeridian.
323
    The function mask scanning pixels which spans across the antimeridian.
324
    If rgb=True, expect rgb dimension to be at last position.
325
    x and y must represents longitude and latitudes.
326
    """
327
    # Get x, y, and array to plot
328
    da = da.compute()
×
329
    x = da[x].data
×
330
    y = da[y].data
×
331
    arr = da.data
×
332

333
    # Infill invalid value and add mask if necessary
334
    x, y, arr = get_valid_pcolormesh_inputs(x, y, arr, rgb=rgb)
×
335

336
    # Ensure arguments
337
    if rgb:
×
338
        add_colorbar = False
×
339

340
    # Compute coordinates of cell corners for pcolormesh quadrilateral mesh
341
    # - This enable correct masking of cells crossing the antimeridian
342
    from gpm_api.utils.area import _get_lonlat_corners
×
343

344
    x, y = _get_lonlat_corners(x, y)
×
345

346
    # Mask cells crossing the antimeridian
347
    # --> Here we assume not invalid coordinates anymore
348
    # --> Cartopy still bugs with several projections when data cross the antimeridian
349
    # --> This flag can be unset with gpm_api.config.set({"viz_hide_antimeridian_data": False})
350
    if gpm_api.config.get("viz_hide_antimeridian_data"):
×
351
        antimeridian_mask = get_antimeridian_mask(x, buffer=True)
×
352
        is_crossing_antimeridian = np.any(antimeridian_mask)
×
353
        if is_crossing_antimeridian:
×
354
            if np.ma.is_masked(arr):
×
355
                if rgb:
×
356
                    data_mask = np.broadcast_to(
×
357
                        np.expand_dims(antimeridian_mask, axis=-1), arr.shape
358
                    )
359
                    combined_mask = np.logical_or(data_mask, antimeridian_mask)
×
360
                else:
361
                    combined_mask = np.logical_or(arr.mask, antimeridian_mask)
×
362
                arr = np.ma.masked_where(combined_mask, arr)
×
363
            else:
364
                arr = np.ma.masked_where(antimeridian_mask, arr)
×
365

366
            # Sanitize cmap bad color to avoid cartopy bug
367
            # - TODO cartopy requires bad_color to be transparent ...
368
            if "cmap" in plot_kwargs:
×
369
                cmap = plot_kwargs["cmap"]
×
370
                bad = cmap.get_bad()
×
371
                bad[3] = 0  # enforce to 0 (transparent)
×
372
                cmap.set_bad(bad)
×
373
                plot_kwargs["cmap"] = cmap
×
374

375
    # Add variable field with cartopy
376
    if not rgb:
×
377
        p = ax.pcolormesh(
×
378
            x,
379
            y,
380
            arr,
381
            transform=ccrs.PlateCarree(),
382
            **plot_kwargs,
383
        )
384

385
    # Add RGB
386
    else:
387
        p = _plot_rgb_pcolormesh(x, y, arr, ax=ax, **plot_kwargs)
×
388

389
    # Add colorbar
390
    # --> TODO: set axis proportion in a meaningful way ...
391
    if add_colorbar:
×
392
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
393
    return p
×
394

395

396
def _plot_mpl_imshow(
1✔
397
    ax,
398
    da,
399
    x,
400
    y,
401
    interpolation="nearest",
402
    add_colorbar=True,
403
    plot_kwargs={},
404
    cbar_kwargs={},
405
):
406
    """Plot imshow with matplotlib."""
407
    # - Ensure image with correct dimensions orders
408
    da = da.transpose(y, x)
×
409
    arr = np.asanyarray(da.data)
×
410

411
    # - Add variable field with matplotlib
412
    p = ax.imshow(
×
413
        arr,
414
        origin="upper",
415
        interpolation=interpolation,
416
        **plot_kwargs,
417
    )
418
    # - Add colorbar
419
    if add_colorbar:
×
420
        # --> TODO: set axis proportion in a meaningful way ...
421
        _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
×
422
    # - Return mappable
423
    return p
×
424

425

426
def set_colorbar_fully_transparent(p):
1✔
427
    """Add a fully transparent colorbar.
428

429
    This is useful for animation where the colorbar should
430
    not always in all frames but the plot area must be fixed.
431
    """
432
    # Get the position of the colorbar
433
    cbar_pos = p.colorbar.ax.get_position()
×
434

435
    cbar_x, cbar_y = cbar_pos.x0, cbar_pos.y0
×
436
    cbar_width, cbar_height = cbar_pos.width, cbar_pos.height
×
437

438
    # Remove the colorbar
439
    p.colorbar.ax.set_visible(False)
×
440

441
    # Now plot an empty rectangle
442
    fig = plt.gcf()
×
443
    rect = plt.Rectangle(
×
444
        (cbar_x, cbar_y),
445
        cbar_width,
446
        cbar_height,
447
        transform=fig.transFigure,
448
        facecolor="none",
449
        edgecolor="none",
450
    )
451

452
    fig.patches.append(rect)
×
453

454

455
def _plot_xr_imshow(
1✔
456
    ax,
457
    da,
458
    x,
459
    y,
460
    interpolation="nearest",
461
    add_colorbar=True,
462
    plot_kwargs={},
463
    cbar_kwargs={},
464
    xarray_colorbar=True,
465
    visible_colorbar=True,
466
):
467
    """Plot imshow with xarray.
468

469
    The colorbar is added with xarray to enable to display multiple colorbars
470
    when calling this function multiple times on different fields with
471
    different colorbars.
472
    """
473
    # --> BUG with colorbar: https://github.com/pydata/xarray/issues/7014
474
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
475
    if not add_colorbar:
×
476
        cbar_kwargs = {}
×
477
    p = da.plot.imshow(
×
478
        x=x,
479
        y=y,
480
        ax=ax,
481
        interpolation=interpolation,
482
        add_colorbar=add_colorbar,
483
        cbar_kwargs=cbar_kwargs,
484
        **plot_kwargs,
485
    )
486
    plt.title(da.name)
×
487
    if add_colorbar and ticklabels is not None:
×
488
        p.colorbar.ax.set_yticklabels(ticklabels)
×
489

490
    # Make the colorbar fully transparent with a smart trick ;)
491
    # - TODO: this still cause issues when plotting 2 colorbars !
492
    if add_colorbar and not visible_colorbar:
×
493
        set_colorbar_fully_transparent(p)
×
494

495
    # Add manually the colorbar
496
    # p = da.plot.imshow(
497
    #     x=x,
498
    #     y=y,
499
    #     ax=ax,
500
    #     interpolation=interpolation,
501
    #     add_colorbar=False,
502
    #     **plot_kwargs,
503
    # )
504
    # plt.title(da.name)
505
    # if add_colorbar:
506
    #     _ = plot_colorbar(p=p, ax=ax, cbar_kwargs=cbar_kwargs)
507
    return p
×
508

509

510
def _plot_xr_pcolormesh(
1✔
511
    ax,
512
    da,
513
    x,
514
    y,
515
    add_colorbar=True,
516
    plot_kwargs={},
517
    cbar_kwargs={},
518
):
519
    """Plot pcolormesh with xarray."""
520
    ticklabels = cbar_kwargs.pop("ticklabels", None)
×
521
    if not add_colorbar:
×
522
        cbar_kwargs = {}
×
523
    p = da.plot.pcolormesh(
×
524
        x=x,
525
        y=y,
526
        ax=ax,
527
        add_colorbar=add_colorbar,
528
        cbar_kwargs=cbar_kwargs,
529
        **plot_kwargs,
530
    )
531
    plt.title(da.name)
×
532
    if add_colorbar and ticklabels is not None:
×
533
        p.colorbar.ax.set_yticklabels(ticklabels)
×
534
    return p
×
535

536

537
####--------------------------------------------------------------------------.
538
#### TODO: doc
539
# figsize, dpi, subplot_kw only used if ax is None
540

541

542
def plot_map(
1✔
543
    da,
544
    x="lon",
545
    y="lat",
546
    ax=None,
547
    add_colorbar=True,
548
    add_swath_lines=True,  # used only for GPM orbit objects
549
    add_background=True,
550
    rgb=False,
551
    interpolation="nearest",  # used only for GPM grid objects
552
    fig_kwargs={},
553
    subplot_kwargs={},
554
    cbar_kwargs={},
555
    **plot_kwargs,
556
):
557
    from gpm_api.checks import is_grid, is_orbit
×
558
    from gpm_api.visualization.grid import plot_grid_map
×
559
    from gpm_api.visualization.orbit import plot_orbit_map
×
560

561
    # Plot orbit
562
    if is_orbit(da):
×
563
        p = plot_orbit_map(
×
564
            da=da,
565
            x=x,
566
            y=y,
567
            ax=ax,
568
            add_colorbar=add_colorbar,
569
            add_swath_lines=add_swath_lines,
570
            add_background=add_background,
571
            rgb=rgb,
572
            fig_kwargs=fig_kwargs,
573
            subplot_kwargs=subplot_kwargs,
574
            cbar_kwargs=cbar_kwargs,
575
            **plot_kwargs,
576
        )
577
    # Plot grid
578
    elif is_grid(da):
×
579
        p = plot_grid_map(
×
580
            da=da,
581
            x=x,
582
            y=y,
583
            ax=ax,
584
            add_colorbar=add_colorbar,
585
            interpolation=interpolation,
586
            add_background=add_background,
587
            fig_kwargs=fig_kwargs,
588
            subplot_kwargs=subplot_kwargs,
589
            cbar_kwargs=cbar_kwargs,
590
            **plot_kwargs,
591
        )
592
    else:
593
        raise ValueError("Can not plot. It's neither a GPM grid, neither a GPM orbit.")
×
594
    # Return mappable
595
    return p
×
596

597

598
def plot_image(
1✔
599
    da,
600
    x=None,
601
    y=None,
602
    ax=None,
603
    add_colorbar=True,
604
    interpolation="nearest",
605
    fig_kwargs={},
606
    cbar_kwargs={},
607
    **plot_kwargs,
608
):
609
    # figsize, dpi, subplot_kw only used if ax is None
610
    from gpm_api.checks import is_grid, is_orbit
×
611
    from gpm_api.visualization.grid import plot_grid_image
×
612
    from gpm_api.visualization.orbit import plot_orbit_image
×
613

614
    # Plot orbit
615
    if is_orbit(da):
×
616
        p = plot_orbit_image(
×
617
            da=da,
618
            x=x,
619
            y=y,
620
            ax=ax,
621
            add_colorbar=add_colorbar,
622
            interpolation=interpolation,
623
            fig_kwargs=fig_kwargs,
624
            cbar_kwargs=cbar_kwargs,
625
            **plot_kwargs,
626
        )
627
    # Plot grid
628
    elif is_grid(da):
×
629
        p = plot_grid_image(
×
630
            da=da,
631
            x=x,
632
            y=y,
633
            ax=ax,
634
            add_colorbar=add_colorbar,
635
            interpolation=interpolation,
636
            fig_kwargs=fig_kwargs,
637
            cbar_kwargs=cbar_kwargs,
638
            **plot_kwargs,
639
        )
640
    else:
641
        raise ValueError("Can not plot. It's neither a GPM GRID, neither a GPM ORBIT.")
×
642
    # Return mappable
643
    return p
×
644

645

646
####--------------------------------------------------------------------------.
647

648

649
def create_grid_mesh_data_array(xr_obj, x, y):
1✔
650
    """
651
    Create a 2D xarray DataArray with mesh coordinates based on the 1D coordinate arrays
652
    from an existing xarray object (Dataset or DataArray).
653

654
    The function creates a 2D grid (mesh) of x and y coordinates and initializes
655
    the data values to NaN.
656

657
    Parameters
658
    ----------
659
    xr_obj : xarray.DataArray or xarray.Dataset
660
        The input xarray object containing the 1D coordinate arrays.
661
    x : str
662
        The name of the x-coordinate in xr_obj.
663
    y : str
664
        The name of the y-coordinate in xr_obj.
665

666
    Returns
667
    -------
668
    da_mesh : xarray.DataArray
669
        A 2D xarray DataArray with mesh coordinates for x and y, and NaN values for data points.
670

671
    Notes
672
    -----
673
    The resulting DataArray has dimensions named 'y' and 'x', corresponding to the y and x coordinates respectively.
674
    The coordinate values are taken directly from the input 1D coordinate arrays, and the data values are set to NaN.
675
    """
676
    # Extract 1D coordinate arrays
677
    x_coords = xr_obj[x].values
×
678
    y_coords = xr_obj[y].values
×
679

680
    # Create 2D meshgrid for x and y coordinates
681
    X, Y = np.meshgrid(x_coords, y_coords, indexing="xy")
×
682

683
    # Create a 2D array of NaN values with the same shape as the meshgrid
684
    dummy_values = np.full(X.shape, np.nan)
×
685

686
    # Create a new DataArray with 2D coordinates and NaN values
687
    da_mesh = xr.DataArray(
×
688
        dummy_values, coords={x: (("y", "x"), X), y: (("y", "x"), Y)}, dims=("y", "x")
689
    )
690
    return da_mesh
×
691

692

693
def plot_map_mesh(
1✔
694
    xr_obj,
695
    x="lon",
696
    y="lat",
697
    ax=None,
698
    edgecolors="k",
699
    linewidth=0.1,
700
    add_background=True,
701
    fig_kwargs={},
702
    subplot_kwargs={},
703
    **plot_kwargs,
704
):
705
    from gpm_api.checks import is_orbit  # is_grid
×
706

707
    from .grid import plot_grid_mesh
×
708
    from .orbit import plot_orbit_mesh
×
709

710
    # Plot orbit
711
    if is_orbit(xr_obj):
×
712
        p = plot_orbit_mesh(
×
713
            da=xr_obj[y],
714
            ax=ax,
715
            x=x,
716
            y=y,
717
            edgecolors=edgecolors,
718
            linewidth=linewidth,
719
            add_background=add_background,
720
            fig_kwargs=fig_kwargs,
721
            subplot_kwargs=subplot_kwargs,
722
            **plot_kwargs,
723
        )
724
    else:  # Plot grid
725
        p = plot_grid_mesh(
×
726
            xr_obj=xr_obj,
727
            x=x,
728
            y=y,
729
            ax=ax,
730
            edgecolors=edgecolors,
731
            linewidth=linewidth,
732
            add_background=add_background,
733
            fig_kwargs=fig_kwargs,
734
            subplot_kwargs=subplot_kwargs,
735
            **plot_kwargs,
736
        )
737
    # Return mappable
738
    return p
×
739

740

741
def plot_map_mesh_centroids(
1✔
742
    xr_obj,
743
    x="lon",
744
    y="lat",
745
    ax=None,
746
    c="r",
747
    s=1,
748
    add_background=True,
749
    fig_kwargs={},
750
    subplot_kwargs={},
751
    **plot_kwargs,
752
):
753
    """Plot GPM orbit granule mesh centroids in a cartographic map."""
754
    from gpm_api.checks import is_grid
×
755

756
    # - Check inputs
757
    _preprocess_figure_args(ax=ax, fig_kwargs=fig_kwargs, subplot_kwargs=subplot_kwargs)
×
758

759
    # - Initialize figure
760
    if ax is None:
×
761
        subplot_kwargs = _preprocess_subplot_kwargs(subplot_kwargs)
×
762
        fig, ax = plt.subplots(subplot_kw=subplot_kwargs, **fig_kwargs)
×
763

764
    # - Add cartopy background
765
    if add_background:
×
766
        ax = plot_cartopy_background(ax)
×
767

768
    # - Retrieve centroids
769
    if is_grid(xr_obj):
×
770
        xr_obj = create_grid_mesh_data_array(xr_obj, x=x, y=y)
×
771
    lon = xr_obj[x].values
×
772
    lat = xr_obj[y].values
×
773

774
    # - Plot centroids
775
    p = ax.scatter(lon, lat, transform=ccrs.PlateCarree(), c=c, s=s, **plot_kwargs)
×
776

777
    # - Return mappable
778
    return p
×
779

780

781
####--------------------------------------------------------------------------.
782

783

784
def _plot_labels(
1✔
785
    xr_obj,
786
    label_name=None,
787
    max_n_labels=50,
788
    add_colorbar=True,
789
    interpolation="nearest",
790
    cmap="Paired",
791
    fig_kwargs={},
792
    **plot_kwargs,
793
):
794
    """Plot labels.
795

796
    The maximum allowed number of labels to plot is 'max_n_labels'.
797
    """
798
    from ximage.labels.labels import get_label_indices, redefine_label_array
×
799
    from ximage.labels.plot_labels import get_label_colorbar_settings
×
800

801
    from gpm_api.visualization.plot import plot_image
×
802

803
    if isinstance(xr_obj, xr.Dataset):
×
804
        dataarray = xr_obj[label_name]
×
805
    else:
806
        if label_name is not None:
×
807
            dataarray = xr_obj[label_name]
×
808
        else:
809
            dataarray = xr_obj
×
810

811
    dataarray = dataarray.compute()
×
812
    label_indices = get_label_indices(dataarray)
×
813
    n_labels = len(label_indices)
×
814
    if add_colorbar and n_labels > max_n_labels:
×
815
        msg = f"""The array currently contains {n_labels} labels
×
816
        and 'max_n_labels' is set to {max_n_labels}. The colorbar is not displayed!"""
817
        print(msg)
×
818
        add_colorbar = False
×
819
    # Relabel array from 1 to ... for plotting
820
    dataarray = redefine_label_array(dataarray, label_indices=label_indices)
×
821
    # Replace 0 with nan
822
    dataarray = dataarray.where(dataarray > 0)
×
823
    # Define appropriate colormap
824
    plot_kwargs, cbar_kwargs = get_label_colorbar_settings(label_indices, cmap="Paired")
×
825
    # Plot image
826
    p = plot_image(
×
827
        dataarray,
828
        interpolation=interpolation,
829
        add_colorbar=add_colorbar,
830
        cbar_kwargs=cbar_kwargs,
831
        fig_kwargs=fig_kwargs,
832
        **plot_kwargs,
833
    )
834
    return p
×
835

836

837
def plot_labels(
1✔
838
    obj,  # Dataset, DataArray or generator
839
    label_name=None,
840
    max_n_labels=50,
841
    add_colorbar=True,
842
    interpolation="nearest",
843
    cmap="Paired",
844
    fig_kwargs={},
845
    **plot_kwargs,
846
):
847
    if is_generator(obj):
×
848
        for label_id, xr_obj in obj:
×
849
            p = _plot_labels(
×
850
                xr_obj=xr_obj,
851
                label_name=label_name,
852
                max_n_labels=max_n_labels,
853
                add_colorbar=add_colorbar,
854
                interpolation=interpolation,
855
                cmap=cmap,
856
                fig_kwargs=fig_kwargs,
857
                **plot_kwargs,
858
            )
859
            plt.show()
×
860
    else:
861
        p = _plot_labels(
×
862
            xr_obj=obj,
863
            label_name=label_name,
864
            max_n_labels=max_n_labels,
865
            add_colorbar=add_colorbar,
866
            interpolation=interpolation,
867
            cmap=cmap,
868
            fig_kwargs=fig_kwargs,
869
            **plot_kwargs,
870
        )
871
    return p
×
872

873

874
def plot_patches(
1✔
875
    patch_gen,
876
    variable=None,
877
    add_colorbar=True,
878
    interpolation="nearest",
879
    fig_kwargs={},
880
    cbar_kwargs={},
881
    **plot_kwargs,
882
):
883
    """Plot patches."""
884
    from gpm_api.visualization.plot import plot_image
×
885

886
    # Plot patches
887
    for label_id, xr_patch in patch_gen:
×
888
        if isinstance(xr_patch, xr.Dataset):
×
889
            if variable is None:
×
890
                raise ValueError("'variable' must be specified when plotting xr.Dataset patches.")
×
891
            xr_patch = xr_patch[variable]
×
892
        try:
×
893
            plot_image(
×
894
                xr_patch,
895
                interpolation=interpolation,
896
                add_colorbar=add_colorbar,
897
                fig_kwargs=fig_kwargs,
898
                cbar_kwargs=cbar_kwargs,
899
                **plot_kwargs,
900
            )
901
            plt.show()
×
902
        except:
×
903
            pass
×
904
    return
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc