• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14313118307

07 Apr 2025 03:23PM UTC coverage: 82.511% (+0.8%) from 81.663%
14313118307

Pull #135

github

web-flow
Merge ea3987b38 into 38ec5ef26
Pull Request #135: Enable patchwise training and prediction

294 of 329 new or added lines in 4 files covered. (89.36%)

1 existing line in 1 file now uncovered.

2340 of 2836 relevant lines covered (82.51%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.97
/deepsensor/model/model.py
1
from deepsensor.data.loader import TaskLoader
2✔
2
from deepsensor.data.processor import (
2✔
3
    DataProcessor,
4
    process_X_mask_for_X,
5
    xarray_to_coord_array_normalised,
6
    mask_coord_array_normalised,
7
)
8
from deepsensor.model.pred import (
2✔
9
    Prediction,
10
    increase_spatial_resolution,
11
    infer_prediction_modality_from_X_t,
12
    stitch_clipped_predictions,
13
)
14
from deepsensor.data.task import Task
2✔
15

16
from typing import List, Union, Optional, Tuple
2✔
17
import copy
2✔
18

19
import time
2✔
20
from tqdm import tqdm
2✔
21

22
import numpy as np
2✔
23
import pandas as pd
2✔
24
import xarray as xr
2✔
25
import lab as B
2✔
26

27
# For dispatching with TF and PyTorch model types when they have not yet been loaded.
28
# See https://beartype.github.io/plum/types.html#moduletype
29

30

31
class ProbabilisticModel:
2✔
32
    """Base class for probabilistic model used for DeepSensor.
33
    Ensures a set of methods required for DeepSensor
34
    are implemented by specific model classes that inherit from it.
35
    """
36

37
    def mean(self, task: Task, *args, **kwargs):
2✔
38
        """Computes the model mean prediction over target points based on given context
39
        data.
40

41
        Args:
42
            task (:class:`~.data.task.Task`):
43
                Task containing context data.
44

45
        Returns:
46
            :class:`numpy:numpy.ndarray`: Mean prediction over target points.
47

48
        Raises:
49
            NotImplementedError
50
                If not implemented by child class.
51
        """
52
        raise NotImplementedError()
×
53

54
    def variance(self, task: Task, *args, **kwargs):
2✔
55
        """Model marginal variance over target points given context points.
56
        Shape (N,).
57

58
        Args:
59
            task (:class:`~.data.task.Task`):
60
                Task containing context data.
61

62
        Returns:
63
            :class:`numpy:numpy.ndarray`: Marginal variance over target points.
64

65
        Raises:
66
            NotImplementedError
67
                If not implemented by child class.
68
        """
69
        raise NotImplementedError()
×
70

71
    def std(self, task: Task):
2✔
72
        """Model marginal standard deviation over target points given context
73
        points. Shape (N,).
74

75
        Args:
76
            task (:class:`~.data.task.Task`):
77
                Task containing context data.
78

79
        Returns:
80
            :class:`numpy:numpy.ndarray`: Marginal standard deviation over target points.
81
        """
82
        var = self.variance(task)
×
83
        return var**0.5
×
84

85
    def stddev(self, *args, **kwargs):  # noqa
2✔
86
        return self.std(*args, **kwargs)
2✔
87

88
    def covariance(self, task: Task, *args, **kwargs):
2✔
89
        """Computes the model covariance matrix over target points based on given
90
        context data. Shape (N, N).
91

92
        Args:
93
            task (:class:`~.data.task.Task`):
94
                Task containing context data.
95

96
        Returns:
97
            :class:`numpy:numpy.ndarray`: Covariance matrix over target points.
98

99
        Raises:
100
            NotImplementedError
101
                If not implemented by child class.
102
        """
103
        raise NotImplementedError()
×
104

105
    def mean_marginal_entropy(self, task: Task, *args, **kwargs):
2✔
106
        """Computes the mean marginal entropy over target points based on given
107
        context data.
108

109
        .. note::
110
            Note: Getting a vector of marginal entropies would be useful too.
111

112

113
        Args:
114
            task (:class:`~.data.task.Task`):
115
                Task containing context data.
116

117
        Returns:
118
            float: Mean marginal entropy over target points.
119

120
        Raises:
121
            NotImplementedError
122
                If not implemented by child class.
123
        """
124
        raise NotImplementedError()
×
125

126
    def joint_entropy(self, task: Task, *args, **kwargs):
2✔
127
        """Computes the model joint entropy over target points based on given
128
        context data.
129

130

131
        Args:
132
            task (:class:`~.data.task.Task`):
133
                Task containing context data.
134

135
        Returns:
136
            float: Joint entropy over target points.
137

138
        Raises:
139
            NotImplementedError
140
                If not implemented by child class.
141
        """
142
        raise NotImplementedError()
×
143

144
    def logpdf(self, task: Task, *args, **kwargs):
2✔
145
        """Computes the joint model logpdf over target points based on given
146
        context data.
147

148
        Args:
149
            task (:class:`~.data.task.Task`):
150
                Task containing context data.
151

152
        Returns:
153
            float: Joint logpdf over target points.
154

155
        Raises:
156
            NotImplementedError
157
                If not implemented by child class.
158
        """
159
        raise NotImplementedError()
×
160

161
    def loss(self, task: Task, *args, **kwargs):
2✔
162
        """Computes the model loss over target points based on given context data.
163

164
        Args:
165
            task (:class:`~.data.task.Task`):
166
                Task containing context data.
167

168
        Returns:
169
            float: Loss over target points.
170

171
        Raises:
172
            NotImplementedError
173
                If not implemented by child class.
174
        """
175
        raise NotImplementedError()
×
176

177
    def sample(self, task: Task, n_samples=1, *args, **kwargs):
2✔
178
        """Draws ``n_samples`` joint samples over target points based on given
179
        context data. Returned shape is ``(n_samples, n_target)``.
180

181

182
        Args:
183
            task (:class:`~.data.task.Task`):
184
                Task containing context data.
185
            n_samples (int, optional):
186
                Number of samples to draw. Defaults to 1.
187

188
        Returns:
189
            tuple[:class:`numpy:numpy.ndarray`]: Joint samples over target points.
190

191
        Raises:
192
            NotImplementedError
193
                If not implemented by child class.
194
        """
195
        raise NotImplementedError()
×
196

197

198
class DeepSensorModel(ProbabilisticModel):
2✔
199
    """Implements DeepSensor prediction functionality of a ProbabilisticModel.
200
    Allows for outputting an xarray object containing on-grid predictions or a
201
    pandas object containing off-grid predictions.
202

203
    Args:
204
        data_processor (:class:`~.data.processor.DataProcessor`):
205
            DataProcessor object, used to unnormalise predictions.
206
        task_loader (:class:`~.data.loader.TaskLoader`):
207
            TaskLoader object, used to determine target variables for unnormalising.
208
    """
209

210
    N_mixture_components = 1  # Number of mixture components for mixture likelihoods
2✔
211

212
    def __init__(
2✔
213
        self,
214
        data_processor: Optional[DataProcessor] = None,
215
        task_loader: Optional[TaskLoader] = None,
216
    ):
217
        self.task_loader = task_loader
2✔
218
        self.data_processor = data_processor
2✔
219

220
    def predict(
2✔
221
        self,
222
        tasks: Union[List[Task], Task],
223
        X_t: Union[
224
            xr.Dataset,
225
            xr.DataArray,
226
            pd.DataFrame,
227
            pd.Series,
228
            pd.Index,
229
            np.ndarray,
230
        ],
231
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
232
        X_t_is_normalised: bool = False,
233
        aux_at_targets_override: Union[xr.Dataset, xr.DataArray] = None,
234
        aux_at_targets_override_is_normalised: bool = False,
235
        resolution_factor: int = 1,
236
        pred_params: Tuple[str] = ("mean", "std"),
237
        n_samples: int = 0,
238
        ar_sample: bool = False,
239
        ar_subsample_factor: int = 1,
240
        unnormalise: bool = True,
241
        seed: int = 0,
242
        append_indexes: dict = None,
243
        progress_bar: int = 0,
244
        verbose: bool = False,
245
    ) -> Prediction:
246
        """Predict on a regular grid or at off-grid locations.
247

248
        Args:
249
            tasks (List[Task] | Task):
250
                List of tasks containing context data.
251
            X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
252
                Target locations to predict at. Can be an xarray object
253
                containingon-grid locations or a pandas object containing off-grid locations.
254
            X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
255
                2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
256
                to the same grid as ``X_t``. Default None (no mask).
257
            X_t_is_normalised (bool):
258
                Whether the ``X_t`` coords are normalised. If False, will normalise
259
                the coords before passing to model. Default ``False``.
260
            aux_at_targets_override (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
261
                Optional auxiliary xarray data to override from the task_loader.
262
            aux_at_targets_override_is_normalised (bool):
263
                Whether the `aux_at_targets_override` coords are normalised.
264
                If False, the DataProcessor will normalise the coords before passing to model.
265
                Default False.
266
            pred_params (Tuple[str]):
267
                Tuple of prediction parameters to return. The strings refer to methods
268
                of the model class which will be called and stored in the Prediction object.
269
                Default ("mean", "std").
270
            resolution_factor (float):
271
                Optional factor to increase the resolution of the target grid
272
                by. E.g. 2 will double the target resolution, 0.5 will halve
273
                it.Applies to on-grid predictions only. Default 1.
274
            n_samples (int):
275
                Number of joint samples to draw from the model. If 0, will not
276
                draw samples. Default 0.
277
            ar_sample (bool):
278
                Whether to use autoregressive sampling. Default ``False``.
279
            unnormalise (bool):
280
                Whether to unnormalise the predictions. Only works if ``self``
281
                hasa ``data_processor`` and ``task_loader`` attribute. Default
282
                ``True``.
283
            seed (int):
284
                Random seed for deterministic sampling. Default 0.
285
            append_indexes (dict):
286
                Dictionary of index metadata to append to pandas indexes in the
287
                off-grid case. Default ``None``.
288
            progress_bar (int):
289
                Whether to display a progress bar over tasks. Default 0.
290
            verbose (bool):
291
                Whether to print time taken for prediction. Default ``False``.
292

293
        Returns:
294
            :class:`~.model.pred.Prediction`):
295
                A `dict`-like object mapping from target variable IDs to xarray or pandas objects
296
                containing model predictions.
297
                - If ``X_t`` is a pandas object, returns pandas objects
298
                containing off-grid predictions.
299
                - If ``X_t`` is an xarray object, returns xarray object
300
                containing on-grid predictions.
301
                - If ``n_samples`` == 0, returns only mean and std predictions.
302
                - If ``n_samples`` > 0, returns mean, std and samples
303
                predictions.
304

305
        Raises:
306
            ValueError
307
                If ``X_t`` is not an xarray object and
308
                ``resolution_factor`` is not 1 or ``ar_subsample_factor`` is
309
                not 1.
310
            ValueError
311
                If ``X_t`` is not a pandas object and ``append_indexes`` is not
312
                ``None``.
313
            ValueError
314
                If ``X_t`` is not an xarray, pandas or numpy object.
315
            ValueError
316
                If ``append_indexes`` are not all the same length as ``X_t``.
317
        """
318
        tic = time.time()
2✔
319
        mode = infer_prediction_modality_from_X_t(X_t)
2✔
320
        if not isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
321
            if resolution_factor != 1:
2✔
322
                raise ValueError(
×
323
                    "resolution_factor can only be used with on-grid predictions."
324
                )
325
            if ar_subsample_factor != 1:
2✔
326
                raise ValueError(
×
327
                    "ar_subsample_factor can only be used with on-grid predictions."
328
                )
329
        if not isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
330
            if append_indexes is not None:
2✔
331
                raise ValueError(
×
332
                    "append_indexes can only be used with off-grid predictions."
333
                )
334
        if mode == "off-grid" and X_t_mask is not None:
2✔
335
            # TODO: Unit test this
336
            raise ValueError("X_t_mask can only be used with on-grid predictions.")
×
337
        if ar_sample and n_samples < 1:
2✔
338
            raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
2✔
339

340
        target_delta_t = self.task_loader.target_delta_t
2✔
341
        dts = [pd.Timedelta(dt) for dt in target_delta_t]
2✔
342
        dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
2✔
343
        if target_delta_t is not None and dts_all_zero:
2✔
344
            forecasting_mode = False
2✔
345
            lead_times = None
2✔
346
        elif target_delta_t is not None and not dts_all_zero:
2✔
347
            target_var_IDs_set = set(self.task_loader.target_var_IDs)
2✔
348
            msg = f"""
2✔
349
            Got more than one set of target variables in target sets,
350
            but predictions can only be made with one set of target variables
351
            to simplify implementation.
352
            Got {target_var_IDs_set}.
353
            """
354
            assert len(target_var_IDs_set) == 1, msg
2✔
355
            # Repeat lead_tim for each variable in each target set
356
            lead_times = []
2✔
357
            for target_set_idx, dt in enumerate(target_delta_t):
2✔
358
                target_set_dim = self.task_loader.target_dims[target_set_idx]
2✔
359
                lead_times += [
2✔
360
                    pd.Timedelta(dt, unit=self.task_loader.time_freq)
361
                    for _ in range(target_set_dim)
362
                ]
363
            forecasting_mode = True
2✔
364
        else:
365
            forecasting_mode = False
×
366
            lead_times = None
×
367

368
        if type(tasks) is Task:
2✔
369
            tasks = [tasks]
2✔
370

371
        if n_samples >= 1:
2✔
372
            B.set_random_seed(seed)
2✔
373
            np.random.seed(seed)
2✔
374

375
        init_dates = [task["time"] for task in tasks]
2✔
376

377
        # Flatten tuple of tuples to single list
378
        target_var_IDs = [
2✔
379
            var_ID for set in self.task_loader.target_var_IDs for var_ID in set
380
        ]
381
        if lead_times is not None:
2✔
382
            assert len(lead_times) == len(target_var_IDs)
2✔
383

384
        # TODO consider removing this logic, can we just depend on the dim names in X_t?
385
        if not unnormalise:
2✔
386
            coord_names = {"x1": "x1", "x2": "x2"}
2✔
387
        elif unnormalise:
2✔
388
            coord_names = {
2✔
389
                "x1": self.data_processor.raw_spatial_coord_names[0],
390
                "x2": self.data_processor.raw_spatial_coord_names[1],
391
            }
392

393
        ### Pre-process X_t if necessary (TODO consider moving this to Prediction class)
394
        if isinstance(X_t, pd.Index):
2✔
395
            X_t = pd.DataFrame(index=X_t)
×
396
        elif isinstance(X_t, np.ndarray):
2✔
397
            # Convert to empty dataframe with normalised or unnormalised coord names
398
            if X_t_is_normalised:
2✔
399
                index_names = ["x1", "x2"]
×
400
            else:
401
                index_names = self.data_processor.raw_spatial_coord_names
2✔
402
            X_t = pd.DataFrame(X_t.T, columns=index_names)
2✔
403
            X_t = X_t.set_index(index_names)
2✔
404
        elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
405
            # Remove time dimension if present
406
            if "time" in X_t.coords:
2✔
407
                X_t = X_t.isel(time=0).drop_vars("time")
2✔
408

409
        if mode == "off-grid" and append_indexes is not None:
2✔
410
            # Check append_indexes are all same length as X_t
411
            if append_indexes is not None:
×
412
                for idx, vals in append_indexes.items():
×
413
                    if len(vals) != len(X_t):
×
414
                        raise ValueError(
×
415
                            f"append_indexes[{idx}] must be same length as X_t, got {len(vals)} and {len(X_t)} respectively."
416
                        )
417
            X_t = X_t.reset_index()
×
418
            X_t = pd.concat([X_t, pd.DataFrame(append_indexes)], axis=1)
×
419
            X_t = X_t.set_index(list(X_t.columns))
×
420

421
        if X_t_is_normalised:
2✔
422
            X_t_normalised = X_t
2✔
423
            # Unnormalise coords to use for xarray/pandas objects for storing predictions
424
            X_t = self.data_processor.map_coords(X_t, unnorm=True)
2✔
425
        elif not X_t_is_normalised:
2✔
426
            # Normalise coords to use for model
427
            X_t_normalised = self.data_processor.map_coords(X_t)
2✔
428

429
        if mode == "on-grid":
2✔
430
            if resolution_factor != 1:
2✔
431
                X_t_normalised = increase_spatial_resolution(
×
432
                    X_t_normalised, resolution_factor
433
                )
434
                X_t = increase_spatial_resolution(
×
435
                    X_t, resolution_factor, coord_names=coord_names
436
                )
437
            if X_t_mask is not None:
2✔
438
                X_t_mask = process_X_mask_for_X(X_t_mask, X_t)
×
439
                X_t_mask_normalised = self.data_processor.map_coords(X_t_mask)
×
440
                X_t_arr = xarray_to_coord_array_normalised(X_t_normalised)
×
441
                # Remove points that lie outside the mask
442
                X_t_arr = mask_coord_array_normalised(X_t_arr, X_t_mask_normalised)
×
443
            else:
444
                X_t_arr = (
2✔
445
                    X_t_normalised["x1"].values,
446
                    X_t_normalised["x2"].values,
447
                )
448
        elif mode == "off-grid":
2✔
449
            X_t_arr = X_t_normalised.reset_index()[["x1", "x2"]].values.T
2✔
450

451
        if isinstance(X_t_arr, tuple):
2✔
452
            target_shape = (len(X_t_arr[0]), len(X_t_arr[1]))
2✔
453
        else:
454
            target_shape = (X_t_arr.shape[1],)
2✔
455

456
        if not unnormalise:
2✔
457
            X_t = X_t_normalised
2✔
458

459
        if "mixture_probs" in pred_params:
2✔
460
            # Store each mixture component separately w/o overriding pred_params
461
            pred_params_to_store = copy.deepcopy(pred_params)
2✔
462
            pred_params_to_store.remove("mixture_probs")
2✔
463
            for component_i in range(self.N_mixture_components):
2✔
464
                pred_params_to_store.append(f"mixture_probs_{component_i}")
2✔
465
        else:
466
            pred_params_to_store = pred_params
2✔
467

468
        # Dict to store predictions for each target variable
469
        pred = Prediction(
2✔
470
            target_var_IDs,
471
            pred_params_to_store,
472
            init_dates,
473
            X_t,
474
            X_t_mask,
475
            coord_names,
476
            n_samples=n_samples,
477
            forecasting_mode=forecasting_mode,
478
            lead_times=lead_times,
479
        )
480

481
        def unnormalise_pred_array(arr, **kwargs):
2✔
482
            """Unnormalise an (N_dims, N_targets) array of predictions."""
483
            var_IDs_flattened = [
2✔
484
                var_ID
485
                for var_IDs in self.task_loader.target_var_IDs
486
                for var_ID in var_IDs
487
            ]
488
            assert arr.shape[0] == len(
2✔
489
                var_IDs_flattened
490
            ), f"{arr.shape[0]} != {len(var_IDs_flattened)}"
491
            for i, var_ID in enumerate(var_IDs_flattened):
2✔
492
                arr[i] = self.data_processor.map_array(
2✔
493
                    arr[i],
494
                    var_ID,
495
                    method=self.data_processor.config[var_ID]["method"],
496
                    unnorm=True,
497
                    **kwargs,
498
                )
499
            return arr
2✔
500

501
        # Don't change tasks by reference when overriding target locations
502
        # TODO consider not copying tasks by default for efficiency
503
        tasks = copy.deepcopy(tasks)
2✔
504

505
        if self.task_loader.aux_at_targets is not None:
2✔
506
            if aux_at_targets_override is not None:
2✔
507
                aux_at_targets = aux_at_targets_override
×
508
                if not aux_at_targets_override_is_normalised:
×
509
                    # Assumes using default normalisation method
510
                    aux_at_targets = self.data_processor(
×
511
                        aux_at_targets, assert_computed=True
512
                    )
513
            else:
514
                aux_at_targets = self.task_loader.aux_at_targets
2✔
515

516
        for task in tqdm(tasks, position=0, disable=progress_bar < 1, leave=True):
2✔
517
            task["X_t"] = [X_t_arr for _ in range(len(self.task_loader.target_var_IDs))]
2✔
518

519
            # If passing auxiliary data, need to sample it at target locations
520
            if self.task_loader.aux_at_targets is not None:
2✔
521
                aux_at_targets_sliced = self.task_loader.time_slice_variable(
2✔
522
                    aux_at_targets, task["time"]
523
                )
524
                task["Y_t_aux"] = self.task_loader.sample_offgrid_aux(
2✔
525
                    X_t_arr, aux_at_targets_sliced
526
                )
527

528
            prediction_arrs = {}
2✔
529
            prediction_methods = {}
2✔
530
            for param in pred_params:
2✔
531
                try:
2✔
532
                    method = getattr(self, param)
2✔
533
                    prediction_methods[param] = method
2✔
534
                except AttributeError:
2✔
535
                    raise AttributeError(
2✔
536
                        f"Prediction method {param} not found in model class."
537
                    )
538
            if n_samples >= 1:
2✔
539
                B.set_random_seed(seed)
2✔
540
                np.random.seed(seed)
2✔
541
                if ar_sample:
2✔
542
                    sample_method = getattr(self, "ar_sample")
2✔
543
                    sample_args = {
2✔
544
                        "n_samples": n_samples,
545
                        "ar_subsample_factor": ar_subsample_factor,
546
                    }
547
                else:
548
                    sample_method = getattr(self, "sample")
2✔
549
                    sample_args = {"n_samples": n_samples}
2✔
550

551
            # If `DeepSensor` model child has been sub-classed with a `__call__` method,
552
            # we assume this is a distribution-like object that can be used to compute
553
            # mean, std and samples. Otherwise, run the model with `Task` for each prediction type.
554
            if hasattr(self, "__call__"):
2✔
555
                # Run model forwards once to generate output distribution, which we re-use
556
                dist = self(task, n_samples=n_samples)
2✔
557
                for param, method in prediction_methods.items():
2✔
558
                    prediction_arrs[param] = method(dist)
2✔
559
                if n_samples >= 1 and not ar_sample:
2✔
560
                    samples_arr = sample_method(dist, **sample_args)
2✔
561
                    # samples_arr = samples_arr.reshape((n_samples, len(target_var_IDs), *target_shape))
562
                    prediction_arrs["samples"] = samples_arr
2✔
563
                elif n_samples >= 1 and ar_sample:
2✔
564
                    # Can't draw AR samples from distribution object, need to re-run with task
565
                    samples_arr = sample_method(task, **sample_args)
2✔
566
                    samples_arr = samples_arr.reshape(
2✔
567
                        (n_samples, len(target_var_IDs), *target_shape)
568
                    )
569
                    prediction_arrs["samples"] = samples_arr
2✔
570
            else:
571
                # Re-run model for each prediction type
572
                for param, method in prediction_methods.items():
×
573
                    prediction_arrs[param] = method(task)
×
574
                if n_samples >= 1:
×
575
                    samples_arr = sample_method(task, **sample_args)
×
576
                    if ar_sample:
×
577
                        samples_arr = samples_arr.reshape(
×
578
                            (n_samples, len(target_var_IDs), *target_shape)
579
                        )
580
                    prediction_arrs["samples"] = samples_arr
×
581

582
            # Concatenate multi-target predictions
583
            for param, arr in prediction_arrs.items():
2✔
584
                if isinstance(arr, (list, tuple)):
2✔
585
                    if param != "samples":
2✔
586
                        concat_axis = 0
2✔
587
                    elif param == "samples":
2✔
588
                        # Axis 0 is sample dim, axis 1 is variable dim
589
                        concat_axis = 1
2✔
590
                    prediction_arrs[param] = np.concatenate(arr, axis=concat_axis)
2✔
591

592
            # Unnormalise predictions
593
            for param, arr in prediction_arrs.items():
2✔
594
                # TODO make class attributes?
595
                scale_and_offset_params = ["mean"]
2✔
596
                scale_only_params = ["std"]
2✔
597
                scale_squared_only_params = ["variance"]
2✔
598
                if unnormalise:
2✔
599
                    if param == "samples":
2✔
600
                        for sample_i in range(n_samples):
2✔
601
                            prediction_arrs["samples"][sample_i] = (
2✔
602
                                unnormalise_pred_array(
603
                                    prediction_arrs["samples"][sample_i]
604
                                )
605
                            )
606
                    elif param in scale_and_offset_params:
2✔
607
                        prediction_arrs[param] = unnormalise_pred_array(arr)
2✔
608
                    elif param in scale_only_params:
2✔
609
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
610
                            arr, add_offset=False
611
                        )
612
                    elif param in scale_squared_only_params:
2✔
613
                        # This is a horrible hack to repeat the scaling operation of the linear
614
                        #   transform twice s.t. new_var = scale ^ 2 * var
615
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
616
                            arr, add_offset=False
617
                        )
618
                        prediction_arrs[param] = unnormalise_pred_array(
2✔
619
                            prediction_arrs[param], add_offset=False
620
                        )
621
                    else:
622
                        # Assume prediction parameters not captured above are dimensionless
623
                        #   quantities like probabilities and should not be unnormalised
624
                        pass
2✔
625

626
            # Assign predictions to Prediction object
627
            for param, arr in prediction_arrs.items():
2✔
628
                if param != "mixture_probs":
2✔
629
                    pred.assign(param, task["time"], arr, lead_times=lead_times)
2✔
630
                elif param == "mixture_probs":
2✔
631
                    assert arr.shape[0] == self.N_mixture_components, (
2✔
632
                        f"Number of mixture components ({arr.shape[0]}) does not match "
633
                        f"model attribute N_mixture_components ({self.N_mixture_components})."
634
                    )
635
                    for component_i, probs in enumerate(arr):
2✔
636
                        pred.assign(
2✔
637
                            f"{param}_{component_i}",
638
                            task["time"],
639
                            probs,
640
                            lead_times=lead_times,
641
                        )
642

643
        if forecasting_mode:
2✔
644
            pred = add_valid_time_coord_to_pred_and_move_time_dims(pred)
2✔
645

646
        if verbose:
2✔
647
            dur = time.time() - tic
×
648
            print(f"Done in {np.floor(dur / 60)}m:{dur % 60:.0f}s.\n")
×
649

650
        return pred
2✔
651

652
    def predict_patchwise(
2✔
653
        self,
654
        tasks: Union[List[Task], Task],
655
        X_t: Union[
656
            xr.Dataset,
657
            xr.DataArray,
658
            pd.DataFrame,
659
            pd.Series,
660
            pd.Index,
661
            np.ndarray,
662
        ],
663
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
664
        **kwargs,
665
    ) -> Prediction:
666
        """Predict using tasks loaded using a sliding window patching strategy. Uses the `predict` method.
667

668
        .. versionadded:: 0.4.3
669
            :py:func:`predict_patchwise()` method.
670

671
        Args:
672
            tasks (List[Task] | Task):
673
                List of tasks containing context data. Tasks for patchwise prediction must be generated by a task loader using the "sliding" patching strategy.
674
            X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
675
                Target locations to predict at. Can be an xarray object
676
                containing on-grid locations or a pandas object containing off-grid locations.
677
            X_t_mask: :class:`xarray.Dataset` | :class:`xarray.DataArray`, optional
678
                2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
679
                to the same grid as ``X_t`` and patched in the same way. Default None (no mask).
680
            **kwargs:
681
                Keyword arguments as per ``predict``.
682

683
        Returns:
684
            :class:`~.model.pred.Prediction`):
685
                A `dict`-like object mapping from target variable IDs to xarray or pandas objects
686
                containing model predictions.
687
                - If ``X_t`` is a pandas object, returns pandas objects
688
                containing off-grid predictions.
689
                - If ``X_t`` is an xarray object, returns xarray object
690
                containing on-grid predictions.
691
                - If ``n_samples`` == 0, returns only mean and std predictions.
692
                - If ``n_samples`` > 0, returns mean, std and samples
693
                predictions.
694

695
        Raises:
696
            AttributeError
697
                If ``tasks`` are not generated using the "sliding" patching strategy of TaskLoader,
698
                i.e. if they do not have a ``bbox`` attribute.
699
            Errors
700
                See `~.model.model.DeepSensorModel.predict`
701
        """
702
        # Get coordinate names of original unnormalised dataset.
703
        orig_x1_name = self.data_processor.x1_name
2✔
704
        orig_x2_name = self.data_processor.x2_name
2✔
705

706
        def get_patches_per_row(preds) -> int:
2✔
707
            """Calculate number of patches per row.
708
            Required to stitch patches back together.
709

710
            Args:
711
                preds (List[class:`~.model.pred.Prediction`]):
712
                        A list of `dict`-like objects containing patchwise predictions.
713

714
            Returns:
715
                patches_per_row: int
716
                    Number of patches per row.
717
            """
718
            patches_per_row = 0
2✔
719
            vars = list(preds[0][0].data_vars)
2✔
720
            var = vars[0]
2✔
721
            x1_val = preds[0][0][var].coords[orig_x1_name].min()
2✔
722

723
            for pred in preds:
2✔
724
                if pred[0][var].coords[orig_x1_name].min() == x1_val:
2✔
725
                    patches_per_row = patches_per_row + 1
2✔
726

727
            return patches_per_row
2✔
728

729
        def get_patch_overlap(
2✔
730
            overlap_norm, data_processor, X_t_ds, x1_ascend, x2_ascend
731
        ) -> int:
732
            """Calculate overlap between adjacent patches in pixels.
733

734
            Parameters
735
            ----------
736
            overlap_norm : tuple[float].
737
                Normalised size of overlap in x1/x2.
738

739
            data_processor (:class:`~.data.processor.DataProcessor`):
740
                Used for unnormalising the coordinates of the bounding boxes of patches.
741

742
            X_t_ds (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`):
743
                Data array containing target locations to predict at.
744

745
            x1_ascend : str:
746
                Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True.
747

748
            x2_ascend : str:
749
                Boolean defining whether the x2 coords ascend (increase) from left to right, default = True.
750

751
            Returns:
752
            -------
753
            patch_overlap : tuple (int)
754
                Unnormalised size of overlap between adjacent patches.
755
            """
756
            # Todo- check if there is simplier and more robust way to convert overlap into pixels.
757
            # Place x1/x2 overlap values in Xarray to pass into unnormalise()
758
            overlap_list = [0, overlap_norm[0], 0, overlap_norm[1]]
2✔
759
            x1 = xr.DataArray([overlap_list[0], overlap_list[1]], dims="x1", name="x1")
2✔
760
            x2 = xr.DataArray([overlap_list[2], overlap_list[3]], dims="x2", name="x2")
2✔
761
            overlap_norm_xr = xr.Dataset(coords={"x1": x1, "x2": x2})
2✔
762

763
            # Unnormalise coordinates of bounding boxes
764
            overlap_unnorm_xr = data_processor.unnormalise(overlap_norm_xr)
2✔
765

766
            unnorm_overlap_x1 = overlap_unnorm_xr.coords[orig_x1_name].values[1]
2✔
767
            unnorm_overlap_x2 = overlap_unnorm_xr.coords[orig_x2_name].values[1]
2✔
768

769
            def overlap_index(
2✔
770
                coords: np.ndarray, ascend: bool, unnorm_overlap: float
771
            ) -> int:
772
                """Find size of overlap in a single coordinate direction, in units of pixels.
773

774
                Parameters
775
                ----------
776
                coords : np.ndarray
777

778
                ascend : bool
779
                    Boolean defining whether coords ascend (increase) from top to bottom or left to right.
780

781
                unnorm_overlap : float
782
                    The patch overlap in unnormalised coordinates.
783

784
                Returns:
785
                -------
786
                int : The number of pixels in the overlap.
787
                """
788
                pixel_coords_overlap_diffs = np.abs(coords - unnorm_overlap)
2✔
789
                if ascend:
2✔
790
                    trim_size = np.argmin(pixel_coords_overlap_diffs) / 2
2✔
791
                    trim_size_rounded = int(
2✔
792
                        np.floor(trim_size)
793
                    )  # Always round down trim slide as stitching method can handle slight overlaps
794
                    return trim_size_rounded
2✔
795

796
                else:
NEW
797
                    overlap_pixel_size = np.argmin(pixel_coords_overlap_diffs)
×
NEW
798
                    overlap_pixel_size_rounded = np.ceil(overlap_pixel_size)
×
NEW
799
                    trim_size = (
×
800
                        (coords.size - int(overlap_pixel_size_rounded)) / 2
801
                    )  # this extra step is so we get the overlap with respect to the largest value (i.e. is the number of pixels = 360, coords.size = 360)
NEW
802
                    trim_size_rounded = int(np.floor(trim_size))
×
NEW
803
                    return trim_size_rounded
×
804

805
            return (
2✔
806
                overlap_index(
807
                    X_t_ds.coords[orig_x1_name].values, x1_ascend, unnorm_overlap_x1
808
                ),
809
                overlap_index(
810
                    X_t_ds.coords[orig_x2_name].values, x2_ascend, unnorm_overlap_x2
811
                ),
812
            )
813

814
        # load patch_size and stride from task
815
        patch_size = tasks[0]["patch_size"]
2✔
816
        stride = tasks[0]["stride"]
2✔
817

818
        # sanitise patch_size and stride arguments
819
        if isinstance(patch_size, float) and patch_size is not None:
2✔
NEW
820
            patch_size = (patch_size, patch_size)
×
821

822
        if isinstance(stride, float) and stride is not None:
2✔
NEW
823
            stride = (stride, stride)
×
824

825
        if stride[0] > patch_size[0] or stride[1] > patch_size[1]:
2✔
NEW
826
            raise ValueError(
×
827
                f"stride must be smaller than patch_size in the corresponding dimensions for patchwise prediction. Got: patch_size: {patch_size}, stride: {stride}"
828
            )
829

830
        # patchwise prediction does not yet support more than a single date
831
        num_task_dates = len(set([t["time"] for t in tasks]))
2✔
832
        if num_task_dates > 1:
2✔
NEW
833
            raise NotImplementedError(
×
834
                f"Patchwise prediction does not yet support more than a single date at a time, got {num_task_dates}."
835
            )
836

837
        # tasks should be iterable, if only one is provided, make it a list
838
        if type(tasks) is Task:
2✔
NEW
839
            tasks = [tasks]
×
840

841
        # Perform patchwise predictions
842
        preds = []
2✔
843
        for task in tasks:
2✔
844
            bbox = task["bbox"]
2✔
845

846
            if bbox is None:
2✔
NEW
847
                raise AttributeError(
×
848
                    "For patchwise prediction, only tasks generated using a patch_strategy of 'sliding' are valid. \
849
                        This task has a bbox value of None, indicating that it was generated with a patch_strategy of \
850
                            'random' or None."
851
                )
852

853
            # Unnormalise coordinates of bounding box of patch
854
            x1 = xr.DataArray([bbox[0], bbox[1]], dims="x1", name="x1")
2✔
855
            x2 = xr.DataArray([bbox[2], bbox[3]], dims="x2", name="x2")
2✔
856
            bbox_norm = xr.Dataset(coords={"x1": x1, "x2": x2})
2✔
857
            bbox_unnorm = self.data_processor.unnormalise(bbox_norm)
2✔
858
            unnorm_bbox_x1 = (
2✔
859
                bbox_unnorm[orig_x1_name].values.min(),
860
                bbox_unnorm[orig_x1_name].values.max(),
861
            )
862
            unnorm_bbox_x2 = (
2✔
863
                bbox_unnorm[orig_x2_name].values.min(),
864
                bbox_unnorm[orig_x2_name].values.max(),
865
            )
866

867
            # Determine X_t for patch, however, cannot assume min/max ordering of slice coordinates
868
            # Check the order of coordinates in X_t, sometimes they are increasing or decreasing in order.
869
            x1_coords = X_t.coords[orig_x1_name].values
2✔
870
            x2_coords = X_t.coords[orig_x2_name].values
2✔
871

872
            if x1_coords[0] < x1_coords[-1]:
2✔
873
                x1_slice = slice(unnorm_bbox_x1[0], unnorm_bbox_x1[1])
2✔
874
                x1_ascending = True
2✔
875
            else:
NEW
876
                x1_slice = slice(unnorm_bbox_x1[1], unnorm_bbox_x1[0])
×
NEW
877
                x1_ascending = False
×
878

879
            if x2_coords[0] < x2_coords[-1]:
2✔
880
                x2_slice = slice(unnorm_bbox_x2[0], unnorm_bbox_x2[1])
2✔
881
                x2_ascending = True
2✔
882
            else:
NEW
883
                x2_slice = slice(unnorm_bbox_x2[1], unnorm_bbox_x2[0])
×
NEW
884
                x2_ascending = False
×
885

886
            # Determine X_t for patch with correct slice direction
887
            task_X_t = X_t.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice})
2✔
888
            task_X_t_mask = (
2✔
889
                X_t_mask.sel(**{orig_x1_name: x1_slice, orig_x2_name: x2_slice})
890
                if X_t_mask
891
                else None
892
            )
893

894
            # Patchwise prediction
895
            pred = self.predict(task, task_X_t, task_X_t_mask, **kwargs)
2✔
896
            # Append patchwise DeepSensor prediction object to list
897
            preds.append(pred)
2✔
898

899
        overlap_norm = tuple(
2✔
900
            patch - stride for patch, stride in zip(patch_size, stride)
901
        )
902
        patch_overlap_unnorm = get_patch_overlap(
2✔
903
            overlap_norm, self.data_processor, X_t, x1_ascending, x2_ascending
904
        )
905

906
        patches_per_row = get_patches_per_row(preds)
2✔
907
        prediction = stitch_clipped_predictions(
2✔
908
            preds,
909
            patch_overlap_unnorm,
910
            patches_per_row,
911
            X_t,
912
            orig_x1_name,
913
            orig_x2_name,
914
            x1_ascending,
915
            x2_ascending,
916
        )
917

918
        return prediction
2✔
919

920

921
def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction:
2✔
922
    """Add a valid time coordinate "time" to a Prediction object based on the
923
    initialisation times "init_time" and lead times "lead_time", and
924
    reorder the time dims from ("lead_time", "init_time") to ("init_time", "lead_time").
925

926
    Args:
927
        pred (:class:`~.model.pred.Prediction`):
928
            Prediction object to add valid time coordinate to.
929

930
    Returns:
931
        :class:`~.model.pred.Prediction`:
932
            Prediction object with valid time coordinate added.
933
    """
934
    for var_ID in pred.keys():
2✔
935
        if isinstance(pred[var_ID], pd.DataFrame):
2✔
936
            x = pred[var_ID].reset_index()
2✔
937
            pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
2✔
938
            pred[var_ID] = pred[var_ID].swaplevel("init_time", "lead_time")
2✔
939
            pred[var_ID] = pred[var_ID].sort_index()
2✔
940
        elif isinstance(pred[var_ID], xr.Dataset):
2✔
941
            x = pred[var_ID]
2✔
942
            pred[var_ID] = pred[var_ID].assign_coords(
2✔
943
                time=x["lead_time"] + x["init_time"]
944
            )
945
            pred[var_ID] = pred[var_ID].transpose("init_time", "lead_time", ...)
2✔
946
        else:
UNCOV
947
            raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
×
948
    return pred
2✔
949

950

951
def main():  # pragma: no cover # noqa: D103
952
    import deepsensor.tensorflow
953
    from deepsensor.data.loader import TaskLoader
954
    from deepsensor.data.processor import DataProcessor
955
    from deepsensor.model.convnp import ConvNP
956

957
    import xarray as xr
958
    import pandas as pd
959
    import numpy as np
960

961
    # Load raw data
962
    ds_raw = xr.tutorial.open_dataset("air_temperature")["air"]
963
    ds_raw2 = copy.deepcopy(ds_raw)
964
    ds_raw2.name = "air2"
965

966
    # Normalise data
967
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
968
    ds = data_processor(ds_raw)
969
    ds2 = data_processor(ds_raw2)
970

971
    # Set up task loader
972
    task_loader = TaskLoader(context=ds, target=[ds, ds2])
973

974
    # Set up model
975
    model = ConvNP(data_processor, task_loader)
976

977
    # Predict on new task with 10% of context data and a dense grid of target points
978
    test_tasks = task_loader(
979
        pd.date_range("2014-12-25", "2014-12-31"), context_sampling=40
980
    )
981
    # print(repr(test_tasks))
982

983
    X_t = ds_raw
984
    pred = model.predict(test_tasks, X_t=X_t, n_samples=5)
985
    print(pred)
986

987
    X_t = np.zeros((2, 1))
988
    pred = model.predict(test_tasks, X_t=X_t, X_t_is_normalised=True)
989
    print(pred)
990

991
    # DEBUG
992
    # task = task_loader("2014-12-31", context_sampling=40, target_sampling="all")
993
    # samples = model.ar_sample(task, 5, ar_subsample_factor=20)
994

995

996
if __name__ == "__main__":  # pragma: no cover
997
    main()
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc