• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14313171846

07 Apr 2025 03:26PM UTC coverage: 82.511% (+0.8%) from 81.663%
14313171846

Pull #135

github

web-flow
Merge fc0266596 into 38ec5ef26
Pull Request #135: Enable patchwise training and prediction

294 of 329 new or added lines in 4 files covered. (89.36%)

1 existing line in 1 file now uncovered.

2340 of 2836 relevant lines covered (82.51%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.43
/deepsensor/data/loader.py
1
import copy
2✔
2
import itertools
2✔
3
import json
2✔
4
import operator
2✔
5
import os
2✔
6
import random
2✔
7
from typing import List, Optional, Sequence, Tuple, Union
2✔
8

9
import numpy as np
2✔
10
import pandas as pd
2✔
11
import xarray as xr
2✔
12

13
from deepsensor.data.task import Task, flatten_X
2✔
14
from deepsensor.errors import InvalidSamplingStrategyError, SamplingTooManyPointsError
2✔
15

16

17
class TaskLoader:
2✔
18
    """Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
19

20
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
21
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
22
    of these.
23

24
    The behaviour is the following:
25
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
26
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
27
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
28
          (using config)
29

30
    Args:
31
        task_loader_ID:
32
            If loading a TaskLoader from a config file, this is the folder the
33
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
34
            arguments are ignored.
35
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
36
            Context data. Can be a single :class:`xarray.DataArray`,
37
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
38
            list/tuple of these.
39
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
40
            Target data. Can be a single :class:`xarray.DataArray`,
41
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
42
            list/tuple of these.
43
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
44
            Auxiliary data at context locations. Tuple of two elements, where
45
            the first element is the index of the context set for which the
46
            auxiliary data will be sampled at, and the second element is the
47
            auxiliary data, which can be a single :class:`xarray.DataArray` or
48
            :class:`xarray.Dataset`. Default: None.
49
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
50
            Auxiliary data at target locations. Can be a single
51
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
52
            None.
53
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
54
            Specifies links between context and target data. Each link is a
55
            tuple of two integers, where the first integer is the index of the
56
            context data and the second integer is the index of the target
57
            data. Can be a single tuple in the case of a single link. If None,
58
            no links are specified. Default: None.
59
        context_delta_t (int | List[int], optional):
60
            Time difference between context data and t=0 (task init time). Can
61
            be a single int (same for all context data) or a list/tuple of
62
            ints. Default is 0.
63
        target_delta_t (int | List[int], optional):
64
            Time difference between target data and t=0 (task init time). Can
65
            be a single int (same for all target data) or a list/tuple of ints.
66
            Default is 0.
67
        time_freq (str, optional):
68
            Time frequency of the data. Default: ``'D'`` (daily).
69
        xarray_interp_method (str, optional):
70
            Interpolation method to use when interpolating
71
            :class:`xarray.DataArray`. Default is ``'linear'``.
72
        discrete_xarray_sampling (bool, optional):
73
            When randomly sampling xarray variables, whether to sample at
74
            discrete points defined at grid cell centres, or at continuous
75
            points within the grid. Default is ``False``.
76
        dtype (object, optional):
77
            Data type of the data. Used to cast the data to the specified
78
            dtype. Default: ``'float32'``.
79
    """
80

81
    config_fname = "task_loader_config.json"
2✔
82

83
    def __init__(
2✔
84
        self,
85
        task_loader_ID: Union[str, None] = None,
86
        context: Union[
87
            xr.DataArray,
88
            xr.Dataset,
89
            pd.DataFrame,
90
            str,
91
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
92
        ] = None,
93
        target: Union[
94
            xr.DataArray,
95
            xr.Dataset,
96
            pd.DataFrame,
97
            str,
98
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
99
        ] = None,
100
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
101
        aux_at_targets: Optional[
102
            Union[
103
                xr.DataArray,
104
                xr.Dataset,
105
            ]
106
        ] = None,
107
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
108
        context_delta_t: Union[int, List[int]] = 0,
109
        target_delta_t: Union[int, List[int]] = 0,
110
        time_freq: str = "D",
111
        xarray_interp_method: str = "linear",
112
        discrete_xarray_sampling: bool = False,
113
        dtype: object = "float32",
114
    ) -> None:
115
        if task_loader_ID is not None:
2✔
116
            self.task_loader_ID = task_loader_ID
2✔
117
            # Load TaskLoader from config file
118
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
119
            with open(fpath, "r") as f:
2✔
120
                self.config = json.load(f)
2✔
121

122
            self.context = self.config["context"]
2✔
123
            self.target = self.config["target"]
2✔
124
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
125
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
126
            self.links = self.config["links"]
2✔
127
            if self.links is not None:
2✔
128
                self.links = [tuple(link) for link in self.links]
2✔
129
            self.context_delta_t = self.config["context_delta_t"]
2✔
130
            self.target_delta_t = self.config["target_delta_t"]
2✔
131
            self.time_freq = self.config["time_freq"]
2✔
132
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
133
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
134
            self.dtype = self.config["dtype"]
2✔
135
        else:
136
            self.context = context
2✔
137
            self.target = target
2✔
138
            self.aux_at_contexts = aux_at_contexts
2✔
139
            self.aux_at_targets = aux_at_targets
2✔
140
            self.links = links
2✔
141
            self.context_delta_t = context_delta_t
2✔
142
            self.target_delta_t = target_delta_t
2✔
143
            self.time_freq = time_freq
2✔
144
            self.xarray_interp_method = xarray_interp_method
2✔
145
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
146
            self.dtype = dtype
2✔
147

148
        if not isinstance(self.context, (tuple, list)):
2✔
149
            self.context = (self.context,)
2✔
150
        if not isinstance(self.target, (tuple, list)):
2✔
151
            self.target = (self.target,)
2✔
152

153
        if isinstance(self.context_delta_t, int):
2✔
154
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
155
        else:
156
            assert len(self.context_delta_t) == len(self.context), (
2✔
157
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
158
                f"the number of context sets ({len(self.context)})"
159
            )
160
        if isinstance(self.target_delta_t, int):
2✔
161
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
162
        else:
163
            assert len(self.target_delta_t) == len(self.target), (
2✔
164
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
165
                f"the number of target sets ({len(self.target)})"
166
            )
167

168
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
169
        if all_paths:
2✔
170
            self._set_config()
2✔
171
            self._load_data_from_paths()
2✔
172

173
        self.context = self._cast_to_dtype(self.context)
2✔
174
        self.target = self._cast_to_dtype(self.target)
2✔
175
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
176
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
177

178
        self.links = self._check_links(self.links)
2✔
179

180
        (
2✔
181
            self.context_dims,
182
            self.target_dims,
183
            self.aux_at_target_dims,
184
        ) = self.count_context_and_target_data_dims()
185
        (
2✔
186
            self.context_var_IDs,
187
            self.target_var_IDs,
188
            self.context_var_IDs_and_delta_t,
189
            self.target_var_IDs_and_delta_t,
190
            self.aux_at_target_var_IDs,
191
        ) = self.infer_context_and_target_var_IDs()
192

193
    def _set_config(self):
2✔
194
        """Instantiate a config dictionary for the TaskLoader object."""
195
        # Take deepcopy to avoid modifying the original config
196
        self.config = copy.deepcopy(
2✔
197
            dict(
198
                context=self.context,
199
                target=self.target,
200
                aux_at_contexts=self.aux_at_contexts,
201
                aux_at_targets=self.aux_at_targets,
202
                links=self.links,
203
                context_delta_t=self.context_delta_t,
204
                target_delta_t=self.target_delta_t,
205
                time_freq=self.time_freq,
206
                xarray_interp_method=self.xarray_interp_method,
207
                discrete_xarray_sampling=self.discrete_xarray_sampling,
208
                dtype=self.dtype,
209
            )
210
        )
211

212
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
213
        """If all data passed as paths, save paths to config and return True."""
214

215
        def _check_if_strings(x, mode="all"):
2✔
216
            if x is None:
2✔
217
                return None
2✔
218
            elif isinstance(x, (tuple, list)):
2✔
219
                if mode == "all":
2✔
220
                    return all([isinstance(x_i, str) for x_i in x])
2✔
221
                elif mode == "any":
2✔
222
                    return any([isinstance(x_i, str) for x_i in x])
2✔
223
            else:
224
                return isinstance(x, str)
2✔
225

226
        all_paths = all(
2✔
227
            filter(
228
                lambda x: x is not None,
229
                [
230
                    _check_if_strings(self.context),
231
                    _check_if_strings(self.target),
232
                    _check_if_strings(self.aux_at_contexts),
233
                    _check_if_strings(self.aux_at_targets),
234
                ],
235
            )
236
        )
237
        self._is_saveable = all_paths
2✔
238

239
        any_paths = any(
2✔
240
            filter(
241
                lambda x: x is not None,
242
                [
243
                    _check_if_strings(self.context, mode="any"),
244
                    _check_if_strings(self.target, mode="any"),
245
                    _check_if_strings(self.aux_at_contexts, mode="any"),
246
                    _check_if_strings(self.aux_at_targets, mode="any"),
247
                ],
248
            )
249
        )
250
        if any_paths and not all_paths:
2✔
251
            raise ValueError(
×
252
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
253
            )
254

255
        return all_paths
2✔
256

257
    def _load_data_from_paths(self):
2✔
258
        """Load data from paths and overwrite paths with loaded data."""
259
        loaded_data = {}
2✔
260

261
        def _load_pandas_or_xarray(path):
2✔
262
            # Need to be careful about this. We need to ensure data gets into the right form
263
            #  for TaskLoader.
264
            if path is None:
2✔
265
                return None
2✔
266
            elif path in loaded_data:
2✔
267
                return loaded_data[path]
2✔
268
            elif path.endswith(".nc"):
2✔
269
                data = xr.open_dataset(path)
2✔
270
            elif path.endswith(".csv"):
2✔
271
                df = pd.read_csv(path)
2✔
272
                if "time" in df.columns:
2✔
273
                    df["time"] = pd.to_datetime(df["time"])
2✔
274
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
275
                else:
276
                    df = df.set_index(["x1", "x2"]).sort_index()
×
277
                data = df
2✔
278
            else:
279
                raise ValueError(f"Unknown file extension for {path}")
×
280
            loaded_data[path] = data
2✔
281
            return data
2✔
282

283
        def _load_data(data):
2✔
284
            if isinstance(data, (tuple, list)):
2✔
285
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
286
            else:
287
                data = _load_pandas_or_xarray(data)
2✔
288
            return data
2✔
289

290
        self.context = _load_data(self.context)
2✔
291
        self.target = _load_data(self.target)
2✔
292
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
293
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
294

295
    def save(self, folder: str):
2✔
296
        """Save TaskLoader config to JSON in `folder`."""
297
        if not self._is_saveable:
2✔
298
            raise ValueError(
2✔
299
                "TaskLoader cannot be saved because not all data was passed as paths."
300
            )
301

302
        os.makedirs(folder, exist_ok=True)
2✔
303
        fpath = os.path.join(folder, self.config_fname)
2✔
304
        with open(fpath, "w") as f:
2✔
305
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
306

307
    def _cast_to_dtype(
2✔
308
        self,
309
        var: Union[
310
            xr.DataArray,
311
            xr.Dataset,
312
            pd.DataFrame,
313
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
314
        ],
315
    ) -> (List, List):
316
        """Cast context and target data to the default dtype.
317

318
        ..
319
            TODO unit test this by passing in a variety of data types and
320
            checking that they are cast correctly.
321

322
        Args:
323
            var : ...
324
                ...
325

326
        Returns:
327
            tuple: Tuple of context data with specified dtype.
328
            tuple: Tuple of target data with specified dtype.
329
        """
330

331
        def cast_to_dtype(var):
2✔
332
            if isinstance(var, xr.DataArray):
2✔
333
                var = var.astype(self.dtype)
2✔
334
                var["x1"] = var["x1"].astype(self.dtype)
2✔
335
                var["x2"] = var["x2"].astype(self.dtype)
2✔
336
            elif isinstance(var, xr.Dataset):
2✔
337
                var = var.astype(self.dtype)
2✔
338
                var["x1"] = var["x1"].astype(self.dtype)
2✔
339
                var["x2"] = var["x2"].astype(self.dtype)
2✔
340
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
341
                var = var.astype(self.dtype)
2✔
342
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
343
                #   x1/x2 coord dtypes during task sampling
344
            else:
345
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
346
            return var
2✔
347

348
        if var is None:
2✔
349
            return var
2✔
350
        elif isinstance(var, (tuple, list)):
2✔
351
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
352
        else:
353
            var = cast_to_dtype(var)
2✔
354

355
        return var
2✔
356

357
    def load_dask(self) -> None:
2✔
358
        """Load any `dask` data into memory.
359

360
        This function triggers the computation and loading of any data that
361
        is represented as dask arrays or datasets into memory.
362

363
        Returns:
364
            None
365
        """
366

367
        def load(datasets):
2✔
368
            if datasets is None:
2✔
369
                return
×
370
            if not isinstance(datasets, (tuple, list)):
2✔
371
                datasets = [datasets]
2✔
372
            for i, var in enumerate(datasets):
2✔
373
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
374
                    var = var.load()
2✔
375

376
        load(self.context)
2✔
377
        load(self.target)
2✔
378
        load(self.aux_at_contexts)
2✔
379
        load(self.aux_at_targets)
2✔
380

381
        return None
2✔
382

383
    def count_context_and_target_data_dims(self):
2✔
384
        """Count the number of data dimensions in the context and target data.
385

386
        Returns:
387
            tuple: context_dims, Tuple of data dimensions in the context data.
388
            tuple: target_dims, Tuple of data dimensions in the target data.
389

390
        Raises:
391
            ValueError: If the context/target data is not a tuple/list of
392
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
393
                        :class:`pandas.DataFrame`.
394
        """
395

396
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
397
            if not isinstance(datasets, (tuple, list)):
2✔
398
                datasets = [datasets]
2✔
399

400
            dims = []
2✔
401
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
402
            for i, var in enumerate(datasets):
2✔
403
                if isinstance(var, xr.Dataset):
2✔
404
                    dim = len(var.data_vars)  # Multiple data variables
2✔
405
                elif isinstance(var, xr.DataArray):
2✔
406
                    dim = 1  # Single data variable
2✔
407
                elif isinstance(var, pd.DataFrame):
2✔
408
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
409
                elif isinstance(var, pd.Series):
2✔
410
                    dim = 1  # Single data variable
2✔
411
                else:
412
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
413
                dims.append(dim)
2✔
414
            return dims
2✔
415

416
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
417
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
418
        if self.aux_at_contexts is not None:
2✔
419
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
420
        if self.aux_at_targets is not None:
2✔
421
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
422
                0
423
            ]
424
        else:
425
            aux_at_target_dims = 0
2✔
426

427
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
428

429
    def infer_context_and_target_var_IDs(self):
2✔
430
        """Infer the variable IDs of the context and target data.
431

432
        Returns:
433
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
434
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
435

436
        Raises:
437
            ValueError: If the context/target data is not a tuple/list of
438
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
439
                        :class:`pandas.DataFrame`.
440
        """
441

442
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
443
            """If delta_ts is not None, then add the delta_t to the variable ID."""
444
            if not isinstance(datasets, (tuple, list)):
2✔
445
                datasets = [datasets]
2✔
446

447
            var_IDs = []
2✔
448
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
449
            for i, var in enumerate(datasets):
2✔
450
                if isinstance(var, xr.DataArray):
2✔
451
                    var_ID = (var.name,)  # Single data variable
2✔
452
                elif isinstance(var, xr.Dataset):
2✔
453
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
454
                elif isinstance(var, pd.DataFrame):
2✔
455
                    var_ID = tuple(var.columns)
2✔
456
                elif isinstance(var, pd.Series):
2✔
457
                    var_ID = (var.name,)
2✔
458
                else:
459
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
460

461
                if delta_ts is not None:
2✔
462
                    # Add delta_t to the variable ID
463
                    var_ID = tuple(
2✔
464
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
465
                    )
466
                else:
467
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
468

469
                var_IDs.append(var_ID)
2✔
470

471
            return var_IDs
2✔
472

473
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
474
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
475
            self.context, self.context_delta_t
476
        )
477
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
478
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
479
            self.target, self.target_delta_t
480
        )
481

482
        if self.aux_at_contexts is not None:
2✔
483
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
484
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
485
                self.aux_at_contexts, [0]
486
            )
487

488
        if self.aux_at_targets is not None:
2✔
489
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
490
                0
491
            ]
492
        else:
493
            aux_at_target_var_IDs = None
2✔
494

495
        return (
2✔
496
            tuple(context_var_IDs),
497
            tuple(target_var_IDs),
498
            tuple(context_var_IDs_and_delta_t),
499
            tuple(target_var_IDs_and_delta_t),
500
            aux_at_target_var_IDs,
501
        )
502

503
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
504
        """Check that the context-target links are valid.
505

506
        Args:
507
            links (Tuple[int, int] | List[Tuple[int, int]]):
508
                Specifies links between context and target data. Each link is a
509
                tuple of two integers, where the first integer is the index of
510
                the context data and the second integer is the index of the
511
                target data. Can be a single tuple in the case of a single
512
                link. If None, no links are specified. Default: None.
513

514
        Returns:
515
            Tuple[int, int] | List[Tuple[int, int]]
516
                The input links, if valid.
517

518
        Raises:
519
            ValueError
520
                If the links are not valid.
521
        """
522
        if links is None:
2✔
523
            return None
2✔
524

525
        assert isinstance(
2✔
526
            links, list
527
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
528
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
529
        assert all(
2✔
530
            isinstance(link, tuple) for link in links
531
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
532
        assert all(
2✔
533
            len(link) == 2 for link in links
534
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
535

536
        # Check that the links are valid
537
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
538
            if context_idx >= len(self.context):
2✔
539
                raise ValueError(
×
540
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
541
                    f"there are only {len(self.context)} context sets"
542
                )
543
            if target_idx >= len(self.target):
2✔
544
                raise ValueError(
2✔
545
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
546
                    f"there are only {len(self.target)} target sets"
547
                )
548

549
        return links
2✔
550

551
    def __str__(self):
2✔
552
        """String representation of the TaskLoader object (user-friendly)."""
553
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
554
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
555
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
556
        if self.aux_at_targets is not None:
×
557
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
558
        return s
×
559

560
    def __repr__(self):
2✔
561
        """Representation of the TaskLoader object (for developers).
562

563
        ..
564
            TODO make this a more verbose version of __str__
565
        """
566
        s = str(self)
×
567
        s += "\n"
×
568
        s += f"\nContext data dimensions: {self.context_dims}"
×
569
        s += f"\nTarget data dimensions: {self.target_dims}"
×
570
        if self.aux_at_targets is not None:
×
571
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
572
        return s
×
573

574
    def sample_da(
2✔
575
        self,
576
        da: Union[xr.DataArray, xr.Dataset],
577
        sampling_strat: Union[str, int, float, np.ndarray],
578
        seed: Optional[int] = None,
579
    ) -> (np.ndarray, np.ndarray):
580
        """Sample a DataArray according to a given strategy.
581

582
        Args:
583
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
584
                DataArray to sample, assumed to be sliced for the task already.
585
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
586
                Sampling strategy, either "all" or an integer for random grid
587
                cell sampling.
588
            seed (int, optional):
589
                Seed for random sampling. Default is None.
590

591
        Returns:
592
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
593
                Tuple of sampled target data and sampled context data.
594

595
        Raises:
596
            InvalidSamplingStrategyError:
597
                If the sampling strategy is not valid or if a numpy coordinate
598
                array is passed to sample an xarray object, but the coordinates
599
                are out of bounds.
600
        """
601
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
602
        if isinstance(da, xr.Dataset):
2✔
603
            da = da.to_array()
2✔
604

605
        if isinstance(sampling_strat, float):
2✔
606
            sampling_strat = int(sampling_strat * da.size)
2✔
607

608
        if isinstance(sampling_strat, (int, np.integer)):
2✔
609
            N = sampling_strat
2✔
610
            rng = np.random.default_rng(seed)
2✔
611
            if self.discrete_xarray_sampling:
2✔
612
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
613
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
614
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
615
            elif not self.discrete_xarray_sampling:
2✔
616
                if N == 0:
2✔
617
                    # Catch zero-context edge case before interp fails
618
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
619
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
620
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
621
                    return X_c, Y_c
2✔
622
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
623
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
624
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
625
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
626
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
627

628
        elif isinstance(sampling_strat, np.ndarray):
2✔
629
            X_c = sampling_strat.astype(self.dtype)
2✔
630
            try:
2✔
631
                Y_c = da.sel(
2✔
632
                    x1=xr.DataArray(X_c[0]),
633
                    x2=xr.DataArray(X_c[1]),
634
                    method="nearest",
635
                    tolerance=0.1,  # Maximum distance from observed point to sample
636
                )
637
            except KeyError:
2✔
638
                raise InvalidSamplingStrategyError(
2✔
639
                    f"Passed a numpy coordinate array to sample xarray object, "
640
                    f"but the coordinates are out of bounds."
641
                )
642
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
643

644
        elif sampling_strat in ["all", "gapfill"]:
2✔
645
            X_c = (
2✔
646
                da.coords["x1"].values[np.newaxis],
647
                da.coords["x2"].values[np.newaxis],
648
            )
649
            Y_c = da.data
2✔
650
            if Y_c.ndim == 2:
2✔
651
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
652
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
653
        else:
654
            raise InvalidSamplingStrategyError(
×
655
                f"Unknown sampling strategy {sampling_strat}"
656
            )
657

658
        if Y_c.ndim == 1:
2✔
659
            # returned a 1D array, but we need a 2D array of shape (variable, N)
660
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
661

662
        return X_c, Y_c
2✔
663

664
    def sample_df(
2✔
665
        self,
666
        df: Union[pd.DataFrame, pd.Series],
667
        sampling_strat: Union[str, int, float, np.ndarray],
668
        seed: Optional[int] = None,
669
    ) -> (np.ndarray, np.ndarray):
670
        """Sample a DataFrame according to a given strategy.
671

672
        Args:
673
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
674
                Dataframe to sample, assumed to be time-sliced for the task
675
                already.
676
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
677
                Sampling strategy, either "all" or an integer for random grid
678
                cell sampling.
679
            seed (int, optional):
680
                Seed for random sampling. Default is None.
681

682
        Returns:
683
            Tuple[X_c, Y_c]:
684
                Tuple of sampled target data and sampled context data.
685

686
        Raises:
687
            InvalidSamplingStrategyError:
688
                If the sampling strategy is not valid or if a numpy coordinate
689
                array is passed to sample a pandas object, but the DataFrame
690
                does not contain all the requested samples.
691
        """
692
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
693

694
        if isinstance(sampling_strat, float):
2✔
695
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
696

697
        if isinstance(sampling_strat, (int, np.integer)):
2✔
698
            N = sampling_strat
2✔
699
            rng = np.random.default_rng(seed)
2✔
700
            if N > df.index.size:
2✔
701
                raise SamplingTooManyPointsError(requested=N, available=df.index.size)
2✔
702
            idx = rng.choice(df.index, N, replace=False)
2✔
703
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
704
            Y_c = df.loc[idx].values.T
2✔
705
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
706
            "all",
707
            "split",
708
        ]:
709
            # NOTE if "split", we assume that the context-target split has already been applied to the df
710
            # in an earlier scope with access to both the context and target data. This is maybe risky!
711
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
712
            Y_c = df.values.T
2✔
713
        elif isinstance(sampling_strat, np.ndarray):
2✔
714
            if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
2✔
715
                raise InvalidSamplingStrategyError(
2✔
716
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
717
                    "but the coordinate array has a different dtype than the DataFrame. "
718
                    f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
719
                )
720
            X_c = sampling_strat.astype(self.dtype)
2✔
721
            try:
2✔
722
                Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
2✔
723
            except KeyError:
2✔
724
                raise InvalidSamplingStrategyError(
2✔
725
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
726
                    "but the DataFrame did not contain all the requested samples.\n"
727
                    f"Indexes: {df.index}\n"
728
                    f"Sampling coords: {X_c}\n"
729
                    "If this is unexpected, check that your numpy sampling array matches "
730
                    "the DataFrame index values *exactly*."
731
                )
732
        else:
733
            raise InvalidSamplingStrategyError(
×
734
                f"Unknown sampling strategy {sampling_strat}"
735
            )
736

737
        if Y_c.ndim == 1:
2✔
738
            # returned a 1D array, but we need a 2D array of shape (variable, N)
739
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
740

741
        return X_c, Y_c
2✔
742

743
    def sample_offgrid_aux(
2✔
744
        self,
745
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
746
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
747
    ) -> np.ndarray:
748
        """Sample auxiliary data at off-grid locations.
749

750
        Args:
751
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
752
                Off-grid locations at which to sample the auxiliary data. Can
753
                be a tuple of two numpy arrays, or a single numpy array.
754
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
755
                Auxiliary data at off-grid locations.
756

757
        Returns:
758
            :class:`numpy:numpy.ndarray`:
759
                [Description of the returned numpy ndarray]
760

761
        Raises:
762
            [ExceptionType]:
763
                [Description of under what conditions this function raises an exception]
764
        """
765
        if "time" in offgrid_aux.dims:
2✔
766
            raise ValueError(
×
767
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
768
                "passing it to `sample_offgrid_aux`."
769
            )
770
        if isinstance(X_t, tuple):
2✔
771
            xt1, xt2 = X_t
2✔
772
            xt1 = xt1.ravel()
2✔
773
            xt2 = xt2.ravel()
2✔
774
        else:
775
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
776
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
777
        if isinstance(Y_t_aux, xr.Dataset):
2✔
778
            Y_t_aux = Y_t_aux.to_array()
×
779
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
780
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
781
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
782
        ):
783
            # Reshape to (variable, *spatial_dims)
784
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
785
        return Y_t_aux
2✔
786

787
    def time_slice_variable(self, var, date, delta_t=0):
2✔
788
        """Slice a variable by a given time delta.
789

790
        Args:
791
            var (...):
792
                Variable to slice.
793
            delta_t (...):
794
                Time delta to slice by.
795

796
        Returns:
797
            var (...)
798
                Sliced variable.
799

800
        Raises:
801
            ValueError
802
                If the variable is of an unknown type.
803
        """
804
        # TODO: Does this work with instantaneous time?
805
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
806
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
807
            if "time" in var.dims:
2✔
808
                var = var.sel(time=date + delta_t)
2✔
809
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
810
            if "time" in var.index.names:
2✔
811
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
812
        else:
813
            raise ValueError(f"Unknown variable type {type(var)}")
×
814
        return var
2✔
815

816
    def task_generation(  # noqa: D102
2✔
817
        self,
818
        date: pd.Timestamp,
819
        context_sampling: Union[
820
            str,
821
            int,
822
            float,
823
            np.ndarray,
824
            List[Union[str, int, float, np.ndarray]],
825
        ] = "all",
826
        target_sampling: Optional[
827
            Union[
828
                str,
829
                int,
830
                float,
831
                np.ndarray,
832
                List[Union[str, int, float, np.ndarray]],
833
            ]
834
        ] = None,
835
        split_frac: float = 0.5,
836
        bbox: Sequence[float] = None,
837
        patch_size: Union[float, Tuple[float]] = None,
838
        stride: Union[float, Tuple[float]] = None,
839
        datewise_deterministic: bool = False,
840
        seed_override: Optional[int] = None,
841
    ) -> Task:
842
        """Generate a task for a given date.
843

844
        There are several sampling strategies available for the context and
845
        target data:
846

847
            - "all": Sample all observations.
848
            - int: Sample N observations uniformly at random.
849
            - float: Sample a fraction of observations uniformly at random.
850
            - :class:`numpy:numpy.ndarray`, shape (2, N): Sample N observations
851
              at the given x1, x2 coordinates. Coords are assumed to be
852
              unnormalised.
853

854
        Parameters
855
        ----------
856
        date : :class:`pandas.Timestamp`
857
            Date for which to generate the task.
858
        context_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`]
859
            Sampling strategy for the context data, either a list of sampling
860
            strategies for each context set, or a single strategy applied to
861
            all context sets. Default is ``"all"``.
862
        target_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`]
863
            Sampling strategy for the target data, either a list of sampling
864
            strategies for each target set, or a single strategy applied to all
865
            target sets. Default is ``"all"``.
866
        split_frac : float
867
            The fraction of observations to use for the context set with the
868
            "split" sampling strategy for linked context and target set pairs.
869
            The remaining observations are used for the target set. Default is
870
            0.5.
871
        bbox : Sequence[float], optional
872
            Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max].
873
            Useful when considering the entire available region is computationally prohibitive for model forward pass.
874
        patch_size : Union(Tuple|float), optional
875
            Only used by patchwise inference. Height and width of patch in x1/x2 normalised coordinates.
876
        stride: Union(Tuple|float), optional
877
            Only used by patchwise inference. Length of stride between adjacent patches in x1/x2 normalised coordinates.
878
        datewise_deterministic : bool
879
            Whether random sampling is datewise_deterministic based on the
880
            date. Default is ``False``.
881
        seed_override : Optional[int]
882
            Override the seed for random sampling. This can be used to use the
883
            same random sampling at different ``date``. Default is None.
884

885
        Returns:
886
        -------
887
        task : :class:`~.data.task.Task`
888
            Task object containing the context and target data.
889
        """
890

891
        def check_sampling_strat(sampling_strat, set):
2✔
892
            """Check the sampling strategy.
893

894
            Ensure ``sampling_strat`` is either a single strategy (broadcast
895
            to all sets) or a list of length equal to the number of sets.
896
            Convert to a tuple of length equal to the number of sets and
897
            return.
898

899
            Args:
900
                sampling_strat:
901
                    Sampling strategy to check.
902
                set:
903
                    Context or target set to check.
904

905
            Returns:
906
                tuple:
907
                    Tuple of sampling strategies, one for each set.
908

909
            Raises:
910
                InvalidSamplingStrategyError:
911
                    - If the sampling strategy is invalid.
912
                    - If the length of the sampling strategy does not match the number of sets.
913
                    - If the sampling strategy is not a valid type.
914
                    - If the sampling strategy is a float but not in [0, 1].
915
                    - If the sampling strategy is an int but not positive.
916
                    - If the sampling strategy is a numpy array but not of shape (2, N).
917
            """
918
            if sampling_strat is None:
2✔
919
                return None
2✔
920
            if not isinstance(sampling_strat, (list, tuple)):
2✔
921
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
922
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
923
                sampling_strat
924
            ) != len(set):
925
                raise InvalidSamplingStrategyError(
2✔
926
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
927
                    f"match number of context sets ({len(set)})"
928
                )
929

930
            for strat in sampling_strat:
2✔
931
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
932
                    raise InvalidSamplingStrategyError(
2✔
933
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
934
                    )
935
                elif isinstance(strat, str) and strat == "gapfill":
2✔
936
                    assert all(
2✔
937
                        isinstance(item, (xr.Dataset, xr.DataArray)) for item in set
938
                    ), (
939
                        "Gapfill sampling strategy can only be used with xarray "
940
                        "datasets or data arrays"
941
                    )
942
                elif isinstance(strat, str) and strat not in [
2✔
943
                    "all",
944
                    "split",
945
                    "gapfill",
946
                ]:
947
                    raise InvalidSamplingStrategyError(
2✔
948
                        f"Unknown sampling strategy {strat} for type str"
949
                    )
950
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
951
                    raise InvalidSamplingStrategyError(
2✔
952
                        f"If sampling strategy is a float, must be fraction "
953
                        f"must be in [0, 1], got {strat}"
954
                    )
955
                elif isinstance(strat, int) and strat < 0:
2✔
956
                    raise InvalidSamplingStrategyError(
2✔
957
                        f"Sampling N must be positive, got {strat}"
958
                    )
959
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
960
                    raise InvalidSamplingStrategyError(
2✔
961
                        "Sampling coordinates must be of shape (2, N), got "
962
                        f"{strat.shape}"
963
                    )
964

965
            return sampling_strat
2✔
966

967
        def sample_variable(var, sampling_strat, seed):
2✔
968
            """Sample a variable by a given sampling strategy to get input and
969
            output data.
970

971
            Args:
972
                var:
973
                    Variable to sample.
974
                sampling_strat:
975
                    Sampling strategy to use.
976
                seed:
977
                    Seed for random sampling.
978

979
            Returns:
980
                Tuple[X, Y]:
981
                    Tuple of input and output data.
982

983
            Raises:
984
                ValueError:
985
                    If the variable is of an unknown type.
986
            """
987
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
988
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
989
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
990
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
991
            else:
992
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
993
            return X, Y
2✔
994

995
        # Check that the sampling strategies are valid
996
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
997
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
998
        # Check `split_frac
999
        if split_frac < 0 or split_frac > 1:
2✔
1000
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
1001
        if self.links is None:
2✔
1002
            b1 = any(
2✔
1003
                [
1004
                    strat in ["split", "gapfill"]
1005
                    for strat in context_sampling
1006
                    if isinstance(strat, str)
1007
                ]
1008
            )
1009
            if target_sampling is None:
2✔
1010
                b2 = False
2✔
1011
            else:
1012
                b2 = any(
2✔
1013
                    [
1014
                        strat in ["split", "gapfill"]
1015
                        for strat in target_sampling
1016
                        if isinstance(strat, str)
1017
                    ]
1018
                )
1019
            if b1 or b2:
2✔
1020
                raise ValueError(
2✔
1021
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
1022
                    "sets must be linked with the TaskLoader `links` attribute."
1023
                )
1024
        if self.links is not None:
2✔
1025
            for context_idx, target_idx in self.links:
2✔
1026
                context_sampling_i = context_sampling[context_idx]
2✔
1027
                if target_sampling is None:
2✔
1028
                    target_sampling_i = None
×
1029
                else:
1030
                    target_sampling_i = target_sampling[target_idx]
2✔
1031
                link_strats = (context_sampling_i, target_sampling_i)
2✔
1032
                if any(
2✔
1033
                    [
1034
                        strat in ["split", "gapfill"]
1035
                        for strat in link_strats
1036
                        if isinstance(strat, str)
1037
                    ]
1038
                ):
1039
                    # If one of the sampling strategies is "split" or "gapfill", the other must
1040
                    # use the same splitting strategy
1041
                    if link_strats[0] != link_strats[1]:
2✔
1042
                        raise ValueError(
2✔
1043
                            f"Linked context set {context_idx} and target set {target_idx} "
1044
                            f"must use the same sampling strategy if one of them "
1045
                            f"uses the 'split' or 'gapfill' sampling strategy. "
1046
                            f"Got {link_strats[0]} and {link_strats[1]}."
1047
                        )
1048

1049
        if not isinstance(date, pd.Timestamp):
2✔
1050
            date = pd.Timestamp(date)
2✔
1051

1052
        if seed_override is not None:
2✔
1053
            # Override the seed for random sampling
1054
            seed = seed_override
×
1055
        elif datewise_deterministic:
2✔
1056
            # Generate a deterministic seed, based on the date, for random sampling
1057
            seed = int(date.strftime("%Y%m%d"))
2✔
1058
        else:
1059
            # 'Truly' random sampling
1060
            seed = None
2✔
1061

1062
        task = {}
2✔
1063

1064
        task["time"] = date
2✔
1065
        task["ops"] = []
2✔
1066
        if bbox:
2✔
1067
            task["bbox"] = bbox
2✔
1068
        if patch_size:
2✔
1069
            task["patch_size"] = (
2✔
1070
                patch_size  # store patch_size and stride in task for use in stitching in prediction
1071
            )
1072
        if stride:
2✔
1073
            task["stride"] = stride
2✔
1074
        task["X_c"] = []
2✔
1075
        task["Y_c"] = []
2✔
1076
        if target_sampling is not None:
2✔
1077
            task["X_t"] = []
2✔
1078
            task["Y_t"] = []
2✔
1079
        else:
1080
            task["X_t"] = None
2✔
1081
            task["Y_t"] = None
2✔
1082

1083
        # temporal slices
1084
        context_slices = [
2✔
1085
            self.time_slice_variable(var, date, delta_t)
1086
            for var, delta_t in zip(self.context, self.context_delta_t)
1087
        ]
1088
        target_slices = [
2✔
1089
            self.time_slice_variable(var, date, delta_t)
1090
            for var, delta_t in zip(self.target, self.target_delta_t)
1091
        ]
1092

1093
        # TODO move to method
1094
        if (
2✔
1095
            self.links is not None
1096
            and "split" in context_sampling
1097
            and "split" in target_sampling
1098
        ):
1099
            # Perform the split sampling strategy for linked context and target sets at this point
1100
            # while we have the full context and target data in scope
1101

1102
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1103
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1104
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1105
                f"Number of context sets with 'split' sampling strategy "
1106
                f"({len(context_split_idxs)}) must match number of target sets "
1107
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1108
            )
1109
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1110
                zip(context_split_idxs, target_split_idxs)
1111
            ):
1112
                assert (context_idx, target_idx) in self.links, (
2✔
1113
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1114
                    f"with the `links` attribute if using the 'split' sampling strategy"
1115
                )
1116

1117
                context_var = context_slices[context_idx]
2✔
1118
                target_var = target_slices[target_idx]
2✔
1119

1120
                for var in [context_var, target_var]:
2✔
1121
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1122
                        f"If using 'split' sampling strategy for linked context and target sets, "
1123
                        f"the context and target sets must be pandas DataFrames or Series, "
1124
                        f"but got {type(var)}."
1125
                    )
1126

1127
                N_obs = len(context_var)
2✔
1128
                N_obs_target_check = len(target_var)
2✔
1129
                if N_obs != N_obs_target_check:
2✔
1130
                    raise ValueError(
×
1131
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1132
                        f"because they have different numbers of observations: "
1133
                        f"{N_obs} and {N_obs_target_check}"
1134
                    )
1135
                split_seed = seed + split_i if seed is not None else None
2✔
1136
                rng = np.random.default_rng(split_seed)
2✔
1137

1138
                N_context = int(N_obs * split_frac)
2✔
1139
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1140

1141
                context_var = context_var.iloc[idxs_context]
2✔
1142
                target_var = target_var.drop(context_var.index)
2✔
1143

1144
                context_slices[context_idx] = context_var
2✔
1145
                target_slices[target_idx] = target_var
2✔
1146

1147
        # TODO move to method
1148
        if (
2✔
1149
            self.links is not None
1150
            and "gapfill" in context_sampling
1151
            and "gapfill" in target_sampling
1152
        ):
1153
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1154
            # while we have the full context and target data in scope
1155

1156
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1157
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1158
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1159
                f"Number of context sets with 'gapfill' sampling strategy "
1160
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1161
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1162
            )
1163
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1164
                zip(context_gapfill_idxs, target_gapfill_idxs)
1165
            ):
1166
                assert (context_idx, target_idx) in self.links, (
2✔
1167
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1168
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1169
                )
1170

1171
                context_var = context_slices[context_idx]
2✔
1172
                target_var = target_slices[target_idx]
2✔
1173

1174
                for var in [context_var, target_var]:
2✔
1175
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1176
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1177
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1178
                        f"but got {type(var)}."
1179
                    )
1180

1181
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1182
                rng = np.random.default_rng(split_seed)
2✔
1183

1184
                # Keep trying until we get a target set with at least one target point
1185
                keep_searching = True
2✔
1186
                while keep_searching:
2✔
1187
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1188
                    added_mask = (
2✔
1189
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1190
                    )
1191
                    curr_mask = context_var.isnull()
2✔
1192

1193
                    # Mask out added missing values
1194
                    context_var = context_var.where(~added_mask)
2✔
1195

1196
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1197
                    #   when we could just slice the target values here
1198
                    target_mask = added_mask & ~curr_mask
2✔
1199
                    if isinstance(target_var, xr.Dataset):
2✔
1200
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1201
                    else:
1202
                        keep_searching = np.all(target_mask.data == False)
2✔
1203
                    if keep_searching:
2✔
1204
                        continue  # No target points -- use a different `added_mask`
×
1205

1206
                    target_var = target_var.where(
2✔
1207
                        target_mask
1208
                    )  # Only keep target locations
1209

1210
                    context_slices[context_idx] = context_var
2✔
1211
                    target_slices[target_idx] = target_var
2✔
1212

1213
        # check bbox size
1214
        if bbox is not None:
2✔
1215
            assert (
2✔
1216
                len(bbox) == 4
1217
            ), "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]"
1218

1219
            # spatial slices
1220
            context_slices = [
2✔
1221
                self.spatial_slice_variable(var, bbox) for var in context_slices
1222
            ]
1223
            target_slices = [
2✔
1224
                self.spatial_slice_variable(var, bbox) for var in target_slices
1225
            ]
1226

1227
        for i, (var, sampling_strat) in enumerate(
2✔
1228
            zip(context_slices, context_sampling)
1229
        ):
1230
            context_seed = seed + i if seed is not None else None
2✔
1231
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1232
            task[f"X_c"].append(X_c)
2✔
1233
            task[f"Y_c"].append(Y_c)
2✔
1234
        if target_sampling is not None:
2✔
1235
            for j, (var, sampling_strat) in enumerate(
2✔
1236
                zip(target_slices, target_sampling)
1237
            ):
1238
                target_seed = seed + i + j if seed is not None else None
2✔
1239
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1240
                task[f"X_t"].append(X_t)
2✔
1241
                task[f"Y_t"].append(Y_t)
2✔
1242

1243
        if self.aux_at_contexts is not None:
2✔
1244
            # Add auxiliary variable sampled at context set as a new context variable
1245
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1246
            if len(X_c_offgrid) == 0:
2✔
1247
                # No offgrid context sets
1248
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1249
            else:
1250
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1251
            Y_c_aux = (
2✔
1252
                self.sample_offgrid_aux(
1253
                    X_c_offrid_all,
1254
                    self.time_slice_variable(self.aux_at_contexts, date),
1255
                ),
1256
            )
1257
            task["X_c"].append(X_c_offrid_all)
2✔
1258
            task["Y_c"].append(Y_c_aux)
2✔
1259

1260
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1261
            task["Y_t_aux"] = None
2✔
1262
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1263
            # Add auxiliary variable to target set
1264
            if len(task["X_t"]) > 1:
2✔
1265
                raise ValueError(
×
1266
                    "Cannot add auxiliary variable to target set when there "
1267
                    "are multiple target variables (not supported by default `ConvNP` model)."
1268
                )
1269
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1270
                task["X_t"][0],
1271
                self.time_slice_variable(self.aux_at_targets, date),
1272
            )
1273

1274
        return Task(task)
2✔
1275

1276
    def __call__(
2✔
1277
        self,
1278
        date: pd.Timestamp,
1279
        context_sampling: Union[
1280
            str,
1281
            int,
1282
            float,
1283
            np.ndarray,
1284
            List[Union[str, int, float, np.ndarray]],
1285
        ] = "all",
1286
        target_sampling: Optional[
1287
            Union[
1288
                str,
1289
                int,
1290
                float,
1291
                np.ndarray,
1292
                List[Union[str, int, float, np.ndarray]],
1293
            ]
1294
        ] = None,
1295
        split_frac: float = 0.5,
1296
        datewise_deterministic: bool = False,
1297
        seed_override: Optional[int] = None,
1298
    ) -> Union[Task, List[Task]]:
1299
        """Generate a task for a given date (or a list of
1300
        :class:`.data.task.Task` objects for a list of dates).
1301

1302
        There are several sampling strategies available for the context and
1303
        target data:
1304

1305
            - "all": Sample all observations.
1306
            - int: Sample N observations uniformly at random.
1307
            - float: Sample a fraction of observations uniformly at random.
1308
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1309
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1310
                normalised.
1311
            - "split": Split pandas observations into disjoint context and target sets.
1312
                `split_frac` determines the fraction of observations
1313
                to use for the context set. The remaining observations are used
1314
                for the target set.
1315
                The context set and target set must be linked through the ``TaskLoader``
1316
                ``links`` argument. Only valid for pandas data.
1317
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1318
                Randomly samples a missing data (NaN) mask from another timestamp and
1319
                adds it to the context set (i.e. increases the number of NaNs).
1320
                The target set is then true values of the data at the added missing locations.
1321
                The context set and target set must be linked through the ``TaskLoader``
1322
                ``links`` argument. Only valid for xarray data.
1323

1324
        Args:
1325
            date (:class:`pandas.Timestamp`):
1326
                Date for which to generate the task.
1327
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1328
                Sampling strategy for the context data, either a list of
1329
                sampling strategies for each context set, or a single strategy
1330
                applied to all context sets. Default is ``"all"``.
1331
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1332
                Sampling strategy for the target data, either a list of
1333
                sampling strategies for each target set, or a single strategy
1334
                applied to all target sets. Default is ``None``, meaning no target
1335
                data is returned.
1336
            split_frac (float, optional):
1337
                The fraction of observations to use for the context set with
1338
                the "split" sampling strategy for linked context and target set
1339
                pairs. The remaining observations are used for the target set.
1340
                Default is 0.5.
1341
            datewise_deterministic (bool, optional):
1342
                Whether random sampling is datewise deterministic based on the
1343
                date. Default is ``False``.
1344
            seed_override (Optional[int], optional):
1345
                Override the seed for random sampling. This can be used to use
1346
                the same random sampling at different ``date``. Default is
1347
                None.
1348

1349
        Returns:
1350
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1351
                Task object or list of task objects for each date containing
1352
                the context and target data.
1353
        """
1354
        if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1355
            return [
2✔
1356
                self.task_generation(
1357
                    date=d,
1358
                    context_sampling=context_sampling,
1359
                    target_sampling=target_sampling,
1360
                    split_frac=split_frac,
1361
                    datewise_deterministic=datewise_deterministic,
1362
                    seed_override=seed_override,
1363
                )
1364
                for d in date
1365
            ]
1366
        else:
1367
            return self.task_generation(
2✔
1368
                date=date,
1369
                context_sampling=context_sampling,
1370
                target_sampling=target_sampling,
1371
                split_frac=split_frac,
1372
                datewise_deterministic=datewise_deterministic,
1373
                seed_override=seed_override,
1374
            )
1375

1376

1377
class PatchwiseTaskLoader(TaskLoader):
2✔
1378
    """Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models using a patchwise approach."""
1379

1380
    def __init__(self, *args, **kwargs) -> None:
2✔
1381
        super().__init__(*args, **kwargs)
2✔
1382
        self.coord_bounds = self._compute_global_coordinate_bounds()
2✔
1383

1384
    def _compute_global_coordinate_bounds(self) -> List[float]:
2✔
1385
        """Compute global coordinate bounds in order to sample spatial bounds if desired.
1386

1387
        Returns:
1388
        -------
1389
        bbox: List[float]
1390
            sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max]
1391
        """
1392
        x1_min, x1_max, x2_min, x2_max = np.inf, -np.inf, np.inf, -np.inf
2✔
1393

1394
        for var in itertools.chain(self.context, self.target):
2✔
1395
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
1396
                var_x1_min = var.x1.min().item()
2✔
1397
                var_x1_max = var.x1.max().item()
2✔
1398
                var_x2_min = var.x2.min().item()
2✔
1399
                var_x2_max = var.x2.max().item()
2✔
1400
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
1401
                var_x1_min = var.index.get_level_values("x1").min()
2✔
1402
                var_x1_max = var.index.get_level_values("x1").max()
2✔
1403
                var_x2_min = var.index.get_level_values("x2").min()
2✔
1404
                var_x2_max = var.index.get_level_values("x2").max()
2✔
1405

1406
            if var_x1_min < x1_min:
2✔
1407
                x1_min = var_x1_min
2✔
1408

1409
            if var_x1_max > x1_max:
2✔
1410
                x1_max = var_x1_max
2✔
1411

1412
            if var_x2_min < x2_min:
2✔
1413
                x2_min = var_x2_min
2✔
1414

1415
            if var_x2_max > x2_max:
2✔
1416
                x2_max = var_x2_max
2✔
1417

1418
        return [x1_min, x1_max, x2_min, x2_max]
2✔
1419

1420
    def _compute_x1x2_direction(self) -> dict:
2✔
1421
        """Compute whether the x1 and x2 coords are ascending or descending.
1422

1423
        Returns:
1424
            dict(bool)
1425
                Dictionary containing two keys: x1 and x2, with boolean values
1426
                defining if these coordings increase or decrease from top left corner.
1427

1428
        Raises:
1429
            ValueError:
1430
                If all datasets are non-gridded or if direction of ascending
1431
                coordinates does not match across non-gridded datasets.
1432

1433
        """
1434
        non_gridded = {"x1": None, "x2": None}  # value to use for non-gridded data
2✔
1435
        ascending = []
2✔
1436
        for var in itertools.chain(self.context, self.target):
2✔
1437
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
1438
                coord_x1_left = var.x1[0]
2✔
1439
                coord_x1_right = var.x1[-1]
2✔
1440
                coord_x2_top = var.x2[0]
2✔
1441
                coord_x2_bottom = var.x2[-1]
2✔
1442

1443
                ascending.append(
2✔
1444
                    {
1445
                        "x1": True if coord_x1_left <= coord_x1_right else False,
1446
                        "x2": True if coord_x2_top <= coord_x2_bottom else False,
1447
                    }
1448
                )
1449

1450
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
1451
                ascending.append(non_gridded)
2✔
1452

1453
        if len(list(filter(lambda x: x != non_gridded, ascending))) == 0:
2✔
NEW
1454
            raise ValueError(
×
1455
                "All data is non gridded, can not proceed with sliding window sampling."
1456
            )
1457

1458
        # get the directions for only the gridded data
1459
        gridded = list(filter(lambda x: x != non_gridded, ascending))
2✔
1460
        # raise error if directions don't match across gridded data
1461
        if gridded.count(gridded[0]) != len(gridded):
2✔
NEW
1462
            raise ValueError(
×
1463
                "Direction of ascending coordinates does not match across all gridded datasets."
1464
            )
1465

1466
        return gridded[0]
2✔
1467

1468
    def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]:
2✔
1469
        """Sample random window uniformly from global coordinates to slice data.
1470

1471
        Parameters
1472
        ----------
1473
        patch_size : Tuple[float]
1474
            Tuple of window extent
1475

1476
        Returns:
1477
        -------
1478
        bbox: List[float]
1479
            sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]
1480
        """
1481
        x1_extend, x2_extend = patch_size
2✔
1482

1483
        x1_side = x1_extend / 2
2✔
1484
        x2_side = x2_extend / 2
2✔
1485

1486
        # sample a point that satisfies the context and target global bounds
1487
        x1_min, x1_max, x2_min, x2_max = self.coord_bounds
2✔
1488

1489
        x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side)
2✔
1490
        x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side)
2✔
1491

1492
        # bbox of x1_min, x1_max, x2_min, x2_max
1493
        bbox = [
2✔
1494
            x1_point - x1_side,
1495
            x1_point + x1_side,
1496
            x2_point - x2_side,
1497
            x2_point + x2_side,
1498
        ]
1499

1500
        return bbox
2✔
1501

1502
    def spatial_slice_variable(self, var, window: List[float]):
2✔
1503
        """Slice a variable by a given window size.
1504

1505
        Args:
1506
            var (...):
1507
                Variable to slice.
1508
            window (List[float]):
1509
                List of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max].
1510

1511
        Returns:
1512
            var (...)
1513
                Sliced variable.
1514

1515
        Raises:
1516
            ValueError
1517
                If the variable is of an unknown type.
1518
        """
1519
        x1_min, x1_max, x2_min, x2_max = window
2✔
1520
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
1521
            # we cannot assume that the coordinates are sorted from small to large
1522
            if var.x1[0] > var.x1[-1]:
2✔
1523
                x1_slice = slice(x1_max, x1_min)
2✔
1524
            else:
1525
                x1_slice = slice(x1_min, x1_max)
2✔
1526
            if var.x2[0] > var.x2[-1]:
2✔
NEW
1527
                x2_slice = slice(x2_max, x2_min)
×
1528
            else:
1529
                x2_slice = slice(x2_min, x2_max)
2✔
1530
            var = var.sel(x1=x1_slice, x2=x2_slice)
2✔
1531
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
1532
            # retrieve desired patch size
1533
            var = var[
2✔
1534
                (var.index.get_level_values("x1") >= x1_min)
1535
                & (var.index.get_level_values("x1") <= x1_max)
1536
                & (var.index.get_level_values("x2") >= x2_min)
1537
                & (var.index.get_level_values("x2") <= x2_max)
1538
            ]
1539
        else:
NEW
1540
            raise ValueError(f"Unknown variable type {type(var)}")
×
1541

1542
        return var
2✔
1543

1544
    def sample_sliding_window(
2✔
1545
        self, patch_size: Tuple[float], stride: Tuple[int]
1546
    ) -> Sequence[float]:
1547
        """Sample data using sliding window from global coordinates to slice data.
1548
        Parameters.
1549
        ----------
1550
        patch_size : Tuple[float]
1551
            Tuple of window extent
1552

1553
        stride : Tuple[float]
1554
            Tuple of step size between each patch along x1 and x2 axis.
1555

1556
        Returns:
1557
        -------
1558
        List[float]
1559
            Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max].
1560
        """
1561
        self.coord_directions = self._compute_x1x2_direction()
2✔
1562
        # define patch size in x1/x2
1563
        size = {}
2✔
1564
        size["x1"], size["x2"] = patch_size
2✔
1565

1566
        # define stride length in x1/x2 or set to patch_size if undefined
1567
        if stride is None:
2✔
NEW
1568
            stride = patch_size
×
1569

1570
        step = {}
2✔
1571
        step["x1"], step["x2"] = stride
2✔
1572

1573
        # Calculate the global bounds of context and target set.
1574
        coord_min = {}
2✔
1575
        coord_max = {}
2✔
1576
        coord_min["x1"], coord_max["x1"], coord_min["x2"], coord_max["x2"] = (
2✔
1577
            self.coord_bounds
1578
        )
1579

1580
        ## start with first patch top left hand corner at coord_min["x1"], coord_min["x2"]
1581
        patch_list = []
2✔
1582

1583
        # define some lambda functions for use below
1584
        # round to 12 figures to avoid floating point error but reduce likelihood of unintentional rounding
1585
        r = lambda x: round(x, 12)
2✔
1586
        bbox_coords_ascend = lambda a, b: [r(a), r(a + b)]
2✔
1587
        bbox_coords_descend = lambda a, b: bbox_coords_ascend(a, b)[::-1]
2✔
1588

1589
        compare = {}
2✔
1590
        bbox_coords = {}
2✔
1591
        # for each coordinate direction specify the correct operations for patching
1592
        for c in ("x1", "x2"):
2✔
1593
            if self.coord_directions[c]:
2✔
1594
                compare[c] = operator.gt
2✔
1595
                bbox_coords[c] = bbox_coords_ascend
2✔
1596
            else:
1597
                step[c] = -step[c]
2✔
1598
                coord_min[c], coord_max[c] = coord_max[c], coord_min[c]
2✔
1599
                size[c] = -size[c]
2✔
1600
                compare[c] = operator.lt
2✔
1601
                bbox_coords[c] = bbox_coords_descend
2✔
1602

1603
        # Define the bounding boxes for all patches, starting in top left corner of dataArray
1604
        for y, x in itertools.product(
2✔
1605
            np.arange(coord_min["x1"], coord_max["x1"], step["x1"]),
1606
            np.arange(coord_min["x2"], coord_max["x2"], step["x2"]),
1607
        ):
1608
            y0 = (
2✔
1609
                coord_max["x1"] - size["x1"]
1610
                if compare["x1"](y + size["x1"], coord_max["x1"])
1611
                else y
1612
            )
1613
            x0 = (
2✔
1614
                coord_max["x2"] - size["x2"]
1615
                if compare["x2"](x + size["x2"], coord_max["x2"])
1616
                else x
1617
            )
1618

1619
            # bbox of x1_min, x1_max, x2_min, x2_max per patch
1620
            bbox = bbox_coords["x1"](y0, size["x1"]) + bbox_coords["x2"](x0, size["x2"])
2✔
1621
            patch_list.append(bbox)
2✔
1622

1623
        # Remove duplicate patches while preserving order
1624
        seen = set()
2✔
1625
        unique_patch_list = []
2✔
1626
        for lst in patch_list:
2✔
1627
            # Convert list to tuple for immutability
1628
            tuple_lst = tuple(lst)
2✔
1629
            if tuple_lst not in seen:
2✔
1630
                seen.add(tuple_lst)
2✔
1631
                unique_patch_list.append(lst)
2✔
1632

1633
        return unique_patch_list
2✔
1634

1635
    def __call__(
2✔
1636
        self,
1637
        date: Union[pd.Timestamp, Sequence[pd.Timestamp]],
1638
        context_sampling: Union[
1639
            str,
1640
            int,
1641
            float,
1642
            np.ndarray,
1643
            List[Union[str, int, float, np.ndarray]],
1644
        ] = "all",
1645
        target_sampling: Optional[
1646
            Union[
1647
                str,
1648
                int,
1649
                float,
1650
                np.ndarray,
1651
                List[Union[str, int, float, np.ndarray]],
1652
            ]
1653
        ] = None,
1654
        split_frac: float = 0.5,
1655
        patch_size: Union[float, Tuple[float]] = None,
1656
        patch_strategy: Optional[str] = None,
1657
        stride: Union[float, Tuple[float]] = None,
1658
        num_patch_tasks: int = 1,
1659
        datewise_deterministic: bool = False,
1660
        seed_override: Optional[int] = None,
1661
    ) -> Union[Task, List[Task]]:
1662
        """Generate a task for a given date (or a list of
1663
        :class:`.data.task.Task` objects for a list of dates).
1664

1665
        There are several sampling strategies available for the context and
1666
        target data:
1667

1668
            - "all": Sample all observations.
1669
            - int: Sample N observations uniformly at random.
1670
            - float: Sample a fraction of observations uniformly at random.
1671
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1672
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1673
                normalised.
1674
            - "split": Split pandas observations into disjoint context and target sets.
1675
                `split_frac` determines the fraction of observations
1676
                to use for the context set. The remaining observations are used
1677
                for the target set.
1678
                The context set and target set must be linked through the ``TaskLoader``
1679
                ``links`` argument. Only valid for pandas data.
1680
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1681
                Randomly samples a missing data (NaN) mask from another timestamp and
1682
                adds it to the context set (i.e. increases the number of NaNs).
1683
                The target set is then true values of the data at the added missing locations.
1684
                The context set and target set must be linked through the ``TaskLoader``
1685
                ``links`` argument. Only valid for xarray data.
1686

1687
        Args:
1688
            date (:class:`pandas.Timestamp`):
1689
                Date for which to generate the task.
1690
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1691
                Sampling strategy for the context data, either a list of
1692
                sampling strategies for each context set, or a single strategy
1693
                applied to all context sets. Default is ``"all"``.
1694
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1695
                Sampling strategy for the target data, either a list of
1696
                sampling strategies for each target set, or a single strategy
1697
                applied to all target sets. Default is ``None``, meaning no target
1698
                data is returned.
1699
            split_frac (float, optional):
1700
                The fraction of observations to use for the context set with
1701
                the "split" sampling strategy for linked context and target set
1702
                pairs. The remaining observations are used for the target set.
1703
                Default is 0.5.
1704
            patch_size : Union[float, tuple[float]], optional
1705
                Desired patch size in x1/x2 used for patchwise task generation. Useful when considering
1706
                the entire available region is computationally prohibitive for model forward pass.
1707
                If passed a single float, will use value for both x1 & x2.
1708
            patch_strategy:
1709
                Patch strategy to use for patchwise task generation. Default is None.
1710
                Possible options are 'random' or 'sliding'.
1711
            stride: Union[float, tuple[float]], optional
1712
                Step size between each sliding window patch along x1 and x2 axis. Default is None.
1713
                If passed a single float, will use value for both x1 & x2.
1714
            num_patch_tasks: int
1715
                The number of patches to generate per date when using the "random" patching strategy.
1716
            datewise_deterministic (bool, optional):
1717
                Whether random sampling is datewise deterministic based on the
1718
                date. Default is ``False``.
1719
            seed_override (Optional[int], optional):
1720
                Override the seed for random sampling. This can be used to use
1721
                the same random sampling at different ``date``. Default is
1722
                None.
1723

1724
        Returns:
1725
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1726
                Task object or list of task objects for each date containing
1727
                the context and target data.
1728
        """
1729
        if patch_strategy not in [None, "random", "sliding"]:
2✔
NEW
1730
            raise ValueError(
×
1731
                f"Invalid patch strategy {patch_strategy}. "
1732
                f"Must be one of [None, 'random', 'sliding']."
1733
            )
1734

1735
        if isinstance(patch_size, float) and patch_size is not None:
2✔
1736
            patch_size = (patch_size, patch_size)
2✔
1737

1738
        if isinstance(stride, float) and stride is not None:
2✔
1739
            stride = (stride, stride)
2✔
1740

1741
        if patch_strategy is None:
2✔
NEW
1742
            return super().__call__(
×
1743
                date=date,
1744
                context_sampling=context_sampling,
1745
                target_sampling=target_sampling,
1746
                split_frac=split_frac,
1747
                datewise_deterministic=datewise_deterministic,
1748
                seed_override=seed_override,
1749
            )
1750

1751
        elif patch_strategy == "random":
2✔
1752
            if patch_size is None:
2✔
NEW
1753
                raise ValueError(
×
1754
                    "Patch size must be specified for random patch sampling"
1755
                )
1756

1757
            coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]]
2✔
1758
            for i, val in enumerate(patch_size):
2✔
1759
                if val < coord_bounds[i][0] or val > coord_bounds[i][1]:
2✔
1760
                    raise ValueError(
2✔
1761
                        f"Values of stride must be between the normalised coordinate bounds of: {self.coord_bounds}. \
1762
                            Got: patch_size: {patch_size}."
1763
                    )
1764

1765
            if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1766
                for d in date:
2✔
1767
                    bboxes = [
2✔
1768
                        self.sample_random_window(patch_size)
1769
                        for _ in range(num_patch_tasks)
1770
                    ]
1771
                    tasks = [
2✔
1772
                        self.task_generation(
1773
                            date=d,
1774
                            bbox=bbox,
1775
                            context_sampling=context_sampling,
1776
                            target_sampling=target_sampling,
1777
                            split_frac=split_frac,
1778
                            datewise_deterministic=datewise_deterministic,
1779
                            seed_override=seed_override,
1780
                        )
1781
                        for bbox in bboxes
1782
                    ]
1783

1784
            else:
1785
                bboxes = [
2✔
1786
                    self.sample_random_window(patch_size)
1787
                    for _ in range(num_patch_tasks)
1788
                ]
1789
                tasks = [
2✔
1790
                    self.task_generation(
1791
                        date=date,
1792
                        bbox=bbox,
1793
                        context_sampling=context_sampling,
1794
                        target_sampling=target_sampling,
1795
                        split_frac=split_frac,
1796
                        datewise_deterministic=datewise_deterministic,
1797
                        seed_override=seed_override,
1798
                    )
1799
                    for bbox in bboxes
1800
                ]
1801

1802
        elif patch_strategy == "sliding":
2✔
1803
            # sliding window sampling of patch
1804

1805
            for val in (patch_size, stride):
2✔
1806
                if val is None:
2✔
NEW
1807
                    raise ValueError(
×
1808
                        f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}."
1809
                    )
1810

1811
            if stride[0] > patch_size[0] or stride[1] > patch_size[1]:
2✔
1812
                raise Warning(
2✔
1813
                    f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}"
1814
                )
1815

1816
            coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]]
2✔
1817
            for i in (0, 1):
2✔
1818
                for val in (patch_size[i], stride[i]):
2✔
1819
                    if val < coord_bounds[i][0] or val > coord_bounds[i][1]:
2✔
NEW
1820
                        raise ValueError(
×
1821
                            f"Values of stride and patch_size must be between the normalised coordinate bounds of: {self.coord_bounds}. \
1822
                                Got: patch_size: {patch_size}, stride: {stride}"
1823
                        )
1824

1825
            if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1826
                tasks = []
2✔
1827
                for d in date:
2✔
1828
                    bboxes = self.sample_sliding_window(patch_size, stride)
2✔
1829
                    tasks.extend(
2✔
1830
                        [
1831
                            self.task_generation(
1832
                                date=d,
1833
                                bbox=bbox,
1834
                                patch_size=patch_size,
1835
                                stride=stride,
1836
                                context_sampling=context_sampling,
1837
                                target_sampling=target_sampling,
1838
                                split_frac=split_frac,
1839
                                datewise_deterministic=datewise_deterministic,
1840
                                seed_override=seed_override,
1841
                            )
1842
                            for bbox in bboxes
1843
                        ]
1844
                    )
1845
            else:
1846
                bboxes = self.sample_sliding_window(patch_size, stride)
2✔
1847
                tasks = [
2✔
1848
                    self.task_generation(
1849
                        date=date,
1850
                        bbox=bbox,
1851
                        context_sampling=context_sampling,
1852
                        target_sampling=target_sampling,
1853
                        split_frac=split_frac,
1854
                        datewise_deterministic=datewise_deterministic,
1855
                        seed_override=seed_override,
1856
                        patch_size=patch_size,
1857
                        stride=stride,
1858
                    )
1859
                    for bbox in bboxes
1860
                ]
1861
        else:
NEW
1862
            raise ValueError(
×
1863
                f"Invalid patch strategy {patch_strategy}. "
1864
                f"Must be one of [None, 'random', 'sliding']."
1865
            )
1866

1867
        return tasks
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc