• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455747995

22 Oct 2024 07:56AM UTC coverage: 81.626% (+0.3%) from 81.333%
11455747995

push

github

davidwilby
incorporate feedback

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.12
/deepsensor/model/model.py
1
from deepsensor.data.loader import TaskLoader
2✔
2
from deepsensor.data.processor import (
2✔
3
    DataProcessor,
4
    process_X_mask_for_X,
5
    xarray_to_coord_array_normalised,
6
    mask_coord_array_normalised,
7
)
8
from deepsensor.model.pred import (
2✔
9
    Prediction,
10
    increase_spatial_resolution,
11
    infer_prediction_modality_from_X_t,
12
)
13
from deepsensor.data.task import Task
2✔
14

15
from typing import List, Union, Optional, Tuple
2✔
16
import copy
2✔
17

18
import time
2✔
19
from tqdm import tqdm
2✔
20

21
import numpy as np
2✔
22
import pandas as pd
2✔
23
import xarray as xr
2✔
24
import lab as B
2✔
25

26
# For dispatching with TF and PyTorch model types when they have not yet been loaded.
27
# See https://beartype.github.io/plum/types.html#moduletype
28

29

30
class ProbabilisticModel:
2✔
31
    """
32
    Base class for probabilistic model used for DeepSensor.
33
    Ensures a set of methods required for DeepSensor
34
    are implemented by specific model classes that inherit from it.
35
    """
36

37
    def mean(self, task: Task, *args, **kwargs):
2✔
38
        """
39
        Computes the model mean prediction over target points based on given context
40
        data.
41

42
        Args:
43
            task (:class:`~.data.task.Task`):
44
                Task containing context data.
45

46
        Returns:
47
            :class:`numpy:numpy.ndarray`: Mean prediction over target points.
48

49
        Raises:
50
            NotImplementedError
51
                If not implemented by child class.
52
        """
53
        raise NotImplementedError()
×
54

55
    def variance(self, task: Task, *args, **kwargs):
2✔
56
        """
57
        Model marginal variance over target points given context points.
58
        Shape (N,).
59

60
        Args:
61
            task (:class:`~.data.task.Task`):
62
                Task containing context data.
63

64
        Returns:
65
            :class:`numpy:numpy.ndarray`: Marginal variance over target points.
66

67
        Raises:
68
            NotImplementedError
69
                If not implemented by child class.
70
        """
71
        raise NotImplementedError()
×
72

73
    def std(self, task: Task):
2✔
74
        """
75
        Model marginal standard deviation over target points given context
76
        points. Shape (N,).
77

78
        Args:
79
            task (:class:`~.data.task.Task`):
80
                Task containing context data.
81

82
        Returns:
83
            :class:`numpy:numpy.ndarray`: Marginal standard deviation over target points.
84
        """
85
        var = self.variance(task)
×
86
        return var**0.5
×
87

88
    def stddev(self, *args, **kwargs):
2✔
89
        return self.std(*args, **kwargs)
2✔
90

91
    def covariance(self, task: Task, *args, **kwargs):
2✔
92
        """
93
        Computes the model covariance matrix over target points based on given
94
        context data. Shape (N, N).
95

96
        Args:
97
            task (:class:`~.data.task.Task`):
98
                Task containing context data.
99

100
        Returns:
101
            :class:`numpy:numpy.ndarray`: Covariance matrix over target points.
102

103
        Raises:
104
            NotImplementedError
105
                If not implemented by child class.
106
        """
107
        raise NotImplementedError()
×
108

109
    def mean_marginal_entropy(self, task: Task, *args, **kwargs):
2✔
110
        """
111
        Computes the mean marginal entropy over target points based on given
112
        context data.
113

114
        .. note::
115
            Note: Getting a vector of marginal entropies would be useful too.
116

117

118
        Args:
119
            task (:class:`~.data.task.Task`):
120
                Task containing context data.
121

122
        Returns:
123
            float: Mean marginal entropy over target points.
124

125
        Raises:
126
            NotImplementedError
127
                If not implemented by child class.
128
        """
129
        raise NotImplementedError()
×
130

131
    def joint_entropy(self, task: Task, *args, **kwargs):
2✔
132
        """
133
        Computes the model joint entropy over target points based on given
134
        context data.
135

136

137
        Args:
138
            task (:class:`~.data.task.Task`):
139
                Task containing context data.
140

141
        Returns:
142
            float: Joint entropy over target points.
143

144
        Raises:
145
            NotImplementedError
146
                If not implemented by child class.
147
        """
148
        raise NotImplementedError()
×
149

150
    def logpdf(self, task: Task, *args, **kwargs):
2✔
151
        """
152
        Computes the joint model logpdf over target points based on given
153
        context data.
154

155
        Args:
156
            task (:class:`~.data.task.Task`):
157
                Task containing context data.
158

159
        Returns:
160
            float: Joint logpdf over target points.
161

162
        Raises:
163
            NotImplementedError
164
                If not implemented by child class.
165
        """
166
        raise NotImplementedError()
×
167

168
    def loss(self, task: Task, *args, **kwargs):
2✔
169
        """
170
        Computes the model loss over target points based on given context data.
171

172
        Args:
173
            task (:class:`~.data.task.Task`):
174
                Task containing context data.
175

176
        Returns:
177
            float: Loss over target points.
178

179
        Raises:
180
            NotImplementedError
181
                If not implemented by child class.
182
        """
183
        raise NotImplementedError()
×
184

185
    def sample(self, task: Task, n_samples=1, *args, **kwargs):
2✔
186
        """
187
        Draws ``n_samples`` joint samples over target points based on given
188
        context data. Returned shape is ``(n_samples, n_target)``.
189

190

191
        Args:
192
            task (:class:`~.data.task.Task`):
193
                Task containing context data.
194
            n_samples (int, optional):
195
                Number of samples to draw. Defaults to 1.
196

197
        Returns:
198
            tuple[:class:`numpy:numpy.ndarray`]: Joint samples over target points.
199

200
        Raises:
201
            NotImplementedError
202
                If not implemented by child class.
203
        """
204
        raise NotImplementedError()
×
205

206

207
class DeepSensorModel(ProbabilisticModel):
2✔
208
    """
209
    Implements DeepSensor prediction functionality of a ProbabilisticModel.
210
    Allows for outputting an xarray object containing on-grid predictions or a
211
    pandas object containing off-grid predictions.
212

213
    Args:
214
        data_processor (:class:`~.data.processor.DataProcessor`):
215
            DataProcessor object, used to unnormalise predictions.
216
        task_loader (:class:`~.data.loader.TaskLoader`):
217
            TaskLoader object, used to determine target variables for unnormalising.
218
    """
219

220
    N_mixture_components = 1  # Number of mixture components for mixture likelihoods
2✔
221

222
    def __init__(
2✔
223
        self,
224
        data_processor: Optional[DataProcessor] = None,
225
        task_loader: Optional[TaskLoader] = None,
226
    ):
227
        self.task_loader = task_loader
2✔
228
        self.data_processor = data_processor
2✔
229

230
    def predict(
2✔
231
        self,
232
        tasks: Union[List[Task], Task],
233
        X_t: Union[
234
            xr.Dataset,
235
            xr.DataArray,
236
            pd.DataFrame,
237
            pd.Series,
238
            pd.Index,
239
            np.ndarray,
240
        ],
241
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
242
        X_t_is_normalised: bool = False,
243
        aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None,
244
        aux_at_targets_override_is_normalised: bool = False,
245
        resolution_factor: int = 1,
246
        pred_params: Tuple[str] = ("mean", "std"),
247
        n_samples: int = 0,
248
        ar_sample: bool = False,
249
        ar_subsample_factor: int = 1,
250
        unnormalise: bool = True,
251
        seed: int = 0,
252
        append_indexes: dict = None,
253
        progress_bar: int = 0,
254
        verbose: bool = False,
255
    ) -> Prediction:
256
        """
257
        Predict on a regular grid or at off-grid locations.
258

259
        Args:
260
            tasks (List[Task] | Task):
261
                List of tasks containing context data.
262
            X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
263
                Target locations to predict at. Can be an xarray object
264
                containingon-grid locations or a pandas object containing off-grid locations.
265
            X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
266
                2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
267
                to the same grid as ``X_t``. Default None (no mask).
268
            X_t_is_normalised (bool):
269
                Whether the ``X_t`` coords are normalised. If False, will normalise
270
                the coords before passing to model. Default ``False``.
271
            aux_at_targets_override (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
272
                Optional auxiliary xarray data to override from the task_loader.
273
            aux_at_targets_override_is_normalised (bool):
274
                Whether the `aux_at_targets_override` coords are normalised.
275
                If False, the DataProcessor will normalise the coords before passing to model.
276
                Default False.
277
            pred_params (Tuple[str]):
278
                Tuple of prediction parameters to return. The strings refer to methods
279
                of the model class which will be called and stored in the Prediction object.
280
                Default ("mean", "std").
281
            resolution_factor (float):
282
                Optional factor to increase the resolution of the target grid
283
                by. E.g. 2 will double the target resolution, 0.5 will halve
284
                it.Applies to on-grid predictions only. Default 1.
285
            n_samples (int):
286
                Number of joint samples to draw from the model. If 0, will not
287
                draw samples. Default 0.
288
            ar_sample (bool):
289
                Whether to use autoregressive sampling. Default ``False``.
290
            unnormalise (bool):
291
                Whether to unnormalise the predictions. Only works if ``self``
292
                hasa ``data_processor`` and ``task_loader`` attribute. Default
293
                ``True``.
294
            seed (int):
295
                Random seed for deterministic sampling. Default 0.
296
            append_indexes (dict):
297
                Dictionary of index metadata to append to pandas indexes in the
298
                off-grid case. Default ``None``.
299
            progress_bar (int):
300
                Whether to display a progress bar over tasks. Default 0.
301
            verbose (bool):
302
                Whether to print time taken for prediction. Default ``False``.
303

304
        Returns:
305
            :class:`~.model.pred.Prediction`):
306
                A `dict`-like object mapping from target variable IDs to xarray or pandas objects
307
                containing model predictions.
308
                - If ``X_t`` is a pandas object, returns pandas objects
309
                containing off-grid predictions.
310
                - If ``X_t`` is an xarray object, returns xarray object
311
                containing on-grid predictions.
312
                - If ``n_samples`` == 0, returns only mean and std predictions.
313
                - If ``n_samples`` > 0, returns mean, std and samples
314
                predictions.
315

316
        Raises:
317
            ValueError
318
                If ``X_t`` is not an xarray object and
319
                ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is
320
                not 1.
321
            ValueError
322
                If ``X_t`` is not a pandas object and ``append_indexes`` is not
323
                ``None``.
324
            ValueError
325
                If ``X_t`` is not an xarray, pandas or numpy object.
326
            ValueError
327
                If ``append_indexes`` are not all the same length as ``X_t``.
328
        """
329
        tic = time.time()
2✔
330
        mode = infer_prediction_modality_from_X_t(X_t)
2✔
331
        if not isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
332
            if resolution_factor != 1:
2✔
333
                raise ValueError(
×
334
                    "resolution_factor can only be used with on-grid predictions."
335
                )
336
            if ar_subsample_factor != 1:
2✔
337
                raise ValueError(
×
338
                    "ar_subsample_factor can only be used with on-grid predictions."
339
                )
340
        if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
341
            if append_indexes is not None:
2✔
342
                raise ValueError(
×
343
                    "append_indexes can only be used with off-grid predictions."
344
                )
345
        if mode == "off-grid" and X_t_mask is not None:
2✔
346
            # TODO: Unit test this
347
            raise ValueError("X_t_mask can only be used with on-grid predictions.")
×
348
        if ar_sample and n_samples < 1:
2✔
349
            raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
2✔
350

351
        target_delta_t = self.task_loader.target_delta_t
2✔
352
        dts = [pd.Timedelta(dt) for dt in target_delta_t]
2✔
353
        dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
2✔
354
        if target_delta_t is not None and dts_all_zero:
2✔
355
            forecasting_mode = False
2✔
356
            lead_times = None
2✔
357
        elif target_delta_t is not None and not dts_all_zero:
2✔
358
            target_var_IDs_set = set(self.task_loader.target_var_IDs)
2✔
359
            msg = f"""
2✔
360
            Got more than one set of target variables in target sets,
361
            but predictions can only be made with one set of target variables
362
            to simplify implementation.
363
            Got {target_var_IDs_set}.
364
            """
365
            assert len(target_var_IDs_set) == 1, msg
2✔
366
            # Repeat lead_tim for each variable in each target set
367
            lead_times = []
2✔
368
            for target_set_idx, dt in enumerate(target_delta_t):
2✔
369
                target_set_dim = self.task_loader.target_dims[target_set_idx]
2✔
370
                lead_times += [
2✔
371
                    pd.Timedelta(dt, unit=self.task_loader.time_freq)
372
                    for _ in range(target_set_dim)
373
                ]
374
            forecasting_mode = True
2✔
375
        else:
376
            forecasting_mode = False
×
377
            lead_times = None
×
378

379
        if type(tasks) is Task:
2✔
380
            tasks = [tasks]
2✔
381

382
        if n_samples >= 1:
2✔
383
            B.set_random_seed(seed)
2✔
384
            np.random.seed(seed)
2✔
385

386
        init_dates = [task["time"] for task in tasks]
2✔
387

388
        # Flatten tuple of tuples to single list
389
        target_var_IDs = [
2✔
390
            var_ID for set in self.task_loader.target_var_IDs for var_ID in set
391
        ]
392
        if lead_times is not None:
2✔
393
            assert len(lead_times) == len(target_var_IDs)
2✔
394

395
        # TODO consider removing this logic, can we just depend on the dim names in X_t?
396
        if not unnormalise:
2✔
397
            coord_names = {"x1": "x1", "x2": "x2"}
2✔
398
        elif unnormalise:
2✔
399
            coord_names = {
2✔
400
                "x1": self.data_processor.raw_spatial_coord_names[0],
401
                "x2": self.data_processor.raw_spatial_coord_names[1],
402
            }
403

404
        ### Pre-process X_t if necessary (TODO consider moving this to Prediction class)
405
        if isinstance(X_t, pd.Index):
2✔
406
            X_t = pd.DataFrame(index=X_t)
×
407
        elif isinstance(X_t, np.ndarray):
2✔
408
            # Convert to empty dataframe with normalised or unnormalised coord names
409
            if X_t_is_normalised:
2✔
410
                index_names = ["x1", "x2"]
×
411
            else:
412
                index_names = self.data_processor.raw_spatial_coord_names
2✔
413
            X_t = pd.DataFrame(X_t.T, columns=index_names)
2✔
414
            X_t = X_t.set_index(index_names)
2✔
415
        elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
416
            # Remove time dimension if present
417
            if "time" in X_t.coords:
2✔
418
                X_t = X_t.isel(time=0).drop_vars("time")
2✔
419

420
        if mode == "off-grid" and append_indexes is not None:
2✔
421
            # Check append_indexes are all same length as X_t
422
            if append_indexes is not None:
×
423
                for idx, vals in append_indexes.items():
×
424
                    if len(vals) != len(X_t):
×
425
                        raise ValueError(
×
426
                            f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
427
                        )
428
            X_t = X_t.reset_index()
×
429
            X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
×
430
            X_t = X_t.set_index(list(X_t.columns))
×
431

432
        if X_t_is_normalised:
2✔
433
            X_t_normalised = X_t
2✔
434
            # Unnormalise coords to use for xarray/pandas objects for storing predictions
435
            X_t = self.data_processor.map_coords(X_t, unnorm=True)
2✔
436
        elif not X_t_is_normalised:
2✔
437
            # Normalise coords to use for model
438
            X_t_normalised = self.data_processor.map_coords(X_t)
2✔
439

440
        if mode == "on-grid":
2✔
441
            if resolution_factor != 1:
2✔
442
                X_t_normalised = increase_spatial_resolution(
×
443
                    X_t_normalised, resolution_factor
444
                )
445
                X_t = increase_spatial_resolution(
×
446
                    X_t, resolution_factor, coord_names=coord_names
447
                )
448
            if X_t_mask is not None:
2✔
449
                X_t_mask = process_X_mask_for_X(X_t_mask, X_t)
×
450
                X_t_mask_normalised = self.data_processor.map_coords(X_t_mask)
×
451
                X_t_arr = xarray_to_coord_array_normalised(X_t_normalised)
×
452
                # Remove points that lie outside the mask
453
                X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
×
454
            else:
455
                X_t_arr = (
2✔
456
                    X_t_normalised["x1"].values,
457
                    X_t_normalised["x2"].values,
458
                )
459
        elif mode == "off-grid":
2✔
460
            X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T
2✔
461

462
        if isinstance(X_t_arr, tuple):
2✔
463
            target_shape = (len(X_t_arr[0]), len(X_t_arr[1]))
2✔
464
        else:
465
            target_shape = (X_t_arr.shape[1],)
2✔
466

467
        if not unnormalise:
2✔
468
            X_t = X_t_normalised
2✔
469

470
        if "mixture_probs" in pred_params:
2✔
471
            # Store each mixture component separately w/o overriding pred_params
472
            pred_params_to_store = copy.deepcopy(pred_params)
2✔
473
            pred_params_to_store.remove("mixture_probs")
2✔
474
            for component_i in range(self.N_mixture_components):
2✔
475
                pred_params_to_store.append(f"mixture_probs_{component_i}")
2✔
476
        else:
477
            pred_params_to_store = pred_params
2✔
478

479
        # Dict to store predictions for each target variable
480
        pred = Prediction(
2✔
481
            target_var_IDs,
482
            pred_params_to_store,
483
            init_dates,
484
            X_t,
485
            X_t_mask,
486
            coord_names,
487
            n_samples=n_samples,
488
            forecasting_mode=forecasting_mode,
489
            lead_times=lead_times,
490
        )
491

492
        def unnormalise_pred_array(arr, **kwargs):
2✔
493
            """Unnormalise an (N_dims, N_targets) array of predictions."""
494
            var_IDs_flattened = [
2✔
495
                var_ID
496
                for var_IDs in self.task_loader.target_var_IDs
497
                for var_ID in var_IDs
498
            ]
499
            assert arr.shape[0] == len(
2✔
500
                var_IDs_flattened
501
            ), f"{arr.shape[0]} != {len(var_IDs_flattened)}"
502
            for i, var_ID in enumerate(var_IDs_flattened):
2✔
503
                arr[i] = self.data_processor.map_array(
2✔
504
                    arr[i],
505
                    var_ID,
506
                    method=self.data_processor.config[var_ID]["method"],
507
                    unnorm=True,
508
                    **kwargs,
509
                )
510
            return arr
2✔
511

512
        # Don't change tasks by reference when overriding target locations
513
        # TODO consider not copying tasks by default for efficiency
514
        tasks = copy.deepcopy(tasks)
2✔
515

516
        if self.task_loader.aux_at_targets is not None:
2✔
517
            if aux_at_targets_override is not None:
2✔
518
                aux_at_targets = aux_at_targets_override
×
519
                if not aux_at_targets_override_is_normalised:
×
520
                    # Assumes using default normalisation method
521
                    aux_at_targets = self.data_processor(
×
522
                        aux_at_targets, assert_computed=True
523
                    )
524
            else:
525
                aux_at_targets = self.task_loader.aux_at_targets
2✔
526

527
        for task in tqdm(tasks, position=0, disable=progress_bar < 1, leave=True):
2✔
528
            task["X_t"] = [X_t_arr for _ in range(len(self.task_loader.target_var_IDs))]
2✔
529

530
            # If passing auxiliary data, need to sample it at target locations
531
            if self.task_loader.aux_at_targets is not None:
2✔
532
                aux_at_targets_sliced = self.task_loader.time_slice_variable(
2✔
533
                    aux_at_targets, task["time"]
534
                )
535
                task["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
2✔
536
                    X_t_arr, aux_at_targets_sliced
537
                )
538

539
            prediction_arrs = {}
2✔
540
            prediction_methods = {}
2✔
541
            for param in pred_params:
2✔
542
                try:
2✔
543
                    method = getattr(self, param)
2✔
544
                    prediction_methods[param] = method
2✔
545
                except AttributeError:
2✔
546
                    raise AttributeError(
2✔
547
                        f"Prediction method {param} not found in model class."
548
                    )
549
            if n_samples >= 1:
2✔
550
                B.set_random_seed(seed)
2✔
551
                np.random.seed(seed)
2✔
552
                if ar_sample:
2✔
553
                    sample_method = getattr(self, "ar_sample")
2✔
554
                    sample_args = {
2✔
555
                        "n_samples": n_samples,
556
                        "ar_subsample_factor": ar_subsample_factor,
557
                    }
558
                else:
559
                    sample_method = getattr(self, "sample")
2✔
560
                    sample_args = {"n_samples": n_samples}
2✔
561

562
            # If `DeepSensor` model child has been sub-classed with a `__call__` method,
563
            # we assume this is a distribution-like object that can be used to compute
564
            # mean, std and samples. Otherwise, run the model with `Task` for each prediction type.
565
            if hasattr(self, "__call__"):
2✔
566
                # Run model forwards once to generate output distribution, which we re-use
567
                dist = self(task, n_samples=n_samples)
2✔
568
                for param, method in prediction_methods.items():
2✔
569
                    prediction_arrs[param] = method(dist)
2✔
570
                if n_samples >= 1 and not ar_sample:
2✔
571
                    samples_arr = sample_method(dist, **sample_args)
2✔
572
                    # samples_arr = samples_arr.reshape((n_samples, len(target_var_IDs), *target_shape))
573
                    prediction_arrs["samples"] = samples_arr
2✔
574
                elif n_samples >= 1 and ar_sample:
2✔
575
                    # Can't draw AR samples from distribution object, need to re-run with task
576
                    samples_arr = sample_method(task, **sample_args)
2✔
577
                    samples_arr = samples_arr.reshape(
2✔
578
                        (n_samples, len(target_var_IDs), *target_shape)
579
                    )
580
                    prediction_arrs["samples"] = samples_arr
2✔
581
            else:
582
                # Re-run model for each prediction type
583
                for param, method in prediction_methods.items():
×
584
                    prediction_arrs[param] = method(task)
×
585
                if n_samples >= 1:
×
586
                    samples_arr = sample_method(task, **sample_args)
×
587
                    if ar_sample:
×
588
                        samples_arr = samples_arr.reshape(
×
589
                            (n_samples, len(target_var_IDs), *target_shape)
590
                        )
591
                    prediction_arrs["samples"] = samples_arr
×
592

593
            # Concatenate multi-target predictions
594
            for param, arr in prediction_arrs.items():
2✔
595
                if isinstance(arr, (list, tuple)):
2✔
596
                    if param != "samples":
2✔
597
                        concat_axis = 0
2✔
598
                    elif param == "samples":
2✔
599
                        # Axis 0 is sample dim, axis 1 is variable dim
600
                        concat_axis = 1
2✔
601
                    prediction_arrs[param] = np.concatenate(arr, axis=concat_axis)
2✔
602

603
            # Unnormalise predictions
604
            for param, arr in prediction_arrs.items():
2✔
605
                # TODO make class attributes?
606
                scale_and_offset_params = ["mean"]
2✔
607
                scale_only_params = ["std"]
2✔
608
                scale_squared_only_params = ["variance"]
2✔
609
                if unnormalise:
2✔
610
                    if param == "samples":
2✔
611
                        for sample_i in range(n_samples):
2✔
612
                            prediction_arrs["samples"][sample_i] = (
2✔
613
                                unnormalise_pred_array(
614
                                    prediction_arrs["samples"][sample_i]
615
                                )
616
                            )
617
                    elif param in scale_and_offset_params:
2✔
618
                        prediction_arrs[param] = unnormalise_pred_array(arr)
2✔
619
                    elif param in scale_only_params:
2✔
620
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
621
                            arr, add_offset=False
622
                        )
623
                    elif param in scale_squared_only_params:
2✔
624
                        # This is a horrible hack to repeat the scaling operation of the linear
625
                        #   transform twice s.t. new_var = scale ^ 2 * var
626
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
627
                            arr, add_offset=False
628
                        )
629
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
630
                            prediction_arrs[param], add_offset=False
631
                        )
632
                    else:
633
                        # Assume prediction parameters not captured above are dimensionless
634
                        #   quantities like probabilities and should not be unnormalised
635
                        pass
2✔
636

637
            # Assign predictions to Prediction object
638
            for param, arr in prediction_arrs.items():
2✔
639
                if param != "mixture_probs":
2✔
640
                    pred.assign(param, task["time"], arr, lead_times=lead_times)
2✔
641
                elif param == "mixture_probs":
2✔
642
                    assert arr.shape[0] == self.N_mixture_components, (
2✔
643
                        f"Number of mixture components ({arr.shape[0]}) does not match "
644
                        f"model attribute N_mixture_components ({self.N_mixture_components})."
645
                    )
646
                    for component_i, probs in enumerate(arr):
2✔
647
                        pred.assign(
2✔
648
                            f"{param}_{component_i}",
649
                            task["time"],
650
                            probs,
651
                            lead_times=lead_times,
652
                        )
653

654
        if forecasting_mode:
2✔
655
            pred = add_valid_time_coord_to_pred_and_move_time_dims(pred)
2✔
656

657
        if verbose:
2✔
658
            dur = time.time() - tic
×
659
            print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")
×
660

661
        return pred
2✔
662

663

664
def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction:
2✔
665
    """
666
    Add a valid time coordinate "time" to a Prediction object based on the
667
    initialisation times "init_time" and lead times "lead_time", and
668
    reorder the time dims from ("lead_time", "init_time") to ("init_time", "lead_time").
669

670
    Args:
671
        pred (:class:`~.model.pred.Prediction`):
672
            Prediction object to add valid time coordinate to.
673

674
    Returns:
675
        :class:`~.model.pred.Prediction`:
676
            Prediction object with valid time coordinate added.
677
    """
678
    for var_ID in pred.keys():
2✔
679
        if isinstance(pred[var_ID], pd.DataFrame):
2✔
680
            x = pred[var_ID].reset_index()
2✔
681
            pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
2✔
682
            pred[var_ID] = pred[var_ID].swaplevel("init_time", "lead_time")
2✔
683
            pred[var_ID] = pred[var_ID].sort_index()
2✔
684
        elif isinstance(pred[var_ID], xr.Dataset):
2✔
685
            x = pred[var_ID]
2✔
686
            pred[var_ID] = pred[var_ID].assign_coords(
2✔
687
                time=x["lead_time"] + x["init_time"]
688
            )
689
            pred[var_ID] = pred[var_ID].transpose("init_time", "lead_time", ...)
2✔
690
        else:
691
            raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
×
692
    return pred
2✔
693

694

695
def main():  # pragma: no cover
696
    import deepsensor.tensorflow
697
    from deepsensor.data.loader import TaskLoader
698
    from deepsensor.data.processor import DataProcessor
699
    from deepsensor.model.convnp import ConvNP
700

701
    import xarray as xr
702
    import pandas as pd
703
    import numpy as np
704

705
    # Load raw data
706
    ds_raw = xr.tutorial.open_dataset("air_temperature")["air"]
707
    ds_raw2 = copy.deepcopy(ds_raw)
708
    ds_raw2.name = "air2"
709

710
    # Normalise data
711
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
712
    ds = data_processor(ds_raw)
713
    ds2 = data_processor(ds_raw2)
714

715
    # Set up task loader
716
    task_loader = TaskLoader(context=ds, target=[ds, ds2])
717

718
    # Set up model
719
    model = ConvNP(data_processor, task_loader)
720

721
    # Predict on new task with 10% of context data and a dense grid of target points
722
    test_tasks = task_loader(
723
        pd.date_range("2014-12-25", "2014-12-31"), context_sampling=40
724
    )
725
    # print(repr(test_tasks))
726

727
    X_t = ds_raw
728
    pred = model.predict(test_tasks, X_t=X_t, n_samples=5)
729
    print(pred)
730

731
    X_t = np.zeros((2, 1))
732
    pred = model.predict(test_tasks, X_t=X_t, X_t_is_normalised=True)
733
    print(pred)
734

735
    # DEBUG
736
    # task = task_loader("2014-12-31", context_sampling=40, target_sampling="all")
737
    # samples = model.ar_sample(task, 5, ar_subsample_factor=20)
738

739

740
if __name__ == "__main__":  # pragma: no cover
741
    main()
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc