• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14311911693

07 Apr 2025 02:30PM UTC coverage: 82.532% (+0.9%) from 81.663%
14311911693

Pull #135

github

web-flow
Merge 620677f66 into 38ec5ef26
Pull Request #135: Enable patchwise training and prediction

281 of 317 new or added lines in 4 files covered. (88.64%)

26 existing lines in 2 files now uncovered.

2334 of 2828 relevant lines covered (82.53%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.49
/deepsensor/data/loader.py
1
import copy
2✔
2
import itertools
2✔
3
import json
2✔
4
import operator
2✔
5
import os
2✔
6
import random
2✔
7
from typing import List, Optional, Sequence, Tuple, Union
2✔
8

9
import numpy as np
2✔
10
import pandas as pd
2✔
11
import xarray as xr
2✔
12

13
from deepsensor.data.task import Task, flatten_X
2✔
14
from deepsensor.errors import InvalidSamplingStrategyError, SamplingTooManyPointsError
2✔
15

16

17
class TaskLoader:
2✔
18
    """Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
19

20
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
21
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
22
    of these.
23

24
    The behaviour is the following:
25
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
26
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
27
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
28
          (using config)
29

30
    Args:
31
        task_loader_ID:
32
            If loading a TaskLoader from a config file, this is the folder the
33
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
34
            arguments are ignored.
35
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
36
            Context data. Can be a single :class:`xarray.DataArray`,
37
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
38
            list/tuple of these.
39
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
40
            Target data. Can be a single :class:`xarray.DataArray`,
41
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
42
            list/tuple of these.
43
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
44
            Auxiliary data at context locations. Tuple of two elements, where
45
            the first element is the index of the context set for which the
46
            auxiliary data will be sampled at, and the second element is the
47
            auxiliary data, which can be a single :class:`xarray.DataArray` or
48
            :class:`xarray.Dataset`. Default: None.
49
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
50
            Auxiliary data at target locations. Can be a single
51
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
52
            None.
53
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
54
            Specifies links between context and target data. Each link is a
55
            tuple of two integers, where the first integer is the index of the
56
            context data and the second integer is the index of the target
57
            data. Can be a single tuple in the case of a single link. If None,
58
            no links are specified. Default: None.
59
        context_delta_t (int | List[int], optional):
60
            Time difference between context data and t=0 (task init time). Can
61
            be a single int (same for all context data) or a list/tuple of
62
            ints. Default is 0.
63
        target_delta_t (int | List[int], optional):
64
            Time difference between target data and t=0 (task init time). Can
65
            be a single int (same for all target data) or a list/tuple of ints.
66
            Default is 0.
67
        time_freq (str, optional):
68
            Time frequency of the data. Default: ``'D'`` (daily).
69
        xarray_interp_method (str, optional):
70
            Interpolation method to use when interpolating
71
            :class:`xarray.DataArray`. Default is ``'linear'``.
72
        discrete_xarray_sampling (bool, optional):
73
            When randomly sampling xarray variables, whether to sample at
74
            discrete points defined at grid cell centres, or at continuous
75
            points within the grid. Default is ``False``.
76
        dtype (object, optional):
77
            Data type of the data. Used to cast the data to the specified
78
            dtype. Default: ``'float32'``.
79
    """
80

81
    config_fname = "task_loader_config.json"
2✔
82

83
    def __init__(
2✔
84
        self,
85
        task_loader_ID: Union[str, None] = None,
86
        context: Union[
87
            xr.DataArray,
88
            xr.Dataset,
89
            pd.DataFrame,
90
            str,
91
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
92
        ] = None,
93
        target: Union[
94
            xr.DataArray,
95
            xr.Dataset,
96
            pd.DataFrame,
97
            str,
98
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
99
        ] = None,
100
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
101
        aux_at_targets: Optional[
102
            Union[
103
                xr.DataArray,
104
                xr.Dataset,
105
            ]
106
        ] = None,
107
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
108
        context_delta_t: Union[int, List[int]] = 0,
109
        target_delta_t: Union[int, List[int]] = 0,
110
        time_freq: str = "D",
111
        xarray_interp_method: str = "linear",
112
        discrete_xarray_sampling: bool = False,
113
        dtype: object = "float32",
114
    ) -> None:
115
        if task_loader_ID is not None:
2✔
116
            self.task_loader_ID = task_loader_ID
2✔
117
            # Load TaskLoader from config file
118
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
119
            with open(fpath, "r") as f:
2✔
120
                self.config = json.load(f)
2✔
121

122
            self.context = self.config["context"]
2✔
123
            self.target = self.config["target"]
2✔
124
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
125
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
126
            self.links = self.config["links"]
2✔
127
            if self.links is not None:
2✔
128
                self.links = [tuple(link) for link in self.links]
2✔
129
            self.context_delta_t = self.config["context_delta_t"]
2✔
130
            self.target_delta_t = self.config["target_delta_t"]
2✔
131
            self.time_freq = self.config["time_freq"]
2✔
132
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
133
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
134
            self.dtype = self.config["dtype"]
2✔
135
        else:
136
            self.context = context
2✔
137
            self.target = target
2✔
138
            self.aux_at_contexts = aux_at_contexts
2✔
139
            self.aux_at_targets = aux_at_targets
2✔
140
            self.links = links
2✔
141
            self.context_delta_t = context_delta_t
2✔
142
            self.target_delta_t = target_delta_t
2✔
143
            self.time_freq = time_freq
2✔
144
            self.xarray_interp_method = xarray_interp_method
2✔
145
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
146
            self.dtype = dtype
2✔
147

148
        if not isinstance(self.context, (tuple, list)):
2✔
149
            self.context = (self.context,)
2✔
150
        if not isinstance(self.target, (tuple, list)):
2✔
151
            self.target = (self.target,)
2✔
152

153
        if isinstance(self.context_delta_t, int):
2✔
154
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
155
        else:
156
            assert len(self.context_delta_t) == len(self.context), (
2✔
157
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
158
                f"the number of context sets ({len(self.context)})"
159
            )
160
        if isinstance(self.target_delta_t, int):
2✔
161
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
162
        else:
163
            assert len(self.target_delta_t) == len(self.target), (
2✔
164
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
165
                f"the number of target sets ({len(self.target)})"
166
            )
167

168
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
169
        if all_paths:
2✔
170
            self._set_config()
2✔
171
            self._load_data_from_paths()
2✔
172

173
        self.context = self._cast_to_dtype(self.context)
2✔
174
        self.target = self._cast_to_dtype(self.target)
2✔
175
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
176
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
177

178
        self.links = self._check_links(self.links)
2✔
179

180
        (
2✔
181
            self.context_dims,
182
            self.target_dims,
183
            self.aux_at_target_dims,
184
        ) = self.count_context_and_target_data_dims()
185
        (
2✔
186
            self.context_var_IDs,
187
            self.target_var_IDs,
188
            self.context_var_IDs_and_delta_t,
189
            self.target_var_IDs_and_delta_t,
190
            self.aux_at_target_var_IDs,
191
        ) = self.infer_context_and_target_var_IDs()
192

193
        self.coord_bounds = self._compute_global_coordinate_bounds()
2✔
194

195
    def _set_config(self):
2✔
196
        """Instantiate a config dictionary for the TaskLoader object."""
197
        # Take deepcopy to avoid modifying the original config
198
        self.config = copy.deepcopy(
2✔
199
            dict(
200
                context=self.context,
201
                target=self.target,
202
                aux_at_contexts=self.aux_at_contexts,
203
                aux_at_targets=self.aux_at_targets,
204
                links=self.links,
205
                context_delta_t=self.context_delta_t,
206
                target_delta_t=self.target_delta_t,
207
                time_freq=self.time_freq,
208
                xarray_interp_method=self.xarray_interp_method,
209
                discrete_xarray_sampling=self.discrete_xarray_sampling,
210
                dtype=self.dtype,
211
            )
212
        )
213

214
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
215
        """If all data passed as paths, save paths to config and return True."""
216

217
        def _check_if_strings(x, mode="all"):
2✔
218
            if x is None:
2✔
219
                return None
2✔
220
            elif isinstance(x, (tuple, list)):
2✔
221
                if mode == "all":
2✔
222
                    return all([isinstance(x_i, str) for x_i in x])
2✔
223
                elif mode == "any":
2✔
224
                    return any([isinstance(x_i, str) for x_i in x])
2✔
225
            else:
226
                return isinstance(x, str)
2✔
227

228
        all_paths = all(
2✔
229
            filter(
230
                lambda x: x is not None,
231
                [
232
                    _check_if_strings(self.context),
233
                    _check_if_strings(self.target),
234
                    _check_if_strings(self.aux_at_contexts),
235
                    _check_if_strings(self.aux_at_targets),
236
                ],
237
            )
238
        )
239
        self._is_saveable = all_paths
2✔
240

241
        any_paths = any(
2✔
242
            filter(
243
                lambda x: x is not None,
244
                [
245
                    _check_if_strings(self.context, mode="any"),
246
                    _check_if_strings(self.target, mode="any"),
247
                    _check_if_strings(self.aux_at_contexts, mode="any"),
248
                    _check_if_strings(self.aux_at_targets, mode="any"),
249
                ],
250
            )
251
        )
252
        if any_paths and not all_paths:
2✔
UNCOV
253
            raise ValueError(
×
254
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
255
            )
256

257
        return all_paths
2✔
258

259
    def _load_data_from_paths(self):
2✔
260
        """Load data from paths and overwrite paths with loaded data."""
261
        loaded_data = {}
2✔
262

263
        def _load_pandas_or_xarray(path):
2✔
264
            # Need to be careful about this. We need to ensure data gets into the right form
265
            #  for TaskLoader.
266
            if path is None:
2✔
267
                return None
2✔
268
            elif path in loaded_data:
2✔
269
                return loaded_data[path]
2✔
270
            elif path.endswith(".nc"):
2✔
271
                data = xr.open_dataset(path)
2✔
272
            elif path.endswith(".csv"):
2✔
273
                df = pd.read_csv(path)
2✔
274
                if "time" in df.columns:
2✔
275
                    df["time"] = pd.to_datetime(df["time"])
2✔
276
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
277
                else:
UNCOV
278
                    df = df.set_index(["x1", "x2"]).sort_index()
×
279
                data = df
2✔
280
            else:
UNCOV
281
                raise ValueError(f"Unknown file extension for {path}")
×
282
            loaded_data[path] = data
2✔
283
            return data
2✔
284

285
        def _load_data(data):
2✔
286
            if isinstance(data, (tuple, list)):
2✔
287
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
288
            else:
289
                data = _load_pandas_or_xarray(data)
2✔
290
            return data
2✔
291

292
        self.context = _load_data(self.context)
2✔
293
        self.target = _load_data(self.target)
2✔
294
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
295
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
296

297
    def save(self, folder: str):
2✔
298
        """Save TaskLoader config to JSON in `folder`."""
299
        if not self._is_saveable:
2✔
300
            raise ValueError(
2✔
301
                "TaskLoader cannot be saved because not all data was passed as paths."
302
            )
303

304
        os.makedirs(folder, exist_ok=True)
2✔
305
        fpath = os.path.join(folder, self.config_fname)
2✔
306
        with open(fpath, "w") as f:
2✔
307
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
308

309
    def _cast_to_dtype(
2✔
310
        self,
311
        var: Union[
312
            xr.DataArray,
313
            xr.Dataset,
314
            pd.DataFrame,
315
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
316
        ],
317
    ) -> (List, List):
318
        """Cast context and target data to the default dtype.
319

320
        ..
321
            TODO unit test this by passing in a variety of data types and
322
            checking that they are cast correctly.
323

324
        Args:
325
            var : ...
326
                ...
327

328
        Returns:
329
            tuple: Tuple of context data with specified dtype.
330
            tuple: Tuple of target data with specified dtype.
331
        """
332

333
        def cast_to_dtype(var):
2✔
334
            if isinstance(var, xr.DataArray):
2✔
335
                var = var.astype(self.dtype)
2✔
336
                var["x1"] = var["x1"].astype(self.dtype)
2✔
337
                var["x2"] = var["x2"].astype(self.dtype)
2✔
338
            elif isinstance(var, xr.Dataset):
2✔
339
                var = var.astype(self.dtype)
2✔
340
                var["x1"] = var["x1"].astype(self.dtype)
2✔
341
                var["x2"] = var["x2"].astype(self.dtype)
2✔
342
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
343
                var = var.astype(self.dtype)
2✔
344
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
345
                #   x1/x2 coord dtypes during task sampling
346
            else:
UNCOV
347
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
348
            return var
2✔
349

350
        if var is None:
2✔
351
            return var
2✔
352
        elif isinstance(var, (tuple, list)):
2✔
353
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
354
        else:
355
            var = cast_to_dtype(var)
2✔
356

357
        return var
2✔
358

359
    def load_dask(self) -> None:
2✔
360
        """Load any `dask` data into memory.
361

362
        This function triggers the computation and loading of any data that
363
        is represented as dask arrays or datasets into memory.
364

365
        Returns:
366
            None
367
        """
368

369
        def load(datasets):
2✔
370
            if datasets is None:
2✔
UNCOV
371
                return
×
372
            if not isinstance(datasets, (tuple, list)):
2✔
373
                datasets = [datasets]
2✔
374
            for i, var in enumerate(datasets):
2✔
375
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
376
                    var = var.load()
2✔
377

378
        load(self.context)
2✔
379
        load(self.target)
2✔
380
        load(self.aux_at_contexts)
2✔
381
        load(self.aux_at_targets)
2✔
382

383
        return None
2✔
384

385
    def count_context_and_target_data_dims(self):
2✔
386
        """Count the number of data dimensions in the context and target data.
387

388
        Returns:
389
            tuple: context_dims, Tuple of data dimensions in the context data.
390
            tuple: target_dims, Tuple of data dimensions in the target data.
391

392
        Raises:
393
            ValueError: If the context/target data is not a tuple/list of
394
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
395
                        :class:`pandas.DataFrame`.
396
        """
397

398
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
399
            if not isinstance(datasets, (tuple, list)):
2✔
400
                datasets = [datasets]
2✔
401

402
            dims = []
2✔
403
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
404
            for i, var in enumerate(datasets):
2✔
405
                if isinstance(var, xr.Dataset):
2✔
406
                    dim = len(var.data_vars)  # Multiple data variables
2✔
407
                elif isinstance(var, xr.DataArray):
2✔
408
                    dim = 1  # Single data variable
2✔
409
                elif isinstance(var, pd.DataFrame):
2✔
410
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
411
                elif isinstance(var, pd.Series):
2✔
412
                    dim = 1  # Single data variable
2✔
413
                else:
UNCOV
414
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
415
                dims.append(dim)
2✔
416
            return dims
2✔
417

418
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
419
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
420
        if self.aux_at_contexts is not None:
2✔
421
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
422
        if self.aux_at_targets is not None:
2✔
423
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
424
                0
425
            ]
426
        else:
427
            aux_at_target_dims = 0
2✔
428

429
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
430

431
    def infer_context_and_target_var_IDs(self):
2✔
432
        """Infer the variable IDs of the context and target data.
433

434
        Returns:
435
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
436
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
437

438
        Raises:
439
            ValueError: If the context/target data is not a tuple/list of
440
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
441
                        :class:`pandas.DataFrame`.
442
        """
443

444
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
445
            """If delta_ts is not None, then add the delta_t to the variable ID."""
446
            if not isinstance(datasets, (tuple, list)):
2✔
447
                datasets = [datasets]
2✔
448

449
            var_IDs = []
2✔
450
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
451
            for i, var in enumerate(datasets):
2✔
452
                if isinstance(var, xr.DataArray):
2✔
453
                    var_ID = (var.name,)  # Single data variable
2✔
454
                elif isinstance(var, xr.Dataset):
2✔
455
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
456
                elif isinstance(var, pd.DataFrame):
2✔
457
                    var_ID = tuple(var.columns)
2✔
458
                elif isinstance(var, pd.Series):
2✔
459
                    var_ID = (var.name,)
2✔
460
                else:
UNCOV
461
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
462

463
                if delta_ts is not None:
2✔
464
                    # Add delta_t to the variable ID
465
                    var_ID = tuple(
2✔
466
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
467
                    )
468
                else:
469
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
470

471
                var_IDs.append(var_ID)
2✔
472

473
            return var_IDs
2✔
474

475
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
476
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
477
            self.context, self.context_delta_t
478
        )
479
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
480
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
481
            self.target, self.target_delta_t
482
        )
483

484
        if self.aux_at_contexts is not None:
2✔
485
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
486
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
487
                self.aux_at_contexts, [0]
488
            )
489

490
        if self.aux_at_targets is not None:
2✔
491
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
492
                0
493
            ]
494
        else:
495
            aux_at_target_var_IDs = None
2✔
496

497
        return (
2✔
498
            tuple(context_var_IDs),
499
            tuple(target_var_IDs),
500
            tuple(context_var_IDs_and_delta_t),
501
            tuple(target_var_IDs_and_delta_t),
502
            aux_at_target_var_IDs,
503
        )
504

505
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
506
        """Check that the context-target links are valid.
507

508
        Args:
509
            links (Tuple[int, int] | List[Tuple[int, int]]):
510
                Specifies links between context and target data. Each link is a
511
                tuple of two integers, where the first integer is the index of
512
                the context data and the second integer is the index of the
513
                target data. Can be a single tuple in the case of a single
514
                link. If None, no links are specified. Default: None.
515

516
        Returns:
517
            Tuple[int, int] | List[Tuple[int, int]]
518
                The input links, if valid.
519

520
        Raises:
521
            ValueError
522
                If the links are not valid.
523
        """
524
        if links is None:
2✔
525
            return None
2✔
526

527
        assert isinstance(
2✔
528
            links, list
529
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
530
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
531
        assert all(
2✔
532
            isinstance(link, tuple) for link in links
533
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
534
        assert all(
2✔
535
            len(link) == 2 for link in links
536
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
537

538
        # Check that the links are valid
539
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
540
            if context_idx >= len(self.context):
2✔
UNCOV
541
                raise ValueError(
×
542
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
543
                    f"there are only {len(self.context)} context sets"
544
                )
545
            if target_idx >= len(self.target):
2✔
546
                raise ValueError(
2✔
547
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
548
                    f"there are only {len(self.target)} target sets"
549
                )
550

551
        return links
2✔
552

553
    def __str__(self):
2✔
554
        """String representation of the TaskLoader object (user-friendly)."""
555
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
556
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
557
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
558
        if self.aux_at_targets is not None:
×
UNCOV
559
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
UNCOV
560
        return s
×
561

562
    def __repr__(self):
2✔
563
        """Representation of the TaskLoader object (for developers).
564

565
        ..
566
            TODO make this a more verbose version of __str__
567
        """
568
        s = str(self)
×
569
        s += "\n"
×
570
        s += f"\nContext data dimensions: {self.context_dims}"
×
571
        s += f"\nTarget data dimensions: {self.target_dims}"
×
572
        if self.aux_at_targets is not None:
×
UNCOV
573
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
UNCOV
574
        return s
×
575

576
    def sample_da(
2✔
577
        self,
578
        da: Union[xr.DataArray, xr.Dataset],
579
        sampling_strat: Union[str, int, float, np.ndarray],
580
        seed: Optional[int] = None,
581
    ) -> (np.ndarray, np.ndarray):
582
        """Sample a DataArray according to a given strategy.
583

584
        Args:
585
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
586
                DataArray to sample, assumed to be sliced for the task already.
587
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
588
                Sampling strategy, either "all" or an integer for random grid
589
                cell sampling.
590
            seed (int, optional):
591
                Seed for random sampling. Default is None.
592

593
        Returns:
594
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
595
                Tuple of sampled target data and sampled context data.
596

597
        Raises:
598
            InvalidSamplingStrategyError:
599
                If the sampling strategy is not valid or if a numpy coordinate
600
                array is passed to sample an xarray object, but the coordinates
601
                are out of bounds.
602
        """
603
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
604
        if isinstance(da, xr.Dataset):
2✔
605
            da = da.to_array()
2✔
606

607
        if isinstance(sampling_strat, float):
2✔
608
            sampling_strat = int(sampling_strat * da.size)
2✔
609

610
        if isinstance(sampling_strat, (int, np.integer)):
2✔
611
            N = sampling_strat
2✔
612
            rng = np.random.default_rng(seed)
2✔
613
            if self.discrete_xarray_sampling:
2✔
614
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
UNCOV
615
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
UNCOV
616
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
617
            elif not self.discrete_xarray_sampling:
2✔
618
                if N == 0:
2✔
619
                    # Catch zero-context edge case before interp fails
620
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
621
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
622
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
623
                    return X_c, Y_c
2✔
624
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
625
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
626
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
627
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
628
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
629

630
        elif isinstance(sampling_strat, np.ndarray):
2✔
631
            X_c = sampling_strat.astype(self.dtype)
2✔
632
            try:
2✔
633
                Y_c = da.sel(
2✔
634
                    x1=xr.DataArray(X_c[0]),
635
                    x2=xr.DataArray(X_c[1]),
636
                    method="nearest",
637
                    tolerance=0.1,  # Maximum distance from observed point to sample
638
                )
639
            except KeyError:
2✔
640
                raise InvalidSamplingStrategyError(
2✔
641
                    f"Passed a numpy coordinate array to sample xarray object, "
642
                    f"but the coordinates are out of bounds."
643
                )
644
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
645

646
        elif sampling_strat in ["all", "gapfill"]:
2✔
647
            X_c = (
2✔
648
                da.coords["x1"].values[np.newaxis],
649
                da.coords["x2"].values[np.newaxis],
650
            )
651
            Y_c = da.data
2✔
652
            if Y_c.ndim == 2:
2✔
653
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
654
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
655
        else:
UNCOV
656
            raise InvalidSamplingStrategyError(
×
657
                f"Unknown sampling strategy {sampling_strat}"
658
            )
659

660
        if Y_c.ndim == 1:
2✔
661
            # returned a 1D array, but we need a 2D array of shape (variable, N)
662
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
663

664
        return X_c, Y_c
2✔
665

666
    def sample_df(
2✔
667
        self,
668
        df: Union[pd.DataFrame, pd.Series],
669
        sampling_strat: Union[str, int, float, np.ndarray],
670
        seed: Optional[int] = None,
671
    ) -> (np.ndarray, np.ndarray):
672
        """Sample a DataFrame according to a given strategy.
673

674
        Args:
675
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
676
                Dataframe to sample, assumed to be time-sliced for the task
677
                already.
678
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
679
                Sampling strategy, either "all" or an integer for random grid
680
                cell sampling.
681
            seed (int, optional):
682
                Seed for random sampling. Default is None.
683

684
        Returns:
685
            Tuple[X_c, Y_c]:
686
                Tuple of sampled target data and sampled context data.
687

688
        Raises:
689
            InvalidSamplingStrategyError:
690
                If the sampling strategy is not valid or if a numpy coordinate
691
                array is passed to sample a pandas object, but the DataFrame
692
                does not contain all the requested samples.
693
        """
694
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
695

696
        if isinstance(sampling_strat, float):
2✔
697
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
698

699
        if isinstance(sampling_strat, (int, np.integer)):
2✔
700
            N = sampling_strat
2✔
701
            rng = np.random.default_rng(seed)
2✔
702
            if N > df.index.size:
2✔
703
                raise SamplingTooManyPointsError(requested=N, available=df.index.size)
2✔
704
            idx = rng.choice(df.index, N, replace=False)
2✔
705
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
706
            Y_c = df.loc[idx].values.T
2✔
707
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
708
            "all",
709
            "split",
710
        ]:
711
            # NOTE if "split", we assume that the context-target split has already been applied to the df
712
            # in an earlier scope with access to both the context and target data. This is maybe risky!
713
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
714
            Y_c = df.values.T
2✔
715
        elif isinstance(sampling_strat, np.ndarray):
2✔
716
            if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
2✔
717
                raise InvalidSamplingStrategyError(
2✔
718
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
719
                    "but the coordinate array has a different dtype than the DataFrame. "
720
                    f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
721
                )
722
            X_c = sampling_strat.astype(self.dtype)
2✔
723
            try:
2✔
724
                Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
2✔
725
            except KeyError:
2✔
726
                raise InvalidSamplingStrategyError(
2✔
727
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
728
                    "but the DataFrame did not contain all the requested samples.\n"
729
                    f"Indexes: {df.index}\n"
730
                    f"Sampling coords: {X_c}\n"
731
                    "If this is unexpected, check that your numpy sampling array matches "
732
                    "the DataFrame index values *exactly*."
733
                )
734
        else:
UNCOV
735
            raise InvalidSamplingStrategyError(
×
736
                f"Unknown sampling strategy {sampling_strat}"
737
            )
738

739
        if Y_c.ndim == 1:
2✔
740
            # returned a 1D array, but we need a 2D array of shape (variable, N)
741
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
742

743
        return X_c, Y_c
2✔
744

745
    def sample_offgrid_aux(
2✔
746
        self,
747
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
748
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
749
    ) -> np.ndarray:
750
        """Sample auxiliary data at off-grid locations.
751

752
        Args:
753
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
754
                Off-grid locations at which to sample the auxiliary data. Can
755
                be a tuple of two numpy arrays, or a single numpy array.
756
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
757
                Auxiliary data at off-grid locations.
758

759
        Returns:
760
            :class:`numpy:numpy.ndarray`:
761
                [Description of the returned numpy ndarray]
762

763
        Raises:
764
            [ExceptionType]:
765
                [Description of under what conditions this function raises an exception]
766
        """
767
        if "time" in offgrid_aux.dims:
2✔
UNCOV
768
            raise ValueError(
×
769
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
770
                "passing it to `sample_offgrid_aux`."
771
            )
772
        if isinstance(X_t, tuple):
2✔
773
            xt1, xt2 = X_t
2✔
774
            xt1 = xt1.ravel()
2✔
775
            xt2 = xt2.ravel()
2✔
776
        else:
777
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
778

779
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
780
        if isinstance(Y_t_aux, xr.Dataset):
2✔
UNCOV
781
            Y_t_aux = Y_t_aux.to_array()
×
782
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
783
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
784
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
785
        ):
786
            # Reshape to (variable, *spatial_dims)
787
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
788
        return Y_t_aux
2✔
789

790
    def _compute_global_coordinate_bounds(self) -> List[float]:
2✔
791
        """Compute global coordinate bounds in order to sample spatial bounds if desired.
792

793
        Returns:
794
        -------
795
        bbox: List[float]
796
            sequence of global spatial extent as [x1_min, x1_max, x2_min, x2_max]
797
        """
798
        x1_min, x1_max, x2_min, x2_max = np.inf, -np.inf, np.inf, -np.inf
2✔
799

800
        for var in itertools.chain(self.context, self.target):
2✔
801
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
802
                var_x1_min = var.x1.min().item()
2✔
803
                var_x1_max = var.x1.max().item()
2✔
804
                var_x2_min = var.x2.min().item()
2✔
805
                var_x2_max = var.x2.max().item()
2✔
806
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
807
                var_x1_min = var.index.get_level_values("x1").min()
2✔
808
                var_x1_max = var.index.get_level_values("x1").max()
2✔
809
                var_x2_min = var.index.get_level_values("x2").min()
2✔
810
                var_x2_max = var.index.get_level_values("x2").max()
2✔
811

812
            if var_x1_min < x1_min:
2✔
813
                x1_min = var_x1_min
2✔
814

815
            if var_x1_max > x1_max:
2✔
816
                x1_max = var_x1_max
2✔
817

818
            if var_x2_min < x2_min:
2✔
819
                x2_min = var_x2_min
2✔
820

821
            if var_x2_max > x2_max:
2✔
822
                x2_max = var_x2_max
2✔
823

824
        return [x1_min, x1_max, x2_min, x2_max]
2✔
825

826
    def _compute_x1x2_direction(self) -> dict:
2✔
827
        """Compute whether the x1 and x2 coords are ascending or descending.
828

829
        Returns:
830
            dict(bool)
831
                Dictionary containing two keys: x1 and x2, with boolean values
832
                defining if these coordings increase or decrease from top left corner.
833

834
        Raises:
835
            ValueError:
836
                If all datasets are non-gridded or if direction of ascending
837
                coordinates does not match across non-gridded datasets.
838

839
        """
840
        non_gridded = {"x1": None, "x2": None}  # value to use for non-gridded data
2✔
841
        ascending = []
2✔
842
        for var in itertools.chain(self.context, self.target):
2✔
843
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
844
                coord_x1_left = var.x1[0]
2✔
845
                coord_x1_right = var.x1[-1]
2✔
846
                coord_x2_top = var.x2[0]
2✔
847
                coord_x2_bottom = var.x2[-1]
2✔
848

849
                ascending.append(
2✔
850
                    {
851
                        "x1": True if coord_x1_left <= coord_x1_right else False,
852
                        "x2": True if coord_x2_top <= coord_x2_bottom else False,
853
                    }
854
                )
855

856
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
857
                ascending.append(non_gridded)
2✔
858

859
        if len(list(filter(lambda x: x != non_gridded, ascending))) == 0:
2✔
NEW
860
            raise ValueError(
×
861
                "All data is non gridded, can not proceed with sliding window sampling."
862
            )
863

864
        # get the directions for only the gridded data
865
        gridded = list(filter(lambda x: x != non_gridded, ascending))
2✔
866
        # raise error if directions don't match across gridded data
867
        if gridded.count(gridded[0]) != len(gridded):
2✔
NEW
868
            raise ValueError(
×
869
                "Direction of ascending coordinates does not match across all gridded datasets."
870
            )
871

872
        return gridded[0]
2✔
873

874
    def sample_random_window(self, patch_size: Tuple[float]) -> Sequence[float]:
2✔
875
        """Sample random window uniformly from global coordinates to slice data.
876

877
        Parameters
878
        ----------
879
        patch_size : Tuple[float]
880
            Tuple of window extent
881

882
        Returns:
883
        -------
884
        bbox: List[float]
885
            sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max]
886
        """
887
        x1_extend, x2_extend = patch_size
2✔
888

889
        x1_side = x1_extend / 2
2✔
890
        x2_side = x2_extend / 2
2✔
891

892
        # sample a point that satisfies the context and target global bounds
893
        x1_min, x1_max, x2_min, x2_max = self.coord_bounds
2✔
894

895
        x1_point = random.uniform(x1_min + x1_side, x1_max - x1_side)
2✔
896
        x2_point = random.uniform(x2_min + x2_side, x2_max - x2_side)
2✔
897

898
        # bbox of x1_min, x1_max, x2_min, x2_max
899
        bbox = [
2✔
900
            x1_point - x1_side,
901
            x1_point + x1_side,
902
            x2_point - x2_side,
903
            x2_point + x2_side,
904
        ]
905

906
        return bbox
2✔
907

908
    def time_slice_variable(self, var, date, delta_t=0):
2✔
909
        """Slice a variable by a given time delta.
910

911
        Args:
912
            var (...):
913
                Variable to slice.
914
            delta_t (...):
915
                Time delta to slice by.
916

917
        Returns:
918
            var (...)
919
                Sliced variable.
920

921
        Raises:
922
            ValueError
923
                If the variable is of an unknown type.
924
        """
925
        # TODO: Does this work with instantaneous time?
926
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
927
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
928
            if "time" in var.dims:
2✔
929
                var = var.sel(time=date + delta_t)
2✔
930
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
931
            if "time" in var.index.names:
2✔
932
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
933
        else:
UNCOV
934
            raise ValueError(f"Unknown variable type {type(var)}")
×
935
        return var
2✔
936

937
    def spatial_slice_variable(self, var, window: List[float]):
2✔
938
        """Slice a variable by a given window size.
939

940
        Args:
941
            var (...):
942
                Variable to slice.
943
            window (List[float]):
944
                List of coordinates specifying the window [x1_min, x1_max, x2_min, x2_max].
945

946
        Returns:
947
            var (...)
948
                Sliced variable.
949

950
        Raises:
951
            ValueError
952
                If the variable is of an unknown type.
953
        """
954
        x1_min, x1_max, x2_min, x2_max = window
2✔
955
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
956
            # we cannot assume that the coordinates are sorted from small to large
957
            if var.x1[0] > var.x1[-1]:
2✔
958
                x1_slice = slice(x1_max, x1_min)
2✔
959
            else:
960
                x1_slice = slice(x1_min, x1_max)
2✔
961
            if var.x2[0] > var.x2[-1]:
2✔
UNCOV
962
                x2_slice = slice(x2_max, x2_min)
×
963
            else:
964
                x2_slice = slice(x2_min, x2_max)
2✔
965
            var = var.sel(x1=x1_slice, x2=x2_slice)
2✔
966
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
967
            # retrieve desired patch size
968
            var = var[
2✔
969
                (var.index.get_level_values("x1") >= x1_min)
970
                & (var.index.get_level_values("x1") <= x1_max)
971
                & (var.index.get_level_values("x2") >= x2_min)
972
                & (var.index.get_level_values("x2") <= x2_max)
973
            ]
974
        else:
UNCOV
975
            raise ValueError(f"Unknown variable type {type(var)}")
×
976

977
        return var
2✔
978

979
    def task_generation(  # noqa: D102
2✔
980
        self,
981
        date: pd.Timestamp,
982
        context_sampling: Union[
983
            str,
984
            int,
985
            float,
986
            np.ndarray,
987
            List[Union[str, int, float, np.ndarray]],
988
        ] = "all",
989
        target_sampling: Optional[
990
            Union[
991
                str,
992
                int,
993
                float,
994
                np.ndarray,
995
                List[Union[str, int, float, np.ndarray]],
996
            ]
997
        ] = None,
998
        split_frac: float = 0.5,
999
        bbox: Sequence[float] = None,
1000
        patch_size: Union[float, Tuple[float]] = None,
1001
        stride: Union[float, Tuple[float]] = None,
1002
        datewise_deterministic: bool = False,
1003
        seed_override: Optional[int] = None,
1004
    ) -> Task:
1005
        """Generate a task for a given date.
1006

1007
        There are several sampling strategies available for the context and
1008
        target data:
1009

1010
            - "all": Sample all observations.
1011
            - int: Sample N observations uniformly at random.
1012
            - float: Sample a fraction of observations uniformly at random.
1013
            - :class:`numpy:numpy.ndarray`, shape (2, N): Sample N observations
1014
              at the given x1, x2 coordinates. Coords are assumed to be
1015
              unnormalised.
1016

1017
        Parameters
1018
        ----------
1019
        date : :class:`pandas.Timestamp`
1020
            Date for which to generate the task.
1021
        context_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`]
1022
            Sampling strategy for the context data, either a list of sampling
1023
            strategies for each context set, or a single strategy applied to
1024
            all context sets. Default is ``"all"``.
1025
        target_sampling : str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`]
1026
            Sampling strategy for the target data, either a list of sampling
1027
            strategies for each target set, or a single strategy applied to all
1028
            target sets. Default is ``"all"``.
1029
        split_frac : float
1030
            The fraction of observations to use for the context set with the
1031
            "split" sampling strategy for linked context and target set pairs.
1032
            The remaining observations are used for the target set. Default is
1033
            0.5.
1034
        bbox : Sequence[float], optional
1035
            Bounding box to spatially slice the data, should be of the form [x1_min, x1_max, x2_min, x2_max].
1036
            Useful when considering the entire available region is computationally prohibitive for model forward pass.
1037
        patch_size : Union(Tuple|float), optional
1038
            Only used by patchwise inference. Height and width of patch in x1/x2 normalised coordinates.
1039
        stride: Union(Tuple|float), optional
1040
            Only used by patchwise inference. Length of stride between adjacent patches in x1/x2 normalised coordinates.
1041
        datewise_deterministic : bool
1042
            Whether random sampling is datewise_deterministic based on the
1043
            date. Default is ``False``.
1044
        seed_override : Optional[int]
1045
            Override the seed for random sampling. This can be used to use the
1046
            same random sampling at different ``date``. Default is None.
1047

1048
        Returns:
1049
        -------
1050
        task : :class:`~.data.task.Task`
1051
            Task object containing the context and target data.
1052
        """
1053

1054
        def check_sampling_strat(sampling_strat, set):
2✔
1055
            """Check the sampling strategy.
1056

1057
            Ensure ``sampling_strat`` is either a single strategy (broadcast
1058
            to all sets) or a list of length equal to the number of sets.
1059
            Convert to a tuple of length equal to the number of sets and
1060
            return.
1061

1062
            Args:
1063
                sampling_strat:
1064
                    Sampling strategy to check.
1065
                set:
1066
                    Context or target set to check.
1067

1068
            Returns:
1069
                tuple:
1070
                    Tuple of sampling strategies, one for each set.
1071

1072
            Raises:
1073
                InvalidSamplingStrategyError:
1074
                    - If the sampling strategy is invalid.
1075
                    - If the length of the sampling strategy does not match the number of sets.
1076
                    - If the sampling strategy is not a valid type.
1077
                    - If the sampling strategy is a float but not in [0, 1].
1078
                    - If the sampling strategy is an int but not positive.
1079
                    - If the sampling strategy is a numpy array but not of shape (2, N).
1080
            """
1081
            if sampling_strat is None:
2✔
1082
                return None
2✔
1083
            if not isinstance(sampling_strat, (list, tuple)):
2✔
1084
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
1085
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
1086
                sampling_strat
1087
            ) != len(set):
1088
                raise InvalidSamplingStrategyError(
2✔
1089
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
1090
                    f"match number of context sets ({len(set)})"
1091
                )
1092

1093
            for strat in sampling_strat:
2✔
1094
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
1095
                    raise InvalidSamplingStrategyError(
2✔
1096
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
1097
                    )
1098
                elif isinstance(strat, str) and strat == "gapfill":
2✔
1099
                    assert all(
2✔
1100
                        isinstance(item, (xr.Dataset, xr.DataArray)) for item in set
1101
                    ), (
1102
                        "Gapfill sampling strategy can only be used with xarray "
1103
                        "datasets or data arrays"
1104
                    )
1105
                elif isinstance(strat, str) and strat not in [
2✔
1106
                    "all",
1107
                    "split",
1108
                    "gapfill",
1109
                ]:
1110
                    raise InvalidSamplingStrategyError(
2✔
1111
                        f"Unknown sampling strategy {strat} for type str"
1112
                    )
1113
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
1114
                    raise InvalidSamplingStrategyError(
2✔
1115
                        f"If sampling strategy is a float, must be fraction "
1116
                        f"must be in [0, 1], got {strat}"
1117
                    )
1118
                elif isinstance(strat, int) and strat < 0:
2✔
1119
                    raise InvalidSamplingStrategyError(
2✔
1120
                        f"Sampling N must be positive, got {strat}"
1121
                    )
1122
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
1123
                    raise InvalidSamplingStrategyError(
2✔
1124
                        "Sampling coordinates must be of shape (2, N), got "
1125
                        f"{strat.shape}"
1126
                    )
1127

1128
            return sampling_strat
2✔
1129

1130
        def sample_variable(var, sampling_strat, seed):
2✔
1131
            """Sample a variable by a given sampling strategy to get input and
1132
            output data.
1133

1134
            Args:
1135
                var:
1136
                    Variable to sample.
1137
                sampling_strat:
1138
                    Sampling strategy to use.
1139
                seed:
1140
                    Seed for random sampling.
1141

1142
            Returns:
1143
                Tuple[X, Y]:
1144
                    Tuple of input and output data.
1145

1146
            Raises:
1147
                ValueError:
1148
                    If the variable is of an unknown type.
1149
            """
1150
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
1151
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
1152
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
1153
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
1154
            else:
UNCOV
1155
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
1156
            return X, Y
2✔
1157

1158
        # Check that the sampling strategies are valid
1159
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
1160
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
1161
        # Check `split_frac
1162
        if split_frac < 0 or split_frac > 1:
2✔
1163
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
1164
        if self.links is None:
2✔
1165
            b1 = any(
2✔
1166
                [
1167
                    strat in ["split", "gapfill"]
1168
                    for strat in context_sampling
1169
                    if isinstance(strat, str)
1170
                ]
1171
            )
1172
            if target_sampling is None:
2✔
1173
                b2 = False
2✔
1174
            else:
1175
                b2 = any(
2✔
1176
                    [
1177
                        strat in ["split", "gapfill"]
1178
                        for strat in target_sampling
1179
                        if isinstance(strat, str)
1180
                    ]
1181
                )
1182
            if b1 or b2:
2✔
1183
                raise ValueError(
2✔
1184
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
1185
                    "sets must be linked with the TaskLoader `links` attribute."
1186
                )
1187
        if self.links is not None:
2✔
1188
            for context_idx, target_idx in self.links:
2✔
1189
                context_sampling_i = context_sampling[context_idx]
2✔
1190
                if target_sampling is None:
2✔
UNCOV
1191
                    target_sampling_i = None
×
1192
                else:
1193
                    target_sampling_i = target_sampling[target_idx]
2✔
1194
                link_strats = (context_sampling_i, target_sampling_i)
2✔
1195
                if any(
2✔
1196
                    [
1197
                        strat in ["split", "gapfill"]
1198
                        for strat in link_strats
1199
                        if isinstance(strat, str)
1200
                    ]
1201
                ):
1202
                    # If one of the sampling strategies is "split" or "gapfill", the other must
1203
                    # use the same splitting strategy
1204
                    if link_strats[0] != link_strats[1]:
2✔
1205
                        raise ValueError(
2✔
1206
                            f"Linked context set {context_idx} and target set {target_idx} "
1207
                            f"must use the same sampling strategy if one of them "
1208
                            f"uses the 'split' or 'gapfill' sampling strategy. "
1209
                            f"Got {link_strats[0]} and {link_strats[1]}."
1210
                        )
1211

1212
        if not isinstance(date, pd.Timestamp):
2✔
1213
            date = pd.Timestamp(date)
2✔
1214

1215
        if seed_override is not None:
2✔
1216
            # Override the seed for random sampling
NEW
1217
            seed = seed_override
×
1218
        elif datewise_deterministic:
2✔
1219
            # Generate a deterministic seed, based on the date, for random sampling
1220
            seed = int(date.strftime("%Y%m%d"))
2✔
1221
        else:
1222
            # 'Truly' random sampling
1223
            seed = None
2✔
1224

1225
        task = {}
2✔
1226

1227
        task["time"] = date
2✔
1228
        task["ops"] = []
2✔
1229
        task["bbox"] = bbox
2✔
1230
        task["patch_size"] = (
2✔
1231
            patch_size  # store patch_size and stride in task for use in stitching in prediction
1232
        )
1233
        task["stride"] = stride
2✔
1234
        task["X_c"] = []
2✔
1235
        task["Y_c"] = []
2✔
1236
        if target_sampling is not None:
2✔
1237
            task["X_t"] = []
2✔
1238
            task["Y_t"] = []
2✔
1239
        else:
1240
            task["X_t"] = None
2✔
1241
            task["Y_t"] = None
2✔
1242

1243
        # temporal slices
1244
        context_slices = [
2✔
1245
            self.time_slice_variable(var, date, delta_t)
1246
            for var, delta_t in zip(self.context, self.context_delta_t)
1247
        ]
1248
        target_slices = [
2✔
1249
            self.time_slice_variable(var, date, delta_t)
1250
            for var, delta_t in zip(self.target, self.target_delta_t)
1251
        ]
1252

1253
        # TODO move to method
1254
        if (
2✔
1255
            self.links is not None
1256
            and "split" in context_sampling
1257
            and "split" in target_sampling
1258
        ):
1259
            # Perform the split sampling strategy for linked context and target sets at this point
1260
            # while we have the full context and target data in scope
1261

1262
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1263
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1264
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1265
                f"Number of context sets with 'split' sampling strategy "
1266
                f"({len(context_split_idxs)}) must match number of target sets "
1267
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1268
            )
1269
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1270
                zip(context_split_idxs, target_split_idxs)
1271
            ):
1272
                assert (context_idx, target_idx) in self.links, (
2✔
1273
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1274
                    f"with the `links` attribute if using the 'split' sampling strategy"
1275
                )
1276

1277
                context_var = context_slices[context_idx]
2✔
1278
                target_var = target_slices[target_idx]
2✔
1279

1280
                for var in [context_var, target_var]:
2✔
1281
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1282
                        f"If using 'split' sampling strategy for linked context and target sets, "
1283
                        f"the context and target sets must be pandas DataFrames or Series, "
1284
                        f"but got {type(var)}."
1285
                    )
1286

1287
                N_obs = len(context_var)
2✔
1288
                N_obs_target_check = len(target_var)
2✔
1289
                if N_obs != N_obs_target_check:
2✔
UNCOV
1290
                    raise ValueError(
×
1291
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1292
                        f"because they have different numbers of observations: "
1293
                        f"{N_obs} and {N_obs_target_check}"
1294
                    )
1295
                split_seed = seed + split_i if seed is not None else None
2✔
1296
                rng = np.random.default_rng(split_seed)
2✔
1297

1298
                N_context = int(N_obs * split_frac)
2✔
1299
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1300

1301
                context_var = context_var.iloc[idxs_context]
2✔
1302
                target_var = target_var.drop(context_var.index)
2✔
1303

1304
                context_slices[context_idx] = context_var
2✔
1305
                target_slices[target_idx] = target_var
2✔
1306

1307
        # TODO move to method
1308
        if (
2✔
1309
            self.links is not None
1310
            and "gapfill" in context_sampling
1311
            and "gapfill" in target_sampling
1312
        ):
1313
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1314
            # while we have the full context and target data in scope
1315

1316
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1317
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1318
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1319
                f"Number of context sets with 'gapfill' sampling strategy "
1320
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1321
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1322
            )
1323
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1324
                zip(context_gapfill_idxs, target_gapfill_idxs)
1325
            ):
1326
                assert (context_idx, target_idx) in self.links, (
2✔
1327
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1328
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1329
                )
1330

1331
                context_var = context_slices[context_idx]
2✔
1332
                target_var = target_slices[target_idx]
2✔
1333

1334
                for var in [context_var, target_var]:
2✔
1335
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1336
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1337
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1338
                        f"but got {type(var)}."
1339
                    )
1340

1341
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1342
                rng = np.random.default_rng(split_seed)
2✔
1343

1344
                # Keep trying until we get a target set with at least one target point
1345
                keep_searching = True
2✔
1346
                while keep_searching:
2✔
1347
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1348
                    added_mask = (
2✔
1349
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1350
                    )
1351
                    curr_mask = context_var.isnull()
2✔
1352

1353
                    # Mask out added missing values
1354
                    context_var = context_var.where(~added_mask)
2✔
1355

1356
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1357
                    #   when we could just slice the target values here
1358
                    target_mask = added_mask & ~curr_mask
2✔
1359
                    if isinstance(target_var, xr.Dataset):
2✔
NEW
1360
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1361
                    else:
1362
                        keep_searching = np.all(target_mask.data == False)
2✔
1363
                    if keep_searching:
2✔
UNCOV
1364
                        continue  # No target points -- use a different `added_mask`
×
1365

1366
                    target_var = target_var.where(
2✔
1367
                        target_mask
1368
                    )  # Only keep target locations
1369

1370
                    context_slices[context_idx] = context_var
2✔
1371
                    target_slices[target_idx] = target_var
2✔
1372

1373
        # check bbox size
1374
        if bbox is not None:
2✔
1375
            assert (
2✔
1376
                len(bbox) == 4
1377
            ), "bbox must be a list of length 4 with [x1_min, x1_max, x2_min, x2_max]"
1378

1379
            # spatial slices
1380
            context_slices = [
2✔
1381
                self.spatial_slice_variable(var, bbox) for var in context_slices
1382
            ]
1383
            target_slices = [
2✔
1384
                self.spatial_slice_variable(var, bbox) for var in target_slices
1385
            ]
1386

1387
        for i, (var, sampling_strat) in enumerate(
2✔
1388
            zip(context_slices, context_sampling)
1389
        ):
1390
            context_seed = seed + i if seed is not None else None
2✔
1391
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1392
            task[f"X_c"].append(X_c)
2✔
1393
            task[f"Y_c"].append(Y_c)
2✔
1394
        if target_sampling is not None:
2✔
1395
            for j, (var, sampling_strat) in enumerate(
2✔
1396
                zip(target_slices, target_sampling)
1397
            ):
1398
                target_seed = seed + i + j if seed is not None else None
2✔
1399
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1400
                task[f"X_t"].append(X_t)
2✔
1401
                task[f"Y_t"].append(Y_t)
2✔
1402

1403
        if self.aux_at_contexts is not None:
2✔
1404
            # Add auxiliary variable sampled at context set as a new context variable
1405
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1406
            if len(X_c_offgrid) == 0:
2✔
1407
                # No offgrid context sets
NEW
1408
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1409
            else:
1410
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1411
            Y_c_aux = (
2✔
1412
                self.sample_offgrid_aux(
1413
                    X_c_offrid_all,
1414
                    self.time_slice_variable(self.aux_at_contexts, date),
1415
                ),
1416
            )
1417
            task["X_c"].append(X_c_offrid_all)
2✔
1418
            task["Y_c"].append(Y_c_aux)
2✔
1419

1420
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1421
            task["Y_t_aux"] = None
2✔
1422
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1423
            # Add auxiliary variable to target set
1424
            if len(task["X_t"]) > 1:
2✔
NEW
1425
                raise ValueError(
×
1426
                    "Cannot add auxiliary variable to target set when there "
1427
                    "are multiple target variables (not supported by default `ConvNP` model)."
1428
                )
1429
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1430
                task["X_t"][0],
1431
                self.time_slice_variable(self.aux_at_targets, date),
1432
            )
1433

1434
        return Task(task)
2✔
1435

1436
    def sample_sliding_window(
2✔
1437
        self, patch_size: Tuple[float], stride: Tuple[int]
1438
    ) -> Sequence[float]:
1439
        """Sample data using sliding window from global coordinates to slice data.
1440
        Parameters.
1441
        ----------
1442
        patch_size : Tuple[float]
1443
            Tuple of window extent
1444

1445
        stride : Tuple[float]
1446
            Tuple of step size between each patch along x1 and x2 axis.
1447

1448
        Returns:
1449
        -------
1450
        List[float]
1451
            Sequence of patch spatial extent as [x1_min, x1_max, x2_min, x2_max].
1452
        """
1453
        self.coord_directions = self._compute_x1x2_direction()
2✔
1454
        # define patch size in x1/x2
1455
        size = {}
2✔
1456
        size["x1"], size["x2"] = patch_size
2✔
1457

1458
        # define stride length in x1/x2 or set to patch_size if undefined
1459
        if stride is None:
2✔
NEW
1460
            stride = patch_size
×
1461

1462
        step = {}
2✔
1463
        step["x1"], step["x2"] = stride
2✔
1464

1465
        # Calculate the global bounds of context and target set.
1466
        coord_min = {}
2✔
1467
        coord_max = {}
2✔
1468
        coord_min["x1"], coord_max["x1"], coord_min["x2"], coord_max["x2"] = (
2✔
1469
            self.coord_bounds
1470
        )
1471

1472
        ## start with first patch top left hand corner at coord_min["x1"], coord_min["x2"]
1473
        patch_list = []
2✔
1474

1475
        # define some lambda functions for use below
1476
        # round to 12 figures to avoid floating point error but reduce likelihood of unintentional rounding
1477
        r = lambda x: round(x, 12)
2✔
1478
        bbox_coords_ascend = lambda a, b: [r(a), r(a + b)]
2✔
1479
        bbox_coords_descend = lambda a, b: bbox_coords_ascend(a, b)[::-1]
2✔
1480

1481
        compare = {}
2✔
1482
        bbox_coords = {}
2✔
1483
        # for each coordinate direction specify the correct operations for patching
1484
        for c in ("x1", "x2"):
2✔
1485
            if self.coord_directions[c]:
2✔
1486
                compare[c] = operator.gt
2✔
1487
                bbox_coords[c] = bbox_coords_ascend
2✔
1488
            else:
1489
                step[c] = -step[c]
2✔
1490
                coord_min[c], coord_max[c] = coord_max[c], coord_min[c]
2✔
1491
                size[c] = -size[c]
2✔
1492
                compare[c] = operator.lt
2✔
1493
                bbox_coords[c] = bbox_coords_descend
2✔
1494

1495
        # Define the bounding boxes for all patches, starting in top left corner of dataArray
1496
        for y, x in itertools.product(
2✔
1497
            np.arange(coord_min["x1"], coord_max["x1"], step["x1"]),
1498
            np.arange(coord_min["x2"], coord_max["x2"], step["x2"]),
1499
        ):
1500
            y0 = (
2✔
1501
                coord_max["x1"] - size["x1"]
1502
                if compare["x1"](y + size["x1"], coord_max["x1"])
1503
                else y
1504
            )
1505
            x0 = (
2✔
1506
                coord_max["x2"] - size["x2"]
1507
                if compare["x2"](x + size["x2"], coord_max["x2"])
1508
                else x
1509
            )
1510

1511
            # bbox of x1_min, x1_max, x2_min, x2_max per patch
1512
            bbox = bbox_coords["x1"](y0, size["x1"]) + bbox_coords["x2"](x0, size["x2"])
2✔
1513
            patch_list.append(bbox)
2✔
1514

1515
        # Remove duplicate patches while preserving order
1516
        seen = set()
2✔
1517
        unique_patch_list = []
2✔
1518
        for lst in patch_list:
2✔
1519
            # Convert list to tuple for immutability
1520
            tuple_lst = tuple(lst)
2✔
1521
            if tuple_lst not in seen:
2✔
1522
                seen.add(tuple_lst)
2✔
1523
                unique_patch_list.append(lst)
2✔
1524

1525
        return unique_patch_list
2✔
1526

1527
    def __call__(
2✔
1528
        self,
1529
        date: Union[pd.Timestamp, Sequence[pd.Timestamp]],
1530
        context_sampling: Union[
1531
            str,
1532
            int,
1533
            float,
1534
            np.ndarray,
1535
            List[Union[str, int, float, np.ndarray]],
1536
        ] = "all",
1537
        target_sampling: Optional[
1538
            Union[
1539
                str,
1540
                int,
1541
                float,
1542
                np.ndarray,
1543
                List[Union[str, int, float, np.ndarray]],
1544
            ]
1545
        ] = None,
1546
        split_frac: float = 0.5,
1547
        patch_size: Union[float, Tuple[float]] = None,
1548
        patch_strategy: Optional[str] = None,
1549
        stride: Union[float, Tuple[float]] = None,
1550
        num_samples_per_date: int = 1,
1551
        datewise_deterministic: bool = False,
1552
        seed_override: Optional[int] = None,
1553
    ) -> Union[Task, List[Task]]:
1554
        """Generate a task for a given date (or a list of
1555
        :class:`.data.task.Task` objects for a list of dates).
1556

1557
        There are several sampling strategies available for the context and
1558
        target data:
1559

1560
            - "all": Sample all observations.
1561
            - int: Sample N observations uniformly at random.
1562
            - float: Sample a fraction of observations uniformly at random.
1563
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1564
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1565
                normalised.
1566
            - "split": Split pandas observations into disjoint context and target sets.
1567
                `split_frac` determines the fraction of observations
1568
                to use for the context set. The remaining observations are used
1569
                for the target set.
1570
                The context set and target set must be linked through the ``TaskLoader``
1571
                ``links`` argument. Only valid for pandas data.
1572
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1573
                Randomly samples a missing data (NaN) mask from another timestamp and
1574
                adds it to the context set (i.e. increases the number of NaNs).
1575
                The target set is then true values of the data at the added missing locations.
1576
                The context set and target set must be linked through the ``TaskLoader``
1577
                ``links`` argument. Only valid for xarray data.
1578

1579
        Args:
1580
            date (:class:`pandas.Timestamp`):
1581
                Date for which to generate the task.
1582
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1583
                Sampling strategy for the context data, either a list of
1584
                sampling strategies for each context set, or a single strategy
1585
                applied to all context sets. Default is ``"all"``.
1586
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1587
                Sampling strategy for the target data, either a list of
1588
                sampling strategies for each target set, or a single strategy
1589
                applied to all target sets. Default is ``None``, meaning no target
1590
                data is returned.
1591
            split_frac (float, optional):
1592
                The fraction of observations to use for the context set with
1593
                the "split" sampling strategy for linked context and target set
1594
                pairs. The remaining observations are used for the target set.
1595
                Default is 0.5.
1596
            patch_size : Union[float, tuple[float]], optional
1597
                Desired patch size in x1/x2 used for patchwise task generation. Useful when considering
1598
                the entire available region is computationally prohibitive for model forward pass.
1599
                If passed a single float, will use value for both x1 & x2.
1600
            patch_strategy:
1601
                Patch strategy to use for patchwise task generation. Default is None.
1602
                Possible options are 'random' or 'sliding'.
1603
            stride: Union[float, tuple[float]], optional
1604
                Step size between each sliding window patch along x1 and x2 axis. Default is None.
1605
                If passed a single float, will use value for both x1 & x2.
1606
            datewise_deterministic (bool, optional):
1607
                Whether random sampling is datewise deterministic based on the
1608
                date. Default is ``False``.
1609
            seed_override (Optional[int], optional):
1610
                Override the seed for random sampling. This can be used to use
1611
                the same random sampling at different ``date``. Default is
1612
                None.
1613

1614
        Returns:
1615
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1616
                Task object or list of task objects for each date containing
1617
                the context and target data.
1618
        """
1619
        if patch_strategy not in [None, "random", "sliding"]:
2✔
NEW
1620
            raise ValueError(
×
1621
                f"Invalid patch strategy {patch_strategy}. "
1622
                f"Must be one of [None, 'random', 'sliding']."
1623
            )
1624

1625
        if isinstance(patch_size, float) and patch_size is not None:
2✔
1626
            patch_size = (patch_size, patch_size)
2✔
1627

1628
        if isinstance(stride, float) and stride is not None:
2✔
1629
            stride = (stride, stride)
2✔
1630

1631
        if patch_strategy is None:
2✔
1632
            if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1633
                tasks = [
2✔
1634
                    self.task_generation(
1635
                        d,
1636
                        context_sampling=context_sampling,
1637
                        target_sampling=target_sampling,
1638
                        split_frac=split_frac,
1639
                        datewise_deterministic=datewise_deterministic,
1640
                        seed_override=seed_override,
1641
                    )
1642
                    for d in date
1643
                ]
1644
            else:
1645
                tasks = self.task_generation(
2✔
1646
                    date=date,
1647
                    context_sampling=context_sampling,
1648
                    target_sampling=target_sampling,
1649
                    split_frac=split_frac,
1650
                    datewise_deterministic=datewise_deterministic,
1651
                    seed_override=seed_override,
1652
                )
1653

1654
        elif patch_strategy == "random":
2✔
1655
            if patch_size is None:
2✔
NEW
1656
                raise ValueError(
×
1657
                    "Patch size must be specified for random patch sampling"
1658
                )
1659

1660
            coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]]
2✔
1661
            for i, val in enumerate(patch_size):
2✔
1662
                if val < coord_bounds[i][0] or val > coord_bounds[i][1]:
2✔
1663
                    raise ValueError(
2✔
1664
                        f"Values of stride must be between the normalised coordinate bounds of: {self.coord_bounds}. \
1665
                            Got: patch_size: {patch_size}."
1666
                    )
1667

1668
            if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1669
                for d in date:
2✔
1670
                    bboxes = [
2✔
1671
                        self.sample_random_window(patch_size)
1672
                        for _ in range(num_samples_per_date)
1673
                    ]
1674
                    tasks = [
2✔
1675
                        self.task_generation(
1676
                            d,
1677
                            bbox=bbox,
1678
                            context_sampling=context_sampling,
1679
                            target_sampling=target_sampling,
1680
                            split_frac=split_frac,
1681
                            datewise_deterministic=datewise_deterministic,
1682
                            seed_override=seed_override,
1683
                        )
1684
                        for bbox in bboxes
1685
                    ]
1686

1687
            else:
1688
                bboxes = [
2✔
1689
                    self.sample_random_window(patch_size)
1690
                    for _ in range(num_samples_per_date)
1691
                ]
1692
                tasks = [
2✔
1693
                    self.task_generation(
1694
                        date,
1695
                        bbox=bbox,
1696
                        context_sampling=context_sampling,
1697
                        target_sampling=target_sampling,
1698
                        split_frac=split_frac,
1699
                        datewise_deterministic=datewise_deterministic,
1700
                        seed_override=seed_override,
1701
                    )
1702
                    for bbox in bboxes
1703
                ]
1704

1705
        elif patch_strategy == "sliding":
2✔
1706
            # sliding window sampling of patch
1707

1708
            for val in (patch_size, stride):
2✔
1709
                if val is None:
2✔
NEW
1710
                    raise ValueError(
×
1711
                        f"patch_size and stride must be specified for sliding window sampling, got patch_size: {patch_size} and stride: {stride}."
1712
                    )
1713

1714
            if stride[0] > patch_size[0] or stride[1] > patch_size[1]:
2✔
1715
                raise Warning(
2✔
1716
                    f"stride should generally be smaller than patch_size in the corresponding dimensions. Got: patch_size: {patch_size}, stride: {stride}"
1717
                )
1718

1719
            coord_bounds = [self.coord_bounds[0:2], self.coord_bounds[2:]]
2✔
1720
            for i in (0, 1):
2✔
1721
                for val in (patch_size[i], stride[i]):
2✔
1722
                    if val < coord_bounds[i][0] or val > coord_bounds[i][1]:
2✔
NEW
1723
                        raise ValueError(
×
1724
                            f"Values of stride and patch_size must be between the normalised coordinate bounds of: {self.coord_bounds}. \
1725
                                Got: patch_size: {patch_size}, stride: {stride}"
1726
                        )
1727

1728
            if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1729
                tasks = []
2✔
1730
                for d in date:
2✔
1731
                    bboxes = self.sample_sliding_window(patch_size, stride)
2✔
1732
                    tasks.extend(
2✔
1733
                        [
1734
                            self.task_generation(
1735
                                d,
1736
                                bbox=bbox,
1737
                                context_sampling=context_sampling,
1738
                                target_sampling=target_sampling,
1739
                                split_frac=split_frac,
1740
                                datewise_deterministic=datewise_deterministic,
1741
                                seed_override=seed_override,
1742
                                patch_size=patch_size,
1743
                                stride=stride,
1744
                            )
1745
                            for bbox in bboxes
1746
                        ]
1747
                    )
1748
            else:
1749
                bboxes = self.sample_sliding_window(patch_size, stride)
2✔
1750
                tasks = [
2✔
1751
                    self.task_generation(
1752
                        date,
1753
                        bbox=bbox,
1754
                        context_sampling=context_sampling,
1755
                        target_sampling=target_sampling,
1756
                        split_frac=split_frac,
1757
                        datewise_deterministic=datewise_deterministic,
1758
                        seed_override=seed_override,
1759
                        patch_size=patch_size,
1760
                        stride=stride,
1761
                    )
1762
                    for bbox in bboxes
1763
                ]
1764
        else:
NEW
1765
            raise ValueError(
×
1766
                f"Invalid patch strategy {patch_strategy}. "
1767
                f"Must be one of [None, 'random', 'sliding']."
1768
            )
1769

1770
        return tasks
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc