• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 8325389045

18 Mar 2024 10:58AM UTC coverage: 81.333% (-0.04%) from 81.374%
8325389045

push

github

tom-andersson
docs: update .all-contributorsrc

1965 of 2416 relevant lines covered (81.33%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.11
/deepsensor/data/loader.py
1
from deepsensor.data.task import Task, flatten_X
2✔
2

3
import os
2✔
4
import json
2✔
5
import copy
2✔
6

7
import numpy as np
2✔
8
import xarray as xr
2✔
9
import pandas as pd
2✔
10

11
from typing import List, Tuple, Union, Optional
2✔
12

13
from deepsensor.errors import InvalidSamplingStrategyError
2✔
14

15

16
class TaskLoader:
2✔
17
    """
18
    Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
19

20
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
21
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
22
    of these.
23

24
    The behaviour is the following:
25
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
26
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
27
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
28
        (using config)
29

30
    Args:
31
        task_loader_ID:
32
            If loading a TaskLoader from a config file, this is the folder the
33
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
34
            arguments are ignored.
35
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
36
            Context data. Can be a single :class:`xarray.DataArray`,
37
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
38
            list/tuple of these.
39
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
40
            Target data. Can be a single :class:`xarray.DataArray`,
41
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
42
            list/tuple of these.
43
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
44
            Auxiliary data at context locations. Tuple of two elements, where
45
            the first element is the index of the context set for which the
46
            auxiliary data will be sampled at, and the second element is the
47
            auxiliary data, which can be a single :class:`xarray.DataArray` or
48
            :class:`xarray.Dataset`. Default: None.
49
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
50
            Auxiliary data at target locations. Can be a single
51
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
52
            None.
53
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
54
            Specifies links between context and target data. Each link is a
55
            tuple of two integers, where the first integer is the index of the
56
            context data and the second integer is the index of the target
57
            data. Can be a single tuple in the case of a single link. If None,
58
            no links are specified. Default: None.
59
        context_delta_t (int | List[int], optional):
60
            Time difference between context data and t=0 (task init time). Can
61
            be a single int (same for all context data) or a list/tuple of
62
            ints. Default is 0.
63
        target_delta_t (int | List[int], optional):
64
            Time difference between target data and t=0 (task init time). Can
65
            be a single int (same for all target data) or a list/tuple of ints.
66
            Default is 0.
67
        time_freq (str, optional):
68
            Time frequency of the data. Default: ``'D'`` (daily).
69
        xarray_interp_method (str, optional):
70
            Interpolation method to use when interpolating
71
            :class:`xarray.DataArray`. Default is ``'linear'``.
72
        discrete_xarray_sampling (bool, optional):
73
            When randomly sampling xarray variables, whether to sample at
74
            discrete points defined at grid cell centres, or at continuous
75
            points within the grid. Default is ``False``.
76
        dtype (object, optional):
77
            Data type of the data. Used to cast the data to the specified
78
            dtype. Default: ``'float32'``.
79
    """
80

81
    config_fname = "task_loader_config.json"
2✔
82

83
    def __init__(
2✔
84
        self,
85
        task_loader_ID: Union[str, None] = None,
86
        context: Union[
87
            xr.DataArray,
88
            xr.Dataset,
89
            pd.DataFrame,
90
            str,
91
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
92
        ] = None,
93
        target: Union[
94
            xr.DataArray,
95
            xr.Dataset,
96
            pd.DataFrame,
97
            str,
98
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
99
        ] = None,
100
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
101
        aux_at_targets: Optional[
102
            Union[
103
                xr.DataArray,
104
                xr.Dataset,
105
            ]
106
        ] = None,
107
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
108
        context_delta_t: Union[int, List[int]] = 0,
109
        target_delta_t: Union[int, List[int]] = 0,
110
        time_freq: str = "D",
111
        xarray_interp_method: str = "linear",
112
        discrete_xarray_sampling: bool = False,
113
        dtype: object = "float32",
114
    ) -> None:
115
        if task_loader_ID is not None:
2✔
116
            self.task_loader_ID = task_loader_ID
2✔
117
            # Load TaskLoader from config file
118
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
119
            with open(fpath, "r") as f:
2✔
120
                self.config = json.load(f)
2✔
121

122
            self.context = self.config["context"]
2✔
123
            self.target = self.config["target"]
2✔
124
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
125
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
126
            self.links = self.config["links"]
2✔
127
            if self.links is not None:
2✔
128
                self.links = [tuple(link) for link in self.links]
2✔
129
            self.context_delta_t = self.config["context_delta_t"]
2✔
130
            self.target_delta_t = self.config["target_delta_t"]
2✔
131
            self.time_freq = self.config["time_freq"]
2✔
132
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
133
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
134
            self.dtype = self.config["dtype"]
2✔
135
        else:
136
            self.context = context
2✔
137
            self.target = target
2✔
138
            self.aux_at_contexts = aux_at_contexts
2✔
139
            self.aux_at_targets = aux_at_targets
2✔
140
            self.links = links
2✔
141
            self.context_delta_t = context_delta_t
2✔
142
            self.target_delta_t = target_delta_t
2✔
143
            self.time_freq = time_freq
2✔
144
            self.xarray_interp_method = xarray_interp_method
2✔
145
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
146
            self.dtype = dtype
2✔
147

148
        if not isinstance(self.context, (tuple, list)):
2✔
149
            self.context = (self.context,)
2✔
150
        if not isinstance(self.target, (tuple, list)):
2✔
151
            self.target = (self.target,)
2✔
152

153
        if isinstance(self.context_delta_t, int):
2✔
154
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
155
        else:
156
            assert len(self.context_delta_t) == len(self.context), (
2✔
157
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
158
                f"the number of context sets ({len(self.context)})"
159
            )
160
        if isinstance(self.target_delta_t, int):
2✔
161
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
162
        else:
163
            assert len(self.target_delta_t) == len(self.target), (
2✔
164
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
165
                f"the number of target sets ({len(self.target)})"
166
            )
167

168
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
169
        if all_paths:
2✔
170
            self._set_config()
2✔
171
            self._load_data_from_paths()
2✔
172

173
        self.context = self._cast_to_dtype(self.context)
2✔
174
        self.target = self._cast_to_dtype(self.target)
2✔
175
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
176
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
177

178
        self.links = self._check_links(self.links)
2✔
179

180
        (
2✔
181
            self.context_dims,
182
            self.target_dims,
183
            self.aux_at_target_dims,
184
        ) = self.count_context_and_target_data_dims()
185
        (
2✔
186
            self.context_var_IDs,
187
            self.target_var_IDs,
188
            self.context_var_IDs_and_delta_t,
189
            self.target_var_IDs_and_delta_t,
190
            self.aux_at_target_var_IDs,
191
        ) = self.infer_context_and_target_var_IDs()
192

193
    def _set_config(self):
2✔
194
        """Instantiate a config dictionary for the TaskLoader object"""
195
        # Take deepcopy to avoid modifying the original config
196
        self.config = copy.deepcopy(
2✔
197
            dict(
198
                context=self.context,
199
                target=self.target,
200
                aux_at_contexts=self.aux_at_contexts,
201
                aux_at_targets=self.aux_at_targets,
202
                links=self.links,
203
                context_delta_t=self.context_delta_t,
204
                target_delta_t=self.target_delta_t,
205
                time_freq=self.time_freq,
206
                xarray_interp_method=self.xarray_interp_method,
207
                discrete_xarray_sampling=self.discrete_xarray_sampling,
208
                dtype=self.dtype,
209
            )
210
        )
211

212
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
213
        """If all data passed as paths, save paths to config and return True."""
214

215
        def _check_if_strings(x, mode="all"):
2✔
216
            if x is None:
2✔
217
                return None
2✔
218
            elif isinstance(x, (tuple, list)):
2✔
219
                if mode == "all":
2✔
220
                    return all([isinstance(x_i, str) for x_i in x])
2✔
221
                elif mode == "any":
2✔
222
                    return any([isinstance(x_i, str) for x_i in x])
2✔
223
            else:
224
                return isinstance(x, str)
2✔
225

226
        all_paths = all(
2✔
227
            filter(
228
                lambda x: x is not None,
229
                [
230
                    _check_if_strings(self.context),
231
                    _check_if_strings(self.target),
232
                    _check_if_strings(self.aux_at_contexts),
233
                    _check_if_strings(self.aux_at_targets),
234
                ],
235
            )
236
        )
237
        self._is_saveable = all_paths
2✔
238

239
        any_paths = any(
2✔
240
            filter(
241
                lambda x: x is not None,
242
                [
243
                    _check_if_strings(self.context, mode="any"),
244
                    _check_if_strings(self.target, mode="any"),
245
                    _check_if_strings(self.aux_at_contexts, mode="any"),
246
                    _check_if_strings(self.aux_at_targets, mode="any"),
247
                ],
248
            )
249
        )
250
        if any_paths and not all_paths:
2✔
251
            raise ValueError(
×
252
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
253
            )
254

255
        return all_paths
2✔
256

257
    def _load_data_from_paths(self):
2✔
258
        """Load data from paths and overwrite paths with loaded data."""
259

260
        loaded_data = {}
2✔
261

262
        def _load_pandas_or_xarray(path):
2✔
263
            # Need to be careful about this. We need to ensure data gets into the right form
264
            #  for TaskLoader.
265
            if path is None:
2✔
266
                return None
2✔
267
            elif path in loaded_data:
2✔
268
                return loaded_data[path]
2✔
269
            elif path.endswith(".nc"):
2✔
270
                data = xr.open_dataset(path)
2✔
271
            elif path.endswith(".csv"):
2✔
272
                df = pd.read_csv(path)
2✔
273
                if "time" in df.columns:
2✔
274
                    df["time"] = pd.to_datetime(df["time"])
2✔
275
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
276
                else:
277
                    df = df.set_index(["x1", "x2"]).sort_index()
×
278
                data = df
2✔
279
            else:
280
                raise ValueError(f"Unknown file extension for {path}")
×
281
            loaded_data[path] = data
2✔
282
            return data
2✔
283

284
        def _load_data(data):
2✔
285
            if isinstance(data, (tuple, list)):
2✔
286
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
287
            else:
288
                data = _load_pandas_or_xarray(data)
2✔
289
            return data
2✔
290

291
        self.context = _load_data(self.context)
2✔
292
        self.target = _load_data(self.target)
2✔
293
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
294
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
295

296
    def save(self, folder: str):
2✔
297
        """Save TaskLoader config to JSON in `folder`"""
298
        if not self._is_saveable:
2✔
299
            raise ValueError(
2✔
300
                "TaskLoader cannot be saved because not all data was passed as paths."
301
            )
302

303
        os.makedirs(folder, exist_ok=True)
2✔
304
        fpath = os.path.join(folder, self.config_fname)
2✔
305
        with open(fpath, "w") as f:
2✔
306
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
307

308
    def _cast_to_dtype(
2✔
309
        self,
310
        var: Union[
311
            xr.DataArray,
312
            xr.Dataset,
313
            pd.DataFrame,
314
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
315
        ],
316
    ) -> (List, List):
317
        """
318
        Cast context and target data to the default dtype.
319

320
        ..
321
            TODO unit test this by passing in a variety of data types and
322
            checking that they are cast correctly.
323

324
        Args:
325
            var : ...
326
                ...
327

328
        Returns:
329
            tuple: Tuple of context data with specified dtype.
330
            tuple: Tuple of target data with specified dtype.
331
        """
332

333
        def cast_to_dtype(var):
2✔
334
            if isinstance(var, xr.DataArray):
2✔
335
                var = var.astype(self.dtype)
2✔
336
                var["x1"] = var["x1"].astype(self.dtype)
2✔
337
                var["x2"] = var["x2"].astype(self.dtype)
2✔
338
            elif isinstance(var, xr.Dataset):
2✔
339
                var = var.astype(self.dtype)
2✔
340
                var["x1"] = var["x1"].astype(self.dtype)
2✔
341
                var["x2"] = var["x2"].astype(self.dtype)
2✔
342
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
343
                var = var.astype(self.dtype)
2✔
344
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
345
                #   x1/x2 coord dtypes during task sampling
346
            else:
347
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
348
            return var
2✔
349

350
        if var is None:
2✔
351
            return var
2✔
352
        elif isinstance(var, (tuple, list)):
2✔
353
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
354
        else:
355
            var = cast_to_dtype(var)
2✔
356

357
        return var
2✔
358

359
    def load_dask(self) -> None:
2✔
360
        """
361
        Load any `dask` data into memory.
362

363
        This function triggers the computation and loading of any data that
364
        is represented as dask arrays or datasets into memory.
365

366
        Returns:
367
            None
368
        """
369

370
        def load(datasets):
2✔
371
            if datasets is None:
2✔
372
                return
×
373
            if not isinstance(datasets, (tuple, list)):
2✔
374
                datasets = [datasets]
2✔
375
            for i, var in enumerate(datasets):
2✔
376
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
377
                    var = var.load()
2✔
378

379
        load(self.context)
2✔
380
        load(self.target)
2✔
381
        load(self.aux_at_contexts)
2✔
382
        load(self.aux_at_targets)
2✔
383

384
        return None
2✔
385

386
    def count_context_and_target_data_dims(self):
2✔
387
        """
388
        Count the number of data dimensions in the context and target data.
389

390
        Returns:
391
            tuple: context_dims, Tuple of data dimensions in the context data.
392
            tuple: target_dims, Tuple of data dimensions in the target data.
393

394
        Raises:
395
            ValueError: If the context/target data is not a tuple/list of
396
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
397
                        :class:`pandas.DataFrame`.
398
        """
399

400
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
401
            if not isinstance(datasets, (tuple, list)):
2✔
402
                datasets = [datasets]
2✔
403

404
            dims = []
2✔
405
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
406
            for i, var in enumerate(datasets):
2✔
407
                if isinstance(var, xr.Dataset):
2✔
408
                    dim = len(var.data_vars)  # Multiple data variables
2✔
409
                elif isinstance(var, xr.DataArray):
2✔
410
                    dim = 1  # Single data variable
2✔
411
                elif isinstance(var, pd.DataFrame):
2✔
412
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
413
                elif isinstance(var, pd.Series):
2✔
414
                    dim = 1  # Single data variable
2✔
415
                else:
416
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
417
                dims.append(dim)
2✔
418
            return dims
2✔
419

420
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
421
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
422
        if self.aux_at_contexts is not None:
2✔
423
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
424
        if self.aux_at_targets is not None:
2✔
425
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
426
                0
427
            ]
428
        else:
429
            aux_at_target_dims = 0
2✔
430

431
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
432

433
    def infer_context_and_target_var_IDs(self):
2✔
434
        """
435
        Infer the variable IDs of the context and target data.
436

437
        Returns:
438
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
439
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
440

441
        Raises:
442
            ValueError: If the context/target data is not a tuple/list of
443
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
444
                        :class:`pandas.DataFrame`.
445
        """
446

447
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
448
            """If delta_ts is not None, then add the delta_t to the variable ID"""
449
            if not isinstance(datasets, (tuple, list)):
2✔
450
                datasets = [datasets]
2✔
451

452
            var_IDs = []
2✔
453
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
454
            for i, var in enumerate(datasets):
2✔
455
                if isinstance(var, xr.DataArray):
2✔
456
                    var_ID = (var.name,)  # Single data variable
2✔
457
                elif isinstance(var, xr.Dataset):
2✔
458
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
459
                elif isinstance(var, pd.DataFrame):
2✔
460
                    var_ID = tuple(var.columns)
2✔
461
                elif isinstance(var, pd.Series):
2✔
462
                    var_ID = (var.name,)
2✔
463
                else:
464
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
465

466
                if delta_ts is not None:
2✔
467
                    # Add delta_t to the variable ID
468
                    var_ID = tuple(
2✔
469
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
470
                    )
471
                else:
472
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
473

474
                var_IDs.append(var_ID)
2✔
475

476
            return var_IDs
2✔
477

478
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
479
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
480
            self.context, self.context_delta_t
481
        )
482
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
483
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
484
            self.target, self.target_delta_t
485
        )
486

487
        if self.aux_at_contexts is not None:
2✔
488
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
489
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
490
                self.aux_at_contexts, [0]
491
            )
492

493
        if self.aux_at_targets is not None:
2✔
494
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
495
                0
496
            ]
497
        else:
498
            aux_at_target_var_IDs = None
2✔
499

500
        return (
2✔
501
            tuple(context_var_IDs),
502
            tuple(target_var_IDs),
503
            tuple(context_var_IDs_and_delta_t),
504
            tuple(target_var_IDs_and_delta_t),
505
            aux_at_target_var_IDs,
506
        )
507

508
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
509
        """
510
        Check that the context-target links are valid.
511

512
        Args:
513
            links (Tuple[int, int] | List[Tuple[int, int]]):
514
                Specifies links between context and target data. Each link is a
515
                tuple of two integers, where the first integer is the index of
516
                the context data and the second integer is the index of the
517
                target data. Can be a single tuple in the case of a single
518
                link. If None, no links are specified. Default: None.
519

520
        Returns:
521
            Tuple[int, int] | List[Tuple[int, int]]
522
                The input links, if valid.
523

524
        Raises:
525
            ValueError
526
                If the links are not valid.
527
        """
528
        if links is None:
2✔
529
            return None
2✔
530

531
        assert isinstance(
2✔
532
            links, list
533
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
534
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
535
        assert all(
2✔
536
            isinstance(link, tuple) for link in links
537
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
538
        assert all(
2✔
539
            len(link) == 2 for link in links
540
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
541

542
        # Check that the links are valid
543
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
544
            if context_idx >= len(self.context):
2✔
545
                raise ValueError(
×
546
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
547
                    f"there are only {len(self.context)} context sets"
548
                )
549
            if target_idx >= len(self.target):
2✔
550
                raise ValueError(
2✔
551
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
552
                    f"there are only {len(self.target)} target sets"
553
                )
554

555
        return links
2✔
556

557
    def __str__(self):
2✔
558
        """
559
        String representation of the TaskLoader object (user-friendly).
560
        """
561
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
562
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
563
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
564
        if self.aux_at_targets is not None:
×
565
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
566
        return s
×
567

568
    def __repr__(self):
2✔
569
        """
570
        Representation of the TaskLoader object (for developers).
571

572
        ..
573
            TODO make this a more verbose version of __str__
574
        """
575
        s = str(self)
×
576
        s += "\n"
×
577
        s += f"\nContext data dimensions: {self.context_dims}"
×
578
        s += f"\nTarget data dimensions: {self.target_dims}"
×
579
        if self.aux_at_targets is not None:
×
580
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
581
        return s
×
582

583
    def sample_da(
2✔
584
        self,
585
        da: Union[xr.DataArray, xr.Dataset],
586
        sampling_strat: Union[str, int, float, np.ndarray],
587
        seed: Optional[int] = None,
588
    ) -> (np.ndarray, np.ndarray):
589
        """
590
        Sample a DataArray according to a given strategy.
591

592
        Args:
593
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
594
                DataArray to sample, assumed to be sliced for the task already.
595
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
596
                Sampling strategy, either "all" or an integer for random grid
597
                cell sampling.
598
            seed (int, optional):
599
                Seed for random sampling. Default is None.
600

601
        Returns:
602
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
603
                Tuple of sampled target data and sampled context data.
604

605
        Raises:
606
            InvalidSamplingStrategyError:
607
                If the sampling strategy is not valid or if a numpy coordinate
608
                array is passed to sample an xarray object, but the coordinates
609
                are out of bounds.
610
        """
611
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
612
        if isinstance(da, xr.Dataset):
2✔
613
            da = da.to_array()
2✔
614

615
        if isinstance(sampling_strat, float):
2✔
616
            sampling_strat = int(sampling_strat * da.size)
2✔
617

618
        if isinstance(sampling_strat, (int, np.integer)):
2✔
619
            N = sampling_strat
2✔
620
            rng = np.random.default_rng(seed)
2✔
621
            if self.discrete_xarray_sampling:
2✔
622
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
623
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
624
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
625
            elif not self.discrete_xarray_sampling:
2✔
626
                if N == 0:
2✔
627
                    # Catch zero-context edge case before interp fails
628
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
629
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
630
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
631
                    return X_c, Y_c
2✔
632
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
633
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
634
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
635
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
636
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
637

638
        elif isinstance(sampling_strat, np.ndarray):
2✔
639
            X_c = sampling_strat.astype(self.dtype)
2✔
640
            try:
2✔
641
                Y_c = da.sel(
2✔
642
                    x1=xr.DataArray(X_c[0]),
643
                    x2=xr.DataArray(X_c[1]),
644
                    method="nearest",
645
                    tolerance=0.1,  # Maximum distance from observed point to sample
646
                )
647
            except KeyError:
2✔
648
                raise InvalidSamplingStrategyError(
2✔
649
                    f"Passed a numpy coordinate array to sample xarray object, "
650
                    f"but the coordinates are out of bounds."
651
                )
652
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
653

654
        elif sampling_strat in ["all", "gapfill"]:
2✔
655
            X_c = (
2✔
656
                da.coords["x1"].values[np.newaxis],
657
                da.coords["x2"].values[np.newaxis],
658
            )
659
            Y_c = da.data
2✔
660
            if Y_c.ndim == 2:
2✔
661
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
662
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
663
        else:
664
            raise InvalidSamplingStrategyError(
×
665
                f"Unknown sampling strategy {sampling_strat}"
666
            )
667

668
        if Y_c.ndim == 1:
2✔
669
            # returned a 1D array, but we need a 2D array of shape (variable, N)
670
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
671

672
        return X_c, Y_c
2✔
673

674
    def sample_df(
2✔
675
        self,
676
        df: Union[pd.DataFrame, pd.Series],
677
        sampling_strat: Union[str, int, float, np.ndarray],
678
        seed: Optional[int] = None,
679
    ) -> (np.ndarray, np.ndarray):
680
        """
681
        Sample a DataArray according to a given strategy.
682

683
        Args:
684
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
685
                DataArray to sample, assumed to be time-sliced for the task
686
                already.
687
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
688
                Sampling strategy, either "all" or an integer for random grid
689
                cell sampling.
690
            seed (int, optional):
691
                Seed for random sampling. Default is None.
692

693
        Returns:
694
            Tuple[X_c, Y_c]:
695
                Tuple of sampled target data and sampled context data.
696

697
        Raises:
698
            InvalidSamplingStrategyError:
699
                If the sampling strategy is not valid or if a numpy coordinate
700
                array is passed to sample a pandas object, but the DataFrame
701
                does not contain all the requested samples.
702
        """
703
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
704

705
        if isinstance(sampling_strat, float):
2✔
706
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
707

708
        if isinstance(sampling_strat, (int, np.integer)):
2✔
709
            N = sampling_strat
2✔
710
            rng = np.random.default_rng(seed)
2✔
711
            idx = rng.choice(df.index, N)
2✔
712
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
713
            Y_c = df.loc[idx].values.T
2✔
714
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
715
            "all",
716
            "split",
717
        ]:
718
            # NOTE if "split", we assume that the context-target split has already been applied to the df
719
            # in an earlier scope with access to both the context and target data. This is maybe risky!
720
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
721
            Y_c = df.values.T
2✔
722
        elif isinstance(sampling_strat, np.ndarray):
2✔
723
            X_c = sampling_strat.astype(self.dtype)
2✔
724
            x1match = np.in1d(df.index.get_level_values("x1"), X_c[0])
2✔
725
            x2match = np.in1d(df.index.get_level_values("x2"), X_c[1])
2✔
726
            num_matches = np.sum(x1match & x2match)
2✔
727

728
            # Check that we got all the samples we asked for
729
            if num_matches != X_c.shape[1]:
2✔
730
                raise InvalidSamplingStrategyError(
2✔
731
                    f"Passed a numpy coordinate array to sample pandas DataFrame, "
732
                    f"but the DataFrame did not contain all the requested samples. "
733
                    f"Requested {X_c.shape[1]} samples but only got {num_matches}."
734
                )
735

736
            Y_c = df[x1match & x2match].values.T
2✔
737
        else:
738
            raise InvalidSamplingStrategyError(
×
739
                f"Unknown sampling strategy {sampling_strat}"
740
            )
741

742
        if Y_c.ndim == 1:
2✔
743
            # returned a 1D array, but we need a 2D array of shape (variable, N)
744
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
745

746
        return X_c, Y_c
2✔
747

748
    def sample_offgrid_aux(
2✔
749
        self,
750
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
751
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
752
    ) -> np.ndarray:
753
        """
754
        Sample auxiliary data at off-grid locations.
755

756
        Args:
757
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
758
                Off-grid locations at which to sample the auxiliary data. Can
759
                be a tuple of two numpy arrays, or a single numpy array.
760
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
761
                Auxiliary data at off-grid locations.
762

763
        Returns:
764
            :class:`numpy:numpy.ndarray`:
765
                [Description of the returned numpy ndarray]
766

767
        Raises:
768
            [ExceptionType]:
769
                [Description of under what conditions this function raises an exception]
770
        """
771
        if "time" in offgrid_aux.dims:
2✔
772
            raise ValueError(
×
773
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
774
                "passing it to `sample_offgrid_aux`."
775
            )
776
        if isinstance(X_t, tuple):
2✔
777
            xt1, xt2 = X_t
2✔
778
            xt1 = xt1.ravel()
2✔
779
            xt2 = xt2.ravel()
2✔
780
        else:
781
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
782
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
783
        if isinstance(Y_t_aux, xr.Dataset):
2✔
784
            Y_t_aux = Y_t_aux.to_array()
×
785
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
786
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
787
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
788
        ):
789
            # Reshape to (variable, *spatial_dims)
790
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
791
        return Y_t_aux
2✔
792

793
    def time_slice_variable(self, var, date, delta_t=0):
2✔
794
        """
795
        Slice a variable by a given time delta.
796

797
        Args:
798
            var (...):
799
                Variable to slice.
800
            delta_t (...):
801
                Time delta to slice by.
802

803
        Returns:
804
            var (...)
805
                Sliced variable.
806

807
        Raises:
808
            ValueError
809
                If the variable is of an unknown type.
810
        """
811
        # TODO: Does this work with instantaneous time?
812
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
813
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
814
            if "time" in var.dims:
2✔
815
                var = var.sel(time=date + delta_t)
2✔
816
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
817
            if "time" in var.index.names:
2✔
818
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
819
        else:
820
            raise ValueError(f"Unknown variable type {type(var)}")
×
821
        return var
2✔
822

823
    def task_generation(
2✔
824
        self,
825
        date: pd.Timestamp,
826
        context_sampling: Union[
827
            str,
828
            int,
829
            float,
830
            np.ndarray,
831
            List[Union[str, int, float, np.ndarray]],
832
        ] = "all",
833
        target_sampling: Optional[
834
            Union[
835
                str,
836
                int,
837
                float,
838
                np.ndarray,
839
                List[Union[str, int, float, np.ndarray]],
840
            ]
841
        ] = None,
842
        split_frac: float = 0.5,
843
        datewise_deterministic: bool = False,
844
        seed_override: Optional[int] = None,
845
    ) -> Task:
846
        def check_sampling_strat(sampling_strat, set):
2✔
847
            """
848
            Check the sampling strategy.
849

850
            Ensure ``sampling_strat`` is either a single strategy (broadcast
851
            to all sets) or a list of length equal to the number of sets.
852
            Convert to a tuple of length equal to the number of sets and
853
            return.
854

855
            Args:
856
                sampling_strat:
857
                    Sampling strategy to check.
858
                set:
859
                    Context or target set to check.
860

861
            Returns:
862
                tuple:
863
                    Tuple of sampling strategies, one for each set.
864

865
            Raises:
866
                InvalidSamplingStrategyError:
867
                    - If the sampling strategy is invalid.
868
                    - If the length of the sampling strategy does not match the number of sets.
869
                    - If the sampling strategy is not a valid type.
870
                    - If the sampling strategy is a float but not in [0, 1].
871
                    - If the sampling strategy is an int but not positive.
872
                    - If the sampling strategy is a numpy array but not of shape (2, N).
873
            """
874
            if sampling_strat is None:
2✔
875
                return None
2✔
876
            if not isinstance(sampling_strat, (list, tuple)):
2✔
877
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
878
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
879
                sampling_strat
880
            ) != len(set):
881
                raise InvalidSamplingStrategyError(
2✔
882
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
883
                    f"match number of context sets ({len(set)})"
884
                )
885

886
            for strat in sampling_strat:
2✔
887
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
888
                    raise InvalidSamplingStrategyError(
2✔
889
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
890
                    )
891
                elif isinstance(strat, str) and strat not in [
2✔
892
                    "all",
893
                    "split",
894
                    "gapfill",
895
                ]:
896
                    raise InvalidSamplingStrategyError(
2✔
897
                        f"Unknown sampling strategy {strat} for type str"
898
                    )
899
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
900
                    raise InvalidSamplingStrategyError(
2✔
901
                        f"If sampling strategy is a float, must be fraction "
902
                        f"must be in [0, 1], got {strat}"
903
                    )
904
                elif isinstance(strat, int) and strat < 0:
2✔
905
                    raise InvalidSamplingStrategyError(
2✔
906
                        f"Sampling N must be positive, got {strat}"
907
                    )
908
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
909
                    raise InvalidSamplingStrategyError(
2✔
910
                        "Sampling coordinates must be of shape (2, N), got "
911
                        f"{strat.shape}"
912
                    )
913

914
            return sampling_strat
2✔
915

916
        def sample_variable(var, sampling_strat, seed):
2✔
917
            """
918
            Sample a variable by a given sampling strategy to get input and
919
            output data.
920

921
            Args:
922
                var:
923
                    Variable to sample.
924
                sampling_strat:
925
                    Sampling strategy to use.
926
                seed:
927
                    Seed for random sampling.
928

929
            Returns:
930
                Tuple[X, Y]:
931
                    Tuple of input and output data.
932

933
            Raises:
934
                ValueError:
935
                    If the variable is of an unknown type.
936
            """
937
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
938
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
939
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
940
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
941
            else:
942
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
943
            return X, Y
2✔
944

945
        # Check that the sampling strategies are valid
946
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
947
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
948
        # Check `split_frac
949
        if split_frac < 0 or split_frac > 1:
2✔
950
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
951
        if self.links is None:
2✔
952
            b1 = any(
2✔
953
                [
954
                    strat in ["split", "gapfill"]
955
                    for strat in context_sampling
956
                    if isinstance(strat, str)
957
                ]
958
            )
959
            if target_sampling is None:
2✔
960
                b2 = False
2✔
961
            else:
962
                b2 = any(
2✔
963
                    [
964
                        strat in ["split", "gapfill"]
965
                        for strat in target_sampling
966
                        if isinstance(strat, str)
967
                    ]
968
                )
969
            if b1 or b2:
2✔
970
                raise ValueError(
2✔
971
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
972
                    "sets must be linked with the TaskLoader `links` attribute."
973
                )
974
        if self.links is not None:
2✔
975
            for context_idx, target_idx in self.links:
2✔
976
                context_sampling_i = context_sampling[context_idx]
2✔
977
                if target_sampling is None:
2✔
978
                    target_sampling_i = None
×
979
                else:
980
                    target_sampling_i = target_sampling[target_idx]
2✔
981
                link_strats = (context_sampling_i, target_sampling_i)
2✔
982
                if any(
2✔
983
                    [
984
                        strat in ["split", "gapfill"]
985
                        for strat in link_strats
986
                        if isinstance(strat, str)
987
                    ]
988
                ):
989
                    # If one of the sampling strategies is "split" or "gapfill", the other must
990
                    # use the same splitting strategy
991
                    if link_strats[0] != link_strats[1]:
2✔
992
                        raise ValueError(
2✔
993
                            f"Linked context set {context_idx} and target set {target_idx} "
994
                            f"must use the same sampling strategy if one of them "
995
                            f"uses the 'split' or 'gapfill' sampling strategy. "
996
                            f"Got {link_strats[0]} and {link_strats[1]}."
997
                        )
998

999
        if not isinstance(date, pd.Timestamp):
2✔
1000
            date = pd.Timestamp(date)
2✔
1001

1002
        if seed_override is not None:
2✔
1003
            # Override the seed for random sampling
1004
            seed = seed_override
×
1005
        elif datewise_deterministic:
2✔
1006
            # Generate a deterministic seed, based on the date, for random sampling
1007
            seed = int(date.strftime("%Y%m%d"))
2✔
1008
        else:
1009
            # 'Truly' random sampling
1010
            seed = None
2✔
1011

1012
        task = {}
2✔
1013

1014
        task["time"] = date
2✔
1015
        task["ops"] = []
2✔
1016
        task["X_c"] = []
2✔
1017
        task["Y_c"] = []
2✔
1018
        if target_sampling is not None:
2✔
1019
            task["X_t"] = []
2✔
1020
            task["Y_t"] = []
2✔
1021
        else:
1022
            task["X_t"] = None
2✔
1023
            task["Y_t"] = None
2✔
1024

1025
        context_slices = [
2✔
1026
            self.time_slice_variable(var, date, delta_t)
1027
            for var, delta_t in zip(self.context, self.context_delta_t)
1028
        ]
1029
        target_slices = [
2✔
1030
            self.time_slice_variable(var, date, delta_t)
1031
            for var, delta_t in zip(self.target, self.target_delta_t)
1032
        ]
1033

1034
        # TODO move to method
1035
        if (
2✔
1036
            self.links is not None
1037
            and "split" in context_sampling
1038
            and "split" in target_sampling
1039
        ):
1040
            # Perform the split sampling strategy for linked context and target sets at this point
1041
            # while we have the full context and target data in scope
1042

1043
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1044
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1045
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1046
                f"Number of context sets with 'split' sampling strategy "
1047
                f"({len(context_split_idxs)}) must match number of target sets "
1048
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1049
            )
1050
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1051
                zip(context_split_idxs, target_split_idxs)
1052
            ):
1053
                assert (context_idx, target_idx) in self.links, (
2✔
1054
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1055
                    f"with the `links` attribute if using the 'split' sampling strategy"
1056
                )
1057

1058
                context_var = context_slices[context_idx]
2✔
1059
                target_var = target_slices[target_idx]
2✔
1060

1061
                for var in [context_var, target_var]:
2✔
1062
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1063
                        f"If using 'split' sampling strategy for linked context and target sets, "
1064
                        f"the context and target sets must be pandas DataFrames or Series, "
1065
                        f"but got {type(var)}."
1066
                    )
1067

1068
                N_obs = len(context_var)
2✔
1069
                N_obs_target_check = len(target_var)
2✔
1070
                if N_obs != N_obs_target_check:
2✔
1071
                    raise ValueError(
×
1072
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1073
                        f"because they have different numbers of observations: "
1074
                        f"{N_obs} and {N_obs_target_check}"
1075
                    )
1076
                split_seed = seed + split_i if seed is not None else None
2✔
1077
                rng = np.random.default_rng(split_seed)
2✔
1078

1079
                N_context = int(N_obs * split_frac)
2✔
1080
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1081

1082
                context_var = context_var.iloc[idxs_context]
2✔
1083
                target_var = target_var.drop(context_var.index)
2✔
1084

1085
                context_slices[context_idx] = context_var
2✔
1086
                target_slices[target_idx] = target_var
2✔
1087

1088
        # TODO move to method
1089
        if (
2✔
1090
            self.links is not None
1091
            and "gapfill" in context_sampling
1092
            and "gapfill" in target_sampling
1093
        ):
1094
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1095
            # while we have the full context and target data in scope
1096

1097
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1098
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1099
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1100
                f"Number of context sets with 'gapfill' sampling strategy "
1101
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1102
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1103
            )
1104
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1105
                zip(context_gapfill_idxs, target_gapfill_idxs)
1106
            ):
1107
                assert (context_idx, target_idx) in self.links, (
2✔
1108
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1109
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1110
                )
1111

1112
                context_var = context_slices[context_idx]
2✔
1113
                target_var = target_slices[target_idx]
2✔
1114

1115
                for var in [context_var, target_var]:
2✔
1116
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1117
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1118
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1119
                        f"but got {type(var)}."
1120
                    )
1121

1122
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1123
                rng = np.random.default_rng(split_seed)
2✔
1124

1125
                # Keep trying until we get a target set with at least one target point
1126
                keep_searching = True
2✔
1127
                while keep_searching:
2✔
1128
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1129
                    added_mask = (
2✔
1130
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1131
                    )
1132
                    curr_mask = context_var.isnull()
2✔
1133

1134
                    # Mask out added missing values
1135
                    context_var = context_var.where(~added_mask)
2✔
1136

1137
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1138
                    #   when we could just slice the target values here
1139
                    target_mask = added_mask & ~curr_mask
2✔
1140
                    if isinstance(target_var, xr.Dataset):
2✔
1141
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1142
                    else:
1143
                        keep_searching = np.all(target_mask.data == False)
2✔
1144
                    if keep_searching:
2✔
1145
                        continue  # No target points -- use a different `added_mask`
×
1146

1147
                    target_var = target_var.where(
2✔
1148
                        target_mask
1149
                    )  # Only keep target locations
1150

1151
                    context_slices[context_idx] = context_var
2✔
1152
                    target_slices[target_idx] = target_var
2✔
1153

1154
        for i, (var, sampling_strat) in enumerate(
2✔
1155
            zip(context_slices, context_sampling)
1156
        ):
1157
            context_seed = seed + i if seed is not None else None
2✔
1158
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1159
            task[f"X_c"].append(X_c)
2✔
1160
            task[f"Y_c"].append(Y_c)
2✔
1161
        if target_sampling is not None:
2✔
1162
            for j, (var, sampling_strat) in enumerate(
2✔
1163
                zip(target_slices, target_sampling)
1164
            ):
1165
                target_seed = seed + i + j if seed is not None else None
2✔
1166
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1167
                task[f"X_t"].append(X_t)
2✔
1168
                task[f"Y_t"].append(Y_t)
2✔
1169

1170
        if self.aux_at_contexts is not None:
2✔
1171
            # Add auxiliary variable sampled at context set as a new context variable
1172
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1173
            if len(X_c_offgrid) == 0:
2✔
1174
                # No offgrid context sets
1175
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1176
            else:
1177
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1178
            Y_c_aux = (
2✔
1179
                self.sample_offgrid_aux(
1180
                    X_c_offrid_all,
1181
                    self.time_slice_variable(self.aux_at_contexts, date),
1182
                ),
1183
            )
1184
            task["X_c"].append(X_c_offrid_all)
2✔
1185
            task["Y_c"].append(Y_c_aux)
2✔
1186

1187
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1188
            task["Y_t_aux"] = None
2✔
1189
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1190
            # Add auxiliary variable to target set
1191
            if len(task["X_t"]) > 1:
2✔
1192
                raise ValueError(
×
1193
                    "Cannot add auxiliary variable to target set when there "
1194
                    "are multiple target variables (not supported by default `ConvNP` model)."
1195
                )
1196
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1197
                task["X_t"][0],
1198
                self.time_slice_variable(self.aux_at_targets, date),
1199
            )
1200

1201
        return Task(task)
2✔
1202

1203
    def __call__(
2✔
1204
        self,
1205
        date: pd.Timestamp,
1206
        context_sampling: Union[
1207
            str,
1208
            int,
1209
            float,
1210
            np.ndarray,
1211
            List[Union[str, int, float, np.ndarray]],
1212
        ] = "all",
1213
        target_sampling: Optional[
1214
            Union[
1215
                str,
1216
                int,
1217
                float,
1218
                np.ndarray,
1219
                List[Union[str, int, float, np.ndarray]],
1220
            ]
1221
        ] = None,
1222
        split_frac: float = 0.5,
1223
        datewise_deterministic: bool = False,
1224
        seed_override: Optional[int] = None,
1225
    ) -> Union[Task, List[Task]]:
1226
        """
1227
        Generate a task for a given date (or a list of
1228
        :class:`.data.task.Task` objects for a list of dates).
1229

1230
        There are several sampling strategies available for the context and
1231
        target data:
1232

1233
            - "all": Sample all observations.
1234
            - int: Sample N observations uniformly at random.
1235
            - float: Sample a fraction of observations uniformly at random.
1236
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1237
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1238
                normalised.
1239
            - "split": Split pandas observations into disjoint context and target sets.
1240
                `split_frac` determines the fraction of observations
1241
                to use for the context set. The remaining observations are used
1242
                for the target set.
1243
                The context set and target set must be linked through the ``TaskLoader``
1244
                ``links`` argument. Only valid for pandas data.
1245
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1246
                Randomly samples a missing data (NaN) mask from another timestamp and
1247
                adds it to the context set (i.e. increases the number of NaNs).
1248
                The target set is then true values of the data at the added missing locations.
1249
                The context set and target set must be linked through the ``TaskLoader``
1250
                ``links`` argument. Only valid for xarray data.
1251

1252
        Args:
1253
            date (:class:`pandas.Timestamp`):
1254
                Date for which to generate the task.
1255
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1256
                Sampling strategy for the context data, either a list of
1257
                sampling strategies for each context set, or a single strategy
1258
                applied to all context sets. Default is ``"all"``.
1259
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1260
                Sampling strategy for the target data, either a list of
1261
                sampling strategies for each target set, or a single strategy
1262
                applied to all target sets. Default is ``None``, meaning no target
1263
                data is returned.
1264
            split_frac (float, optional):
1265
                The fraction of observations to use for the context set with
1266
                the "split" sampling strategy for linked context and target set
1267
                pairs. The remaining observations are used for the target set.
1268
                Default is 0.5.
1269
            datewise_deterministic (bool, optional):
1270
                Whether random sampling is datewise deterministic based on the
1271
                date. Default is ``False``.
1272
            seed_override (Optional[int], optional):
1273
                Override the seed for random sampling. This can be used to use
1274
                the same random sampling at different ``date``. Default is
1275
                None.
1276

1277
        Returns:
1278
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1279
                Task object or list of task objects for each date containing
1280
                the context and target data.
1281
        """
1282
        if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1283
            return [
2✔
1284
                self.task_generation(
1285
                    d,
1286
                    context_sampling,
1287
                    target_sampling,
1288
                    split_frac,
1289
                    datewise_deterministic,
1290
                    seed_override,
1291
                )
1292
                for d in date
1293
            ]
1294
        else:
1295
            return self.task_generation(
2✔
1296
                date,
1297
                context_sampling,
1298
                target_sampling,
1299
                split_frac,
1300
                datewise_deterministic,
1301
                seed_override,
1302
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc