• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 19842460617

08 Oct 2025 10:03AM UTC coverage: 81.663%. Remained the same
19842460617

push

github

web-flow
Update README.md, adding reference to GIANT project

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

60.85
/deepsensor/plot.py
1
import numpy as np
2✔
2

3
import matplotlib.pyplot as plt
2✔
4
import pandas as pd
2✔
5
from mpl_toolkits.axes_grid1 import make_axes_locatable
2✔
6
import matplotlib.patches as mpatches
2✔
7

8
import lab as B
2✔
9

10
from typing import Optional, Union, List, Tuple
2✔
11

12
from deepsensor.data.task import Task, flatten_X
2✔
13
from deepsensor.data.loader import TaskLoader
2✔
14
from deepsensor.data.processor import DataProcessor
2✔
15
from deepsensor.model.pred import Prediction
2✔
16
from pandas import DataFrame
2✔
17
from matplotlib.colors import Colormap
2✔
18
from matplotlib.axes import Axes
2✔
19

20

21
def task(
2✔
22
    task: Task,
23
    task_loader: TaskLoader,
24
    figsize=3,
25
    markersize=None,
26
    equal_aspect=False,
27
    plot_ticks=False,
28
    extent=None,
29
) -> plt.Figure:
30
    """Plot the context and target sets of a task.
31

32
    Args:
33
        task (:class:`~.data.task.Task`):
34
            Task to plot.
35
        task_loader (:class:`~.data.loader.TaskLoader`):
36
            Task loader used to load ``task``, containing variable IDs used for
37
            plotting.
38
        figsize (int, optional):
39
            Figure size in inches, by default 3.
40
        markersize (int, optional):
41
            Marker size (in units of points squared), by default None. If None,
42
            the marker size is set to ``(2**2) * figsize / 3``.
43
        equal_aspect (bool, optional):
44
            Whether to set the aspect ratio of the plots to be equal, by
45
            default False.
46
        plot_ticks (bool, optional):
47
            Whether to plot the coordinate ticks on the axes, by default False.
48
        extent (Tuple[int, int, int, int], optional):
49
            Extent of the plot in format (x2_min, x2_max, x1_min, x1_max).
50
            Defaults to None (uses the smallest extent that contains all data points
51
            across all context and target sets).
52

53
    Returns:
54
        :class:`matplotlib:matplotlib.figure.Figure`:
55
    """
56
    if markersize is None:
×
57
        markersize = (2**2) * figsize / 3
×
58

59
    # Scale font size with figure size
60
    fontsize = 10 * figsize / 3
×
61
    params = {
×
62
        "axes.labelsize": fontsize,
63
        "axes.titlesize": fontsize,
64
        "font.size": fontsize,
65
        "figure.titlesize": fontsize,
66
        "legend.fontsize": fontsize,
67
        "xtick.labelsize": fontsize,
68
        "ytick.labelsize": fontsize,
69
    }
70

71
    var_IDs = task_loader.context_var_IDs + task_loader.target_var_IDs
×
72
    Y_c = task["Y_c"]
×
73
    X_c = task["X_c"]
×
74
    if task["Y_t"] is not None:
×
75
        Y_t = task["Y_t"]
×
76
        X_t = task["X_t"]
×
77
    else:
78
        Y_t = []
×
79
        X_t = []
×
80
    n_context = len(Y_c)
×
81
    n_target = len(Y_t)
×
82
    if "Y_t_aux" in task and task["Y_t_aux"] is not None:
×
83
        # Assumes only 1 target set
84
        X_t = X_t + [task["X_t"][-1]]
×
85
        Y_t = Y_t + [task["Y_t_aux"]]
×
86
        var_IDs = var_IDs + (task_loader.aux_at_target_var_IDs,)
×
87
        ncols = n_context + n_target + 1
×
88
    else:
89
        ncols = n_context + n_target
×
90
    nrows = max([Y.shape[0] for Y in Y_c + Y_t])
×
91

92
    if extent is None:
×
93
        x1_min = np.min([np.min(X[0]) for X in X_c + X_t])
×
94
        x1_max = np.max([np.max(X[0]) for X in X_c + X_t])
×
95
        x2_min = np.min([np.min(X[1]) for X in X_c + X_t])
×
96
        x2_max = np.max([np.max(X[1]) for X in X_c + X_t])
×
97
        extent = (x2_min, x2_max, x1_min, x1_max)
×
98

99
    with plt.rc_context(params):
×
100
        fig, axes = plt.subplots(
×
101
            nrows=nrows,
102
            ncols=ncols,
103
            figsize=(ncols * figsize, nrows * figsize),
104
        )
105
        if nrows == 1:
×
106
            axes = axes[np.newaxis]
×
107
        if ncols == 1:
×
108
            axes = axes[:, np.newaxis]
×
109
        # j = loop index over columns/context sets
110
        # i = loop index over rows/variables within context sets
111
        for j, (X, Y) in enumerate(zip(X_c + X_t, Y_c + Y_t)):
×
112
            for i in range(Y.shape[0]):
×
113
                if i == 0:
×
114
                    if j < n_context:
×
115
                        axes[0, j].set_title(f"Context set {j}")
×
116
                    elif j < n_context + n_target:
×
117
                        axes[0, j].set_title(f"Target set {j - n_context}")
×
118
                    else:
119
                        axes[0, j].set_title(f"Auxiliary at targets")
×
120
                if isinstance(X, tuple):
×
121
                    X = flatten_X(X)
×
122
                    Y = Y.reshape(Y.shape[0], -1)
×
123
                axes[i, j].scatter(X[1, :], X[0, :], c=Y[i], s=markersize, marker=".")
×
124
                if equal_aspect:
×
125
                    # Don't warp aspect ratio
126
                    axes[i, j].set_aspect("equal")
×
127
                if not plot_ticks:
×
128
                    axes[i, j].set_xticks([])
×
129
                    axes[i, j].set_yticks([])
×
130
                axes[i, j].set_ylabel(var_IDs[j][i])
×
131

132
                axes[i, j].set_xlim(extent[0], extent[1])
×
133
                axes[i, j].set_ylim(extent[2], extent[3])
×
134

135
                # Add colorbar with same height as axis
136
                divider = make_axes_locatable(axes[i, j])
×
137
                box = axes[i, j].get_position()
×
138
                ratio = 0.3
×
139
                pad = 0.1
×
140
                width = box.width * ratio
×
141
                cax = divider.append_axes("right", size=width, pad=pad)
×
142
                fig.colorbar(axes[i, j].collections[0], cax=cax)
×
143

144
            for i in range(Y.shape[0], nrows):
×
145
                axes[i, j].axis("off")
×
146

147
        plt.tight_layout()
×
148

149
    return fig
×
150

151

152
def context_encoding(
2✔
153
    model,
154
    task: Task,
155
    task_loader: TaskLoader,
156
    batch_idx: int = 0,
157
    context_set_idxs: Optional[Union[List[int], int]] = None,
158
    land_idx: Optional[int] = None,
159
    cbar: bool = True,
160
    clim: Optional[Tuple] = None,
161
    cmap: Union[str, Colormap] = "viridis",
162
    verbose_titles: bool = True,
163
    titles: Optional[dict] = None,
164
    size: int = 3,
165
    return_axes: bool = False,
166
):
167
    """Plot the ``ConvNP`` SetConv encoding of a context set in a task.
168

169
    Args:
170
        model (:class:`~.model.convnp.ConvNP`):
171
            ConvNP model.
172
        task (:class:`~.data.task.Task`):
173
            Task containing context set to plot encoding of ...
174
        task_loader (:class:`~.data.loader.TaskLoader`):
175
            DataLoader used to load the data, containing context set metadata
176
            used for plotting.
177
        batch_idx (int, optional):
178
            Batch index in encoding to plot, by default 0.
179
        context_set_idxs (List[int] | int, optional):
180
            Indices of context sets to plot, by default None (plots all context
181
            sets).
182
        land_idx (int, optional):
183
            Index of the land mask in the encoding (used to overlay land
184
            contour on plots), by default None.
185
        cbar (bool, optional):
186
            Whether to add a colorbar to the plots, by default True.
187
        clim (tuple, optional):
188
            Colorbar limits, by default None.
189
        cmap (str | matplotlib.colors.Colormap, optional):
190
            Color map to use for the plots, by default "viridis".
191
        verbose_titles (bool, optional):
192
            Whether to include verbose titles for the variable IDs in the
193
            context set (including the time index), by default True.
194
        titles (dict, optional):
195
            Dict of titles to override for each subplot, by default None. If
196
            None, titles are generated from context set metadata.
197
        size (int, optional):
198
            Size of the figure in inches, by default 3.
199
        return_axes (bool, optional):
200
            Whether to return the axes of the figure, by default False.
201

202
    Returns:
203
        :obj:`matplotlib.figure.Figure` | Tuple[:obj:`matplotlib.figure.Figure`, :obj:`matplotlib.pyplot.Axes`]:
204
            Either a figure containing the context set encoding plots, or a
205
            tuple containing the :obj:`figure <matplotlib.figure.Figure>` and
206
            the :obj:`axes <matplotlib.axes.Axes>` of the figure (if
207
            ``return_axes`` was set to ``True``).
208
    """
209
    from .model.nps import compute_encoding_tensor
2✔
210

211
    encoding_tensor = compute_encoding_tensor(model, task)
2✔
212
    encoding_tensor = encoding_tensor[batch_idx]
2✔
213

214
    if isinstance(context_set_idxs, int):
2✔
215
        context_set_idxs = [context_set_idxs]
×
216
    if context_set_idxs is None:
2✔
217
        context_set_idxs = np.array(range(len(task_loader.context_dims)))
2✔
218

219
    context_var_ID_set_sizes = [
2✔
220
        ndim + 1 for ndim in np.array(task_loader.context_dims)[context_set_idxs]
221
    ]  # Add density channel to each set size
222
    max_context_set_size = max(context_var_ID_set_sizes)
2✔
223
    ncols = max_context_set_size
2✔
224
    nrows = len(context_set_idxs)
2✔
225

226
    figsize = (ncols * size, nrows * size)
2✔
227
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
2✔
228
    if nrows == 1:
2✔
229
        axes = axes[np.newaxis]
2✔
230

231
    ctx_channel_idxs = np.cumsum(np.array(task_loader.context_dims) + 1)
2✔
232

233
    for row_i, ctx_i in enumerate(context_set_idxs):
2✔
234
        channel_i = (
2✔
235
            ctx_channel_idxs[ctx_i - 1] if ctx_i > 0 else 0
236
        )  # Starting channel index
237
        if verbose_titles:
2✔
238
            var_IDs = task_loader.context_var_IDs_and_delta_t[ctx_i]
2✔
239
        else:
240
            var_IDs = task_loader.context_var_IDs[ctx_i]
×
241

242
        ncols_row_i = task_loader.context_dims[ctx_i] + 1  # Add density channel
2✔
243
        for col_i in range(ncols_row_i):
2✔
244
            ax = axes[row_i, col_i]
2✔
245
            # Need `origin="lower"` because encoding has `x1` increasing from top to bottom,
246
            # whereas in visualisations we want `x1` increasing from bottom to top.
247

248
            im = ax.imshow(
2✔
249
                encoding_tensor[channel_i],
250
                origin="lower",
251
                clim=clim,
252
                cmap=cmap,
253
            )
254
            if titles is not None:
2✔
255
                ax.set_title(titles[channel_i])
×
256
            elif col_i == 0:
2✔
257
                ax.set_title(f"Density {ctx_i}")
2✔
258
            elif col_i > 0:
2✔
259
                ax.set_title(f"{var_IDs[col_i - 1]}")
2✔
260
            if col_i == 0:
2✔
261
                ax.set_ylabel(f"Context set {ctx_i}")
2✔
262
            if cbar:
2✔
263
                divider = make_axes_locatable(ax)
2✔
264
                cax = divider.append_axes("right", size="5%", pad=0.05)
2✔
265
                plt.colorbar(im, cax)
2✔
266
            ax.patch.set_edgecolor("black")
2✔
267
            ax.patch.set_linewidth(1)
2✔
268
            if land_idx is not None:
2✔
269
                ax.contour(
×
270
                    encoding_tensor[land_idx],
271
                    colors="k",
272
                    levels=[0.5],
273
                    origin="lower",
274
                )
275
            ax.tick_params(
2✔
276
                which="both",
277
                bottom=False,
278
                left=False,
279
                labelbottom=False,
280
                labelleft=False,
281
            )
282
            channel_i += 1
2✔
283
        for col_i in range(ncols_row_i, ncols):
2✔
284
            # Hide unused axes
285
            ax = axes[ctx_i, col_i]
×
286
            ax.axis("off")
×
287

288
    plt.tight_layout()
2✔
289
    if not return_axes:
2✔
290
        return fig
2✔
291
    elif return_axes:
×
292
        return fig, axes
×
293

294

295
def offgrid_context(
2✔
296
    axes: Union[np.ndarray, List[plt.Axes], Tuple[plt.Axes]],
297
    task: Task,
298
    data_processor: Optional[DataProcessor] = None,
299
    task_loader: Optional[TaskLoader] = None,
300
    plot_target: bool = False,
301
    add_legend: bool = True,
302
    context_set_idxs: Optional[Union[List[int], int]] = None,
303
    markers: Optional[str] = None,
304
    colors: Optional[str] = None,
305
    **scatter_kwargs,
306
) -> None:
307
    """Plot the off-grid context points on ``axes``.
308

309
    Uses a provided :class:`~.data.processor.DataProcessor` to unnormalise the
310
    context coordinates if provided.
311

312
    Args:
313
        axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]):
314
            Axes to plot on.
315
        task (:class:`~.data.task.Task`):
316
            Task containing the context set to plot.
317
        data_processor (:class:`~.data.processor.DataProcessor`, optional):
318
            Data processor used to unnormalise the context set, by default
319
            None.
320
        task_loader (:class:`~.data.loader.TaskLoader`, optional):
321
            Task loader used to load the data, containing context set metadata
322
            used for plotting, by default None.
323
        plot_target (bool, optional):
324
            Whether to plot the target set, by default False.
325
        add_legend (bool, optional):
326
            Whether to add a legend to the plot, by default True.
327
        context_set_idxs (List[int] | int, optional):
328
            Indices of context sets to plot, by default None (plots all context
329
            sets).
330
        markers (str, optional):
331
            Marker styles to use for each context set, by default None.
332
        colors (str, optional):
333
            Colors to use for each context set, by default None.
334
        scatter_kwargs:
335
            Additional keyword arguments to pass to the scatter plot.
336

337
    Returns:
338
        None
339
    """
340
    if markers is None:
2✔
341
        # all matplotlib markers
342
        markers = "ovs^Dxv<>1234spP*hHd|_"
2✔
343
    if colors is None:
2✔
344
        # all one-letter matplotlib colors
345
        colors = "kbrgy" * 10
2✔
346

347
    if isinstance(context_set_idxs, int):
2✔
348
        context_set_idxs = [context_set_idxs]
×
349

350
    if type(axes) is np.ndarray:
2✔
351
        axes = axes.ravel()
×
352
    elif not isinstance(axes, (list, tuple)):
2✔
353
        axes = [axes]
2✔
354

355
    if plot_target:
2✔
356
        X = [*task["X_c"], *task["X_t"]]
×
357
    else:
358
        X = task["X_c"]
2✔
359

360
    for set_i, X in enumerate(X):
2✔
361
        if context_set_idxs is not None and set_i not in context_set_idxs:
2✔
362
            continue
×
363

364
        if isinstance(X, tuple):
2✔
365
            continue  # Don't plot gridded context data locations
×
366
        if X.ndim == 3:
2✔
367
            X = X[0]  # select first batch
×
368

369
        if data_processor is not None:
2✔
370
            x1, x2 = data_processor.map_x1_and_x2(X[0], X[1], unnorm=True)
2✔
371
            X = np.stack([x1, x2], axis=0)
2✔
372

373
        X = X[::-1]  # flip 2D coords for Cartesian fmt
2✔
374

375
        label = ""
2✔
376
        if plot_target and set_i < len(task["X_c"]):
2✔
377
            label += f"Context set {set_i} "
×
378
            if task_loader is not None:
×
379
                label += f"({task_loader.context_var_IDs[set_i]})"
×
380
        elif plot_target and set_i >= len(task["X_c"]):
2✔
381
            label += f"Target set {set_i - len(task['X_c'])} "
×
382
            if task_loader is not None:
×
383
                label += f"({task_loader.target_var_IDs[set_i - len(task['X_c'])]})"
×
384

385
        for ax in axes:
2✔
386
            ax.scatter(
2✔
387
                *X,
388
                marker=markers[set_i],
389
                color=colors[set_i],
390
                **scatter_kwargs,
391
                facecolors=None if markers[set_i] == "x" else "none",
392
                label=label,
393
            )
394

395
    if add_legend:
2✔
396
        axes[0].legend(loc="best")
2✔
397

398

399
def offgrid_context_observations(
2✔
400
    axes: Union[np.ndarray, List[plt.Axes], Tuple[plt.Axes]],
401
    task: Task,
402
    data_processor: DataProcessor,
403
    task_loader: TaskLoader,
404
    context_set_idx: int,
405
    format_str: Optional[str] = None,
406
    extent: Optional[Tuple[int, int, int, int]] = None,
407
    color: str = "black",
408
) -> None:
409
    """Plot unnormalised context observation values.
410

411
    Args:
412
        axes (:class:`numpy:numpy.ndarray` | List[:class:`matplotlib:matplotlib.axes.Axes`] | Tuple[:class:`matplotlib:matplotlib.axes.Axes`]):
413
            Axes to plot on.
414
        task (:class:`~.data.task.Task`):
415
            Task containing the context set to plot.
416
        data_processor (:class:`~.data.processor.DataProcessor`):
417
            Data processor used to unnormalise the context set.
418
        task_loader (:class:`~.data.loader.TaskLoader`):
419
            Task loader used to load the data, containing context set metadata
420
            used for plotting.
421
        context_set_idx (int):
422
            Index of the context set to plot.
423
        format_str (str, optional):
424
            Format string for the context observation values. By default
425
            ``"{:.2f}"``.
426
        extent (Tuple[int, int, int, int], optional):
427
            Extent of the plot, by default None.
428
        color (str, optional):
429
            Color of the text, by default "black".
430

431
    Returns:
432
        None.
433

434
    Raises:
435
        AssertionError:
436
            If the context set is gridded.
437
        AssertionError:
438
            If the context set is not 1D.
439
        AssertionError:
440
            If the task's "Y_c" value for the context set ID is not 2D.
441
        AssertionError:
442
            If the task's "Y_c" value for the context set ID does not have
443
            exactly one variable.
444
    """
445
    if type(axes) is np.ndarray:
2✔
446
        axes = axes.ravel()
×
447
    elif not isinstance(axes, (list, tuple)):
2✔
448
        axes = [axes]
2✔
449

450
    if format_str is None:
2✔
451
        format_str = "{:.2f}"
2✔
452

453
    var_ID = task_loader.context_var_IDs[
2✔
454
        context_set_idx
455
    ]  # Tuple of variable IDs for the context set
456
    assert (
2✔
457
        len(var_ID) == 1
458
    ), "Plotting context observations only supported for single-variable (1D) context sets"
459
    var_ID = var_ID[0]
2✔
460

461
    X_c = task["X_c"][context_set_idx]
2✔
462
    assert not isinstance(
2✔
463
        X_c, tuple
464
    ), f"The context set must not be gridded but is of type {type(X_c)} for context set at index {context_set_idx}"
465
    X_c = data_processor.map_coord_array(X_c, unnorm=True)
2✔
466

467
    Y_c = task["Y_c"][context_set_idx]
2✔
468
    assert Y_c.ndim == 2
2✔
469
    assert Y_c.shape[0] == 1
2✔
470
    Y_c = data_processor.map_array(Y_c, var_ID, unnorm=True).ravel()
2✔
471

472
    for x_c, y_c in zip(X_c.T, Y_c):
2✔
473
        if extent is not None:
2✔
474
            if not (
×
475
                extent[0] <= x_c[0] <= extent[1] and extent[2] <= x_c[1] <= extent[3]
476
            ):
477
                continue
×
478
        for ax in axes:
2✔
479
            ax.text(*x_c[::-1], format_str.format(float(y_c)), color=color)
2✔
480

481

482
def receptive_field(
483
    receptive_field,
484
    data_processor: DataProcessor,
485
    crs,
486
    extent: Union[str, Tuple[float, float, float, float]] = "global",
487
) -> plt.Figure:  # pragma: no cover
488
    """...
489

490
    Args:
491
        receptive_field (...):
492
            Receptive field to plot.
493
        data_processor (:class:`~.data.processor.DataProcessor`):
494
            Data processor used to unnormalise the context set.
495
        crs (cartopy CRS):
496
            Coordinate reference system for the plots.
497
        extent (str | Tuple[float, float, float, float], optional):
498
            Extent of the plot, in format (x2_min, x2_max, x1_min, x1_max), e.g. in
499
            lat-lon format (lon_min, lon_max, lat_min, lat_max). By default "global".
500

501
    Returns:
502
        None.
503
    """
504
    fig, ax = plt.subplots(subplot_kw=dict(projection=crs))
505

506
    if isinstance(extent, str):
507
        extent = extent_str_to_tuple(extent)
508
    else:
509
        extent = tuple([float(x) for x in extent])
510
    x2_min, x2_max, x1_min, x1_max = extent
511
    ax.set_extent(extent, crs=crs)
512

513
    x11, x12 = data_processor.config["coords"]["x1"]["map"]
514
    x21, x22 = data_processor.config["coords"]["x2"]["map"]
515

516
    x1_rf_raw = receptive_field * (x12 - x11)
517
    x2_rf_raw = receptive_field * (x22 - x21)
518

519
    x1_midpoint_raw = (x1_max + x1_min) / 2
520
    x2_midpoint_raw = (x2_max + x2_min) / 2
521

522
    # Compute bottom left corner of receptive field
523
    x1_corner = x1_midpoint_raw - x1_rf_raw / 2
524
    x2_corner = x2_midpoint_raw - x2_rf_raw / 2
525

526
    ax.add_patch(
527
        mpatches.Rectangle(
528
            xy=[x2_corner, x1_corner],  # Cartesian fmt: x2, x1
529
            width=x2_rf_raw,
530
            height=x1_rf_raw,
531
            facecolor="black",
532
            alpha=0.3,
533
            transform=crs,
534
        )
535
    )
536
    ax.coastlines()
537
    ax.gridlines(draw_labels=True, alpha=0.2)
538

539
    x1_name = data_processor.config["coords"]["x1"]["name"]
540
    x2_name = data_processor.config["coords"]["x2"]["name"]
541
    ax.set_title(
542
        f"Receptive field in raw coords: {x1_name}={x1_rf_raw:.2f}, "
543
        f"{x2_name}={x2_rf_raw:.2f}"
544
    )
545

546
    return fig
547

548

549
def feature_maps(
2✔
550
    model,
551
    task: Task,
552
    n_features_per_layer: int = 1,
553
    seed: Optional[int] = None,
554
    figsize: int = 3,
555
    add_colorbar: bool = False,
556
    cmap: Union[str, Colormap] = "Greys",
557
) -> plt.Figure:
558
    """Plot the feature maps of a ``ConvNP`` model's decoder layers after a
559
    forward pass with a ``Task``.
560

561
    Args:
562
        model (:class:`~.model.model.convnp.ConvNP`):
563
            ...
564
        task (:class:`~.data.task.Task`):
565
            ...
566
        n_features_per_layer (int, optional):
567
            ..., by default 1.
568
        seed (int, optional):
569
            ..., by default None.
570
        figsize (int, optional):
571
            ..., by default 3.
572
        add_colorbar (bool, optional):
573
            ..., by default False.
574
        cmap (str | matplotlib.colors.Colormap, optional):
575
            ..., by default "Greys".
576

577
    Returns:
578
        matplotlib.figure.Figure:
579
            A figure containing the feature maps.
580

581
    Raises:
582
        ValueError:
583
            If the backend is not recognised.
584
    """
585
    from .model.nps import compute_encoding_tensor
2✔
586

587
    import deepsensor
2✔
588

589
    # Hacky way to load the correct __init__.py to get `convert_to_tensor` method
590
    if deepsensor.backend.str == "tf":
2✔
591
        import deepsensor.tensorflow as deepsensor
×
592
    elif deepsensor.backend.str == "torch":
2✔
593
        import deepsensor.torch as deepsensor
2✔
594
    else:
595
        raise ValueError(f"Unknown backend: {deepsensor.backend.str}")
×
596

597
    unet = model.model.decoder[0]
2✔
598

599
    # Produce encoding
600
    x = deepsensor.convert_to_tensor(compute_encoding_tensor(model, task))
2✔
601

602
    # Manually construct the U-Net forward pass from
603
    # `neuralprocesses.construct_convgnp` to get the feature maps
604
    def unet_forward(unet, x):
2✔
605
        feature_maps = []
2✔
606

607
        h = unet.activations[0](unet.before_turn_layers[0](x))
2✔
608
        hs = [h]
2✔
609
        feature_map = B.to_numpy(h)
2✔
610
        feature_maps.append(feature_map)
2✔
611
        for layer, activation in zip(
2✔
612
            unet.before_turn_layers[1:],
613
            unet.activations[1:],
614
        ):
615
            h = activation(layer(hs[-1]))
2✔
616
            hs.append(h)
2✔
617
            feature_map = B.to_numpy(h)
2✔
618
            feature_maps.append(feature_map)
2✔
619

620
        # Now make the turn!
621

622
        h = unet.activations[-1](unet.after_turn_layers[-1](hs[-1]))
2✔
623
        feature_map = B.to_numpy(h)
2✔
624
        feature_maps.append(feature_map)
2✔
625
        for h_prev, layer, activation in zip(
2✔
626
            reversed(hs[:-1]),
627
            reversed(unet.after_turn_layers[:-1]),
628
            reversed(unet.activations[:-1]),
629
        ):
630
            h = activation(layer(B.concat(h_prev, h, axis=1)))
2✔
631
            feature_map = B.to_numpy(h)
2✔
632
            feature_maps.append(feature_map)
2✔
633

634
        h = unet.final_linear(h)
2✔
635
        feature_map = B.to_numpy(h)
2✔
636
        feature_maps.append(feature_map)
2✔
637

638
        return feature_maps
2✔
639

640
    feature_maps = unet_forward(unet, x)
2✔
641

642
    figs = []
2✔
643
    rng = np.random.default_rng(seed)
2✔
644
    for layer_i, feature_map in enumerate(feature_maps):
2✔
645
        n_features = feature_map.shape[1]
2✔
646
        n_features_to_plot = min(n_features_per_layer, n_features)
2✔
647
        feature_idxs = rng.choice(n_features, n_features_to_plot, replace=False)
2✔
648

649
        fig, axes = plt.subplots(
2✔
650
            nrows=1,
651
            ncols=n_features_to_plot,
652
            figsize=(figsize * n_features_to_plot, figsize),
653
        )
654
        if n_features_to_plot == 1:
2✔
655
            axes = [axes]
2✔
656
        for f_i, ax in zip(feature_idxs, axes):
2✔
657
            fm = feature_map[0, f_i]
2✔
658
            im = ax.imshow(fm, origin="lower", cmap=cmap)
2✔
659
            ax.set_title(f"Feature {f_i}", fontsize=figsize * 15 / 4)
2✔
660
            ax.tick_params(
2✔
661
                which="both",
662
                bottom=False,
663
                left=False,
664
                labelbottom=False,
665
                labelleft=False,
666
            )
667
            if add_colorbar:
2✔
668
                cbar = ax.figure.colorbar(im, ax=ax, format="%.2f")
×
669

670
        fig.suptitle(
2✔
671
            f"Layer {layer_i} feature map. Shape: {feature_map.shape}. Min={np.min(feature_map):.2f}, Max={np.max(feature_map):.2f}.",
672
            fontsize=figsize * 15 / 4,
673
        )
674
        plt.tight_layout()
2✔
675
        plt.subplots_adjust(top=0.75)
2✔
676
        figs.append(fig)
2✔
677

678
    return figs
2✔
679

680

681
def placements(
682
    task: Task,
683
    X_new_df: DataFrame,
684
    data_processor: DataProcessor,
685
    crs,
686
    extent: Optional[Union[Tuple[int, int, int, int], str]] = None,
687
    figsize: int = 3,
688
    **scatter_kwargs,
689
) -> plt.Figure:  # pragma: no cover
690
    """...
691

692
    Args:
693
        task (:class:`~.data.task.Task`):
694
            Task containing the context set used to compute the acquisition
695
            function.
696
        X_new_df (:class:`pandas.DataFrame`):
697
            Dataframe containing the placement locations.
698
        data_processor (:class:`~.data.processor.DataProcessor`):
699
            Data processor used to unnormalise the context set and placement
700
            locations.
701
        crs (cartopy CRS):
702
            Coordinate reference system for the plots.
703
        extent (Tuple[int, int, int, int] | str, optional):
704
            Extent of the plots, by default None.
705
        figsize (int, optional):
706
            Figure size in inches, by default 3.
707

708
    Returns:
709
        :class:`matplotlib:matplotlib.figure.Figure`
710
            A figure containing the placement plots.
711
    """
712
    fig, ax = plt.subplots(subplot_kw={"projection": crs}, figsize=(figsize, figsize))
713
    ax.scatter(*X_new_df.values.T[::-1], c="r", linewidths=0.5, **scatter_kwargs)
714
    offgrid_context(ax, task, data_processor, linewidths=0.5, **scatter_kwargs)
715

716
    ax.coastlines()
717
    if extent is None:
718
        pass
719
    elif extent == "global":
720
        ax.set_global()
721
    else:
722
        ax.set_extent(extent, crs=crs)
723

724
    return fig
725

726

727
def acquisition_fn(
728
    task: Task,
729
    acquisition_fn_ds: np.ndarray,
730
    X_new_df: DataFrame,
731
    data_processor: DataProcessor,
732
    crs,
733
    col_dim: str = "iteration",
734
    cmap: Union[str, Colormap] = "Greys_r",
735
    figsize: int = 3,
736
    add_colorbar: bool = True,
737
    max_ncol: int = 5,
738
) -> plt.Figure:  # pragma: no cover
739
    """Args:
740
        task (:class:`~.data.task.Task`):
741
            Task containing the context set used to compute the acquisition
742
            function.
743
        acquisition_fn_ds (:class:`numpy:numpy.ndarray`):
744
            Acquisition function dataset.
745
        X_new_df (:class:`pandas.DataFrame`):
746
            Dataframe containing the placement locations.
747
        data_processor (:class:`~.data.processor.DataProcessor`):
748
            Data processor used to unnormalise the context set and placement
749
            locations.
750
        crs (cartopy CRS):
751
            Coordinate reference system for the plots.
752
        col_dim (str, optional):
753
            Column dimension to plot over, by default "iteration".
754
        cmap (str | matplotlib.colors.Colormap, optional):
755
            Color map to use for the plots, by default "Greys_r".
756
        figsize (int, optional):
757
            Figure size in inches, by default 3.
758
        add_colorbar (bool, optional):
759
            Whether to add a colorbar to the plots, by default True.
760
        max_ncol (int, optional):
761
            Maximum number of columns to use for the plots, by default 5.
762

763
    Returns:
764
        matplotlib.pyplot.Figure
765
            A figure containing the acquisition function plots.
766

767
    Raises:
768
        ValueError:
769
            If a column dimension is encountered that is not one of
770
            ``["time", "sample"]``.
771
        AssertionError:
772
            If the number of columns in the acquisition function dataset is
773
            greater than ``max_ncol``.
774
    """
775
    # Remove spatial dims using data_processor.raw_spatial_coords_names
776
    plot_dims = [col_dim, *data_processor.raw_spatial_coord_names]
777
    non_plot_dims = [dim for dim in acquisition_fn_ds.dims if dim not in plot_dims]
778
    valid_avg_dims = ["time", "sample"]
779
    for dim in non_plot_dims:
780
        if dim not in valid_avg_dims:
781
            raise ValueError(
782
                f"Cannot average over dim {dim} for plotting. Must be one of "
783
                f"{valid_avg_dims}. Select a single value for {dim} using "
784
                f"`acquisition_fn_ds.sel({dim}=...)`."
785
            )
786
    if len(non_plot_dims) > 0:
787
        # Average over non-plot dims
788
        print(
789
            "Averaging acquisition function over dims for plotting: " f"{non_plot_dims}"
790
        )
791
        acquisition_fn_ds = acquisition_fn_ds.mean(dim=non_plot_dims)
792

793
    col_vals = acquisition_fn_ds[col_dim].values
794
    if col_vals.size == 1:
795
        n_col_vals = 1
796
    else:
797
        n_col_vals = len(col_vals)
798
    ncols = np.min([max_ncol, n_col_vals])
799

800
    if n_col_vals > ncols:
801
        nrows = int(np.ceil(n_col_vals / ncols))
802
    else:
803
        nrows = 1
804

805
    fig, axes = plt.subplots(
806
        subplot_kw={"projection": crs},
807
        ncols=ncols,
808
        nrows=nrows,
809
        figsize=(figsize * ncols, figsize * nrows),
810
    )
811
    if nrows == 1 and ncols == 1:
812
        axes = [axes]
813
    else:
814
        axes = axes.ravel()
815
    if add_colorbar:
816
        min, max = acquisition_fn_ds.min(), acquisition_fn_ds.max()
817
    else:
818
        # Use different colour scales for each plot
819
        min, max = None, None
820
    for i, col_val in enumerate(col_vals):
821
        ax = axes[i]
822
        if i == len(col_vals) - 1:
823
            final_axis = True
824
        else:
825
            final_axis = False
826
        acquisition_fn_ds.sel(**{col_dim: col_val}).plot(
827
            ax=ax, cmap=cmap, vmin=min, vmax=max, add_colorbar=False
828
        )
829
        if add_colorbar and final_axis:
830
            im = ax.get_children()[0]
831
            label = acquisition_fn_ds.name
832
            cax = plt.axes([0.93, 0.035, 0.02, 0.91])  # add a small custom axis
833
            cbar = plt.colorbar(
834
                im, cax=cax, label=label
835
            )  # specify axis for colorbar to occupy with cax
836
        ax.set_title(f"{col_dim}={col_val}")
837
        ax.coastlines()
838
        if col_dim == "iteration":
839
            X_new_df_plot = X_new_df.loc[slice(0, col_val)].values.T[::-1]
840
        else:
841
            # Assumed plotting single iteration
842
            iter = acquisition_fn_ds.iteration.values
843
            assert iter.size == 1, "Expected single iteration"
844
            X_new_df_plot = X_new_df.loc[slice(0, iter.item())].values.T[::-1]
845
        ax.scatter(
846
            *X_new_df_plot,
847
            c="r",
848
            linewidths=0.5,
849
        )
850

851
    offgrid_context(axes, task, data_processor, linewidths=0.5)
852

853
    # Remove any unused axes
854
    for ax in axes[len(col_vals) :]:
855
        ax.remove()
856

857
    return fig
858

859

860
def prediction(
861
    pred: Prediction,
862
    date: Optional[Union[str, pd.Timestamp]] = None,
863
    data_processor: Optional[DataProcessor] = None,
864
    task_loader: Optional[TaskLoader] = None,
865
    task: Optional[Task] = None,
866
    prediction_parameters: Union[List[str], str] = "all",
867
    crs=None,
868
    colorbar: bool = True,
869
    cmap: str = "viridis",
870
    size: int = 5,
871
    extent: Optional[Union[Tuple[float, float, float, float], str]] = None,
872
) -> plt.Figure:  # pragma: no cover
873
    """Plot the mean and standard deviation of a prediction.
874

875
    Args:
876
        pred (:class:`~.model.prediction.Prediction`):
877
            Prediction to plot.
878
        date (str | :class:`pandas:pandas.Timestamp`):
879
            Date of the prediction.
880
        data_processor (:class:`~.data.processor.DataProcessor`):
881
            Data processor used to unnormalise the context set.
882
        task_loader (:class:`~.data.loader.TaskLoader`):
883
            Task loader used to load the data, containing context set metadata
884
            used for plotting.
885
        task (:class:`~.data.task.Task`, optional):
886
            Task containing the context data to overlay.
887
        prediction_parameters (List[str] | str, optional):
888
            Prediction parameters to plot, by default "all".
889
        crs (cartopy CRS, optional):
890
            Coordinate reference system for the plots, by default None.
891
        colorbar (bool, optional):
892
            Whether to add a colorbar to the plots, by default True.
893
        cmap (str):
894
            Colormap to use for the plots. By default "viridis".
895
        size (int, optional):
896
            Size of the figure in inches per axis, by default 5.
897
        extent: (tuple | str, optional):
898
            Tuple of (lon_min, lon_max, lat_min, lat_max) or string of region name.
899
            Options are: "global", "usa", "uk", "europe". Defaults to None (no
900
            setting of extent).
901
        c
902
    """
903
    if pred.mode == "off-grid":
904
        assert date is None, "Cannot pass a `date` for off-grid predictions"
905
        assert (
906
            data_processor is None
907
        ), "Cannot pass a `data_processor` for off-grid predictions"
908
        assert (
909
            task_loader is None
910
        ), "Cannot pass a `task_loader` for off-grid predictions"
911
        assert task is None, "Cannot pass a `task` for off-grid predictions"
912
        assert crs is None, "Cannot pass a `crs` for off-grid predictions"
913

914
    x1_name = pred.x1_name
915
    x2_name = pred.x2_name
916

917
    if prediction_parameters == "all":
918
        prediction_parameters = {
919
            var_ID: [param for param in pred[var_ID]] for var_ID in pred
920
        }
921
    else:
922
        prediction_parameters = {var_ID: prediction_parameters for var_ID in pred}
923

924
    n_vars = len(pred.target_var_IDs)
925
    n_params = max(len(params) for params in prediction_parameters.values())
926

927
    if isinstance(extent, str):
928
        extent = extent_str_to_tuple(extent)
929
    elif isinstance(extent, tuple):
930
        extent = tuple([float(x) for x in extent])
931

932
    fig, axes = plt.subplots(
933
        n_vars,
934
        n_params,
935
        figsize=(size * n_params, size * n_vars),
936
        subplot_kw=dict(projection=crs),
937
    )
938
    axes = np.array(axes)
939
    if n_vars == 1:
940
        axes = np.expand_dims(axes, axis=0)
941
    if n_params == 1:
942
        axes = np.expand_dims(axes, axis=1)
943
    for row_i, var_ID in enumerate(pred.target_var_IDs):
944
        for col_i, param in enumerate(prediction_parameters[var_ID]):
945
            ax = axes[row_i, col_i]
946

947
            if pred.mode == "on-grid":
948
                if "init_time" in pred[0].indexes:
949
                    raise ValueError("Plotting forecasts not currently supported.")
950
                if param == "std":
951
                    vmin = 0
952
                else:
953
                    vmin = None
954
                pred[var_ID][param].sel(time=date).plot(
955
                    ax=ax,
956
                    cmap=cmap,
957
                    vmin=vmin,
958
                    add_colorbar=False,
959
                    center=False,
960
                )
961
                # ax.set_aspect("auto")
962
                if colorbar:
963
                    im = ax.get_children()[0]
964
                    # add axis to right
965
                    cax = fig.add_axes(
966
                        [
967
                            ax.get_position().x1 + 0.01,
968
                            ax.get_position().y0,
969
                            0.02,
970
                            ax.get_position().height,
971
                        ]
972
                    )
973
                    cbar = plt.colorbar(
974
                        im, cax=cax
975
                    )  # specify axis for colorbar to occupy with cax
976
                if task is not None:
977
                    offgrid_context(
978
                        ax,
979
                        task,
980
                        data_processor,
981
                        task_loader,
982
                        linewidths=0.5,
983
                        add_legend=False,
984
                    )
985
                if crs is not None:
986
                    da = pred[var_ID][param]
987
                    ax.coastlines()
988
                    import cartopy.feature as cfeature
989

990
                    ax.add_feature(cfeature.BORDERS)
991
                    # ax.set_extent(
992
                    #     [da["lon"].min(), da["lon"].max(), da["lat"].min(), da["lat"].max()]
993
                    # )
994

995
            elif pred.mode == "off-grid":
996
                if "init_time" in pred[0].index.names:
997
                    raise ValueError("Plotting forecasts not currently supported.")
998
                import seaborn as sns
999

1000
                hue = (
1001
                    pred[var_ID]
1002
                    .reset_index()[[x1_name, x2_name]]
1003
                    .apply(lambda row: f"({row[x1_name]}, {row[x2_name]})", axis=1)
1004
                )
1005
                hue.name = f"{x1_name}, {x2_name}"
1006

1007
                sns.lineplot(
1008
                    data=pred[var_ID],
1009
                    x="time",
1010
                    y=param,
1011
                    ax=ax,
1012
                    hue=hue.values,
1013
                )
1014
                # set legend title
1015
                ax.legend(title=hue.name, loc="best")
1016

1017
                # rotate date times
1018
                ax.set_xticklabels(
1019
                    ax.get_xticklabels(),
1020
                    rotation=45,
1021
                    horizontalalignment="right",
1022
                )
1023

1024
            ax.set_title(f"{var_ID} {param}")
1025

1026
            if extent is not None:
1027
                ax.set_extent(extent, crs=crs)
1028

1029
    plt.subplots_adjust(wspace=0.3)
1030
    return fig
1031

1032

1033
def extent_str_to_tuple(extent: str) -> Tuple[float, float, float, float]:
2✔
1034
    """Convert extent string to (lon_min, lon_max, lat_min, lat_max) tuple.
1035

1036
    Args:
1037
        extent: str
1038
            String of region name. Options are: "global", "usa", "uk", "europe".
1039

1040
    Returns:
1041
        tuple
1042
            Tuple of (lon_min, lon_max, lat_min, lat_max).
1043
    """
1044
    if extent == "global":
×
1045
        return (-180, 180, -90, 90)
×
1046
    elif extent == "north_america":
×
1047
        return (-160, -60, 15, 75)
×
1048
    elif extent == "uk":
×
1049
        return (-12, 3, 50, 60)
×
1050
    elif extent == "europe":
×
1051
        return (-15, 40, 35, 70)
×
1052
    elif extent == "germany":
×
1053
        return (5, 15, 47, 55)
×
1054
    else:
1055
        raise ValueError(
×
1056
            f"Region {extent} not in supported list of regions with default bounds. "
1057
            f"Options are: 'global', 'north_america', 'uk', 'europe'."
1058
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc