• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455747995

22 Oct 2024 07:56AM UTC coverage: 81.626% (+0.3%) from 81.333%
11455747995

push

github

davidwilby
incorporate feedback

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.11
/deepsensor/data/loader.py
1
from deepsensor.data.task import Task, flatten_X
2✔
2

3
import os
2✔
4
import json
2✔
5
import copy
2✔
6

7
import numpy as np
2✔
8
import xarray as xr
2✔
9
import pandas as pd
2✔
10

11
from typing import List, Tuple, Union, Optional
2✔
12

13
from deepsensor.errors import InvalidSamplingStrategyError
2✔
14

15

16
class TaskLoader:
2✔
17
    """
18
    Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
19

20
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
21
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
22
    of these.
23

24
    The behaviour is the following:
25
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
26
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
27
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
28
          (using config)
29

30
    Args:
31
        task_loader_ID:
32
            If loading a TaskLoader from a config file, this is the folder the
33
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
34
            arguments are ignored.
35
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
36
            Context data. Can be a single :class:`xarray.DataArray`,
37
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
38
            list/tuple of these.
39
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
40
            Target data. Can be a single :class:`xarray.DataArray`,
41
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
42
            list/tuple of these.
43
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
44
            Auxiliary data at context locations. Tuple of two elements, where
45
            the first element is the index of the context set for which the
46
            auxiliary data will be sampled at, and the second element is the
47
            auxiliary data, which can be a single :class:`xarray.DataArray` or
48
            :class:`xarray.Dataset`. Default: None.
49
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
50
            Auxiliary data at target locations. Can be a single
51
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
52
            None.
53
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
54
            Specifies links between context and target data. Each link is a
55
            tuple of two integers, where the first integer is the index of the
56
            context data and the second integer is the index of the target
57
            data. Can be a single tuple in the case of a single link. If None,
58
            no links are specified. Default: None.
59
        context_delta_t (int | List[int], optional):
60
            Time difference between context data and t=0 (task init time). Can
61
            be a single int (same for all context data) or a list/tuple of
62
            ints. Default is 0.
63
        target_delta_t (int | List[int], optional):
64
            Time difference between target data and t=0 (task init time). Can
65
            be a single int (same for all target data) or a list/tuple of ints.
66
            Default is 0.
67
        time_freq (str, optional):
68
            Time frequency of the data. Default: ``'D'`` (daily).
69
        xarray_interp_method (str, optional):
70
            Interpolation method to use when interpolating
71
            :class:`xarray.DataArray`. Default is ``'linear'``.
72
        discrete_xarray_sampling (bool, optional):
73
            When randomly sampling xarray variables, whether to sample at
74
            discrete points defined at grid cell centres, or at continuous
75
            points within the grid. Default is ``False``.
76
        dtype (object, optional):
77
            Data type of the data. Used to cast the data to the specified
78
            dtype. Default: ``'float32'``.
79
    """
80

81
    config_fname = "task_loader_config.json"
2✔
82

83
    def __init__(
2✔
84
        self,
85
        task_loader_ID: Union[str, None] = None,
86
        context: Union[
87
            xr.DataArray,
88
            xr.Dataset,
89
            pd.DataFrame,
90
            str,
91
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
92
        ] = None,
93
        target: Union[
94
            xr.DataArray,
95
            xr.Dataset,
96
            pd.DataFrame,
97
            str,
98
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
99
        ] = None,
100
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
101
        aux_at_targets: Optional[
102
            Union[
103
                xr.DataArray,
104
                xr.Dataset,
105
            ]
106
        ] = None,
107
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
108
        context_delta_t: Union[int, List[int]] = 0,
109
        target_delta_t: Union[int, List[int]] = 0,
110
        time_freq: str = "D",
111
        xarray_interp_method: str = "linear",
112
        discrete_xarray_sampling: bool = False,
113
        dtype: object = "float32",
114
    ) -> None:
115
        if task_loader_ID is not None:
2✔
116
            self.task_loader_ID = task_loader_ID
2✔
117
            # Load TaskLoader from config file
118
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
119
            with open(fpath, "r") as f:
2✔
120
                self.config = json.load(f)
2✔
121

122
            self.context = self.config["context"]
2✔
123
            self.target = self.config["target"]
2✔
124
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
125
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
126
            self.links = self.config["links"]
2✔
127
            if self.links is not None:
2✔
128
                self.links = [tuple(link) for link in self.links]
2✔
129
            self.context_delta_t = self.config["context_delta_t"]
2✔
130
            self.target_delta_t = self.config["target_delta_t"]
2✔
131
            self.time_freq = self.config["time_freq"]
2✔
132
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
133
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
134
            self.dtype = self.config["dtype"]
2✔
135
        else:
136
            self.context = context
2✔
137
            self.target = target
2✔
138
            self.aux_at_contexts = aux_at_contexts
2✔
139
            self.aux_at_targets = aux_at_targets
2✔
140
            self.links = links
2✔
141
            self.context_delta_t = context_delta_t
2✔
142
            self.target_delta_t = target_delta_t
2✔
143
            self.time_freq = time_freq
2✔
144
            self.xarray_interp_method = xarray_interp_method
2✔
145
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
146
            self.dtype = dtype
2✔
147

148
        if not isinstance(self.context, (tuple, list)):
2✔
149
            self.context = (self.context,)
2✔
150
        if not isinstance(self.target, (tuple, list)):
2✔
151
            self.target = (self.target,)
2✔
152

153
        if isinstance(self.context_delta_t, int):
2✔
154
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
155
        else:
156
            assert len(self.context_delta_t) == len(self.context), (
2✔
157
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
158
                f"the number of context sets ({len(self.context)})"
159
            )
160
        if isinstance(self.target_delta_t, int):
2✔
161
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
162
        else:
163
            assert len(self.target_delta_t) == len(self.target), (
2✔
164
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
165
                f"the number of target sets ({len(self.target)})"
166
            )
167

168
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
169
        if all_paths:
2✔
170
            self._set_config()
2✔
171
            self._load_data_from_paths()
2✔
172

173
        self.context = self._cast_to_dtype(self.context)
2✔
174
        self.target = self._cast_to_dtype(self.target)
2✔
175
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
176
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
177

178
        self.links = self._check_links(self.links)
2✔
179

180
        (
2✔
181
            self.context_dims,
182
            self.target_dims,
183
            self.aux_at_target_dims,
184
        ) = self.count_context_and_target_data_dims()
185
        (
2✔
186
            self.context_var_IDs,
187
            self.target_var_IDs,
188
            self.context_var_IDs_and_delta_t,
189
            self.target_var_IDs_and_delta_t,
190
            self.aux_at_target_var_IDs,
191
        ) = self.infer_context_and_target_var_IDs()
192

193
    def _set_config(self):
2✔
194
        """Instantiate a config dictionary for the TaskLoader object"""
195
        # Take deepcopy to avoid modifying the original config
196
        self.config = copy.deepcopy(
2✔
197
            dict(
198
                context=self.context,
199
                target=self.target,
200
                aux_at_contexts=self.aux_at_contexts,
201
                aux_at_targets=self.aux_at_targets,
202
                links=self.links,
203
                context_delta_t=self.context_delta_t,
204
                target_delta_t=self.target_delta_t,
205
                time_freq=self.time_freq,
206
                xarray_interp_method=self.xarray_interp_method,
207
                discrete_xarray_sampling=self.discrete_xarray_sampling,
208
                dtype=self.dtype,
209
            )
210
        )
211

212
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
213
        """If all data passed as paths, save paths to config and return True."""
214

215
        def _check_if_strings(x, mode="all"):
2✔
216
            if x is None:
2✔
217
                return None
2✔
218
            elif isinstance(x, (tuple, list)):
2✔
219
                if mode == "all":
2✔
220
                    return all([isinstance(x_i, str) for x_i in x])
2✔
221
                elif mode == "any":
2✔
222
                    return any([isinstance(x_i, str) for x_i in x])
2✔
223
            else:
224
                return isinstance(x, str)
2✔
225

226
        all_paths = all(
2✔
227
            filter(
228
                lambda x: x is not None,
229
                [
230
                    _check_if_strings(self.context),
231
                    _check_if_strings(self.target),
232
                    _check_if_strings(self.aux_at_contexts),
233
                    _check_if_strings(self.aux_at_targets),
234
                ],
235
            )
236
        )
237
        self._is_saveable = all_paths
2✔
238

239
        any_paths = any(
2✔
240
            filter(
241
                lambda x: x is not None,
242
                [
243
                    _check_if_strings(self.context, mode="any"),
244
                    _check_if_strings(self.target, mode="any"),
245
                    _check_if_strings(self.aux_at_contexts, mode="any"),
246
                    _check_if_strings(self.aux_at_targets, mode="any"),
247
                ],
248
            )
249
        )
250
        if any_paths and not all_paths:
2✔
251
            raise ValueError(
×
252
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
253
            )
254

255
        return all_paths
2✔
256

257
    def _load_data_from_paths(self):
2✔
258
        """Load data from paths and overwrite paths with loaded data."""
259

260
        loaded_data = {}
2✔
261

262
        def _load_pandas_or_xarray(path):
2✔
263
            # Need to be careful about this. We need to ensure data gets into the right form
264
            #  for TaskLoader.
265
            if path is None:
2✔
266
                return None
2✔
267
            elif path in loaded_data:
2✔
268
                return loaded_data[path]
2✔
269
            elif path.endswith(".nc"):
2✔
270
                data = xr.open_dataset(path)
2✔
271
            elif path.endswith(".csv"):
2✔
272
                df = pd.read_csv(path)
2✔
273
                if "time" in df.columns:
2✔
274
                    df["time"] = pd.to_datetime(df["time"])
2✔
275
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
276
                else:
277
                    df = df.set_index(["x1", "x2"]).sort_index()
×
278
                data = df
2✔
279
            else:
280
                raise ValueError(f"Unknown file extension for {path}")
×
281
            loaded_data[path] = data
2✔
282
            return data
2✔
283

284
        def _load_data(data):
2✔
285
            if isinstance(data, (tuple, list)):
2✔
286
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
287
            else:
288
                data = _load_pandas_or_xarray(data)
2✔
289
            return data
2✔
290

291
        self.context = _load_data(self.context)
2✔
292
        self.target = _load_data(self.target)
2✔
293
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
294
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
295

296
    def save(self, folder: str):
2✔
297
        """Save TaskLoader config to JSON in `folder`"""
298
        if not self._is_saveable:
2✔
299
            raise ValueError(
2✔
300
                "TaskLoader cannot be saved because not all data was passed as paths."
301
            )
302

303
        os.makedirs(folder, exist_ok=True)
2✔
304
        fpath = os.path.join(folder, self.config_fname)
2✔
305
        with open(fpath, "w") as f:
2✔
306
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
307

308
    def _cast_to_dtype(
2✔
309
        self,
310
        var: Union[
311
            xr.DataArray,
312
            xr.Dataset,
313
            pd.DataFrame,
314
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
315
        ],
316
    ) -> (List, List):
317
        """
318
        Cast context and target data to the default dtype.
319

320
        ..
321
            TODO unit test this by passing in a variety of data types and
322
            checking that they are cast correctly.
323

324
        Args:
325
            var : ...
326
                ...
327

328
        Returns:
329
            tuple: Tuple of context data with specified dtype.
330
            tuple: Tuple of target data with specified dtype.
331
        """
332

333
        def cast_to_dtype(var):
2✔
334
            if isinstance(var, xr.DataArray):
2✔
335
                var = var.astype(self.dtype)
2✔
336
                var["x1"] = var["x1"].astype(self.dtype)
2✔
337
                var["x2"] = var["x2"].astype(self.dtype)
2✔
338
            elif isinstance(var, xr.Dataset):
2✔
339
                var = var.astype(self.dtype)
2✔
340
                var["x1"] = var["x1"].astype(self.dtype)
2✔
341
                var["x2"] = var["x2"].astype(self.dtype)
2✔
342
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
343
                var = var.astype(self.dtype)
2✔
344
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
345
                #   x1/x2 coord dtypes during task sampling
346
            else:
347
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
348
            return var
2✔
349

350
        if var is None:
2✔
351
            return var
2✔
352
        elif isinstance(var, (tuple, list)):
2✔
353
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
354
        else:
355
            var = cast_to_dtype(var)
2✔
356

357
        return var
2✔
358

359
    def load_dask(self) -> None:
2✔
360
        """
361
        Load any `dask` data into memory.
362

363
        This function triggers the computation and loading of any data that
364
        is represented as dask arrays or datasets into memory.
365

366
        Returns:
367
            None
368
        """
369

370
        def load(datasets):
2✔
371
            if datasets is None:
2✔
372
                return
×
373
            if not isinstance(datasets, (tuple, list)):
2✔
374
                datasets = [datasets]
2✔
375
            for i, var in enumerate(datasets):
2✔
376
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
377
                    var = var.load()
2✔
378

379
        load(self.context)
2✔
380
        load(self.target)
2✔
381
        load(self.aux_at_contexts)
2✔
382
        load(self.aux_at_targets)
2✔
383

384
        return None
2✔
385

386
    def count_context_and_target_data_dims(self):
2✔
387
        """
388
        Count the number of data dimensions in the context and target data.
389

390
        Returns:
391
            tuple: context_dims, Tuple of data dimensions in the context data.
392
            tuple: target_dims, Tuple of data dimensions in the target data.
393

394
        Raises:
395
            ValueError: If the context/target data is not a tuple/list of
396
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
397
                        :class:`pandas.DataFrame`.
398
        """
399

400
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
401
            if not isinstance(datasets, (tuple, list)):
2✔
402
                datasets = [datasets]
2✔
403

404
            dims = []
2✔
405
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
406
            for i, var in enumerate(datasets):
2✔
407
                if isinstance(var, xr.Dataset):
2✔
408
                    dim = len(var.data_vars)  # Multiple data variables
2✔
409
                elif isinstance(var, xr.DataArray):
2✔
410
                    dim = 1  # Single data variable
2✔
411
                elif isinstance(var, pd.DataFrame):
2✔
412
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
413
                elif isinstance(var, pd.Series):
2✔
414
                    dim = 1  # Single data variable
2✔
415
                else:
416
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
417
                dims.append(dim)
2✔
418
            return dims
2✔
419

420
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
421
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
422
        if self.aux_at_contexts is not None:
2✔
423
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
424
        if self.aux_at_targets is not None:
2✔
425
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
426
                0
427
            ]
428
        else:
429
            aux_at_target_dims = 0
2✔
430

431
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
432

433
    def infer_context_and_target_var_IDs(self):
2✔
434
        """
435
        Infer the variable IDs of the context and target data.
436

437
        Returns:
438
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
439
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
440

441
        Raises:
442
            ValueError: If the context/target data is not a tuple/list of
443
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
444
                        :class:`pandas.DataFrame`.
445
        """
446

447
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
448
            """If delta_ts is not None, then add the delta_t to the variable ID"""
449
            if not isinstance(datasets, (tuple, list)):
2✔
450
                datasets = [datasets]
2✔
451

452
            var_IDs = []
2✔
453
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
454
            for i, var in enumerate(datasets):
2✔
455
                if isinstance(var, xr.DataArray):
2✔
456
                    var_ID = (var.name,)  # Single data variable
2✔
457
                elif isinstance(var, xr.Dataset):
2✔
458
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
459
                elif isinstance(var, pd.DataFrame):
2✔
460
                    var_ID = tuple(var.columns)
2✔
461
                elif isinstance(var, pd.Series):
2✔
462
                    var_ID = (var.name,)
2✔
463
                else:
464
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
465

466
                if delta_ts is not None:
2✔
467
                    # Add delta_t to the variable ID
468
                    var_ID = tuple(
2✔
469
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
470
                    )
471
                else:
472
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
473

474
                var_IDs.append(var_ID)
2✔
475

476
            return var_IDs
2✔
477

478
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
479
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
480
            self.context, self.context_delta_t
481
        )
482
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
483
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
484
            self.target, self.target_delta_t
485
        )
486

487
        if self.aux_at_contexts is not None:
2✔
488
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
489
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
490
                self.aux_at_contexts, [0]
491
            )
492

493
        if self.aux_at_targets is not None:
2✔
494
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
495
                0
496
            ]
497
        else:
498
            aux_at_target_var_IDs = None
2✔
499

500
        return (
2✔
501
            tuple(context_var_IDs),
502
            tuple(target_var_IDs),
503
            tuple(context_var_IDs_and_delta_t),
504
            tuple(target_var_IDs_and_delta_t),
505
            aux_at_target_var_IDs,
506
        )
507

508
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
509
        """
510
        Check that the context-target links are valid.
511

512
        Args:
513
            links (Tuple[int, int] | List[Tuple[int, int]]):
514
                Specifies links between context and target data. Each link is a
515
                tuple of two integers, where the first integer is the index of
516
                the context data and the second integer is the index of the
517
                target data. Can be a single tuple in the case of a single
518
                link. If None, no links are specified. Default: None.
519

520
        Returns:
521
            Tuple[int, int] | List[Tuple[int, int]]
522
                The input links, if valid.
523

524
        Raises:
525
            ValueError
526
                If the links are not valid.
527
        """
528
        if links is None:
2✔
529
            return None
2✔
530

531
        assert isinstance(
2✔
532
            links, list
533
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
534
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
535
        assert all(
2✔
536
            isinstance(link, tuple) for link in links
537
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
538
        assert all(
2✔
539
            len(link) == 2 for link in links
540
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
541

542
        # Check that the links are valid
543
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
544
            if context_idx >= len(self.context):
2✔
545
                raise ValueError(
×
546
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
547
                    f"there are only {len(self.context)} context sets"
548
                )
549
            if target_idx >= len(self.target):
2✔
550
                raise ValueError(
2✔
551
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
552
                    f"there are only {len(self.target)} target sets"
553
                )
554

555
        return links
2✔
556

557
    def __str__(self):
2✔
558
        """
559
        String representation of the TaskLoader object (user-friendly).
560
        """
561
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
562
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
563
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
564
        if self.aux_at_targets is not None:
×
565
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
566
        return s
×
567

568
    def __repr__(self):
2✔
569
        """
570
        Representation of the TaskLoader object (for developers).
571

572
        ..
573
            TODO make this a more verbose version of __str__
574
        """
575
        s = str(self)
×
576
        s += "\n"
×
577
        s += f"\nContext data dimensions: {self.context_dims}"
×
578
        s += f"\nTarget data dimensions: {self.target_dims}"
×
579
        if self.aux_at_targets is not None:
×
580
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
581
        return s
×
582

583
    def sample_da(
2✔
584
        self,
585
        da: Union[xr.DataArray, xr.Dataset],
586
        sampling_strat: Union[str, int, float, np.ndarray],
587
        seed: Optional[int] = None,
588
    ) -> (np.ndarray, np.ndarray):
589
        """
590
        Sample a DataArray according to a given strategy.
591

592
        Args:
593
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
594
                DataArray to sample, assumed to be sliced for the task already.
595
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
596
                Sampling strategy, either "all" or an integer for random grid
597
                cell sampling.
598
            seed (int, optional):
599
                Seed for random sampling. Default is None.
600

601
        Returns:
602
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
603
                Tuple of sampled target data and sampled context data.
604

605
        Raises:
606
            InvalidSamplingStrategyError:
607
                If the sampling strategy is not valid or if a numpy coordinate
608
                array is passed to sample an xarray object, but the coordinates
609
                are out of bounds.
610
        """
611
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
612
        if isinstance(da, xr.Dataset):
2✔
613
            da = da.to_array()
2✔
614

615
        if isinstance(sampling_strat, float):
2✔
616
            sampling_strat = int(sampling_strat * da.size)
2✔
617

618
        if isinstance(sampling_strat, (int, np.integer)):
2✔
619
            N = sampling_strat
2✔
620
            rng = np.random.default_rng(seed)
2✔
621
            if self.discrete_xarray_sampling:
2✔
622
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
623
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
624
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
625
            elif not self.discrete_xarray_sampling:
2✔
626
                if N == 0:
2✔
627
                    # Catch zero-context edge case before interp fails
628
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
629
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
630
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
631
                    return X_c, Y_c
2✔
632
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
633
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
634
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
635
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
636
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
637

638
        elif isinstance(sampling_strat, np.ndarray):
2✔
639
            X_c = sampling_strat.astype(self.dtype)
2✔
640
            try:
2✔
641
                Y_c = da.sel(
2✔
642
                    x1=xr.DataArray(X_c[0]),
643
                    x2=xr.DataArray(X_c[1]),
644
                    method="nearest",
645
                    tolerance=0.1,  # Maximum distance from observed point to sample
646
                )
647
            except KeyError:
2✔
648
                raise InvalidSamplingStrategyError(
2✔
649
                    f"Passed a numpy coordinate array to sample xarray object, "
650
                    f"but the coordinates are out of bounds."
651
                )
652
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
653

654
        elif sampling_strat in ["all", "gapfill"]:
2✔
655
            X_c = (
2✔
656
                da.coords["x1"].values[np.newaxis],
657
                da.coords["x2"].values[np.newaxis],
658
            )
659
            Y_c = da.data
2✔
660
            if Y_c.ndim == 2:
2✔
661
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
662
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
663
        else:
664
            raise InvalidSamplingStrategyError(
×
665
                f"Unknown sampling strategy {sampling_strat}"
666
            )
667

668
        if Y_c.ndim == 1:
2✔
669
            # returned a 1D array, but we need a 2D array of shape (variable, N)
670
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
671

672
        return X_c, Y_c
2✔
673

674
    def sample_df(
2✔
675
        self,
676
        df: Union[pd.DataFrame, pd.Series],
677
        sampling_strat: Union[str, int, float, np.ndarray],
678
        seed: Optional[int] = None,
679
    ) -> (np.ndarray, np.ndarray):
680
        """
681
        Sample a DataFrame according to a given strategy.
682

683
        Args:
684
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
685
                Dataframe to sample, assumed to be time-sliced for the task
686
                already.
687
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
688
                Sampling strategy, either "all" or an integer for random grid
689
                cell sampling.
690
            seed (int, optional):
691
                Seed for random sampling. Default is None.
692

693
        Returns:
694
            Tuple[X_c, Y_c]:
695
                Tuple of sampled target data and sampled context data.
696

697
        Raises:
698
            InvalidSamplingStrategyError:
699
                If the sampling strategy is not valid or if a numpy coordinate
700
                array is passed to sample a pandas object, but the DataFrame
701
                does not contain all the requested samples.
702
        """
703
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
704

705
        if isinstance(sampling_strat, float):
2✔
706
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
707

708
        if isinstance(sampling_strat, (int, np.integer)):
2✔
709
            N = sampling_strat
2✔
710
            rng = np.random.default_rng(seed)
2✔
711
            idx = rng.choice(df.index, N)
2✔
712
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
713
            Y_c = df.loc[idx].values.T
2✔
714
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
715
            "all",
716
            "split",
717
        ]:
718
            # NOTE if "split", we assume that the context-target split has already been applied to the df
719
            # in an earlier scope with access to both the context and target data. This is maybe risky!
720
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
721
            Y_c = df.values.T
2✔
722
        elif isinstance(sampling_strat, np.ndarray):
2✔
723
            if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
2✔
724
                raise InvalidSamplingStrategyError(
2✔
725
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
726
                    "but the coordinate array has a different dtype than the DataFrame. "
727
                    f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
728
                )
729
            X_c = sampling_strat.astype(self.dtype)
2✔
730
            try:
2✔
731
                Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
2✔
732
            except KeyError:
2✔
733
                raise InvalidSamplingStrategyError(
2✔
734
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
735
                    "but the DataFrame did not contain all the requested samples.\n"
736
                    f"Indexes: {df.index}\n"
737
                    f"Sampling coords: {X_c}\n"
738
                    "If this is unexpected, check that your numpy sampling array matches "
739
                    "the DataFrame index values *exactly*."
740
                )
741
        else:
742
            raise InvalidSamplingStrategyError(
×
743
                f"Unknown sampling strategy {sampling_strat}"
744
            )
745

746
        if Y_c.ndim == 1:
2✔
747
            # returned a 1D array, but we need a 2D array of shape (variable, N)
748
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
749

750
        return X_c, Y_c
2✔
751

752
    def sample_offgrid_aux(
2✔
753
        self,
754
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
755
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
756
    ) -> np.ndarray:
757
        """
758
        Sample auxiliary data at off-grid locations.
759

760
        Args:
761
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
762
                Off-grid locations at which to sample the auxiliary data. Can
763
                be a tuple of two numpy arrays, or a single numpy array.
764
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
765
                Auxiliary data at off-grid locations.
766

767
        Returns:
768
            :class:`numpy:numpy.ndarray`:
769
                [Description of the returned numpy ndarray]
770

771
        Raises:
772
            [ExceptionType]:
773
                [Description of under what conditions this function raises an exception]
774
        """
775
        if "time" in offgrid_aux.dims:
2✔
776
            raise ValueError(
×
777
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
778
                "passing it to `sample_offgrid_aux`."
779
            )
780
        if isinstance(X_t, tuple):
2✔
781
            xt1, xt2 = X_t
2✔
782
            xt1 = xt1.ravel()
2✔
783
            xt2 = xt2.ravel()
2✔
784
        else:
785
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
786
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
787
        if isinstance(Y_t_aux, xr.Dataset):
2✔
788
            Y_t_aux = Y_t_aux.to_array()
×
789
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
790
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
791
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
792
        ):
793
            # Reshape to (variable, *spatial_dims)
794
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
795
        return Y_t_aux
2✔
796

797
    def time_slice_variable(self, var, date, delta_t=0):
2✔
798
        """
799
        Slice a variable by a given time delta.
800

801
        Args:
802
            var (...):
803
                Variable to slice.
804
            delta_t (...):
805
                Time delta to slice by.
806

807
        Returns:
808
            var (...)
809
                Sliced variable.
810

811
        Raises:
812
            ValueError
813
                If the variable is of an unknown type.
814
        """
815
        # TODO: Does this work with instantaneous time?
816
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
817
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
818
            if "time" in var.dims:
2✔
819
                var = var.sel(time=date + delta_t)
2✔
820
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
821
            if "time" in var.index.names:
2✔
822
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
823
        else:
824
            raise ValueError(f"Unknown variable type {type(var)}")
×
825
        return var
2✔
826

827
    def task_generation(
2✔
828
        self,
829
        date: pd.Timestamp,
830
        context_sampling: Union[
831
            str,
832
            int,
833
            float,
834
            np.ndarray,
835
            List[Union[str, int, float, np.ndarray]],
836
        ] = "all",
837
        target_sampling: Optional[
838
            Union[
839
                str,
840
                int,
841
                float,
842
                np.ndarray,
843
                List[Union[str, int, float, np.ndarray]],
844
            ]
845
        ] = None,
846
        split_frac: float = 0.5,
847
        datewise_deterministic: bool = False,
848
        seed_override: Optional[int] = None,
849
    ) -> Task:
850
        def check_sampling_strat(sampling_strat, set):
2✔
851
            """
852
            Check the sampling strategy.
853

854
            Ensure ``sampling_strat`` is either a single strategy (broadcast
855
            to all sets) or a list of length equal to the number of sets.
856
            Convert to a tuple of length equal to the number of sets and
857
            return.
858

859
            Args:
860
                sampling_strat:
861
                    Sampling strategy to check.
862
                set:
863
                    Context or target set to check.
864

865
            Returns:
866
                tuple:
867
                    Tuple of sampling strategies, one for each set.
868

869
            Raises:
870
                InvalidSamplingStrategyError:
871
                    - If the sampling strategy is invalid.
872
                    - If the length of the sampling strategy does not match the number of sets.
873
                    - If the sampling strategy is not a valid type.
874
                    - If the sampling strategy is a float but not in [0, 1].
875
                    - If the sampling strategy is an int but not positive.
876
                    - If the sampling strategy is a numpy array but not of shape (2, N).
877
            """
878
            if sampling_strat is None:
2✔
879
                return None
2✔
880
            if not isinstance(sampling_strat, (list, tuple)):
2✔
881
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
882
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
883
                sampling_strat
884
            ) != len(set):
885
                raise InvalidSamplingStrategyError(
2✔
886
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
887
                    f"match number of context sets ({len(set)})"
888
                )
889

890
            for strat in sampling_strat:
2✔
891
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
892
                    raise InvalidSamplingStrategyError(
2✔
893
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
894
                    )
895
                elif isinstance(strat, str) and strat not in [
2✔
896
                    "all",
897
                    "split",
898
                    "gapfill",
899
                ]:
900
                    raise InvalidSamplingStrategyError(
2✔
901
                        f"Unknown sampling strategy {strat} for type str"
902
                    )
903
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
904
                    raise InvalidSamplingStrategyError(
2✔
905
                        f"If sampling strategy is a float, must be fraction "
906
                        f"must be in [0, 1], got {strat}"
907
                    )
908
                elif isinstance(strat, int) and strat < 0:
2✔
909
                    raise InvalidSamplingStrategyError(
2✔
910
                        f"Sampling N must be positive, got {strat}"
911
                    )
912
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
913
                    raise InvalidSamplingStrategyError(
2✔
914
                        "Sampling coordinates must be of shape (2, N), got "
915
                        f"{strat.shape}"
916
                    )
917

918
            return sampling_strat
2✔
919

920
        def sample_variable(var, sampling_strat, seed):
2✔
921
            """
922
            Sample a variable by a given sampling strategy to get input and
923
            output data.
924

925
            Args:
926
                var:
927
                    Variable to sample.
928
                sampling_strat:
929
                    Sampling strategy to use.
930
                seed:
931
                    Seed for random sampling.
932

933
            Returns:
934
                Tuple[X, Y]:
935
                    Tuple of input and output data.
936

937
            Raises:
938
                ValueError:
939
                    If the variable is of an unknown type.
940
            """
941
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
942
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
943
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
944
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
945
            else:
946
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
947
            return X, Y
2✔
948

949
        # Check that the sampling strategies are valid
950
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
951
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
952
        # Check `split_frac
953
        if split_frac < 0 or split_frac > 1:
2✔
954
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
955
        if self.links is None:
2✔
956
            b1 = any(
2✔
957
                [
958
                    strat in ["split", "gapfill"]
959
                    for strat in context_sampling
960
                    if isinstance(strat, str)
961
                ]
962
            )
963
            if target_sampling is None:
2✔
964
                b2 = False
2✔
965
            else:
966
                b2 = any(
2✔
967
                    [
968
                        strat in ["split", "gapfill"]
969
                        for strat in target_sampling
970
                        if isinstance(strat, str)
971
                    ]
972
                )
973
            if b1 or b2:
2✔
974
                raise ValueError(
2✔
975
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
976
                    "sets must be linked with the TaskLoader `links` attribute."
977
                )
978
        if self.links is not None:
2✔
979
            for context_idx, target_idx in self.links:
2✔
980
                context_sampling_i = context_sampling[context_idx]
2✔
981
                if target_sampling is None:
2✔
982
                    target_sampling_i = None
×
983
                else:
984
                    target_sampling_i = target_sampling[target_idx]
2✔
985
                link_strats = (context_sampling_i, target_sampling_i)
2✔
986
                if any(
2✔
987
                    [
988
                        strat in ["split", "gapfill"]
989
                        for strat in link_strats
990
                        if isinstance(strat, str)
991
                    ]
992
                ):
993
                    # If one of the sampling strategies is "split" or "gapfill", the other must
994
                    # use the same splitting strategy
995
                    if link_strats[0] != link_strats[1]:
2✔
996
                        raise ValueError(
2✔
997
                            f"Linked context set {context_idx} and target set {target_idx} "
998
                            f"must use the same sampling strategy if one of them "
999
                            f"uses the 'split' or 'gapfill' sampling strategy. "
1000
                            f"Got {link_strats[0]} and {link_strats[1]}."
1001
                        )
1002

1003
        if not isinstance(date, pd.Timestamp):
2✔
1004
            date = pd.Timestamp(date)
2✔
1005

1006
        if seed_override is not None:
2✔
1007
            # Override the seed for random sampling
1008
            seed = seed_override
×
1009
        elif datewise_deterministic:
2✔
1010
            # Generate a deterministic seed, based on the date, for random sampling
1011
            seed = int(date.strftime("%Y%m%d"))
2✔
1012
        else:
1013
            # 'Truly' random sampling
1014
            seed = None
2✔
1015

1016
        task = {}
2✔
1017

1018
        task["time"] = date
2✔
1019
        task["ops"] = []
2✔
1020
        task["X_c"] = []
2✔
1021
        task["Y_c"] = []
2✔
1022
        if target_sampling is not None:
2✔
1023
            task["X_t"] = []
2✔
1024
            task["Y_t"] = []
2✔
1025
        else:
1026
            task["X_t"] = None
2✔
1027
            task["Y_t"] = None
2✔
1028

1029
        context_slices = [
2✔
1030
            self.time_slice_variable(var, date, delta_t)
1031
            for var, delta_t in zip(self.context, self.context_delta_t)
1032
        ]
1033
        target_slices = [
2✔
1034
            self.time_slice_variable(var, date, delta_t)
1035
            for var, delta_t in zip(self.target, self.target_delta_t)
1036
        ]
1037

1038
        # TODO move to method
1039
        if (
2✔
1040
            self.links is not None
1041
            and "split" in context_sampling
1042
            and "split" in target_sampling
1043
        ):
1044
            # Perform the split sampling strategy for linked context and target sets at this point
1045
            # while we have the full context and target data in scope
1046

1047
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1048
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1049
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1050
                f"Number of context sets with 'split' sampling strategy "
1051
                f"({len(context_split_idxs)}) must match number of target sets "
1052
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1053
            )
1054
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1055
                zip(context_split_idxs, target_split_idxs)
1056
            ):
1057
                assert (context_idx, target_idx) in self.links, (
2✔
1058
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1059
                    f"with the `links` attribute if using the 'split' sampling strategy"
1060
                )
1061

1062
                context_var = context_slices[context_idx]
2✔
1063
                target_var = target_slices[target_idx]
2✔
1064

1065
                for var in [context_var, target_var]:
2✔
1066
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1067
                        f"If using 'split' sampling strategy for linked context and target sets, "
1068
                        f"the context and target sets must be pandas DataFrames or Series, "
1069
                        f"but got {type(var)}."
1070
                    )
1071

1072
                N_obs = len(context_var)
2✔
1073
                N_obs_target_check = len(target_var)
2✔
1074
                if N_obs != N_obs_target_check:
2✔
1075
                    raise ValueError(
×
1076
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1077
                        f"because they have different numbers of observations: "
1078
                        f"{N_obs} and {N_obs_target_check}"
1079
                    )
1080
                split_seed = seed + split_i if seed is not None else None
2✔
1081
                rng = np.random.default_rng(split_seed)
2✔
1082

1083
                N_context = int(N_obs * split_frac)
2✔
1084
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1085

1086
                context_var = context_var.iloc[idxs_context]
2✔
1087
                target_var = target_var.drop(context_var.index)
2✔
1088

1089
                context_slices[context_idx] = context_var
2✔
1090
                target_slices[target_idx] = target_var
2✔
1091

1092
        # TODO move to method
1093
        if (
2✔
1094
            self.links is not None
1095
            and "gapfill" in context_sampling
1096
            and "gapfill" in target_sampling
1097
        ):
1098
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1099
            # while we have the full context and target data in scope
1100

1101
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1102
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1103
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1104
                f"Number of context sets with 'gapfill' sampling strategy "
1105
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1106
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1107
            )
1108
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1109
                zip(context_gapfill_idxs, target_gapfill_idxs)
1110
            ):
1111
                assert (context_idx, target_idx) in self.links, (
2✔
1112
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1113
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1114
                )
1115

1116
                context_var = context_slices[context_idx]
2✔
1117
                target_var = target_slices[target_idx]
2✔
1118

1119
                for var in [context_var, target_var]:
2✔
1120
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1121
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1122
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1123
                        f"but got {type(var)}."
1124
                    )
1125

1126
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1127
                rng = np.random.default_rng(split_seed)
2✔
1128

1129
                # Keep trying until we get a target set with at least one target point
1130
                keep_searching = True
2✔
1131
                while keep_searching:
2✔
1132
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1133
                    added_mask = (
2✔
1134
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1135
                    )
1136
                    curr_mask = context_var.isnull()
2✔
1137

1138
                    # Mask out added missing values
1139
                    context_var = context_var.where(~added_mask)
2✔
1140

1141
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1142
                    #   when we could just slice the target values here
1143
                    target_mask = added_mask & ~curr_mask
2✔
1144
                    if isinstance(target_var, xr.Dataset):
2✔
1145
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1146
                    else:
1147
                        keep_searching = np.all(target_mask.data == False)
2✔
1148
                    if keep_searching:
2✔
1149
                        continue  # No target points -- use a different `added_mask`
×
1150

1151
                    target_var = target_var.where(
2✔
1152
                        target_mask
1153
                    )  # Only keep target locations
1154

1155
                    context_slices[context_idx] = context_var
2✔
1156
                    target_slices[target_idx] = target_var
2✔
1157

1158
        for i, (var, sampling_strat) in enumerate(
2✔
1159
            zip(context_slices, context_sampling)
1160
        ):
1161
            context_seed = seed + i if seed is not None else None
2✔
1162
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1163
            task[f"X_c"].append(X_c)
2✔
1164
            task[f"Y_c"].append(Y_c)
2✔
1165
        if target_sampling is not None:
2✔
1166
            for j, (var, sampling_strat) in enumerate(
2✔
1167
                zip(target_slices, target_sampling)
1168
            ):
1169
                target_seed = seed + i + j if seed is not None else None
2✔
1170
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1171
                task[f"X_t"].append(X_t)
2✔
1172
                task[f"Y_t"].append(Y_t)
2✔
1173

1174
        if self.aux_at_contexts is not None:
2✔
1175
            # Add auxiliary variable sampled at context set as a new context variable
1176
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1177
            if len(X_c_offgrid) == 0:
2✔
1178
                # No offgrid context sets
1179
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1180
            else:
1181
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1182
            Y_c_aux = (
2✔
1183
                self.sample_offgrid_aux(
1184
                    X_c_offrid_all,
1185
                    self.time_slice_variable(self.aux_at_contexts, date),
1186
                ),
1187
            )
1188
            task["X_c"].append(X_c_offrid_all)
2✔
1189
            task["Y_c"].append(Y_c_aux)
2✔
1190

1191
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1192
            task["Y_t_aux"] = None
2✔
1193
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1194
            # Add auxiliary variable to target set
1195
            if len(task["X_t"]) > 1:
2✔
1196
                raise ValueError(
×
1197
                    "Cannot add auxiliary variable to target set when there "
1198
                    "are multiple target variables (not supported by default `ConvNP` model)."
1199
                )
1200
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1201
                task["X_t"][0],
1202
                self.time_slice_variable(self.aux_at_targets, date),
1203
            )
1204

1205
        return Task(task)
2✔
1206

1207
    def __call__(
2✔
1208
        self,
1209
        date: pd.Timestamp,
1210
        context_sampling: Union[
1211
            str,
1212
            int,
1213
            float,
1214
            np.ndarray,
1215
            List[Union[str, int, float, np.ndarray]],
1216
        ] = "all",
1217
        target_sampling: Optional[
1218
            Union[
1219
                str,
1220
                int,
1221
                float,
1222
                np.ndarray,
1223
                List[Union[str, int, float, np.ndarray]],
1224
            ]
1225
        ] = None,
1226
        split_frac: float = 0.5,
1227
        datewise_deterministic: bool = False,
1228
        seed_override: Optional[int] = None,
1229
    ) -> Union[Task, List[Task]]:
1230
        """
1231
        Generate a task for a given date (or a list of
1232
        :class:`.data.task.Task` objects for a list of dates).
1233

1234
        There are several sampling strategies available for the context and
1235
        target data:
1236

1237
            - "all": Sample all observations.
1238
            - int: Sample N observations uniformly at random.
1239
            - float: Sample a fraction of observations uniformly at random.
1240
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1241
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1242
                normalised.
1243
            - "split": Split pandas observations into disjoint context and target sets.
1244
                `split_frac` determines the fraction of observations
1245
                to use for the context set. The remaining observations are used
1246
                for the target set.
1247
                The context set and target set must be linked through the ``TaskLoader``
1248
                ``links`` argument. Only valid for pandas data.
1249
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1250
                Randomly samples a missing data (NaN) mask from another timestamp and
1251
                adds it to the context set (i.e. increases the number of NaNs).
1252
                The target set is then true values of the data at the added missing locations.
1253
                The context set and target set must be linked through the ``TaskLoader``
1254
                ``links`` argument. Only valid for xarray data.
1255

1256
        Args:
1257
            date (:class:`pandas.Timestamp`):
1258
                Date for which to generate the task.
1259
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1260
                Sampling strategy for the context data, either a list of
1261
                sampling strategies for each context set, or a single strategy
1262
                applied to all context sets. Default is ``"all"``.
1263
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1264
                Sampling strategy for the target data, either a list of
1265
                sampling strategies for each target set, or a single strategy
1266
                applied to all target sets. Default is ``None``, meaning no target
1267
                data is returned.
1268
            split_frac (float, optional):
1269
                The fraction of observations to use for the context set with
1270
                the "split" sampling strategy for linked context and target set
1271
                pairs. The remaining observations are used for the target set.
1272
                Default is 0.5.
1273
            datewise_deterministic (bool, optional):
1274
                Whether random sampling is datewise deterministic based on the
1275
                date. Default is ``False``.
1276
            seed_override (Optional[int], optional):
1277
                Override the seed for random sampling. This can be used to use
1278
                the same random sampling at different ``date``. Default is
1279
                None.
1280

1281
        Returns:
1282
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1283
                Task object or list of task objects for each date containing
1284
                the context and target data.
1285
        """
1286
        if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1287
            return [
2✔
1288
                self.task_generation(
1289
                    d,
1290
                    context_sampling,
1291
                    target_sampling,
1292
                    split_frac,
1293
                    datewise_deterministic,
1294
                    seed_override,
1295
                )
1296
                for d in date
1297
            ]
1298
        else:
1299
            return self.task_generation(
2✔
1300
                date,
1301
                context_sampling,
1302
                target_sampling,
1303
                split_frac,
1304
                datewise_deterministic,
1305
                seed_override,
1306
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc