• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455747995

22 Oct 2024 07:56AM UTC coverage: 81.626% (+0.3%) from 81.333%
11455747995

push

github

davidwilby
incorporate feedback

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.81
/deepsensor/model/convnp.py
1
import copy
2✔
2
import os.path
2✔
3
import json
2✔
4
from typing import Union, List, Literal, Optional
2✔
5
import warnings
2✔
6

7
import lab as B
2✔
8
import numpy as np
2✔
9
import warnings
2✔
10
from matrix import Diagonal
2✔
11
from plum import ModuleType, dispatch
2✔
12

13
from deepsensor import backend
2✔
14
from deepsensor.data.loader import TaskLoader
2✔
15
from deepsensor.data.processor import DataProcessor
2✔
16
from deepsensor.data.task import Task
2✔
17
from deepsensor.model.defaults import (
2✔
18
    compute_greatest_data_density,
19
    gen_encoder_scales,
20
    gen_decoder_scale,
21
)
22
from deepsensor.model.model import DeepSensorModel
2✔
23
from deepsensor.model.nps import (
2✔
24
    construct_neural_process,
25
    convert_task_to_nps_args,
26
    run_nps_model,
27
    run_nps_model_ar,
28
)
29

30
from neuralprocesses.dist import AbstractMultiOutputDistribution
2✔
31

32

33
TFModel = ModuleType("tensorflow.keras", "Model")
2✔
34
TorchModel = ModuleType("torch.nn", "Module")
2✔
35

36

37
class ConvNP(DeepSensorModel):
2✔
38
    """
39
    A Convolutional Neural Process (ConvNP) regression probabilistic model (by default a ConvCNP).
40

41
    Wraps around the ``neuralprocesses`` package to construct a ConvNP model.
42
    See: https://github.com/wesselb/neuralprocesses/blob/main/neuralprocesses/architectures/convgnp.py.
43
    Init kwargs passed to the `ConvNP` are passed to the `neuralprocesses.construct_convgnp` function
44
    and can be used to specify hyperparameters (see parameter list below). In particular, the
45
    `likelihood` parameter can be used to specify the likelihood of the model, which dictates
46
    whether the model outputs marginal distributions at each target point (a ConvCNP) or a
47
    joint Gaussian distribution over all target points (a ConvGNP). By default a ConvCNP
48
    with Gaussian likelihoods is constructed.
49

50
    Additionally, the ``ConvNP`` can optionally be instantiated with:
51
        - a ``DataProcessor`` object to auto-unnormalise the data at inference time with the ``.predict`` method.
52
        - a ``TaskLoader`` object to infer sensible default model parameters from the data.
53

54
    Multiple dispatch is implemented using ``plum`` to allow for re-using the
55
    model's forward prediction object when computing the logpdf, entropy, etc.
56
    Alternatively, the model can be run forwards with a ``Task`` object of data
57
    from the ``TaskLoader``.
58

59
    Many of the ``ConvNP`` class methods utilise multiple dispatch so that they
60
    can either be run with a ``Task`` object or a ``neuralprocesses`` distribution
61
    object. This allows for re-using the model's forward prediction object.
62

63
    Dimension shapes are expressed in method docstrings in terms of:
64
        - ``N_features``: number of features/dimensions in the target set.
65
        - ``N_targets``: number of target points (1D for off-grid targets, 2D for gridded targets).
66
        - ``N_components``: number of mixture components in the likelihood (for mixture likelihoods only).
67
        - ``N_samples``: number of samples drawn from the distribution.
68

69
    If the model has multiple target sets and the ``Task`` object
70
    has different target locations for each set, a list of arrays is returned
71
    for each target set. Otherwise, a single array is returned.
72

73
    Examples:
74
        Instantiate a ``ConvNP`` with all hyperparameters set to their default values:
75
            >>> ConvNP(data_processor, task_loader)
76
        Instantiate a ``ConvNP`` and override some hyperparameters:
77
            >>> ConvNP(data_processor, task_loader, internal_density=250, unet_channels=(128,) * 6)
78
        Instantiate a ``ConvNP`` with a pre-trained model saved in the folder ``my_trained_model``:
79
            >>> ConvNP(data_processor, task_loader, model_ID="my_trained_model")
80
        Instantiate a ``ConvNP`` with an existing ``neuralprocesses`` model object:
81
            >>> ConvNP(data_processor, task_loader, neural_process=my_neural_process_model)
82

83
    Args:
84
        data_processor (:class:`~.data.processor.DataProcessor`, optional):
85
            Used for unnormalising model predictions in
86
            ``.predict`` method.
87
        task_loader (:class:`~.data.loader.TaskLoader`, optional):
88
            Used for inferring sensible defaults for hyperparameters
89
            that are not set by the user.
90
        model_ID (str, optional):
91
            Folder to load the model config and weights from. This argument can only
92
            be used alongside the ``data_processor`` and ``task_loader`` arguments.
93
        neural_process (TFModel | TorchModel, optional):
94
            Pre-defined neural process PyTorch/TensorFlow model object. This argument can
95
            only be used alongside the ``data_processor`` and ``task_loader`` arguments.
96
        internal_density (int, optional):
97
            Density of the ConvNP's internal grid (in terms of number of points
98
            per 1x1 unit square). Defaults to 100.
99
        likelihood (str, optional):
100
            Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``),
101
            ``"gnp"`` (equivalently ``"lowrank"``), ``"cnp-spikes-beta"``,
102
            (equivalently ``"spikes-beta"``) or "bernoulli-gamma".
103
            Defaults to ``"cnp"``.
104
        dim_x (int, optional):
105
            Dimensionality of the inputs. Defaults to 1.
106
        dim_y (int, optional):
107
            Dimensionality of the outputs. Defaults to 1.
108
        dim_yc (int or tuple[int], optional):
109
            Dimensionality of the outputs of the context set. You should set this
110
            if the dimensionality of the outputs of the context set is not equal
111
            to the dimensionality of the outputs of the target set. You should
112
            also set this if you want to use multiple context sets. In that case,
113
            set this equal to a tuple of integers indicating the respective output
114
            dimensionalities.
115
        dim_yt (int, optional):
116
            Dimensionality of the outputs of the target set. You should set this
117
            if the dimensionality of the outputs of the target set is not equal to
118
            the dimensionality of the outputs of the context set.
119
        dim_aux_t (int, optional):
120
            Dimensionality of target-specific auxiliary variables.
121
        conv_arch (str, optional):
122
            Convolutional architecture to use. Must be one of
123
            ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to
124
            ``"unet"``.
125
        unet_channels (tuple[int], optional):
126
            Number of channels in the downsampling path of the UNet (including the bottleneck).
127
            Defaults to four downsampling layers, each with 64 channels. I.e. (64, 64, 64, 64).
128
            Note: The downsampling path is followed by an upsampling path with the same number of
129
            channels in the reverse order (plus extra channels for the skip connections).
130
        unet_kernels (int or tuple[int], optional):
131
            Sizes of the kernels in the UNet. Defaults to 5.
132
        unet_resize_convs (bool, optional):
133
            Use resize convolutions rather than transposed convolutions in the
134
            UNet. Defaults to ``False``.
135
        unet_resize_conv_interp_method (str, optional):
136
            Interpolation method for the resize convolutions in the UNet. Can be
137
            set to ``"bilinear"``. Defaults to "bilinear".
138
        num_basis_functions (int, optional):
139
            Number of basis functions for the low-rank likelihood. Defaults to
140
            64.
141
        dim_lv (int, optional):
142
            Dimensionality of the latent variable. Setting to >0 constructs a
143
            latent neural process. Defaults to 0.
144
        encoder_scales (float or tuple[float], optional):
145
            Initial value for the length scales of the set convolutions for the
146
            context sets embeddings. Set to a tuple equal to the number of context
147
            sets to use different values for each set. Set to a single value to use
148
            the same value for all context sets. Defaults to
149
            ``1 / internal_density``.
150
        encoder_scales_learnable (bool, optional):
151
            Whether the encoder SetConv length scale(s) are learnable. Defaults to
152
            ``False``.
153
        decoder_scale (float, optional):
154
            Initial value for the length scale of the set convolution in the
155
            decoder. Defaults to ``1 / internal_density``.
156
        decoder_scale_learnable (bool, optional):
157
            Whether the decoder SetConv length scale(s) are learnable. Defaults to
158
            ``False``.
159
        aux_t_mlp_layers (tuple[int], optional):
160
            Widths of the layers of the MLP for the target-specific auxiliary
161
            variable. Defaults to three layers of width 128.
162
        epsilon (float, optional):
163
            Epsilon added by the set convolutions before dividing by the density
164
            channel. Defaults to ``1e-2``.
165
        dtype (dtype, optional):
166
            Data type.
167
    """
168

169
    @dispatch
2✔
170
    def __init__(self, *args, **kwargs):
2✔
171
        """
172
        Generate a new model using ``construct_neural_process`` with default or
173
        specified parameters.
174

175
        This method does not take a ``TaskLoader`` or ``DataProcessor`` object,
176
        so the model will not auto-unnormalise predictions at inference time.
177
        """
178
        super().__init__()
×
179

180
        self.model, self.config = construct_neural_process(*args, **kwargs)
×
181

182
    @dispatch
2✔
183
    def __init__(
2✔
184
        self,
185
        data_processor: DataProcessor,
186
        task_loader: TaskLoader,
187
        *args,
188
        verbose: bool = True,
189
        **kwargs,
190
    ):
191
        """
192
        Instantiate model from TaskLoader, using data to infer model parameters
193
        (unless overridden).
194

195
        Args:
196
            data_processor (:class:`~.data.processor.DataProcessor`):
197
                DataProcessor object. Used for unnormalising model predictions in
198
                ``.predict`` method.
199
            task_loader (:class:`~.data.loader.TaskLoader`):
200
                TaskLoader object. Used for inferring sensible defaults for hyperparameters
201
                that are not set by the user.
202
            verbose (bool, optional):
203
                Whether to print inferred model parameters, by default True.
204
        """
205
        super().__init__(data_processor, task_loader)
2✔
206

207
        if "dim_yc" not in kwargs:
2✔
208
            dim_yc = task_loader.context_dims
2✔
209
            if verbose:
2✔
210
                print(f"dim_yc inferred from TaskLoader: {dim_yc}")
2✔
211
            kwargs["dim_yc"] = dim_yc
2✔
212
        if "dim_yt" not in kwargs:
2✔
213
            dim_yt = sum(task_loader.target_dims)  # Must be an int
2✔
214
            if verbose:
2✔
215
                print(f"dim_yt inferred from TaskLoader: {dim_yt}")
2✔
216
            kwargs["dim_yt"] = dim_yt
2✔
217
        if "dim_aux_t" not in kwargs:
2✔
218
            dim_aux_t = task_loader.aux_at_target_dims
2✔
219
            if verbose:
2✔
220
                print(f"dim_aux_t inferred from TaskLoader: {dim_aux_t}")
2✔
221
            kwargs["dim_aux_t"] = dim_aux_t
2✔
222
        if "aux_t_mlp_layers" not in kwargs and kwargs["dim_aux_t"] > 0:
2✔
223
            kwargs["aux_t_mlp_layers"] = (64,) * 3
2✔
224
            if verbose:
2✔
225
                print(f"Setting aux_t_mlp_layers: {kwargs['aux_t_mlp_layers']}")
×
226
        if "internal_density" not in kwargs:
2✔
227
            internal_density = compute_greatest_data_density(task_loader)
2✔
228
            if verbose:
2✔
229
                print(f"internal_density inferred from TaskLoader: {internal_density}")
2✔
230
            kwargs["internal_density"] = internal_density
2✔
231
        if "encoder_scales" not in kwargs:
2✔
232
            encoder_scales = gen_encoder_scales(kwargs["internal_density"], task_loader)
2✔
233
            if verbose:
2✔
234
                print(f"encoder_scales inferred from TaskLoader: {encoder_scales}")
2✔
235
            kwargs["encoder_scales"] = encoder_scales
2✔
236
        if "decoder_scale" not in kwargs:
2✔
237
            decoder_scale = gen_decoder_scale(kwargs["internal_density"])
2✔
238
            if verbose:
2✔
239
                print(f"decoder_scale inferred from TaskLoader: {decoder_scale}")
2✔
240
            kwargs["decoder_scale"] = decoder_scale
2✔
241

242
        self.model, self.config = construct_neural_process(*args, **kwargs)
2✔
243
        self._set_num_mixture_components()
2✔
244

245
    @dispatch
2✔
246
    def __init__(
2✔
247
        self,
248
        data_processor: DataProcessor,
249
        task_loader: TaskLoader,
250
        neural_process: Union[TFModel, TorchModel],
251
    ):
252
        """
253
        Instantiate with a pre-defined neural process model.
254

255
        Args:
256
            data_processor (:class:`~.data.processor.DataProcessor`):
257
                DataProcessor object. Used for unnormalising model predictions in
258
                ``.predict`` method.
259
            task_loader (:class:`~.data.loader.TaskLoader`):
260
                TaskLoader object. Used for inferring sensible defaults for hyperparameters
261
                that are not set by the user.
262
            neural_process (TFModel | TorchModel):
263
                Pre-defined neural process PyTorch/TensorFlow model object.
264
        """
265
        super().__init__(data_processor, task_loader)
×
266

267
        self.model = neural_process
×
268
        self.config = None
×
269

270
    @dispatch
2✔
271
    def __init__(self, model_ID: str):
2✔
272
        """Instantiate a model from a folder containing model weights and config."""
273
        super().__init__()
×
274

275
        self.load(model_ID)
×
276
        self._set_num_mixture_components()
×
277

278
    @dispatch
2✔
279
    def __init__(
2✔
280
        self,
281
        data_processor: DataProcessor,
282
        task_loader: TaskLoader,
283
        model_ID: str,
284
    ):
285
        """Instantiate a model from a folder containing model weights and config.
286

287
        Args:
288
            data_processor (:class:`~.data.processor.DataProcessor`):
289
                dataprocessor object. used for unnormalising model predictions in
290
                ``.predict`` method.
291
            task_loader (:class:`~.data.loader.TaskLoader`):
292
                taskloader object. used for inferring sensible defaults for hyperparameters
293
                that are not set by the user.
294
            model_ID (str):
295
                folder to load the model config and weights from.
296
        """
297
        super().__init__(data_processor, task_loader)
2✔
298

299
        self.load(model_ID)
2✔
300
        self._set_num_mixture_components()
2✔
301

302
    def _set_num_mixture_components(self):
2✔
303
        """
304
        Set the number of mixture components for the model based on the likelihood.
305
        """
306
        if self.config["likelihood"] in ["spikes-beta"]:
2✔
307
            self.N_mixture_components = 3
2✔
308
        elif self.config["likelihood"] in ["bernoulli-gamma"]:
2✔
309
            self.N_mixture_components = 2
2✔
310
        else:
311
            self.N_mixture_components = 1
2✔
312

313
    def save(self, model_ID: str):
2✔
314
        """
315
        Save the model weights and config to a folder.
316

317
        Args:
318
            model_ID (str):
319
                Folder to save the model to.
320

321
        Returns:
322
            None.
323
        """
324
        os.makedirs(model_ID, exist_ok=True)
2✔
325

326
        if backend.str == "torch":
2✔
327
            import torch
2✔
328

329
            torch.save(self.model.state_dict(), os.path.join(model_ID, "model.pt"))
2✔
330
        elif backend.str == "tf":
×
331
            self.model.save_weights(os.path.join(model_ID, "model"))
×
332
        else:
333
            raise NotImplementedError(f"Backend {backend.str} not supported.")
×
334

335
        config_fpath = os.path.join(model_ID, "model_config.json")
2✔
336
        with open(config_fpath, "w") as f:
2✔
337
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
338

339
    def load(self, model_ID: str):
2✔
340
        """
341
        Load a model from a folder containing model weights and config.
342

343
        Args:
344
            model_ID (str):
345
                Folder to load the model from.
346

347
        Returns:
348
            None.
349
        """
350
        config_fpath = os.path.join(model_ID, "model_config.json")
2✔
351
        with open(config_fpath, "r") as f:
2✔
352
            self.config = json.load(f)
2✔
353

354
        self.model, _ = construct_neural_process(**self.config)
2✔
355

356
        if backend.str == "torch":
2✔
357
            import torch
2✔
358

359
            self.model.load_state_dict(torch.load(os.path.join(model_ID, "model.pt")))
2✔
360
        elif backend.str == "tf":
×
361
            self.model.load_weights(os.path.join(model_ID, "model"))
×
362
        else:
363
            raise NotImplementedError(f"Backend {backend.str} not supported.")
×
364

365
    def __str__(self):
2✔
366
        return (
×
367
            f"ConvNP with config:"
368
            + "\n"
369
            + json.dumps(self.config, indent=4, sort_keys=False)
370
        )
371

372
    @classmethod
2✔
373
    def modify_task(cls, task: Task):
2✔
374
        """
375
        Cast numpy arrays to TensorFlow or PyTorch tensors, add batch dim, and
376
        mask NaNs.
377

378
        Args:
379
            task (:class:`~.data.task.Task`):
380
                ...
381

382
        Returns:
383
            ...: ...
384
        """
385

386
        if "batch_dim" not in task["ops"]:
2✔
387
            task = task.add_batch_dim()
2✔
388
        if "float32" not in task["ops"]:
2✔
389
            task = task.cast_to_float32()
2✔
390
        if "numpy_mask" not in task["ops"]:
2✔
391
            task = task.mask_nans_numpy()
2✔
392
        if "nps_mask" not in task["ops"]:
2✔
393
            task = task.mask_nans_nps()
2✔
394
        if "tensor" not in task["ops"]:
2✔
395
            task = task.convert_to_tensor()
2✔
396

397
        return task
2✔
398

399
    def __call__(self, task, n_samples=10, requires_grad=False):
2✔
400
        """
401
        Compute ConvNP distribution.
402

403
        Args:
404
            task (:class:`~.data.task.Task`):
405
                ...
406
            n_samples (int, optional):
407
                Number of samples to draw from the distribution, by default 10.
408
            requires_grad (bool, optional):
409
                Whether to compute gradients, by default False.
410

411
        Returns:
412
            ...: The ConvNP distribution.
413
        """
414
        task = ConvNP.modify_task(task)
2✔
415
        dist = run_nps_model(self.model, task, n_samples, requires_grad)
2✔
416
        return dist
2✔
417

418
    def _cast_numpy_and_squeeze(
2✔
419
        self,
420
        x: Union[B.Numeric, List[B.Numeric]],
421
        squeeze_axes: List[int] = (0, 1),
422
    ):
423
        """TODO docstring"""
424
        if isinstance(x, backend.nps.Aggregate):
2✔
425
            return [np.squeeze(B.to_numpy(xi), axis=squeeze_axes) for xi in x]
2✔
426
        else:
427
            return np.squeeze(B.to_numpy(x), axis=squeeze_axes)
2✔
428

429
    def _maybe_concat_multi_targets(
2✔
430
        self,
431
        x: Union[np.ndarray, List[np.ndarray]],
432
        concat_axis: int = 0,
433
    ) -> Union[np.ndarray, List[np.ndarray]]:
434
        """
435
        Concatenate multiple target sets into a single tensor along feature dimension
436
        and remove size-1 dimensions.
437

438
        Args:
439
            x (:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]):
440
                List of target sets.
441
            squeeze_axes (List[int], optional):
442
                Axes to squeeze out of the concatenated target sets. Defaults to (0, 1).
443
            concat_axis (int, optional):
444
                Axis to concatenate along (*after* squeezing arrays) when
445
                merging multiple target sets. Defaults to 0.
446

447
        Returns:
448
            (:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]):
449
                Concatenated target sets.
450
        """
451
        if isinstance(x, (list, tuple)):
2✔
452
            new_list = []
2✔
453
            pos = 0
2✔
454
            for dim in self.task_loader.target_dims:
2✔
455
                new_list.append(x[pos : pos + dim])
2✔
456
                pos += dim
2✔
457
            return [
2✔
458
                B.concat(*[xi for xi in sub_list], axis=concat_axis)
459
                for sub_list in new_list
460
            ]
461
        else:
462
            return x
2✔
463

464
    @dispatch
2✔
465
    def mean(self, dist: AbstractMultiOutputDistribution):
2✔
466
        mean = dist.mean
2✔
467
        mean = self._cast_numpy_and_squeeze(mean)
2✔
468
        return self._maybe_concat_multi_targets(mean)
2✔
469

470
    @dispatch
2✔
471
    def mean(self, task: Task):
2✔
472
        """
473
        Mean values of model's distribution at target locations in task.
474

475
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
476

477
        Args:
478
            task (:class:`~.data.task.Task`):
479
                The task containing the context and target data.
480

481
        Returns:
482
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
483
                Mean values.
484
        """
485
        dist = self(task)
2✔
486
        return self.mean(dist)
2✔
487

488
    @dispatch
2✔
489
    def variance(self, dist: AbstractMultiOutputDistribution):
2✔
490
        variance = dist.var
2✔
491
        variance = self._cast_numpy_and_squeeze(variance)
2✔
492
        return self._maybe_concat_multi_targets(variance)
2✔
493

494
    @dispatch
2✔
495
    def variance(self, task: Task):
2✔
496
        """
497
        Variance values of model's distribution at target locations in task.
498

499
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
500

501
        Args:
502
            task (:class:`~.data.task.Task`):
503
                The task containing the context and target data.
504

505
        Returns:
506
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
507
                Variance values.
508
        """
509
        dist = self(task)
2✔
510
        return self.variance(dist)
2✔
511

512
    @dispatch
2✔
513
    def std(self, dist: AbstractMultiOutputDistribution):
2✔
514
        variance = self.variance(dist)
2✔
515
        if isinstance(variance, (list, tuple)):
2✔
516
            return [np.sqrt(v) for v in variance]
2✔
517
        else:
518
            return np.sqrt(variance)
2✔
519

520
    @dispatch
2✔
521
    def std(self, task: Task):
2✔
522
        """
523
        Standard deviation values of model's distribution at target locations in task.
524

525
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
526

527
        Args:
528
            task (:class:`~.data.task.Task`):
529
                The task containing the context and target data.
530

531
        Returns:
532
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
533
                Standard deviation values.
534
        """
535
        dist = self(task)
2✔
536
        return self.std(dist)
2✔
537

538
    @dispatch
2✔
539
    def alpha(
2✔
540
        self, dist: AbstractMultiOutputDistribution
541
    ) -> Union[np.ndarray, List[np.ndarray]]:
542
        if self.config["likelihood"] not in ["spikes-beta"]:
2✔
543
            raise NotImplementedError(
×
544
                f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
545
                f"Valid likelihoods: 'spikes-beta'."
546
            )
547
        alpha = dist.slab.alpha
2✔
548
        alpha = self._cast_numpy_and_squeeze(alpha)
2✔
549
        return self._maybe_concat_multi_targets(alpha)
2✔
550

551
    @dispatch
2✔
552
    def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
2✔
553
        """
554
        Alpha parameter values of model's distribution at target locations in task.
555

556
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
557

558
        .. note::
559
            This method only works for models that return a distribution with
560
            a ``dist.slab.alpha`` attribute, e.g. models with a Beta or
561
            Bernoulli-Gamma likelihood, where it returns the alpha values of
562
            the slab component of the mixture model.
563

564
        Args:
565
            task (:class:`~.data.task.Task`):
566
                The task containing the context and target data.
567

568
        Returns:
569
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
570
                Alpha values.
571
        """
572
        dist = self(task)
2✔
573
        return self.alpha(dist)
2✔
574

575
    @dispatch
2✔
576
    def beta(
2✔
577
        self, dist: AbstractMultiOutputDistribution
578
    ) -> Union[np.ndarray, List[np.ndarray]]:
579
        if self.config["likelihood"] not in ["spikes-beta"]:
2✔
580
            raise NotImplementedError(
×
581
                f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
582
                f"Valid likelihoods: 'spikes-beta'."
583
            )
584
        beta = dist.slab.beta
2✔
585
        beta = self._cast_numpy_and_squeeze(beta)
2✔
586
        return self._maybe_concat_multi_targets(beta)
2✔
587

588
    @dispatch
2✔
589
    def beta(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
2✔
590
        """
591
        Beta values of model's distribution at target locations in task.
592

593
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
594

595
        .. note::
596
            This method only works for models that return a distribution with
597
            a ``dist.slab.beta`` attribute, e.g. models with a Beta or
598
            Bernoulli-Gamma likelihood.
599

600
        Args:
601
            task (:class:`~.data.task.Task`):
602
                The task containing the context and target data.
603

604
        Returns:
605
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
606
                Beta values.
607
        """
608
        dist = self(task)
2✔
609
        return self.beta(dist)
2✔
610

611
    @dispatch
2✔
612
    def k(
2✔
613
        self, dist: AbstractMultiOutputDistribution
614
    ) -> Union[np.ndarray, List[np.ndarray]]:
615
        if self.config["likelihood"] not in ["bernoulli-gamma"]:
2✔
616
            raise NotImplementedError(
×
617
                f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. "
618
                f"Valid likelihoods: 'bernoulli-gamma'."
619
            )
620
        k = dist.slab.k
2✔
621
        k = self._cast_numpy_and_squeeze(k)
2✔
622
        return self._maybe_concat_multi_targets(k)
2✔
623

624
    @dispatch
2✔
625
    def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
2✔
626
        """
627
        k parameter values of model's distribution at target locations in task.
628

629
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
630

631
        .. note::
632
            This method only works for models that return a distribution with
633
            a ``dist.slab.k`` attribute, e.g. models with a Beta or
634
            Bernoulli-Gamma likelihood, where it returns the k values of
635
            the slab component of the mixture model.
636

637
        Args:
638
            task (:class:`~.data.task.Task`):
639
                The task containing the context and target data.
640

641
        Returns:
642
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
643
                k values.
644
        """
645
        dist = self(task)
×
646
        return self.k(dist)
×
647

648
    @dispatch
2✔
649
    def scale(
2✔
650
        self, dist: AbstractMultiOutputDistribution
651
    ) -> Union[np.ndarray, List[np.ndarray]]:
652
        if self.config["likelihood"] not in ["bernoulli-gamma"]:
2✔
653
            raise NotImplementedError(
×
654
                f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. "
655
                f"Valid likelihoods: 'bernoulli-gamma'."
656
            )
657
        scale = dist.slab.scale
2✔
658
        scale = self._cast_numpy_and_squeeze(scale)
2✔
659
        return self._maybe_concat_multi_targets(scale)
2✔
660

661
    @dispatch
2✔
662
    def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
2✔
663
        """
664
        Scale parameter values of model's distribution at target locations in task.
665

666
        Returned numpy arrays have shape ``(N_features, *N_targets)``.
667

668
        .. note::
669
            This method only works for models that return a distribution with
670
            a ``dist.slab.scale`` attribute, e.g. models with a Beta or
671
            Bernoulli-Gamma likelihood, where it returns the scale values of
672
            the slab component of the mixture model.
673

674
        Args:
675
            task (:class:`~.data.task.Task`):
676
                The task containing the context and target data.
677

678
        Returns:
679
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
680
                Scale values.
681
        """
682
        dist = self(task)
×
683
        return self.scale(dist)
×
684

685
    @dispatch
2✔
686
    def mixture_probs(self, dist: AbstractMultiOutputDistribution):
2✔
687
        if self.N_mixture_components == 1:
2✔
688
            raise NotImplementedError(
×
689
                f"mixture_probs not supported if model attribute N_mixture_components == 1. "
690
                f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta'."
691
            )
692
        mixture_probs = dist.logprobs
2✔
693
        mixture_probs = self._cast_numpy_and_squeeze(mixture_probs)
2✔
694
        mixture_probs = self._maybe_concat_multi_targets(mixture_probs)
2✔
695
        if isinstance(mixture_probs, (list, tuple)):
2✔
696
            return [np.moveaxis(np.exp(m), -1, 0) for m in mixture_probs]
2✔
697
        else:
698
            return np.moveaxis(np.exp(mixture_probs), -1, 0)
2✔
699

700
    @dispatch
2✔
701
    def mixture_probs(self, task: Task):
2✔
702
        """
703
        Mixture probabilities of model's distribution at target locations in task.
704

705
        Returned numpy arrays have shape ``(N_components, N_features, *N_targets)``.
706

707
        Args:
708
            task (:class:`~.data.task.Task`):
709
                The task containing the context and target data.
710

711
        Returns:
712
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
713
                Mixture probabilities.
714
        """
715
        dist = self(task)
2✔
716
        return self.mixture_probs(dist)
2✔
717

718
    @dispatch
2✔
719
    def covariance(self, dist: AbstractMultiOutputDistribution):
2✔
720
        return B.to_numpy(B.dense(dist.vectorised_normal.var))[0, 0]
2✔
721

722
    @dispatch
2✔
723
    def covariance(self, task: Task):
2✔
724
        """
725
        ...
726

727
        Args:
728
            task (:class:`~.data.task.Task`):
729
                ...
730

731
        Returns:
732
            ...: ...
733
        """
734
        dist = self(task)
2✔
735
        return self.covariance(dist)
2✔
736

737
    @dispatch
2✔
738
    def sample(
2✔
739
        self,
740
        dist: AbstractMultiOutputDistribution,
741
        n_samples: int = 1,
742
    ):
743
        if self.config["likelihood"] in ["gnp", "lowrank"]:
2✔
744
            samples = dist.noiseless.sample(n_samples)
2✔
745
        else:
746
            samples = dist.sample(n_samples)
2✔
747
        # Be careful to keep sample dimension in position 0
748
        samples = self._cast_numpy_and_squeeze(samples, squeeze_axes=(1, 2))
2✔
749
        return self._maybe_concat_multi_targets(samples, concat_axis=1)
2✔
750

751
    @dispatch
2✔
752
    def sample(self, task: Task, n_samples: int = 1):
2✔
753
        """
754
        Create samples from a ConvNP distribution.
755

756
        Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``,
757

758
        Args:
759
            dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
760
                The distribution to sample from.
761
            n_samples (int, optional):
762
                The number of samples to draw from the distribution, by
763
                default 1.
764

765
        Returns:
766
            :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
767
                The samples as an array or list of arrays.
768
        """
769
        dist = self(task)
2✔
770
        return self.sample(dist, n_samples)
2✔
771

772
    @dispatch
2✔
773
    def slice_diag(self, task: Task):
2✔
774
        """
775
        Slice out the ConvCNP part of the ConvNP distribution.
776

777
        Args:
778
            task (:class:`~.data.task.Task`):
779
                The task to slice.
780

781
        Returns:
782
            ...: ...
783
        """
784
        dist = self(task)
2✔
785
        if self.config["likelihood"] in ["spikes-beta"]:
2✔
786
            dist_diag = dist
×
787
        else:
788
            dist_diag = backend.nps.MultiOutputNormal(
2✔
789
                dist._mean,
790
                B.zeros(dist._var),
791
                Diagonal(B.diag(dist._noise + dist._var)),
792
                dist.shape,
793
            )
794
        return dist_diag
2✔
795

796
    @dispatch
2✔
797
    def slice_diag(self, dist: AbstractMultiOutputDistribution):
2✔
798
        """
799
        Slice out the ConvCNP part of the ConvNP distribution.
800

801
        Args:
802
            dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
803
                The distribution to slice.
804

805
        Returns:
806
            ...: ...
807
        """
808
        if self.config["likelihood"]:
×
809
            dist_diag = dist
×
810
        else:
811
            dist_diag = backend.nps.MultiOutputNormal(
×
812
                dist._mean,
813
                B.zeros(dist._var),
814
                Diagonal(B.diag(dist._noise + dist._var)),
815
                dist.shape,
816
            )
817
        return dist_diag
×
818

819
    @dispatch
2✔
820
    def mean_marginal_entropy(self, dist: AbstractMultiOutputDistribution):
2✔
821
        """
822
        Mean marginal entropy over target points given context points.
823

824
        Args:
825
            dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
826
                The distribution to compute the entropy of.
827

828
        Returns:
829
            float: The mean marginal entropy.
830
        """
831
        dist_diag = self.slice_diag(dist)
×
832
        return B.mean(B.to_numpy(dist_diag.entropy())[0, 0])
×
833

834
    @dispatch
2✔
835
    def mean_marginal_entropy(self, task: Task):
2✔
836
        """
837
        Mean marginal entropy over target points given context points.
838

839
        Args:
840
            task (:class:`~.data.task.Task`):
841
                The task to compute the entropy of.
842

843
        Returns:
844
            float: The mean marginal entropy.
845
        """
846
        dist_diag = self.slice_diag(task)
2✔
847
        return B.mean(B.to_numpy(dist_diag.entropy())[0, 0])
2✔
848

849
    @dispatch
2✔
850
    def joint_entropy(self, dist: AbstractMultiOutputDistribution):
2✔
851
        """
852
        Model entropy over target points given context points.
853

854
        Args:
855
            dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
856
                The distribution to compute the entropy of.
857

858
        Returns:
859
            float: The model entropy.
860
        """
861
        return B.to_numpy(dist.entropy())[0, 0]
×
862

863
    @dispatch
2✔
864
    def joint_entropy(self, task: Task):
2✔
865
        """
866
        Model entropy over target points given context points.
867

868
        Args:
869
            task (:class:`~.data.task.Task`):
870
                The task to compute the entropy of.
871

872
        Returns:
873
            float: The model entropy.
874
        """
875
        return B.to_numpy(self(task).entropy())[0, 0]
2✔
876

877
    @dispatch
2✔
878
    def logpdf(self, dist: AbstractMultiOutputDistribution, task: Task):
2✔
879
        """
880
        Joint logpdf over all target sets.
881

882
        .. note::
883
            If the model has multiple target sets, the returned logpdf is the
884
            mean logpdf over all target sets.
885

886
        Args:
887
            dist (neuralprocesses.dist.AbstractMultiOutputDistribution):
888
                The distribution to compute the logpdf of.
889
            task (:class:`~.data.task.Task`):
890
                The task to compute the logpdf of.
891

892
        Returns:
893
            float: The logpdf.
894
        """
895
        # Need to ensure `Y_t` is a tensor and, if multiple target sets,
896
        #   an nps.Aggregate object
897
        task = ConvNP.modify_task(task)
2✔
898
        _, _, Y_t, _ = convert_task_to_nps_args(task)
2✔
899
        return B.to_numpy(dist.logpdf(Y_t)).mean()
2✔
900

901
    @dispatch
2✔
902
    def logpdf(self, task: Task):
2✔
903
        """
904
        Joint logpdf over all target sets.
905

906
        .. note::
907
            If the model has multiple target sets, the returned logpdf is the
908
            mean logpdf over all target sets.
909

910
        Args:
911
            task (:class:`~.data.task.Task`):
912
                The task to compute the logpdf of.
913

914
        Returns:
915
            float: The logpdf.
916
        """
917
        dist = self(task)
2✔
918
        return self.logpdf(dist, task)
2✔
919

920
    def loss_fn(
2✔
921
        self,
922
        task: Task,
923
        fix_noise=None,
924
        num_lv_samples: int = 8,
925
        normalise: bool = False,
926
    ):
927
        """
928
        Compute the loss of a task.
929

930
        Args:
931
            task (:class:`~.data.task.Task`):
932
                The task to compute the loss of.
933
            fix_noise (...):
934
                Whether to fix the noise to the value specified in the model
935
                config.
936
            num_lv_samples (int, optional):
937
                If latent variable model, number of lv samples for evaluating
938
                the loss, by default 8.
939
            normalise (bool, optional):
940
                Whether to normalise the loss by the number of target points,
941
                by default False.
942

943
        Returns:
944
            float: The loss.
945
        """
946
        task = ConvNP.modify_task(task)
2✔
947

948
        context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task)
2✔
949

950
        logpdfs = backend.nps.loglik(
2✔
951
            self.model,
952
            context_data,
953
            xt,
954
            yt,
955
            **model_kwargs,
956
            fix_noise=fix_noise,
957
            num_samples=num_lv_samples,
958
            normalise=normalise,
959
        )
960

961
        loss = -B.mean(logpdfs)
2✔
962

963
        return loss
2✔
964

965
    def ar_sample(
2✔
966
        self,
967
        task: Task,
968
        n_samples: int = 1,
969
        X_target_AR: Optional[np.ndarray] = None,
970
        ar_subsample_factor: int = 1,
971
        fill_type: Literal["mean", "sample"] = "mean",
972
    ):
973
        """
974
        Autoregressive sampling from the model.
975

976
        AR sampling with optional functionality to only draw AR samples over a
977
        subset of the target set and then infill the rest of the sample with
978
        the model mean or joint sample conditioned on the AR samples.
979

980
        Returned numpy arrays have shape ``(N_samples, N_features, *N_targets)``,
981

982
        .. note::
983
            AR sampling only works for 0th context/target set, and only for models with
984
            a single target set.
985

986
        Args:
987
            task (:class:`~.data.task.Task`):
988
                The task to sample from.
989
            n_samples (int, optional):
990
                The number of samples to draw from the distribution, by
991
                default 1.
992
            X_target_AR (:class:`numpy:numpy.ndarray`, optional):
993
                Locations to draw AR samples over. If None, AR samples will be
994
                drawn over the target locations in the task. Defaults to None.
995
            ar_subsample_factor (int, optional):
996
                Subsample target locations to draw AR samples over. Defaults
997
                to 1.
998
            fill_type (Literal["mean", "sample"], optional):
999
                How to infill the rest of the sample. Must be one of "mean" or
1000
                "sample". Defaults to "mean".
1001

1002
        Returns:
1003
            :class:`numpy:numpy.ndarray`
1004
                The samples.
1005
        """
1006
        if len(task["X_t"]) > 1 or (task["Y_t"] is not None and len(task["Y_t"]) > 1):
2✔
1007
            raise NotImplementedError(
×
1008
                "AR sampling with multiple target sets is not supported."
1009
            )
1010

1011
        # AR sampling requires gridded data to be flattened, not coordinate tuples
1012
        task_arsample = copy.deepcopy(task)
2✔
1013
        task = copy.deepcopy(task)
2✔
1014

1015
        if X_target_AR is not None:
2✔
1016
            # User has specified a set of locations to draw AR samples over
1017
            task_arsample["X_t"][0] = X_target_AR
×
1018
        elif ar_subsample_factor > 1:
2✔
1019
            # Subsample target locations to draw AR samples over
1020
            xt = task["X_t"][0]
2✔
1021
            if isinstance(xt, tuple):
2✔
1022
                # Targets on a grid: subsample targets for AR along spatial dimension
1023
                xt = (
2✔
1024
                    xt[0][..., ::ar_subsample_factor],
1025
                    xt[1][..., ::ar_subsample_factor],
1026
                )
1027
            else:
1028
                xt = xt[..., ::ar_subsample_factor]
×
1029
            task_arsample["X_t"][0] = xt
2✔
1030
        else:
1031
            task_arsample = copy.deepcopy(task)
×
1032

1033
        task = task.flatten_gridded_data()
2✔
1034
        task_arsample = task_arsample.flatten_gridded_data()
2✔
1035

1036
        task_arsample = ConvNP.modify_task(task_arsample)
2✔
1037
        task = ConvNP.modify_task(task)
2✔
1038

1039
        if backend.str == "torch":
2✔
1040
            import torch
2✔
1041

1042
            # Run AR sampling with torch.no_grad() to avoid prohibitive backprop computation for AR
1043
            with torch.no_grad():
2✔
1044
                (
2✔
1045
                    mean,
1046
                    variance,
1047
                    noiseless_samples,
1048
                    noisy_samples,
1049
                ) = run_nps_model_ar(self.model, task_arsample, num_samples=n_samples)
1050
        else:
1051
            (
×
1052
                mean,
1053
                variance,
1054
                noiseless_samples,
1055
                noisy_samples,
1056
            ) = run_nps_model_ar(self.model, task_arsample, num_samples=n_samples)
1057

1058
        # Slice out first (and assumed only) target entry in nps.Aggregate object
1059
        noiseless_samples = B.to_numpy(noiseless_samples)
2✔
1060

1061
        if ar_subsample_factor > 1 or X_target_AR is not None:
2✔
1062
            # AR sample locations not equal to target locations - infill the rest of the
1063
            # sample with the model mean conditioned on the AR samples
1064
            full_samples = []
2✔
1065
            for sample in noiseless_samples:
2✔
1066
                task_with_sample = copy.deepcopy(task)
2✔
1067
                task_with_sample["X_c"][0] = B.concat(
2✔
1068
                    task["X_c"][0], task_arsample["X_t"][0], axis=-1
1069
                )
1070
                task_with_sample["Y_c"][0] = B.concat(task["Y_c"][0], sample, axis=-1)
2✔
1071

1072
                if fill_type == "mean":
2✔
1073
                    # Compute the mean conditioned on the AR samples
1074
                    # Should this be a `.sample` call?
1075
                    pred = self.mean(task_with_sample)
2✔
1076
                elif fill_type == "sample":
×
1077
                    # Sample from joint distribution over all target locations
1078
                    pred = self.sample(task_with_sample, n_samples=1)
×
1079

1080
                full_samples.append(pred)
2✔
1081
            full_samples = np.stack(full_samples, axis=0)
2✔
1082

1083
            return full_samples
2✔
1084
        else:
1085
            return noiseless_samples[:, 0]  # Slice out batch dim
×
1086

1087

1088
def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
2✔
1089
    warnings.warn(
×
1090
        "concat_tasks has been moved to deepsensor.data.task and will be removed from "
1091
        "deepsensor.model.convnp in a future release.",
1092
        FutureWarning,
1093
    )
1094
    return deepsensor.data.task.concat_tasks(tasks, multiple)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc