• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 9828217984

07 Jul 2024 02:59PM UTC coverage: 81.333%. Remained the same
9828217984

push

github

web-flow
use unittest's `setUpClass` instead of overriding `__init__` (#117)

1965 of 2416 relevant lines covered (81.33%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.85
/deepsensor/plot.py
1
import numpy as np
2✔
2

3
import matplotlib.pyplot as plt
2✔
4
import pandas as pd
2✔
5
from mpl_toolkits.axes_grid1 import make_axes_locatable
2✔
6
import matplotlib.patches as mpatches
2✔
7

8
import lab as B
2✔
9

10
from typing import Optional, Union, List, Tuple
2✔
11

12
from deepsensor.data.task import Task, flatten_X
2✔
13
from deepsensor.data.loader import TaskLoader
2✔
14
from deepsensor.data.processor import DataProcessor
2✔
15
from deepsensor.model.pred import Prediction
2✔
16
from pandas import DataFrame
2✔
17
from matplotlib.colors import Colormap
2✔
18
from matplotlib.axes import Axes
2✔
19

20

21
def task(
2✔
22
    task: Task,
23
    task_loader: TaskLoader,
24
    figsize=3,
25
    markersize=None,
26
    equal_aspect=False,
27
    plot_ticks=False,
28
    extent=None,
29
) -> plt.Figure:
30
    """
31
    Plot the context and target sets of a task.
32

33
    Args:
34
        task (:class:`~.data.task.Task`):
35
            Task to plot.
36
        task_loader (:class:`~.data.loader.TaskLoader`):
37
            Task loader used to load ``task``, containing variable IDs used for
38
            plotting.
39
        figsize (int, optional):
40
            Figure size in inches, by default 3.
41
        markersize (int, optional):
42
            Marker size (in units of points squared), by default None. If None,
43
            the marker size is set to ``(2**2) * figsize / 3``.
44
        equal_aspect (bool, optional):
45
            Whether to set the aspect ratio of the plots to be equal, by
46
            default False.
47
        plot_ticks (bool, optional):
48
            Whether to plot the coordinate ticks on the axes, by default False.
49
        extent (Tuple[int, int, int, int], optional):
50
            Extent of the plot in format (x2_min, x2_max, x1_min, x1_max).
51
            Defaults to None (uses the smallest extent that contains all data points
52
            across all context and target sets).
53

54
    Returns:
55
        :class:`matplotlib:matplotlib.figure.Figure`:
56
    """
57
    if markersize is None:
×
58
        markersize = (2**2) * figsize / 3
×
59

60
    # Scale font size with figure size
61
    fontsize = 10 * figsize / 3
×
62
    params = {
×
63
        "axes.labelsize": fontsize,
64
        "axes.titlesize": fontsize,
65
        "font.size": fontsize,
66
        "figure.titlesize": fontsize,
67
        "legend.fontsize": fontsize,
68
        "xtick.labelsize": fontsize,
69
        "ytick.labelsize": fontsize,
70
    }
71

72
    var_IDs = task_loader.context_var_IDs + task_loader.target_var_IDs
×
73
    Y_c = task["Y_c"]
×
74
    X_c = task["X_c"]
×
75
    if task["Y_t"] is not None:
×
76
        Y_t = task["Y_t"]
×
77
        X_t = task["X_t"]
×
78
    else:
79
        Y_t = []
×
80
        X_t = []
×
81
    n_context = len(Y_c)
×
82
    n_target = len(Y_t)
×
83
    if "Y_t_aux" in task and task["Y_t_aux"] is not None:
×
84
        # Assumes only 1 target set
85
        X_t = X_t + [task["X_t"][-1]]
×
86
        Y_t = Y_t + [task["Y_t_aux"]]
×
87
        var_IDs = var_IDs + (task_loader.aux_at_target_var_IDs,)
×
88
        ncols = n_context + n_target + 1
×
89
    else:
90
        ncols = n_context + n_target
×
91
    nrows = max([Y.shape[0] for Y in Y_c + Y_t])
×
92

93
    if extent is None:
×
94
        x1_min = np.min([np.min(X[0]) for X in X_c + X_t])
×
95
        x1_max = np.max([np.max(X[0]) for X in X_c + X_t])
×
96
        x2_min = np.min([np.min(X[1]) for X in X_c + X_t])
×
97
        x2_max = np.max([np.max(X[1]) for X in X_c + X_t])
×
98
        extent = (x2_min, x2_max, x1_min, x1_max)
×
99

100
    with plt.rc_context(params):
×
101
        fig, axes = plt.subplots(
×
102
            nrows=nrows,
103
            ncols=ncols,
104
            figsize=(ncols * figsize, nrows * figsize),
105
        )
106
        if nrows == 1:
×
107
            axes = axes[np.newaxis]
×
108
        if ncols == 1:
×
109
            axes = axes[:, np.newaxis]
×
110
        # j = loop index over columns/context sets
111
        # i = loop index over rows/variables within context sets
112
        for j, (X, Y) in enumerate(zip(X_c + X_t, Y_c + Y_t)):
×
113
            for i in range(Y.shape[0]):
×
114
                if i == 0:
×
115
                    if j < n_context:
×
116
                        axes[0, j].set_title(f"Context set {j}")
×
117
                    elif j < n_context + n_target:
×
118
                        axes[0, j].set_title(f"Target set {j - n_context}")
×
119
                    else:
120
                        axes[0, j].set_title(f"Auxiliary at targets")
×
121
                if isinstance(X, tuple):
×
122
                    X = flatten_X(X)
×
123
                    Y = Y.reshape(Y.shape[0], -1)
×
124
                axes[i, j].scatter(X[1, :], X[0, :], c=Y[i], s=markersize, marker=".")
×
125
                if equal_aspect:
×
126
                    # Don't warp aspect ratio
127
                    axes[i, j].set_aspect("equal")
×
128
                if not plot_ticks:
×
129
                    axes[i, j].set_xticks([])
×
130
                    axes[i, j].set_yticks([])
×
131
                axes[i, j].set_ylabel(var_IDs[j][i])
×
132

133
                axes[i, j].set_xlim(extent[0], extent[1])
×
134
                axes[i, j].set_ylim(extent[2], extent[3])
×
135

136
                # Add colorbar with same height as axis
137
                divider = make_axes_locatable(axes[i, j])
×
138
                box = axes[i, j].get_position()
×
139
                ratio = 0.3
×
140
                pad = 0.1
×
141
                width = box.width * ratio
×
142
                cax = divider.append_axes("right", size=width, pad=pad)
×
143
                fig.colorbar(axes[i, j].collections[0], cax=cax)
×
144

145
            for i in range(Y.shape[0], nrows):
×
146
                axes[i, j].axis("off")
×
147

148
        plt.tight_layout()
×
149

150
    return fig
×
151

152

153
def context_encoding(
2✔
154
    model,
155
    task: Task,
156
    task_loader: TaskLoader,
157
    batch_idx: int = 0,
158
    context_set_idxs: Optional[Union[List[int], int]] = None,
159
    land_idx: Optional[int] = None,
160
    cbar: bool = True,
161
    clim: Optional[Tuple] = None,
162
    cmap: Union[str, Colormap] = "viridis",
163
    verbose_titles: bool = True,
164
    titles: Optional[dict] = None,
165
    size: int = 3,
166
    return_axes: bool = False,
167
):
168
    """Plot the ``ConvNP`` SetConv encoding of a context set in a task.
169

170
    Args:
171
        model (:class:`~.model.convnp.ConvNP`):
172
            ConvNP model.
173
        task (:class:`~.data.task.Task`):
174
            Task containing context set to plot encoding of ...
175
        task_loader (:class:`~.data.loader.TaskLoader`):
176
            DataLoader used to load the data, containing context set metadata
177
            used for plotting.
178
        batch_idx (int, optional):
179
            Batch index in encoding to plot, by default 0.
180
        context_set_idxs (List[int] | int, optional):
181
            Indices of context sets to plot, by default None (plots all context
182
            sets).
183
        land_idx (int, optional):
184
            Index of the land mask in the encoding (used to overlay land
185
            contour on plots), by default None.
186
        cbar (bool, optional):
187
            Whether to add a colorbar to the plots, by default True.
188
        clim (tuple, optional):
189
            Colorbar limits, by default None.
190
        cmap (str | matplotlib.colors.Colormap, optional):
191
            Color map to use for the plots, by default "viridis".
192
        verbose_titles (bool, optional):
193
            Whether to include verbose titles for the variable IDs in the
194
            context set (including the time index), by default True.
195
        titles (dict, optional):
196
            Dict of titles to override for each subplot, by default None. If
197
            None, titles are generated from context set metadata.
198
        size (int, optional):
199
            Size of the figure in inches, by default 3.
200
        return_axes (bool, optional):
201
            Whether to return the axes of the figure, by default False.
202

203
    Returns:
204
        :obj:`matplotlib.figure.Figure` | Tuple[:obj:`matplotlib.figure.Figure`, :obj:`matplotlib.pyplot.Axes`]:
205
            Either a figure containing the context set encoding plots, or a
206
            tuple containing the :obj:`figure <matplotlib.figure.Figure>` and
207
            the :obj:`axes <matplotlib.axes.Axes>` of the figure (if
208
            ``return_axes`` was set to ``True``).
209
    """
210
    from .model.nps import compute_encoding_tensor
2✔
211

212
    encoding_tensor = compute_encoding_tensor(model, task)
2✔
213
    encoding_tensor = encoding_tensor[batch_idx]
2✔
214

215
    if isinstance(context_set_idxs, int):
2✔
216
        context_set_idxs = [context_set_idxs]
×
217
    if context_set_idxs is None:
2✔
218
        context_set_idxs = np.array(range(len(task_loader.context_dims)))
2✔
219

220
    context_var_ID_set_sizes = [
2✔
221
        ndim + 1 for ndim in np.array(task_loader.context_dims)[context_set_idxs]
222
    ]  # Add density channel to each set size
223
    max_context_set_size = max(context_var_ID_set_sizes)
2✔
224
    ncols = max_context_set_size
2✔
225
    nrows = len(context_set_idxs)
2✔
226

227
    figsize = (ncols * size, nrows * size)
2✔
228
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
2✔
229
    if nrows == 1:
2✔
230
        axes = axes[np.newaxis]
2✔
231

232
    ctx_channel_idxs = np.cumsum(np.array(task_loader.context_dims) + 1)
2✔
233

234
    for row_i, ctx_i in enumerate(context_set_idxs):
2✔
235
        channel_i = (
2✔
236
            ctx_channel_idxs[ctx_i - 1] if ctx_i > 0 else 0
237
        )  # Starting channel index
238
        if verbose_titles:
2✔
239
            var_IDs = task_loader.context_var_IDs_and_delta_t[ctx_i]
2✔
240
        else:
241
            var_IDs = task_loader.context_var_IDs[ctx_i]
×
242

243
        ncols_row_i = task_loader.context_dims[ctx_i] + 1  # Add density channel
2✔
244
        for col_i in range(ncols_row_i):
2✔
245
            ax = axes[row_i, col_i]
2✔
246
            # Need `origin="lower"` because encoding has `x1` increasing from top to bottom,
247
            # whereas in visualisations we want `x1` increasing from bottom to top.
248

249
            im = ax.imshow(
2✔
250
                encoding_tensor[channel_i],
251
                origin="lower",
252
                clim=clim,
253
                cmap=cmap,
254
            )
255
            if titles is not None:
2✔
256
                ax.set_title(titles[channel_i])
×
257
            elif col_i == 0:
2✔
258
                ax.set_title(f"Density {ctx_i}")
2✔
259
            elif col_i > 0:
2✔
260
                ax.set_title(f"{var_IDs[col_i - 1]}")
2✔
261
            if col_i == 0:
2✔
262
                ax.set_ylabel(f"Context set {ctx_i}")
2✔
263
            if cbar:
2✔
264
                divider = make_axes_locatable(ax)
2✔
265
                cax = divider.append_axes("right", size="5%", pad=0.05)
2✔
266
                plt.colorbar(im, cax)
2✔
267
            ax.patch.set_edgecolor("black")
2✔
268
            ax.patch.set_linewidth(1)
2✔
269
            if land_idx is not None:
2✔
270
                ax.contour(
×
271
                    encoding_tensor[land_idx],
272
                    colors="k",
273
                    levels=[0.5],
274
                    origin="lower",
275
                )
276
            ax.tick_params(
2✔
277
                which="both",
278
                bottom=False,
279
                left=False,
280
                labelbottom=False,
281
                labelleft=False,
282
            )
283
            channel_i += 1
2✔
284
        for col_i in range(ncols_row_i, ncols):
2✔
285
            # Hide unused axes
286
            ax = axes[ctx_i, col_i]
×
287
            ax.axis("off")
×
288

289
    plt.tight_layout()
2✔
290
    if not return_axes:
2✔
291
        return fig
2✔
292
    elif return_axes:
×
293
        return fig, axes
×
294

295

296
def offgrid_context(
2✔
297
    axes: Union[np.ndarray, List[plt.Axes], Tuple[plt.Axes]],
298
    task: Task,
299
    data_processor: Optional[DataProcessor] = None,
300
    task_loader: Optional[TaskLoader] = None,
301
    plot_target: bool = False,
302
    add_legend: bool = True,
303
    context_set_idxs: Optional[Union[List[int], int]] = None,
304
    markers: Optional[str] = None,
305
    colors: Optional[str] = None,
306
    **scatter_kwargs,
307
) -> None:
308
    """
309
    Plot the off-grid context points on ``axes``.
310

311
    Uses a provided :class:`~.data.processor.DataProcessor` to unnormalise the
312
    context coordinates if provided.
313

314
    Args:
315
        axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]):
316
            Axes to plot on.
317
        task (:class:`~.data.task.Task`):
318
            Task containing the context set to plot.
319
        data_processor (:class:`~.data.processor.DataProcessor`, optional):
320
            Data processor used to unnormalise the context set, by default
321
            None.
322
        task_loader (:class:`~.data.loader.TaskLoader`, optional):
323
            Task loader used to load the data, containing context set metadata
324
            used for plotting, by default None.
325
        plot_target (bool, optional):
326
            Whether to plot the target set, by default False.
327
        add_legend (bool, optional):
328
            Whether to add a legend to the plot, by default True.
329
        context_set_idxs (List[int] | int, optional):
330
            Indices of context sets to plot, by default None (plots all context
331
            sets).
332
        markers (str, optional):
333
            Marker styles to use for each context set, by default None.
334
        colors (str, optional):
335
            Colors to use for each context set, by default None.
336
        scatter_kwargs:
337
            Additional keyword arguments to pass to the scatter plot.
338

339
    Returns:
340
        None
341
    """
342
    if markers is None:
2✔
343
        # all matplotlib markers
344
        markers = "ovs^Dxv<>1234spP*hHd|_"
2✔
345
    if colors is None:
2✔
346
        # all one-letter matplotlib colors
347
        colors = "kbrgy" * 10
2✔
348

349
    if isinstance(context_set_idxs, int):
2✔
350
        context_set_idxs = [context_set_idxs]
×
351

352
    if type(axes) is np.ndarray:
2✔
353
        axes = axes.ravel()
×
354
    elif not isinstance(axes, (list, tuple)):
2✔
355
        axes = [axes]
2✔
356

357
    if plot_target:
2✔
358
        X = [*task["X_c"], *task["X_t"]]
×
359
    else:
360
        X = task["X_c"]
2✔
361

362
    for set_i, X in enumerate(X):
2✔
363
        if context_set_idxs is not None and set_i not in context_set_idxs:
2✔
364
            continue
×
365

366
        if isinstance(X, tuple):
2✔
367
            continue  # Don't plot gridded context data locations
×
368
        if X.ndim == 3:
2✔
369
            X = X[0]  # select first batch
×
370

371
        if data_processor is not None:
2✔
372
            x1, x2 = data_processor.map_x1_and_x2(X[0], X[1], unnorm=True)
2✔
373
            X = np.stack([x1, x2], axis=0)
2✔
374

375
        X = X[::-1]  # flip 2D coords for Cartesian fmt
2✔
376

377
        label = ""
2✔
378
        if plot_target and set_i < len(task["X_c"]):
2✔
379
            label += f"Context set {set_i} "
×
380
            if task_loader is not None:
×
381
                label += f"({task_loader.context_var_IDs[set_i]})"
×
382
        elif plot_target and set_i >= len(task["X_c"]):
2✔
383
            label += f"Target set {set_i - len(task['X_c'])} "
×
384
            if task_loader is not None:
×
385
                label += f"({task_loader.target_var_IDs[set_i - len(task['X_c'])]})"
×
386

387
        for ax in axes:
2✔
388
            ax.scatter(
2✔
389
                *X,
390
                marker=markers[set_i],
391
                color=colors[set_i],
392
                **scatter_kwargs,
393
                facecolors=None if markers[set_i] == "x" else "none",
394
                label=label,
395
            )
396

397
    if add_legend:
2✔
398
        axes[0].legend(loc="best")
2✔
399

400

401
def offgrid_context_observations(
2✔
402
    axes: Union[np.ndarray, List[plt.Axes], Tuple[plt.Axes]],
403
    task: Task,
404
    data_processor: DataProcessor,
405
    task_loader: TaskLoader,
406
    context_set_idx: int,
407
    format_str: Optional[str] = None,
408
    extent: Optional[Tuple[int, int, int, int]] = None,
409
    color: str = "black",
410
) -> None:
411
    """
412
    Plot unnormalised context observation values.
413

414
    Args:
415
        axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]):
416
            Axes to plot on.
417
        task (:class:`~.data.task.Task`):
418
            Task containing the context set to plot.
419
        data_processor (:class:`~.data.processor.DataProcessor`):
420
            Data processor used to unnormalise the context set.
421
        task_loader (:class:`~.data.loader.TaskLoader`):
422
            Task loader used to load the data, containing context set metadata
423
            used for plotting.
424
        context_set_idx (int):
425
            Index of the context set to plot.
426
        format_str (str, optional):
427
            Format string for the context observation values. By default
428
            ``"{:.2f}"``.
429
        extent (Tuple[int, int, int, int], optional):
430
            Extent of the plot, by default None.
431
        color (str, optional):
432
            Color of the text, by default "black".
433

434
    Returns:
435
        None.
436

437
    Raises:
438
        AssertionError:
439
            If the context set is gridded.
440
        AssertionError:
441
            If the context set is not 1D.
442
        AssertionError:
443
            If the task's "Y_c" value for the context set ID is not 2D.
444
        AssertionError:
445
            If the task's "Y_c" value for the context set ID does not have
446
            exactly one variable.
447
    """
448
    if type(axes) is np.ndarray:
2✔
449
        axes = axes.ravel()
×
450
    elif not isinstance(axes, (list, tuple)):
2✔
451
        axes = [axes]
2✔
452

453
    if format_str is None:
2✔
454
        format_str = "{:.2f}"
2✔
455

456
    var_ID = task_loader.context_var_IDs[
2✔
457
        context_set_idx
458
    ]  # Tuple of variable IDs for the context set
459
    assert (
2✔
460
        len(var_ID) == 1
461
    ), "Plotting context observations only supported for single-variable (1D) context sets"
462
    var_ID = var_ID[0]
2✔
463

464
    X_c = task["X_c"][context_set_idx]
2✔
465
    assert not isinstance(
2✔
466
        X_c, tuple
467
    ), f"The context set must not be gridded but is of type {type(X_c)} for context set at index {context_set_idx}"
468
    X_c = data_processor.map_coord_array(X_c, unnorm=True)
2✔
469

470
    Y_c = task["Y_c"][context_set_idx]
2✔
471
    assert Y_c.ndim == 2
2✔
472
    assert Y_c.shape[0] == 1
2✔
473
    Y_c = data_processor.map_array(Y_c, var_ID, unnorm=True).ravel()
2✔
474

475
    for x_c, y_c in zip(X_c.T, Y_c):
2✔
476
        if extent is not None:
2✔
477
            if not (
×
478
                extent[0] <= x_c[0] <= extent[1] and extent[2] <= x_c[1] <= extent[3]
479
            ):
480
                continue
×
481
        for ax in axes:
2✔
482
            ax.text(*x_c[::-1], format_str.format(float(y_c)), color=color)
2✔
483

484

485
def receptive_field(
486
    receptive_field,
487
    data_processor: DataProcessor,
488
    crs,
489
    extent: Union[str, Tuple[float, float, float, float]] = "global",
490
) -> plt.Figure:  # pragma: no cover
491
    """
492
    ...
493

494
    Args:
495
        receptive_field (...):
496
            Receptive field to plot.
497
        data_processor (:class:`~.data.processor.DataProcessor`):
498
            Data processor used to unnormalise the context set.
499
        crs (cartopy CRS):
500
            Coordinate reference system for the plots.
501
        extent (str | Tuple[float, float, float, float], optional):
502
            Extent of the plot, in format (x2_min, x2_max, x1_min, x1_max), e.g. in
503
            lat-lon format (lon_min, lon_max, lat_min, lat_max). By default "global".
504

505
    Returns:
506
        None.
507
    """
508
    fig, ax = plt.subplots(subplot_kw=dict(projection=crs))
509

510
    if isinstance(extent, str):
511
        extent = extent_str_to_tuple(extent)
512
    else:
513
        extent = tuple([float(x) for x in extent])
514
    x2_min, x2_max, x1_min, x1_max = extent
515
    ax.set_extent(extent, crs=crs)
516

517
    x11, x12 = data_processor.config["coords"]["x1"]["map"]
518
    x21, x22 = data_processor.config["coords"]["x2"]["map"]
519

520
    x1_rf_raw = receptive_field * (x12 - x11)
521
    x2_rf_raw = receptive_field * (x22 - x21)
522

523
    x1_midpoint_raw = (x1_max + x1_min) / 2
524
    x2_midpoint_raw = (x2_max + x2_min) / 2
525

526
    # Compute bottom left corner of receptive field
527
    x1_corner = x1_midpoint_raw - x1_rf_raw / 2
528
    x2_corner = x2_midpoint_raw - x2_rf_raw / 2
529

530
    ax.add_patch(
531
        mpatches.Rectangle(
532
            xy=[x2_corner, x1_corner],  # Cartesian fmt: x2, x1
533
            width=x2_rf_raw,
534
            height=x1_rf_raw,
535
            facecolor="black",
536
            alpha=0.3,
537
            transform=crs,
538
        )
539
    )
540
    ax.coastlines()
541
    ax.gridlines(draw_labels=True, alpha=0.2)
542

543
    x1_name = data_processor.config["coords"]["x1"]["name"]
544
    x2_name = data_processor.config["coords"]["x2"]["name"]
545
    ax.set_title(
546
        f"Receptive field in raw coords: {x1_name}={x1_rf_raw:.2f}, "
547
        f"{x2_name}={x2_rf_raw:.2f}"
548
    )
549

550
    return fig
551

552

553
def feature_maps(
2✔
554
    model,
555
    task: Task,
556
    n_features_per_layer: int = 1,
557
    seed: Optional[int] = None,
558
    figsize: int = 3,
559
    add_colorbar: bool = False,
560
    cmap: Union[str, Colormap] = "Greys",
561
) -> plt.Figure:
562
    """
563
    Plot the feature maps of a ``ConvNP`` model's decoder layers after a
564
    forward pass with a ``Task``.
565

566
    Args:
567
        model (:class:`~.model.model.convnp.ConvNP`):
568
            ...
569
        task (:class:`~.data.task.Task`):
570
            ...
571
        n_features_per_layer (int, optional):
572
            ..., by default 1.
573
        seed (int, optional):
574
            ..., by default None.
575
        figsize (int, optional):
576
            ..., by default 3.
577
        add_colorbar (bool, optional):
578
            ..., by default False.
579
        cmap (str | matplotlib.colors.Colormap, optional):
580
            ..., by default "Greys".
581

582
    Returns:
583
        matplotlib.figure.Figure:
584
            A figure containing the feature maps.
585

586
    Raises:
587
        ValueError:
588
            If the backend is not recognised.
589
    """
590
    from .model.nps import compute_encoding_tensor
2✔
591

592
    import deepsensor
2✔
593

594
    # Hacky way to load the correct __init__.py to get `convert_to_tensor` method
595
    if deepsensor.backend.str == "tf":
2✔
596
        import deepsensor.tensorflow as deepsensor
×
597
    elif deepsensor.backend.str == "torch":
2✔
598
        import deepsensor.torch as deepsensor
2✔
599
    else:
600
        raise ValueError(f"Unknown backend: {deepsensor.backend.str}")
×
601

602
    unet = model.model.decoder[0]
2✔
603

604
    # Produce encoding
605
    x = deepsensor.convert_to_tensor(compute_encoding_tensor(model, task))
2✔
606

607
    # Manually construct the U-Net forward pass from
608
    # `neuralprocesses.construct_convgnp` to get the feature maps
609
    def unet_forward(unet, x):
2✔
610
        feature_maps = []
2✔
611

612
        h = unet.activations[0](unet.before_turn_layers[0](x))
2✔
613
        hs = [h]
2✔
614
        feature_map = B.to_numpy(h)
2✔
615
        feature_maps.append(feature_map)
2✔
616
        for layer, activation in zip(
2✔
617
            unet.before_turn_layers[1:],
618
            unet.activations[1:],
619
        ):
620
            h = activation(layer(hs[-1]))
2✔
621
            hs.append(h)
2✔
622
            feature_map = B.to_numpy(h)
2✔
623
            feature_maps.append(feature_map)
2✔
624

625
        # Now make the turn!
626

627
        h = unet.activations[-1](unet.after_turn_layers[-1](hs[-1]))
2✔
628
        feature_map = B.to_numpy(h)
2✔
629
        feature_maps.append(feature_map)
2✔
630
        for h_prev, layer, activation in zip(
2✔
631
            reversed(hs[:-1]),
632
            reversed(unet.after_turn_layers[:-1]),
633
            reversed(unet.activations[:-1]),
634
        ):
635
            h = activation(layer(B.concat(h_prev, h, axis=1)))
2✔
636
            feature_map = B.to_numpy(h)
2✔
637
            feature_maps.append(feature_map)
2✔
638

639
        h = unet.final_linear(h)
2✔
640
        feature_map = B.to_numpy(h)
2✔
641
        feature_maps.append(feature_map)
2✔
642

643
        return feature_maps
2✔
644

645
    feature_maps = unet_forward(unet, x)
2✔
646

647
    figs = []
2✔
648
    rng = np.random.default_rng(seed)
2✔
649
    for layer_i, feature_map in enumerate(feature_maps):
2✔
650
        n_features = feature_map.shape[1]
2✔
651
        n_features_to_plot = min(n_features_per_layer, n_features)
2✔
652
        feature_idxs = rng.choice(n_features, n_features_to_plot, replace=False)
2✔
653

654
        fig, axes = plt.subplots(
2✔
655
            nrows=1,
656
            ncols=n_features_to_plot,
657
            figsize=(figsize * n_features_to_plot, figsize),
658
        )
659
        if n_features_to_plot == 1:
2✔
660
            axes = [axes]
2✔
661
        for f_i, ax in zip(feature_idxs, axes):
2✔
662
            fm = feature_map[0, f_i]
2✔
663
            im = ax.imshow(fm, origin="lower", cmap=cmap)
2✔
664
            ax.set_title(f"Feature {f_i}", fontsize=figsize * 15 / 4)
2✔
665
            ax.tick_params(
2✔
666
                which="both",
667
                bottom=False,
668
                left=False,
669
                labelbottom=False,
670
                labelleft=False,
671
            )
672
            if add_colorbar:
2✔
673
                cbar = ax.figure.colorbar(im, ax=ax, format="%.2f")
×
674

675
        fig.suptitle(
2✔
676
            f"Layer {layer_i} feature map. Shape: {feature_map.shape}. Min={np.min(feature_map):.2f}, Max={np.max(feature_map):.2f}.",
677
            fontsize=figsize * 15 / 4,
678
        )
679
        plt.tight_layout()
2✔
680
        plt.subplots_adjust(top=0.75)
2✔
681
        figs.append(fig)
2✔
682

683
    return figs
2✔
684

685

686
def placements(
687
    task: Task,
688
    X_new_df: DataFrame,
689
    data_processor: DataProcessor,
690
    crs,
691
    extent: Optional[Union[Tuple[int, int, int, int], str]] = None,
692
    figsize: int = 3,
693
    **scatter_kwargs,
694
) -> plt.Figure:  # pragma: no cover
695
    """
696
    ...
697

698
    Args:
699
        task (:class:`~.data.task.Task`):
700
            Task containing the context set used to compute the acquisition
701
            function.
702
        X_new_df (:class:`pandas.DataFrame`):
703
            Dataframe containing the placement locations.
704
        data_processor (:class:`~.data.processor.DataProcessor`):
705
            Data processor used to unnormalise the context set and placement
706
            locations.
707
        crs (cartopy CRS):
708
            Coordinate reference system for the plots.
709
        extent (Tuple[int, int, int, int] | str, optional):
710
            Extent of the plots, by default None.
711
        figsize (int, optional):
712
            Figure size in inches, by default 3.
713

714
    Returns:
715
        :class:`matplotlib:matplotlib.figure.Figure`
716
            A figure containing the placement plots.
717
    """
718
    fig, ax = plt.subplots(subplot_kw={"projection": crs}, figsize=(figsize, figsize))
719
    ax.scatter(*X_new_df.values.T[::-1], c="r", linewidths=0.5, **scatter_kwargs)
720
    offgrid_context(ax, task, data_processor, linewidths=0.5, **scatter_kwargs)
721

722
    ax.coastlines()
723
    if extent is None:
724
        pass
725
    elif extent == "global":
726
        ax.set_global()
727
    else:
728
        ax.set_extent(extent, crs=crs)
729

730
    return fig
731

732

733
def acquisition_fn(
734
    task: Task,
735
    acquisition_fn_ds: np.ndarray,
736
    X_new_df: DataFrame,
737
    data_processor: DataProcessor,
738
    crs,
739
    col_dim: str = "iteration",
740
    cmap: Union[str, Colormap] = "Greys_r",
741
    figsize: int = 3,
742
    add_colorbar: bool = True,
743
    max_ncol: int = 5,
744
) -> plt.Figure:  # pragma: no cover
745
    """
746

747
    Args:
748
        task (:class:`~.data.task.Task`):
749
            Task containing the context set used to compute the acquisition
750
            function.
751
        acquisition_fn_ds (:class:`numpy:numpy.ndarray`):
752
            Acquisition function dataset.
753
        X_new_df (:class:`pandas.DataFrame`):
754
            Dataframe containing the placement locations.
755
        data_processor (:class:`~.data.processor.DataProcessor`):
756
            Data processor used to unnormalise the context set and placement
757
            locations.
758
        crs (cartopy CRS):
759
            Coordinate reference system for the plots.
760
        col_dim (str, optional):
761
            Column dimension to plot over, by default "iteration".
762
        cmap (str | matplotlib.colors.Colormap, optional):
763
            Color map to use for the plots, by default "Greys_r".
764
        figsize (int, optional):
765
            Figure size in inches, by default 3.
766
        add_colorbar (bool, optional):
767
            Whether to add a colorbar to the plots, by default True.
768
        max_ncol (int, optional):
769
            Maximum number of columns to use for the plots, by default 5.
770

771
    Returns:
772
        matplotlib.pyplot.Figure
773
            A figure containing the acquisition function plots.
774

775
    Raises:
776
        ValueError:
777
            If a column dimension is encountered that is not one of
778
            ``["time", "sample"]``.
779
        AssertionError:
780
            If the number of columns in the acquisition function dataset is
781
            greater than ``max_ncol``.
782
    """
783
    # Remove spatial dims using data_processor.raw_spatial_coords_names
784
    plot_dims = [col_dim, *data_processor.raw_spatial_coord_names]
785
    non_plot_dims = [dim for dim in acquisition_fn_ds.dims if dim not in plot_dims]
786
    valid_avg_dims = ["time", "sample"]
787
    for dim in non_plot_dims:
788
        if dim not in valid_avg_dims:
789
            raise ValueError(
790
                f"Cannot average over dim {dim} for plotting. Must be one of "
791
                f"{valid_avg_dims}. Select a single value for {dim} using "
792
                f"`acquisition_fn_ds.sel({dim}=...)`."
793
            )
794
    if len(non_plot_dims) > 0:
795
        # Average over non-plot dims
796
        print(
797
            "Averaging acquisition function over dims for plotting: " f"{non_plot_dims}"
798
        )
799
        acquisition_fn_ds = acquisition_fn_ds.mean(dim=non_plot_dims)
800

801
    col_vals = acquisition_fn_ds[col_dim].values
802
    if col_vals.size == 1:
803
        n_col_vals = 1
804
    else:
805
        n_col_vals = len(col_vals)
806
    ncols = np.min([max_ncol, n_col_vals])
807

808
    if n_col_vals > ncols:
809
        nrows = int(np.ceil(n_col_vals / ncols))
810
    else:
811
        nrows = 1
812

813
    fig, axes = plt.subplots(
814
        subplot_kw={"projection": crs},
815
        ncols=ncols,
816
        nrows=nrows,
817
        figsize=(figsize * ncols, figsize * nrows),
818
    )
819
    if nrows == 1 and ncols == 1:
820
        axes = [axes]
821
    else:
822
        axes = axes.ravel()
823
    if add_colorbar:
824
        min, max = acquisition_fn_ds.min(), acquisition_fn_ds.max()
825
    else:
826
        # Use different colour scales for each plot
827
        min, max = None, None
828
    for i, col_val in enumerate(col_vals):
829
        ax = axes[i]
830
        if i == len(col_vals) - 1:
831
            final_axis = True
832
        else:
833
            final_axis = False
834
        acquisition_fn_ds.sel(**{col_dim: col_val}).plot(
835
            ax=ax, cmap=cmap, vmin=min, vmax=max, add_colorbar=False
836
        )
837
        if add_colorbar and final_axis:
838
            im = ax.get_children()[0]
839
            label = acquisition_fn_ds.name
840
            cax = plt.axes([0.93, 0.035, 0.02, 0.91])  # add a small custom axis
841
            cbar = plt.colorbar(
842
                im, cax=cax, label=label
843
            )  # specify axis for colorbar to occupy with cax
844
        ax.set_title(f"{col_dim}={col_val}")
845
        ax.coastlines()
846
        if col_dim == "iteration":
847
            X_new_df_plot = X_new_df.loc[slice(0, col_val)].values.T[::-1]
848
        else:
849
            # Assumed plotting single iteration
850
            iter = acquisition_fn_ds.iteration.values
851
            assert iter.size == 1, "Expected single iteration"
852
            X_new_df_plot = X_new_df.loc[slice(0, iter.item())].values.T[::-1]
853
        ax.scatter(
854
            *X_new_df_plot,
855
            c="r",
856
            linewidths=0.5,
857
        )
858

859
    offgrid_context(axes, task, data_processor, linewidths=0.5)
860

861
    # Remove any unused axes
862
    for ax in axes[len(col_vals) :]:
863
        ax.remove()
864

865
    return fig
866

867

868
def prediction(
869
    pred: Prediction,
870
    date: Optional[Union[str, pd.Timestamp]] = None,
871
    data_processor: Optional[DataProcessor] = None,
872
    task_loader: Optional[TaskLoader] = None,
873
    task: Optional[Task] = None,
874
    prediction_parameters: Union[List[str], str] = "all",
875
    crs=None,
876
    colorbar: bool = True,
877
    cmap: str = "viridis",
878
    size: int = 5,
879
    extent: Optional[Union[Tuple[float, float, float, float], str]] = None,
880
) -> plt.Figure:  # pragma: no cover
881
    """
882
    Plot the mean and standard deviation of a prediction.
883

884
    Args:
885
        pred (:class:`~.model.prediction.Prediction`):
886
            Prediction to plot.
887
        date (str | :class:`pandas:pandas.Timestamp`):
888
            Date of the prediction.
889
        data_processor (:class:`~.data.processor.DataProcessor`):
890
            Data processor used to unnormalise the context set.
891
        task_loader (:class:`~.data.loader.TaskLoader`):
892
            Task loader used to load the data, containing context set metadata
893
            used for plotting.
894
        task (:class:`~.data.task.Task`, optional):
895
            Task containing the context data to overlay.
896
        prediction_parameters (List[str] | str, optional):
897
            Prediction parameters to plot, by default "all".
898
        crs (cartopy CRS, optional):
899
            Coordinate reference system for the plots, by default None.
900
        colorbar (bool, optional):
901
            Whether to add a colorbar to the plots, by default True.
902
        cmap (str):
903
            Colormap to use for the plots. By default "viridis".
904
        size (int, optional):
905
            Size of the figure in inches per axis, by default 5.
906
        extent: (tuple | str, optional):
907
            Tuple of (lon_min, lon_max, lat_min, lat_max) or string of region name.
908
            Options are: "global", "usa", "uk", "europe". Defaults to None (no
909
            setting of extent).
910
        c
911
    """
912
    if pred.mode == "off-grid":
913
        assert date is None, "Cannot pass a `date` for off-grid predictions"
914
        assert (
915
            data_processor is None
916
        ), "Cannot pass a `data_processor` for off-grid predictions"
917
        assert (
918
            task_loader is None
919
        ), "Cannot pass a `task_loader` for off-grid predictions"
920
        assert task is None, "Cannot pass a `task` for off-grid predictions"
921
        assert crs is None, "Cannot pass a `crs` for off-grid predictions"
922

923
    x1_name = pred.x1_name
924
    x2_name = pred.x2_name
925

926
    if prediction_parameters == "all":
927
        prediction_parameters = {
928
            var_ID: [param for param in pred[var_ID]] for var_ID in pred
929
        }
930
    else:
931
        prediction_parameters = {var_ID: prediction_parameters for var_ID in pred}
932

933
    n_vars = len(pred.target_var_IDs)
934
    n_params = max(len(params) for params in prediction_parameters.values())
935

936
    if isinstance(extent, str):
937
        extent = extent_str_to_tuple(extent)
938
    elif isinstance(extent, tuple):
939
        extent = tuple([float(x) for x in extent])
940

941
    fig, axes = plt.subplots(
942
        n_vars,
943
        n_params,
944
        figsize=(size * n_params, size * n_vars),
945
        subplot_kw=dict(projection=crs),
946
    )
947
    axes = np.array(axes)
948
    if n_vars == 1:
949
        axes = np.expand_dims(axes, axis=0)
950
    if n_params == 1:
951
        axes = np.expand_dims(axes, axis=1)
952
    for row_i, var_ID in enumerate(pred.target_var_IDs):
953
        for col_i, param in enumerate(prediction_parameters[var_ID]):
954
            ax = axes[row_i, col_i]
955

956
            if pred.mode == "on-grid":
957
                if param == "std":
958
                    vmin = 0
959
                else:
960
                    vmin = None
961
                pred[var_ID][param].sel(time=date).plot(
962
                    ax=ax,
963
                    cmap=cmap,
964
                    vmin=vmin,
965
                    add_colorbar=False,
966
                    center=False,
967
                )
968
                # ax.set_aspect("auto")
969
                if colorbar:
970
                    im = ax.get_children()[0]
971
                    # add axis to right
972
                    cax = fig.add_axes(
973
                        [
974
                            ax.get_position().x1 + 0.01,
975
                            ax.get_position().y0,
976
                            0.02,
977
                            ax.get_position().height,
978
                        ]
979
                    )
980
                    cbar = plt.colorbar(
981
                        im, cax=cax
982
                    )  # specify axis for colorbar to occupy with cax
983
                if task is not None:
984
                    offgrid_context(
985
                        ax,
986
                        task,
987
                        data_processor,
988
                        task_loader,
989
                        linewidths=0.5,
990
                        add_legend=False,
991
                    )
992
                if crs is not None:
993
                    da = pred[var_ID][param]
994
                    ax.coastlines()
995
                    import cartopy.feature as cfeature
996

997
                    ax.add_feature(cfeature.BORDERS)
998
                    # ax.set_extent(
999
                    #     [da["lon"].min(), da["lon"].max(), da["lat"].min(), da["lat"].max()]
1000
                    # )
1001

1002
            elif pred.mode == "off-grid":
1003
                import seaborn as sns
1004

1005
                hue = (
1006
                    pred[var_ID]
1007
                    .reset_index()[[x1_name, x2_name]]
1008
                    .apply(lambda row: f"({row[x1_name]}, {row[x2_name]})", axis=1)
1009
                )
1010
                hue.name = f"{x1_name}, {x2_name}"
1011

1012
                sns.lineplot(
1013
                    data=pred[var_ID],
1014
                    x="time",
1015
                    y=param,
1016
                    ax=ax,
1017
                    hue=hue.values,
1018
                )
1019
                # set legend title
1020
                ax.legend(title=hue.name, loc="best")
1021

1022
                # rotate date times
1023
                ax.set_xticklabels(
1024
                    ax.get_xticklabels(),
1025
                    rotation=45,
1026
                    horizontalalignment="right",
1027
                )
1028

1029
            ax.set_title(f"{var_ID} {param}")
1030

1031
            if extent is not None:
1032
                ax.set_extent(extent, crs=crs)
1033

1034
    plt.subplots_adjust(wspace=0.3)
1035
    return fig
1036

1037

1038
def extent_str_to_tuple(extent: str) -> Tuple[float, float, float, float]:
2✔
1039
    """
1040
    Convert extent string to (lon_min, lon_max, lat_min, lat_max) tuple.
1041

1042
    Args:
1043
        extent: str
1044
            String of region name. Options are: "global", "usa", "uk", "europe".
1045

1046
    Returns:
1047
        tuple
1048
            Tuple of (lon_min, lon_max, lat_min, lat_max).
1049
    """
1050
    if extent == "global":
×
1051
        return (-180, 180, -90, 90)
×
1052
    elif extent == "north_america":
×
1053
        return (-160, -60, 15, 75)
×
1054
    elif extent == "uk":
×
1055
        return (-12, 3, 50, 60)
×
1056
    elif extent == "europe":
×
1057
        return (-15, 40, 35, 70)
×
1058
    elif extent == "germany":
×
1059
        return (5, 15, 47, 55)
×
1060
    else:
1061
        raise ValueError(
×
1062
            f"Region {extent} not in supported list of regions with default bounds. "
1063
            f"Options are: 'global', 'north_america', 'uk', 'europe'."
1064
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc