• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 19842460617

08 Oct 2025 10:03AM UTC coverage: 81.663%. Remained the same
19842460617

push

github

web-flow
Update README.md, adding reference to GIANT project

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.12
/deepsensor/model/model.py
1
from deepsensor.data.loader import TaskLoader
2✔
2
from deepsensor.data.processor import (
2✔
3
    DataProcessor,
4
    process_X_mask_for_X,
5
    xarray_to_coord_array_normalised,
6
    mask_coord_array_normalised,
7
)
8
from deepsensor.model.pred import (
2✔
9
    Prediction,
10
    increase_spatial_resolution,
11
    infer_prediction_modality_from_X_t,
12
)
13
from deepsensor.data.task import Task
2✔
14

15
from typing import List, Union, Optional, Tuple
2✔
16
import copy
2✔
17

18
import time
2✔
19
from tqdm import tqdm
2✔
20

21
import numpy as np
2✔
22
import pandas as pd
2✔
23
import xarray as xr
2✔
24
import lab as B
2✔
25

26
# For dispatching with TF and PyTorch model types when they have not yet been loaded.
27
# See https://beartype.github.io/plum/types.html#moduletype
28

29

30
class ProbabilisticModel:
2✔
31
    """Base class for probabilistic model used for DeepSensor.
32
    Ensures a set of methods required for DeepSensor
33
    are implemented by specific model classes that inherit from it.
34
    """
35

36
    def mean(self, task: Task, *args, **kwargs):
2✔
37
        """Computes the model mean prediction over target points based on given context
38
        data.
39

40
        Args:
41
            task (:class:`~.data.task.Task`):
42
                Task containing context data.
43

44
        Returns:
45
            :class:`numpy:numpy.ndarray`: Mean prediction over target points.
46

47
        Raises:
48
            NotImplementedError
49
                If not implemented by child class.
50
        """
51
        raise NotImplementedError()
×
52

53
    def variance(self, task: Task, *args, **kwargs):
2✔
54
        """Model marginal variance over target points given context points.
55
        Shape (N,).
56

57
        Args:
58
            task (:class:`~.data.task.Task`):
59
                Task containing context data.
60

61
        Returns:
62
            :class:`numpy:numpy.ndarray`: Marginal variance over target points.
63

64
        Raises:
65
            NotImplementedError
66
                If not implemented by child class.
67
        """
68
        raise NotImplementedError()
×
69

70
    def std(self, task: Task):
2✔
71
        """Model marginal standard deviation over target points given context
72
        points. Shape (N,).
73

74
        Args:
75
            task (:class:`~.data.task.Task`):
76
                Task containing context data.
77

78
        Returns:
79
            :class:`numpy:numpy.ndarray`: Marginal standard deviation over target points.
80
        """
81
        var = self.variance(task)
×
82
        return var**0.5
×
83

84
    def stddev(self, *args, **kwargs):  # noqa
2✔
85
        return self.std(*args, **kwargs)
2✔
86

87
    def covariance(self, task: Task, *args, **kwargs):
2✔
88
        """Computes the model covariance matrix over target points based on given
89
        context data. Shape (N, N).
90

91
        Args:
92
            task (:class:`~.data.task.Task`):
93
                Task containing context data.
94

95
        Returns:
96
            :class:`numpy:numpy.ndarray`: Covariance matrix over target points.
97

98
        Raises:
99
            NotImplementedError
100
                If not implemented by child class.
101
        """
102
        raise NotImplementedError()
×
103

104
    def mean_marginal_entropy(self, task: Task, *args, **kwargs):
2✔
105
        """Computes the mean marginal entropy over target points based on given
106
        context data.
107

108
        .. note::
109
            Note: Getting a vector of marginal entropies would be useful too.
110

111

112
        Args:
113
            task (:class:`~.data.task.Task`):
114
                Task containing context data.
115

116
        Returns:
117
            float: Mean marginal entropy over target points.
118

119
        Raises:
120
            NotImplementedError
121
                If not implemented by child class.
122
        """
123
        raise NotImplementedError()
×
124

125
    def joint_entropy(self, task: Task, *args, **kwargs):
2✔
126
        """Computes the model joint entropy over target points based on given
127
        context data.
128

129

130
        Args:
131
            task (:class:`~.data.task.Task`):
132
                Task containing context data.
133

134
        Returns:
135
            float: Joint entropy over target points.
136

137
        Raises:
138
            NotImplementedError
139
                If not implemented by child class.
140
        """
141
        raise NotImplementedError()
×
142

143
    def logpdf(self, task: Task, *args, **kwargs):
2✔
144
        """Computes the joint model logpdf over target points based on given
145
        context data.
146

147
        Args:
148
            task (:class:`~.data.task.Task`):
149
                Task containing context data.
150

151
        Returns:
152
            float: Joint logpdf over target points.
153

154
        Raises:
155
            NotImplementedError
156
                If not implemented by child class.
157
        """
158
        raise NotImplementedError()
×
159

160
    def loss(self, task: Task, *args, **kwargs):
2✔
161
        """Computes the model loss over target points based on given context data.
162

163
        Args:
164
            task (:class:`~.data.task.Task`):
165
                Task containing context data.
166

167
        Returns:
168
            float: Loss over target points.
169

170
        Raises:
171
            NotImplementedError
172
                If not implemented by child class.
173
        """
174
        raise NotImplementedError()
×
175

176
    def sample(self, task: Task, n_samples=1, *args, **kwargs):
2✔
177
        """Draws ``n_samples`` joint samples over target points based on given
178
        context data. Returned shape is ``(n_samples, n_target)``.
179

180

181
        Args:
182
            task (:class:`~.data.task.Task`):
183
                Task containing context data.
184
            n_samples (int, optional):
185
                Number of samples to draw. Defaults to 1.
186

187
        Returns:
188
            tuple[:class:`numpy:numpy.ndarray`]: Joint samples over target points.
189

190
        Raises:
191
            NotImplementedError
192
                If not implemented by child class.
193
        """
194
        raise NotImplementedError()
×
195

196

197
class DeepSensorModel(ProbabilisticModel):
2✔
198
    """Implements DeepSensor prediction functionality of a ProbabilisticModel.
199
    Allows for outputting an xarray object containing on-grid predictions or a
200
    pandas object containing off-grid predictions.
201

202
    Args:
203
        data_processor (:class:`~.data.processor.DataProcessor`):
204
            DataProcessor object, used to unnormalise predictions.
205
        task_loader (:class:`~.data.loader.TaskLoader`):
206
            TaskLoader object, used to determine target variables for unnormalising.
207
    """
208

209
    N_mixture_components = 1  # Number of mixture components for mixture likelihoods
2✔
210

211
    def __init__(
2✔
212
        self,
213
        data_processor: Optional[DataProcessor] = None,
214
        task_loader: Optional[TaskLoader] = None,
215
    ):
216
        self.task_loader = task_loader
2✔
217
        self.data_processor = data_processor
2✔
218

219
    def predict(
2✔
220
        self,
221
        tasks: Union[List[Task], Task],
222
        X_t: Union[
223
            xr.Dataset,
224
            xr.DataArray,
225
            pd.DataFrame,
226
            pd.Series,
227
            pd.Index,
228
            np.ndarray,
229
        ],
230
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
231
        X_t_is_normalised: bool = False,
232
        aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None,
233
        aux_at_targets_override_is_normalised: bool = False,
234
        resolution_factor: int = 1,
235
        pred_params: Tuple[str] = ("mean", "std"),
236
        n_samples: int = 0,
237
        ar_sample: bool = False,
238
        ar_subsample_factor: int = 1,
239
        unnormalise: bool = True,
240
        seed: int = 0,
241
        append_indexes: dict = None,
242
        progress_bar: int = 0,
243
        verbose: bool = False,
244
    ) -> Prediction:
245
        """Predict on a regular grid or at off-grid locations.
246

247
        Args:
248
            tasks (List[Task] | Task):
249
                List of tasks containing context data.
250
            X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
251
                Target locations to predict at. Can be an xarray object
252
                containingon-grid locations or a pandas object containing off-grid locations.
253
            X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
254
                2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
255
                to the same grid as ``X_t``. Default None (no mask).
256
            X_t_is_normalised (bool):
257
                Whether the ``X_t`` coords are normalised. If False, will normalise
258
                the coords before passing to model. Default ``False``.
259
            aux_at_targets_override (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
260
                Optional auxiliary xarray data to override from the task_loader.
261
            aux_at_targets_override_is_normalised (bool):
262
                Whether the `aux_at_targets_override` coords are normalised.
263
                If False, the DataProcessor will normalise the coords before passing to model.
264
                Default False.
265
            pred_params (Tuple[str]):
266
                Tuple of prediction parameters to return. The strings refer to methods
267
                of the model class which will be called and stored in the Prediction object.
268
                Default ("mean", "std").
269
            resolution_factor (float):
270
                Optional factor to increase the resolution of the target grid
271
                by. E.g. 2 will double the target resolution, 0.5 will halve
272
                it.Applies to on-grid predictions only. Default 1.
273
            n_samples (int):
274
                Number of joint samples to draw from the model. If 0, will not
275
                draw samples. Default 0.
276
            ar_sample (bool):
277
                Whether to use autoregressive sampling. Default ``False``.
278
            unnormalise (bool):
279
                Whether to unnormalise the predictions. Only works if ``self``
280
                hasa ``data_processor`` and ``task_loader`` attribute. Default
281
                ``True``.
282
            seed (int):
283
                Random seed for deterministic sampling. Default 0.
284
            append_indexes (dict):
285
                Dictionary of index metadata to append to pandas indexes in the
286
                off-grid case. Default ``None``.
287
            progress_bar (int):
288
                Whether to display a progress bar over tasks. Default 0.
289
            verbose (bool):
290
                Whether to print time taken for prediction. Default ``False``.
291

292
        Returns:
293
            :class:`~.model.pred.Prediction`):
294
                A `dict`-like object mapping from target variable IDs to xarray or pandas objects
295
                containing model predictions.
296
                - If ``X_t`` is a pandas object, returns pandas objects
297
                containing off-grid predictions.
298
                - If ``X_t`` is an xarray object, returns xarray object
299
                containing on-grid predictions.
300
                - If ``n_samples`` == 0, returns only mean and std predictions.
301
                - If ``n_samples`` > 0, returns mean, std and samples
302
                predictions.
303

304
        Raises:
305
            ValueError
306
                If ``X_t`` is not an xarray object and
307
                ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is
308
                not 1.
309
            ValueError
310
                If ``X_t`` is not a pandas object and ``append_indexes`` is not
311
                ``None``.
312
            ValueError
313
                If ``X_t`` is not an xarray, pandas or numpy object.
314
            ValueError
315
                If ``append_indexes`` are not all the same length as ``X_t``.
316
        """
317
        tic = time.time()
2✔
318
        mode = infer_prediction_modality_from_X_t(X_t)
2✔
319
        if not isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
320
            if resolution_factor != 1:
2✔
321
                raise ValueError(
×
322
                    "resolution_factor can only be used with on-grid predictions."
323
                )
324
            if ar_subsample_factor != 1:
2✔
325
                raise ValueError(
×
326
                    "ar_subsample_factor can only be used with on-grid predictions."
327
                )
328
        if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
329
            if append_indexes is not None:
2✔
330
                raise ValueError(
×
331
                    "append_indexes can only be used with off-grid predictions."
332
                )
333
        if mode == "off-grid" and X_t_mask is not None:
2✔
334
            # TODO: Unit test this
335
            raise ValueError("X_t_mask can only be used with on-grid predictions.")
×
336
        if ar_sample and n_samples < 1:
2✔
337
            raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
2✔
338

339
        target_delta_t = self.task_loader.target_delta_t
2✔
340
        dts = [pd.Timedelta(dt) for dt in target_delta_t]
2✔
341
        dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
2✔
342
        if target_delta_t is not None and dts_all_zero:
2✔
343
            forecasting_mode = False
2✔
344
            lead_times = None
2✔
345
        elif target_delta_t is not None and not dts_all_zero:
2✔
346
            target_var_IDs_set = set(self.task_loader.target_var_IDs)
2✔
347
            msg = f"""
2✔
348
            Got more than one set of target variables in target sets,
349
            but predictions can only be made with one set of target variables
350
            to simplify implementation.
351
            Got {target_var_IDs_set}.
352
            """
353
            assert len(target_var_IDs_set) == 1, msg
2✔
354
            # Repeat lead_tim for each variable in each target set
355
            lead_times = []
2✔
356
            for target_set_idx, dt in enumerate(target_delta_t):
2✔
357
                target_set_dim = self.task_loader.target_dims[target_set_idx]
2✔
358
                lead_times += [
2✔
359
                    pd.Timedelta(dt, unit=self.task_loader.time_freq)
360
                    for _ in range(target_set_dim)
361
                ]
362
            forecasting_mode = True
2✔
363
        else:
364
            forecasting_mode = False
×
365
            lead_times = None
×
366

367
        if type(tasks) is Task:
2✔
368
            tasks = [tasks]
2✔
369

370
        if n_samples >= 1:
2✔
371
            B.set_random_seed(seed)
2✔
372
            np.random.seed(seed)
2✔
373

374
        init_dates = [task["time"] for task in tasks]
2✔
375

376
        # Flatten tuple of tuples to single list
377
        target_var_IDs = [
2✔
378
            var_ID for set in self.task_loader.target_var_IDs for var_ID in set
379
        ]
380
        if lead_times is not None:
2✔
381
            assert len(lead_times) == len(target_var_IDs)
2✔
382

383
        # TODO consider removing this logic, can we just depend on the dim names in X_t?
384
        if not unnormalise:
2✔
385
            coord_names = {"x1": "x1", "x2": "x2"}
2✔
386
        elif unnormalise:
2✔
387
            coord_names = {
2✔
388
                "x1": self.data_processor.raw_spatial_coord_names[0],
389
                "x2": self.data_processor.raw_spatial_coord_names[1],
390
            }
391

392
        ### Pre-process X_t if necessary (TODO consider moving this to Prediction class)
393
        if isinstance(X_t, pd.Index):
2✔
394
            X_t = pd.DataFrame(index=X_t)
×
395
        elif isinstance(X_t, np.ndarray):
2✔
396
            # Convert to empty dataframe with normalised or unnormalised coord names
397
            if X_t_is_normalised:
2✔
398
                index_names = ["x1", "x2"]
×
399
            else:
400
                index_names = self.data_processor.raw_spatial_coord_names
2✔
401
            X_t = pd.DataFrame(X_t.T, columns=index_names)
2✔
402
            X_t = X_t.set_index(index_names)
2✔
403
        elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
404
            # Remove time dimension if present
405
            if "time" in X_t.coords:
2✔
406
                X_t = X_t.isel(time=0).drop_vars("time")
2✔
407

408
        if mode == "off-grid" and append_indexes is not None:
2✔
409
            # Check append_indexes are all same length as X_t
410
            if append_indexes is not None:
×
411
                for idx, vals in append_indexes.items():
×
412
                    if len(vals) != len(X_t):
×
413
                        raise ValueError(
×
414
                            f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
415
                        )
416
            X_t = X_t.reset_index()
×
417
            X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
×
418
            X_t = X_t.set_index(list(X_t.columns))
×
419

420
        if X_t_is_normalised:
2✔
421
            X_t_normalised = X_t
2✔
422
            # Unnormalise coords to use for xarray/pandas objects for storing predictions
423
            X_t = self.data_processor.map_coords(X_t, unnorm=True)
2✔
424
        elif not X_t_is_normalised:
2✔
425
            # Normalise coords to use for model
426
            X_t_normalised = self.data_processor.map_coords(X_t)
2✔
427

428
        if mode == "on-grid":
2✔
429
            if resolution_factor != 1:
2✔
430
                X_t_normalised = increase_spatial_resolution(
×
431
                    X_t_normalised, resolution_factor
432
                )
433
                X_t = increase_spatial_resolution(
×
434
                    X_t, resolution_factor, coord_names=coord_names
435
                )
436
            if X_t_mask is not None:
2✔
437
                X_t_mask = process_X_mask_for_X(X_t_mask, X_t)
×
438
                X_t_mask_normalised = self.data_processor.map_coords(X_t_mask)
×
439
                X_t_arr = xarray_to_coord_array_normalised(X_t_normalised)
×
440
                # Remove points that lie outside the mask
441
                X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
×
442
            else:
443
                X_t_arr = (
2✔
444
                    X_t_normalised["x1"].values,
445
                    X_t_normalised["x2"].values,
446
                )
447
        elif mode == "off-grid":
2✔
448
            X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T
2✔
449

450
        if isinstance(X_t_arr, tuple):
2✔
451
            target_shape = (len(X_t_arr[0]), len(X_t_arr[1]))
2✔
452
        else:
453
            target_shape = (X_t_arr.shape[1],)
2✔
454

455
        if not unnormalise:
2✔
456
            X_t = X_t_normalised
2✔
457

458
        if "mixture_probs" in pred_params:
2✔
459
            # Store each mixture component separately w/o overriding pred_params
460
            pred_params_to_store = copy.deepcopy(pred_params)
2✔
461
            pred_params_to_store.remove("mixture_probs")
2✔
462
            for component_i in range(self.N_mixture_components):
2✔
463
                pred_params_to_store.append(f"mixture_probs_{component_i}")
2✔
464
        else:
465
            pred_params_to_store = pred_params
2✔
466

467
        # Dict to store predictions for each target variable
468
        pred = Prediction(
2✔
469
            target_var_IDs,
470
            pred_params_to_store,
471
            init_dates,
472
            X_t,
473
            X_t_mask,
474
            coord_names,
475
            n_samples=n_samples,
476
            forecasting_mode=forecasting_mode,
477
            lead_times=lead_times,
478
        )
479

480
        def unnormalise_pred_array(arr, **kwargs):
2✔
481
            """Unnormalise an (N_dims, N_targets) array of predictions."""
482
            var_IDs_flattened = [
2✔
483
                var_ID
484
                for var_IDs in self.task_loader.target_var_IDs
485
                for var_ID in var_IDs
486
            ]
487
            assert arr.shape[0] == len(
2✔
488
                var_IDs_flattened
489
            ), f"{arr.shape[0]} != {len(var_IDs_flattened)}"
490
            for i, var_ID in enumerate(var_IDs_flattened):
2✔
491
                arr[i] = self.data_processor.map_array(
2✔
492
                    arr[i],
493
                    var_ID,
494
                    method=self.data_processor.config[var_ID]["method"],
495
                    unnorm=True,
496
                    **kwargs,
497
                )
498
            return arr
2✔
499

500
        # Don't change tasks by reference when overriding target locations
501
        # TODO consider not copying tasks by default for efficiency
502
        tasks = copy.deepcopy(tasks)
2✔
503

504
        if self.task_loader.aux_at_targets is not None:
2✔
505
            if aux_at_targets_override is not None:
2✔
506
                aux_at_targets = aux_at_targets_override
×
507
                if not aux_at_targets_override_is_normalised:
×
508
                    # Assumes using default normalisation method
509
                    aux_at_targets = self.data_processor(
×
510
                        aux_at_targets, assert_computed=True
511
                    )
512
            else:
513
                aux_at_targets = self.task_loader.aux_at_targets
2✔
514

515
        for task in tqdm(tasks, position=0, disable=progress_bar < 1, leave=True):
2✔
516
            task["X_t"] = [X_t_arr for _ in range(len(self.task_loader.target_var_IDs))]
2✔
517

518
            # If passing auxiliary data, need to sample it at target locations
519
            if self.task_loader.aux_at_targets is not None:
2✔
520
                aux_at_targets_sliced = self.task_loader.time_slice_variable(
2✔
521
                    aux_at_targets, task["time"]
522
                )
523
                task["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
2✔
524
                    X_t_arr, aux_at_targets_sliced
525
                )
526

527
            prediction_arrs = {}
2✔
528
            prediction_methods = {}
2✔
529
            for param in pred_params:
2✔
530
                try:
2✔
531
                    method = getattr(self, param)
2✔
532
                    prediction_methods[param] = method
2✔
533
                except AttributeError:
2✔
534
                    raise AttributeError(
2✔
535
                        f"Prediction method {param} not found in model class."
536
                    )
537
            if n_samples >= 1:
2✔
538
                B.set_random_seed(seed)
2✔
539
                np.random.seed(seed)
2✔
540
                if ar_sample:
2✔
541
                    sample_method = getattr(self, "ar_sample")
2✔
542
                    sample_args = {
2✔
543
                        "n_samples": n_samples,
544
                        "ar_subsample_factor": ar_subsample_factor,
545
                    }
546
                else:
547
                    sample_method = getattr(self, "sample")
2✔
548
                    sample_args = {"n_samples": n_samples}
2✔
549

550
            # If `DeepSensor` model child has been sub-classed with a `__call__` method,
551
            # we assume this is a distribution-like object that can be used to compute
552
            # mean, std and samples. Otherwise, run the model with `Task` for each prediction type.
553
            if hasattr(self, "__call__"):
2✔
554
                # Run model forwards once to generate output distribution, which we re-use
555
                dist = self(task, n_samples=n_samples)
2✔
556
                for param, method in prediction_methods.items():
2✔
557
                    prediction_arrs[param] = method(dist)
2✔
558
                if n_samples >= 1 and not ar_sample:
2✔
559
                    samples_arr = sample_method(dist, **sample_args)
2✔
560
                    # samples_arr = samples_arr.reshape((n_samples, len(target_var_IDs), *target_shape))
561
                    prediction_arrs["samples"] = samples_arr
2✔
562
                elif n_samples >= 1 and ar_sample:
2✔
563
                    # Can't draw AR samples from distribution object, need to re-run with task
564
                    samples_arr = sample_method(task, **sample_args)
2✔
565
                    samples_arr = samples_arr.reshape(
2✔
566
                        (n_samples, len(target_var_IDs), *target_shape)
567
                    )
568
                    prediction_arrs["samples"] = samples_arr
2✔
569
            else:
570
                # Re-run model for each prediction type
571
                for param, method in prediction_methods.items():
×
572
                    prediction_arrs[param] = method(task)
×
573
                if n_samples >= 1:
×
574
                    samples_arr = sample_method(task, **sample_args)
×
575
                    if ar_sample:
×
576
                        samples_arr = samples_arr.reshape(
×
577
                            (n_samples, len(target_var_IDs), *target_shape)
578
                        )
579
                    prediction_arrs["samples"] = samples_arr
×
580

581
            # Concatenate multi-target predictions
582
            for param, arr in prediction_arrs.items():
2✔
583
                if isinstance(arr, (list, tuple)):
2✔
584
                    if param != "samples":
2✔
585
                        concat_axis = 0
2✔
586
                    elif param == "samples":
2✔
587
                        # Axis 0 is sample dim, axis 1 is variable dim
588
                        concat_axis = 1
2✔
589
                    prediction_arrs[param] = np.concatenate(arr, axis=concat_axis)
2✔
590

591
            # Unnormalise predictions
592
            for param, arr in prediction_arrs.items():
2✔
593
                # TODO make class attributes?
594
                scale_and_offset_params = ["mean"]
2✔
595
                scale_only_params = ["std"]
2✔
596
                scale_squared_only_params = ["variance"]
2✔
597
                if unnormalise:
2✔
598
                    if param == "samples":
2✔
599
                        for sample_i in range(n_samples):
2✔
600
                            prediction_arrs["samples"][sample_i] = (
2✔
601
                                unnormalise_pred_array(
602
                                    prediction_arrs["samples"][sample_i]
603
                                )
604
                            )
605
                    elif param in scale_and_offset_params:
2✔
606
                        prediction_arrs[param] = unnormalise_pred_array(arr)
2✔
607
                    elif param in scale_only_params:
2✔
608
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
609
                            arr, add_offset=False
610
                        )
611
                    elif param in scale_squared_only_params:
2✔
612
                        # This is a horrible hack to repeat the scaling operation of the linear
613
                        #   transform twice s.t. new_var = scale ^ 2 * var
614
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
615
                            arr, add_offset=False
616
                        )
617
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
618
                            prediction_arrs[param], add_offset=False
619
                        )
620
                    else:
621
                        # Assume prediction parameters not captured above are dimensionless
622
                        #   quantities like probabilities and should not be unnormalised
623
                        pass
2✔
624

625
            # Assign predictions to Prediction object
626
            for param, arr in prediction_arrs.items():
2✔
627
                if param != "mixture_probs":
2✔
628
                    pred.assign(param, task["time"], arr, lead_times=lead_times)
2✔
629
                elif param == "mixture_probs":
2✔
630
                    assert arr.shape[0] == self.N_mixture_components, (
2✔
631
                        f"Number of mixture components ({arr.shape[0]}) does not match "
632
                        f"model attribute N_mixture_components ({self.N_mixture_components})."
633
                    )
634
                    for component_i, probs in enumerate(arr):
2✔
635
                        pred.assign(
2✔
636
                            f"{param}_{component_i}",
637
                            task["time"],
638
                            probs,
639
                            lead_times=lead_times,
640
                        )
641

642
        if forecasting_mode:
2✔
643
            pred = add_valid_time_coord_to_pred_and_move_time_dims(pred)
2✔
644

645
        if verbose:
2✔
646
            dur = time.time() - tic
×
647
            print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")
×
648

649
        return pred
2✔
650

651

652
def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction:
2✔
653
    """Add a valid time coordinate "time" to a Prediction object based on the
654
    initialisation times "init_time" and lead times "lead_time", and
655
    reorder the time dims from ("lead_time", "init_time") to ("init_time", "lead_time").
656

657
    Args:
658
        pred (:class:`~.model.pred.Prediction`):
659
            Prediction object to add valid time coordinate to.
660

661
    Returns:
662
        :class:`~.model.pred.Prediction`:
663
            Prediction object with valid time coordinate added.
664
    """
665
    for var_ID in pred.keys():
2✔
666
        if isinstance(pred[var_ID], pd.DataFrame):
2✔
667
            x = pred[var_ID].reset_index()
2✔
668
            pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
2✔
669
            pred[var_ID] = pred[var_ID].swaplevel("init_time", "lead_time")
2✔
670
            pred[var_ID] = pred[var_ID].sort_index()
2✔
671
        elif isinstance(pred[var_ID], xr.Dataset):
2✔
672
            x = pred[var_ID]
2✔
673
            pred[var_ID] = pred[var_ID].assign_coords(
2✔
674
                time=x["lead_time"] + x["init_time"]
675
            )
676
            pred[var_ID] = pred[var_ID].transpose("init_time", "lead_time", ...)
2✔
677
        else:
678
            raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
×
679
    return pred
2✔
680

681

682
def main():  # pragma: no cover # noqa: D103
683
    import deepsensor.tensorflow
684
    from deepsensor.data.loader import TaskLoader
685
    from deepsensor.data.processor import DataProcessor
686
    from deepsensor.model.convnp import ConvNP
687

688
    import xarray as xr
689
    import pandas as pd
690
    import numpy as np
691

692
    # Load raw data
693
    ds_raw = xr.tutorial.open_dataset("air_temperature")["air"]
694
    ds_raw2 = copy.deepcopy(ds_raw)
695
    ds_raw2.name = "air2"
696

697
    # Normalise data
698
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
699
    ds = data_processor(ds_raw)
700
    ds2 = data_processor(ds_raw2)
701

702
    # Set up task loader
703
    task_loader = TaskLoader(context=ds, target=[ds, ds2])
704

705
    # Set up model
706
    model = ConvNP(data_processor, task_loader)
707

708
    # Predict on new task with 10% of context data and a dense grid of target points
709
    test_tasks = task_loader(
710
        pd.date_range("2014-12-25", "2014-12-31"), context_sampling=40
711
    )
712
    # print(repr(test_tasks))
713

714
    X_t = ds_raw
715
    pred = model.predict(test_tasks, X_t=X_t, n_samples=5)
716
    print(pred)
717

718
    X_t = np.zeros((2, 1))
719
    pred = model.predict(test_tasks, X_t=X_t, X_t_is_normalised=True)
720
    print(pred)
721

722
    # DEBUG
723
    # task = task_loader("2014-12-31", context_sampling=40, target_sampling="all")
724
    # samples = model.ar_sample(task, 5, ar_subsample_factor=20)
725

726

727
if __name__ == "__main__":  # pragma: no cover
728
    main()
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc