• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455483170

22 Oct 2024 07:38AM UTC coverage: 81.626%. Remained the same
11455483170

push

github

davidwilby
update pre-commit ruff version

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.11
/deepsensor/data/loader.py
1
from deepsensor.data.task import Task, flatten_X
2✔
2

3
import os
2✔
4
import json
2✔
5
import copy
2✔
6

7
import numpy as np
2✔
8
import xarray as xr
2✔
9
import pandas as pd
2✔
10

11
from typing import List, Tuple, Union, Optional
2✔
12

13
from deepsensor.errors import InvalidSamplingStrategyError
2✔
14

15

16
class TaskLoader:
2✔
17
    """Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
18

19
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
20
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
21
    of these.
22

23
    The behaviour is the following:
24
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
25
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
26
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
27
          (using config)
28

29
    Args:
30
        task_loader_ID:
31
            If loading a TaskLoader from a config file, this is the folder the
32
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
33
            arguments are ignored.
34
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
35
            Context data. Can be a single :class:`xarray.DataArray`,
36
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
37
            list/tuple of these.
38
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
39
            Target data. Can be a single :class:`xarray.DataArray`,
40
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
41
            list/tuple of these.
42
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
43
            Auxiliary data at context locations. Tuple of two elements, where
44
            the first element is the index of the context set for which the
45
            auxiliary data will be sampled at, and the second element is the
46
            auxiliary data, which can be a single :class:`xarray.DataArray` or
47
            :class:`xarray.Dataset`. Default: None.
48
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
49
            Auxiliary data at target locations. Can be a single
50
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
51
            None.
52
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
53
            Specifies links between context and target data. Each link is a
54
            tuple of two integers, where the first integer is the index of the
55
            context data and the second integer is the index of the target
56
            data. Can be a single tuple in the case of a single link. If None,
57
            no links are specified. Default: None.
58
        context_delta_t (int | List[int], optional):
59
            Time difference between context data and t=0 (task init time). Can
60
            be a single int (same for all context data) or a list/tuple of
61
            ints. Default is 0.
62
        target_delta_t (int | List[int], optional):
63
            Time difference between target data and t=0 (task init time). Can
64
            be a single int (same for all target data) or a list/tuple of ints.
65
            Default is 0.
66
        time_freq (str, optional):
67
            Time frequency of the data. Default: ``'D'`` (daily).
68
        xarray_interp_method (str, optional):
69
            Interpolation method to use when interpolating
70
            :class:`xarray.DataArray`. Default is ``'linear'``.
71
        discrete_xarray_sampling (bool, optional):
72
            When randomly sampling xarray variables, whether to sample at
73
            discrete points defined at grid cell centres, or at continuous
74
            points within the grid. Default is ``False``.
75
        dtype (object, optional):
76
            Data type of the data. Used to cast the data to the specified
77
            dtype. Default: ``'float32'``.
78
    """
79

80
    config_fname = "task_loader_config.json"
2✔
81

82
    def __init__(
2✔
83
        self,
84
        task_loader_ID: Union[str, None] = None,
85
        context: Union[
86
            xr.DataArray,
87
            xr.Dataset,
88
            pd.DataFrame,
89
            str,
90
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
91
        ] = None,
92
        target: Union[
93
            xr.DataArray,
94
            xr.Dataset,
95
            pd.DataFrame,
96
            str,
97
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
98
        ] = None,
99
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
100
        aux_at_targets: Optional[
101
            Union[
102
                xr.DataArray,
103
                xr.Dataset,
104
            ]
105
        ] = None,
106
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
107
        context_delta_t: Union[int, List[int]] = 0,
108
        target_delta_t: Union[int, List[int]] = 0,
109
        time_freq: str = "D",
110
        xarray_interp_method: str = "linear",
111
        discrete_xarray_sampling: bool = False,
112
        dtype: object = "float32",
113
    ) -> None:
114
        if task_loader_ID is not None:
2✔
115
            self.task_loader_ID = task_loader_ID
2✔
116
            # Load TaskLoader from config file
117
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
118
            with open(fpath, "r") as f:
2✔
119
                self.config = json.load(f)
2✔
120

121
            self.context = self.config["context"]
2✔
122
            self.target = self.config["target"]
2✔
123
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
124
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
125
            self.links = self.config["links"]
2✔
126
            if self.links is not None:
2✔
127
                self.links = [tuple(link) for link in self.links]
2✔
128
            self.context_delta_t = self.config["context_delta_t"]
2✔
129
            self.target_delta_t = self.config["target_delta_t"]
2✔
130
            self.time_freq = self.config["time_freq"]
2✔
131
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
132
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
133
            self.dtype = self.config["dtype"]
2✔
134
        else:
135
            self.context = context
2✔
136
            self.target = target
2✔
137
            self.aux_at_contexts = aux_at_contexts
2✔
138
            self.aux_at_targets = aux_at_targets
2✔
139
            self.links = links
2✔
140
            self.context_delta_t = context_delta_t
2✔
141
            self.target_delta_t = target_delta_t
2✔
142
            self.time_freq = time_freq
2✔
143
            self.xarray_interp_method = xarray_interp_method
2✔
144
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
145
            self.dtype = dtype
2✔
146

147
        if not isinstance(self.context, (tuple, list)):
2✔
148
            self.context = (self.context,)
2✔
149
        if not isinstance(self.target, (tuple, list)):
2✔
150
            self.target = (self.target,)
2✔
151

152
        if isinstance(self.context_delta_t, int):
2✔
153
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
154
        else:
155
            assert len(self.context_delta_t) == len(self.context), (
2✔
156
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
157
                f"the number of context sets ({len(self.context)})"
158
            )
159
        if isinstance(self.target_delta_t, int):
2✔
160
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
161
        else:
162
            assert len(self.target_delta_t) == len(self.target), (
2✔
163
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
164
                f"the number of target sets ({len(self.target)})"
165
            )
166

167
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
168
        if all_paths:
2✔
169
            self._set_config()
2✔
170
            self._load_data_from_paths()
2✔
171

172
        self.context = self._cast_to_dtype(self.context)
2✔
173
        self.target = self._cast_to_dtype(self.target)
2✔
174
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
175
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
176

177
        self.links = self._check_links(self.links)
2✔
178

179
        (
2✔
180
            self.context_dims,
181
            self.target_dims,
182
            self.aux_at_target_dims,
183
        ) = self.count_context_and_target_data_dims()
184
        (
2✔
185
            self.context_var_IDs,
186
            self.target_var_IDs,
187
            self.context_var_IDs_and_delta_t,
188
            self.target_var_IDs_and_delta_t,
189
            self.aux_at_target_var_IDs,
190
        ) = self.infer_context_and_target_var_IDs()
191

192
    def _set_config(self):
2✔
193
        """Instantiate a config dictionary for the TaskLoader object."""
194
        # Take deepcopy to avoid modifying the original config
195
        self.config = copy.deepcopy(
2✔
196
            dict(
197
                context=self.context,
198
                target=self.target,
199
                aux_at_contexts=self.aux_at_contexts,
200
                aux_at_targets=self.aux_at_targets,
201
                links=self.links,
202
                context_delta_t=self.context_delta_t,
203
                target_delta_t=self.target_delta_t,
204
                time_freq=self.time_freq,
205
                xarray_interp_method=self.xarray_interp_method,
206
                discrete_xarray_sampling=self.discrete_xarray_sampling,
207
                dtype=self.dtype,
208
            )
209
        )
210

211
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
212
        """If all data passed as paths, save paths to config and return True."""
213

214
        def _check_if_strings(x, mode="all"):
2✔
215
            if x is None:
2✔
216
                return None
2✔
217
            elif isinstance(x, (tuple, list)):
2✔
218
                if mode == "all":
2✔
219
                    return all([isinstance(x_i, str) for x_i in x])
2✔
220
                elif mode == "any":
2✔
221
                    return any([isinstance(x_i, str) for x_i in x])
2✔
222
            else:
223
                return isinstance(x, str)
2✔
224

225
        all_paths = all(
2✔
226
            filter(
227
                lambda x: x is not None,
228
                [
229
                    _check_if_strings(self.context),
230
                    _check_if_strings(self.target),
231
                    _check_if_strings(self.aux_at_contexts),
232
                    _check_if_strings(self.aux_at_targets),
233
                ],
234
            )
235
        )
236
        self._is_saveable = all_paths
2✔
237

238
        any_paths = any(
2✔
239
            filter(
240
                lambda x: x is not None,
241
                [
242
                    _check_if_strings(self.context, mode="any"),
243
                    _check_if_strings(self.target, mode="any"),
244
                    _check_if_strings(self.aux_at_contexts, mode="any"),
245
                    _check_if_strings(self.aux_at_targets, mode="any"),
246
                ],
247
            )
248
        )
249
        if any_paths and not all_paths:
2✔
250
            raise ValueError(
×
251
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
252
            )
253

254
        return all_paths
2✔
255

256
    def _load_data_from_paths(self):
2✔
257
        """Load data from paths and overwrite paths with loaded data."""
258
        loaded_data = {}
2✔
259

260
        def _load_pandas_or_xarray(path):
2✔
261
            # Need to be careful about this. We need to ensure data gets into the right form
262
            #  for TaskLoader.
263
            if path is None:
2✔
264
                return None
2✔
265
            elif path in loaded_data:
2✔
266
                return loaded_data[path]
2✔
267
            elif path.endswith(".nc"):
2✔
268
                data = xr.open_dataset(path)
2✔
269
            elif path.endswith(".csv"):
2✔
270
                df = pd.read_csv(path)
2✔
271
                if "time" in df.columns:
2✔
272
                    df["time"] = pd.to_datetime(df["time"])
2✔
273
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
274
                else:
275
                    df = df.set_index(["x1", "x2"]).sort_index()
×
276
                data = df
2✔
277
            else:
278
                raise ValueError(f"Unknown file extension for {path}")
×
279
            loaded_data[path] = data
2✔
280
            return data
2✔
281

282
        def _load_data(data):
2✔
283
            if isinstance(data, (tuple, list)):
2✔
284
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
285
            else:
286
                data = _load_pandas_or_xarray(data)
2✔
287
            return data
2✔
288

289
        self.context = _load_data(self.context)
2✔
290
        self.target = _load_data(self.target)
2✔
291
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
292
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
293

294
    def save(self, folder: str):
2✔
295
        """Save TaskLoader config to JSON in `folder`."""
296
        if not self._is_saveable:
2✔
297
            raise ValueError(
2✔
298
                "TaskLoader cannot be saved because not all data was passed as paths."
299
            )
300

301
        os.makedirs(folder, exist_ok=True)
2✔
302
        fpath = os.path.join(folder, self.config_fname)
2✔
303
        with open(fpath, "w") as f:
2✔
304
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
305

306
    def _cast_to_dtype(
2✔
307
        self,
308
        var: Union[
309
            xr.DataArray,
310
            xr.Dataset,
311
            pd.DataFrame,
312
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
313
        ],
314
    ) -> (List, List):
315
        """Cast context and target data to the default dtype.
316

317
        ..
318
            TODO unit test this by passing in a variety of data types and
319
            checking that they are cast correctly.
320

321
        Args:
322
            var : ...
323
                ...
324

325
        Returns:
326
            tuple: Tuple of context data with specified dtype.
327
            tuple: Tuple of target data with specified dtype.
328
        """
329

330
        def cast_to_dtype(var):
2✔
331
            if isinstance(var, xr.DataArray):
2✔
332
                var = var.astype(self.dtype)
2✔
333
                var["x1"] = var["x1"].astype(self.dtype)
2✔
334
                var["x2"] = var["x2"].astype(self.dtype)
2✔
335
            elif isinstance(var, xr.Dataset):
2✔
336
                var = var.astype(self.dtype)
2✔
337
                var["x1"] = var["x1"].astype(self.dtype)
2✔
338
                var["x2"] = var["x2"].astype(self.dtype)
2✔
339
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
340
                var = var.astype(self.dtype)
2✔
341
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
342
                #   x1/x2 coord dtypes during task sampling
343
            else:
344
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
345
            return var
2✔
346

347
        if var is None:
2✔
348
            return var
2✔
349
        elif isinstance(var, (tuple, list)):
2✔
350
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
351
        else:
352
            var = cast_to_dtype(var)
2✔
353

354
        return var
2✔
355

356
    def load_dask(self) -> None:
2✔
357
        """Load any `dask` data into memory.
358

359
        This function triggers the computation and loading of any data that
360
        is represented as dask arrays or datasets into memory.
361

362
        Returns:
363
            None
364
        """
365

366
        def load(datasets):
2✔
367
            if datasets is None:
2✔
368
                return
×
369
            if not isinstance(datasets, (tuple, list)):
2✔
370
                datasets = [datasets]
2✔
371
            for i, var in enumerate(datasets):
2✔
372
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
373
                    var = var.load()
2✔
374

375
        load(self.context)
2✔
376
        load(self.target)
2✔
377
        load(self.aux_at_contexts)
2✔
378
        load(self.aux_at_targets)
2✔
379

380
        return None
2✔
381

382
    def count_context_and_target_data_dims(self):
2✔
383
        """Count the number of data dimensions in the context and target data.
384

385
        Returns:
386
            tuple: context_dims, Tuple of data dimensions in the context data.
387
            tuple: target_dims, Tuple of data dimensions in the target data.
388

389
        Raises:
390
            ValueError: If the context/target data is not a tuple/list of
391
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
392
                        :class:`pandas.DataFrame`.
393
        """
394

395
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
396
            if not isinstance(datasets, (tuple, list)):
2✔
397
                datasets = [datasets]
2✔
398

399
            dims = []
2✔
400
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
401
            for i, var in enumerate(datasets):
2✔
402
                if isinstance(var, xr.Dataset):
2✔
403
                    dim = len(var.data_vars)  # Multiple data variables
2✔
404
                elif isinstance(var, xr.DataArray):
2✔
405
                    dim = 1  # Single data variable
2✔
406
                elif isinstance(var, pd.DataFrame):
2✔
407
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
408
                elif isinstance(var, pd.Series):
2✔
409
                    dim = 1  # Single data variable
2✔
410
                else:
411
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
412
                dims.append(dim)
2✔
413
            return dims
2✔
414

415
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
416
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
417
        if self.aux_at_contexts is not None:
2✔
418
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
419
        if self.aux_at_targets is not None:
2✔
420
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
421
                0
422
            ]
423
        else:
424
            aux_at_target_dims = 0
2✔
425

426
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
427

428
    def infer_context_and_target_var_IDs(self):
2✔
429
        """Infer the variable IDs of the context and target data.
430

431
        Returns:
432
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
433
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
434

435
        Raises:
436
            ValueError: If the context/target data is not a tuple/list of
437
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
438
                        :class:`pandas.DataFrame`.
439
        """
440

441
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
442
            """If delta_ts is not None, then add the delta_t to the variable ID."""
443
            if not isinstance(datasets, (tuple, list)):
2✔
444
                datasets = [datasets]
2✔
445

446
            var_IDs = []
2✔
447
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
448
            for i, var in enumerate(datasets):
2✔
449
                if isinstance(var, xr.DataArray):
2✔
450
                    var_ID = (var.name,)  # Single data variable
2✔
451
                elif isinstance(var, xr.Dataset):
2✔
452
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
453
                elif isinstance(var, pd.DataFrame):
2✔
454
                    var_ID = tuple(var.columns)
2✔
455
                elif isinstance(var, pd.Series):
2✔
456
                    var_ID = (var.name,)
2✔
457
                else:
458
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
459

460
                if delta_ts is not None:
2✔
461
                    # Add delta_t to the variable ID
462
                    var_ID = tuple(
2✔
463
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
464
                    )
465
                else:
466
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
467

468
                var_IDs.append(var_ID)
2✔
469

470
            return var_IDs
2✔
471

472
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
473
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
474
            self.context, self.context_delta_t
475
        )
476
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
477
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
478
            self.target, self.target_delta_t
479
        )
480

481
        if self.aux_at_contexts is not None:
2✔
482
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
483
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
484
                self.aux_at_contexts, [0]
485
            )
486

487
        if self.aux_at_targets is not None:
2✔
488
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
489
                0
490
            ]
491
        else:
492
            aux_at_target_var_IDs = None
2✔
493

494
        return (
2✔
495
            tuple(context_var_IDs),
496
            tuple(target_var_IDs),
497
            tuple(context_var_IDs_and_delta_t),
498
            tuple(target_var_IDs_and_delta_t),
499
            aux_at_target_var_IDs,
500
        )
501

502
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
503
        """Check that the context-target links are valid.
504

505
        Args:
506
            links (Tuple[int, int] | List[Tuple[int, int]]):
507
                Specifies links between context and target data. Each link is a
508
                tuple of two integers, where the first integer is the index of
509
                the context data and the second integer is the index of the
510
                target data. Can be a single tuple in the case of a single
511
                link. If None, no links are specified. Default: None.
512

513
        Returns:
514
            Tuple[int, int] | List[Tuple[int, int]]
515
                The input links, if valid.
516

517
        Raises:
518
            ValueError
519
                If the links are not valid.
520
        """
521
        if links is None:
2✔
522
            return None
2✔
523

524
        assert isinstance(
2✔
525
            links, list
526
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
527
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
528
        assert all(
2✔
529
            isinstance(link, tuple) for link in links
530
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
531
        assert all(
2✔
532
            len(link) == 2 for link in links
533
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
534

535
        # Check that the links are valid
536
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
537
            if context_idx >= len(self.context):
2✔
538
                raise ValueError(
×
539
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
540
                    f"there are only {len(self.context)} context sets"
541
                )
542
            if target_idx >= len(self.target):
2✔
543
                raise ValueError(
2✔
544
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
545
                    f"there are only {len(self.target)} target sets"
546
                )
547

548
        return links
2✔
549

550
    def __str__(self):
2✔
551
        """String representation of the TaskLoader object (user-friendly)."""
552
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
553
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
554
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
555
        if self.aux_at_targets is not None:
×
556
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
557
        return s
×
558

559
    def __repr__(self):
2✔
560
        """Representation of the TaskLoader object (for developers).
561

562
        ..
563
            TODO make this a more verbose version of __str__
564
        """
565
        s = str(self)
×
566
        s += "\n"
×
567
        s += f"\nContext data dimensions: {self.context_dims}"
×
568
        s += f"\nTarget data dimensions: {self.target_dims}"
×
569
        if self.aux_at_targets is not None:
×
570
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
571
        return s
×
572

573
    def sample_da(
2✔
574
        self,
575
        da: Union[xr.DataArray, xr.Dataset],
576
        sampling_strat: Union[str, int, float, np.ndarray],
577
        seed: Optional[int] = None,
578
    ) -> (np.ndarray, np.ndarray):
579
        """Sample a DataArray according to a given strategy.
580

581
        Args:
582
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
583
                DataArray to sample, assumed to be sliced for the task already.
584
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
585
                Sampling strategy, either "all" or an integer for random grid
586
                cell sampling.
587
            seed (int, optional):
588
                Seed for random sampling. Default is None.
589

590
        Returns:
591
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
592
                Tuple of sampled target data and sampled context data.
593

594
        Raises:
595
            InvalidSamplingStrategyError:
596
                If the sampling strategy is not valid or if a numpy coordinate
597
                array is passed to sample an xarray object, but the coordinates
598
                are out of bounds.
599
        """
600
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
601
        if isinstance(da, xr.Dataset):
2✔
602
            da = da.to_array()
2✔
603

604
        if isinstance(sampling_strat, float):
2✔
605
            sampling_strat = int(sampling_strat * da.size)
2✔
606

607
        if isinstance(sampling_strat, (int, np.integer)):
2✔
608
            N = sampling_strat
2✔
609
            rng = np.random.default_rng(seed)
2✔
610
            if self.discrete_xarray_sampling:
2✔
611
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
612
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
613
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
614
            elif not self.discrete_xarray_sampling:
2✔
615
                if N == 0:
2✔
616
                    # Catch zero-context edge case before interp fails
617
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
618
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
619
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
620
                    return X_c, Y_c
2✔
621
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
622
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
623
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
624
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
625
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
626

627
        elif isinstance(sampling_strat, np.ndarray):
2✔
628
            X_c = sampling_strat.astype(self.dtype)
2✔
629
            try:
2✔
630
                Y_c = da.sel(
2✔
631
                    x1=xr.DataArray(X_c[0]),
632
                    x2=xr.DataArray(X_c[1]),
633
                    method="nearest",
634
                    tolerance=0.1,  # Maximum distance from observed point to sample
635
                )
636
            except KeyError:
2✔
637
                raise InvalidSamplingStrategyError(
2✔
638
                    f"Passed a numpy coordinate array to sample xarray object, "
639
                    f"but the coordinates are out of bounds."
640
                )
641
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
642

643
        elif sampling_strat in ["all", "gapfill"]:
2✔
644
            X_c = (
2✔
645
                da.coords["x1"].values[np.newaxis],
646
                da.coords["x2"].values[np.newaxis],
647
            )
648
            Y_c = da.data
2✔
649
            if Y_c.ndim == 2:
2✔
650
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
651
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
652
        else:
653
            raise InvalidSamplingStrategyError(
×
654
                f"Unknown sampling strategy {sampling_strat}"
655
            )
656

657
        if Y_c.ndim == 1:
2✔
658
            # returned a 1D array, but we need a 2D array of shape (variable, N)
659
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
660

661
        return X_c, Y_c
2✔
662

663
    def sample_df(
2✔
664
        self,
665
        df: Union[pd.DataFrame, pd.Series],
666
        sampling_strat: Union[str, int, float, np.ndarray],
667
        seed: Optional[int] = None,
668
    ) -> (np.ndarray, np.ndarray):
669
        """Sample a DataFrame according to a given strategy.
670

671
        Args:
672
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
673
                Dataframe to sample, assumed to be time-sliced for the task
674
                already.
675
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
676
                Sampling strategy, either "all" or an integer for random grid
677
                cell sampling.
678
            seed (int, optional):
679
                Seed for random sampling. Default is None.
680

681
        Returns:
682
            Tuple[X_c, Y_c]:
683
                Tuple of sampled target data and sampled context data.
684

685
        Raises:
686
            InvalidSamplingStrategyError:
687
                If the sampling strategy is not valid or if a numpy coordinate
688
                array is passed to sample a pandas object, but the DataFrame
689
                does not contain all the requested samples.
690
        """
691
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
692

693
        if isinstance(sampling_strat, float):
2✔
694
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
695

696
        if isinstance(sampling_strat, (int, np.integer)):
2✔
697
            N = sampling_strat
2✔
698
            rng = np.random.default_rng(seed)
2✔
699
            idx = rng.choice(df.index, N)
2✔
700
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
701
            Y_c = df.loc[idx].values.T
2✔
702
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
703
            "all",
704
            "split",
705
        ]:
706
            # NOTE if "split", we assume that the context-target split has already been applied to the df
707
            # in an earlier scope with access to both the context and target data. This is maybe risky!
708
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
709
            Y_c = df.values.T
2✔
710
        elif isinstance(sampling_strat, np.ndarray):
2✔
711
            if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
2✔
712
                raise InvalidSamplingStrategyError(
2✔
713
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
714
                    "but the coordinate array has a different dtype than the DataFrame. "
715
                    f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
716
                )
717
            X_c = sampling_strat.astype(self.dtype)
2✔
718
            try:
2✔
719
                Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
2✔
720
            except KeyError:
2✔
721
                raise InvalidSamplingStrategyError(
2✔
722
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
723
                    "but the DataFrame did not contain all the requested samples.\n"
724
                    f"Indexes: {df.index}\n"
725
                    f"Sampling coords: {X_c}\n"
726
                    "If this is unexpected, check that your numpy sampling array matches "
727
                    "the DataFrame index values *exactly*."
728
                )
729
        else:
730
            raise InvalidSamplingStrategyError(
×
731
                f"Unknown sampling strategy {sampling_strat}"
732
            )
733

734
        if Y_c.ndim == 1:
2✔
735
            # returned a 1D array, but we need a 2D array of shape (variable, N)
736
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
737

738
        return X_c, Y_c
2✔
739

740
    def sample_offgrid_aux(
2✔
741
        self,
742
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
743
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
744
    ) -> np.ndarray:
745
        """Sample auxiliary data at off-grid locations.
746

747
        Args:
748
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
749
                Off-grid locations at which to sample the auxiliary data. Can
750
                be a tuple of two numpy arrays, or a single numpy array.
751
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
752
                Auxiliary data at off-grid locations.
753

754
        Returns:
755
            :class:`numpy:numpy.ndarray`:
756
                [Description of the returned numpy ndarray]
757

758
        Raises:
759
            [ExceptionType]:
760
                [Description of under what conditions this function raises an exception]
761
        """
762
        if "time" in offgrid_aux.dims:
2✔
763
            raise ValueError(
×
764
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
765
                "passing it to `sample_offgrid_aux`."
766
            )
767
        if isinstance(X_t, tuple):
2✔
768
            xt1, xt2 = X_t
2✔
769
            xt1 = xt1.ravel()
2✔
770
            xt2 = xt2.ravel()
2✔
771
        else:
772
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
773
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
774
        if isinstance(Y_t_aux, xr.Dataset):
2✔
775
            Y_t_aux = Y_t_aux.to_array()
×
776
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
777
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
778
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
779
        ):
780
            # Reshape to (variable, *spatial_dims)
781
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
782
        return Y_t_aux
2✔
783

784
    def time_slice_variable(self, var, date, delta_t=0):
2✔
785
        """Slice a variable by a given time delta.
786

787
        Args:
788
            var (...):
789
                Variable to slice.
790
            delta_t (...):
791
                Time delta to slice by.
792

793
        Returns:
794
            var (...)
795
                Sliced variable.
796

797
        Raises:
798
            ValueError
799
                If the variable is of an unknown type.
800
        """
801
        # TODO: Does this work with instantaneous time?
802
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
803
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
804
            if "time" in var.dims:
2✔
805
                var = var.sel(time=date + delta_t)
2✔
806
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
807
            if "time" in var.index.names:
2✔
808
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
809
        else:
810
            raise ValueError(f"Unknown variable type {type(var)}")
×
811
        return var
2✔
812

813
    def task_generation(  # noqa: D102
2✔
814
        self,
815
        date: pd.Timestamp,
816
        context_sampling: Union[
817
            str,
818
            int,
819
            float,
820
            np.ndarray,
821
            List[Union[str, int, float, np.ndarray]],
822
        ] = "all",
823
        target_sampling: Optional[
824
            Union[
825
                str,
826
                int,
827
                float,
828
                np.ndarray,
829
                List[Union[str, int, float, np.ndarray]],
830
            ]
831
        ] = None,
832
        split_frac: float = 0.5,
833
        datewise_deterministic: bool = False,
834
        seed_override: Optional[int] = None,
835
    ) -> Task:
836
        def check_sampling_strat(sampling_strat, set):
2✔
837
            """Check the sampling strategy.
838

839
            Ensure ``sampling_strat`` is either a single strategy (broadcast
840
            to all sets) or a list of length equal to the number of sets.
841
            Convert to a tuple of length equal to the number of sets and
842
            return.
843

844
            Args:
845
                sampling_strat:
846
                    Sampling strategy to check.
847
                set:
848
                    Context or target set to check.
849

850
            Returns:
851
                tuple:
852
                    Tuple of sampling strategies, one for each set.
853

854
            Raises:
855
                InvalidSamplingStrategyError:
856
                    - If the sampling strategy is invalid.
857
                    - If the length of the sampling strategy does not match the number of sets.
858
                    - If the sampling strategy is not a valid type.
859
                    - If the sampling strategy is a float but not in [0, 1].
860
                    - If the sampling strategy is an int but not positive.
861
                    - If the sampling strategy is a numpy array but not of shape (2, N).
862
            """
863
            if sampling_strat is None:
2✔
864
                return None
2✔
865
            if not isinstance(sampling_strat, (list, tuple)):
2✔
866
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
867
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
868
                sampling_strat
869
            ) != len(set):
870
                raise InvalidSamplingStrategyError(
2✔
871
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
872
                    f"match number of context sets ({len(set)})"
873
                )
874

875
            for strat in sampling_strat:
2✔
876
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
877
                    raise InvalidSamplingStrategyError(
2✔
878
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
879
                    )
880
                elif isinstance(strat, str) and strat not in [
2✔
881
                    "all",
882
                    "split",
883
                    "gapfill",
884
                ]:
885
                    raise InvalidSamplingStrategyError(
2✔
886
                        f"Unknown sampling strategy {strat} for type str"
887
                    )
888
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
889
                    raise InvalidSamplingStrategyError(
2✔
890
                        f"If sampling strategy is a float, must be fraction "
891
                        f"must be in [0, 1], got {strat}"
892
                    )
893
                elif isinstance(strat, int) and strat < 0:
2✔
894
                    raise InvalidSamplingStrategyError(
2✔
895
                        f"Sampling N must be positive, got {strat}"
896
                    )
897
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
898
                    raise InvalidSamplingStrategyError(
2✔
899
                        "Sampling coordinates must be of shape (2, N), got "
900
                        f"{strat.shape}"
901
                    )
902

903
            return sampling_strat
2✔
904

905
        def sample_variable(var, sampling_strat, seed):
2✔
906
            """Sample a variable by a given sampling strategy to get input and
907
            output data.
908

909
            Args:
910
                var:
911
                    Variable to sample.
912
                sampling_strat:
913
                    Sampling strategy to use.
914
                seed:
915
                    Seed for random sampling.
916

917
            Returns:
918
                Tuple[X, Y]:
919
                    Tuple of input and output data.
920

921
            Raises:
922
                ValueError:
923
                    If the variable is of an unknown type.
924
            """
925
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
926
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
927
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
928
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
929
            else:
930
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
931
            return X, Y
2✔
932

933
        # Check that the sampling strategies are valid
934
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
935
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
936
        # Check `split_frac
937
        if split_frac < 0 or split_frac > 1:
2✔
938
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
939
        if self.links is None:
2✔
940
            b1 = any(
2✔
941
                [
942
                    strat in ["split", "gapfill"]
943
                    for strat in context_sampling
944
                    if isinstance(strat, str)
945
                ]
946
            )
947
            if target_sampling is None:
2✔
948
                b2 = False
2✔
949
            else:
950
                b2 = any(
2✔
951
                    [
952
                        strat in ["split", "gapfill"]
953
                        for strat in target_sampling
954
                        if isinstance(strat, str)
955
                    ]
956
                )
957
            if b1 or b2:
2✔
958
                raise ValueError(
2✔
959
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
960
                    "sets must be linked with the TaskLoader `links` attribute."
961
                )
962
        if self.links is not None:
2✔
963
            for context_idx, target_idx in self.links:
2✔
964
                context_sampling_i = context_sampling[context_idx]
2✔
965
                if target_sampling is None:
2✔
966
                    target_sampling_i = None
×
967
                else:
968
                    target_sampling_i = target_sampling[target_idx]
2✔
969
                link_strats = (context_sampling_i, target_sampling_i)
2✔
970
                if any(
2✔
971
                    [
972
                        strat in ["split", "gapfill"]
973
                        for strat in link_strats
974
                        if isinstance(strat, str)
975
                    ]
976
                ):
977
                    # If one of the sampling strategies is "split" or "gapfill", the other must
978
                    # use the same splitting strategy
979
                    if link_strats[0] != link_strats[1]:
2✔
980
                        raise ValueError(
2✔
981
                            f"Linked context set {context_idx} and target set {target_idx} "
982
                            f"must use the same sampling strategy if one of them "
983
                            f"uses the 'split' or 'gapfill' sampling strategy. "
984
                            f"Got {link_strats[0]} and {link_strats[1]}."
985
                        )
986

987
        if not isinstance(date, pd.Timestamp):
2✔
988
            date = pd.Timestamp(date)
2✔
989

990
        if seed_override is not None:
2✔
991
            # Override the seed for random sampling
992
            seed = seed_override
×
993
        elif datewise_deterministic:
2✔
994
            # Generate a deterministic seed, based on the date, for random sampling
995
            seed = int(date.strftime("%Y%m%d"))
2✔
996
        else:
997
            # 'Truly' random sampling
998
            seed = None
2✔
999

1000
        task = {}
2✔
1001

1002
        task["time"] = date
2✔
1003
        task["ops"] = []
2✔
1004
        task["X_c"] = []
2✔
1005
        task["Y_c"] = []
2✔
1006
        if target_sampling is not None:
2✔
1007
            task["X_t"] = []
2✔
1008
            task["Y_t"] = []
2✔
1009
        else:
1010
            task["X_t"] = None
2✔
1011
            task["Y_t"] = None
2✔
1012

1013
        context_slices = [
2✔
1014
            self.time_slice_variable(var, date, delta_t)
1015
            for var, delta_t in zip(self.context, self.context_delta_t)
1016
        ]
1017
        target_slices = [
2✔
1018
            self.time_slice_variable(var, date, delta_t)
1019
            for var, delta_t in zip(self.target, self.target_delta_t)
1020
        ]
1021

1022
        # TODO move to method
1023
        if (
2✔
1024
            self.links is not None
1025
            and "split" in context_sampling
1026
            and "split" in target_sampling
1027
        ):
1028
            # Perform the split sampling strategy for linked context and target sets at this point
1029
            # while we have the full context and target data in scope
1030

1031
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1032
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1033
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1034
                f"Number of context sets with 'split' sampling strategy "
1035
                f"({len(context_split_idxs)}) must match number of target sets "
1036
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1037
            )
1038
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1039
                zip(context_split_idxs, target_split_idxs)
1040
            ):
1041
                assert (context_idx, target_idx) in self.links, (
2✔
1042
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1043
                    f"with the `links` attribute if using the 'split' sampling strategy"
1044
                )
1045

1046
                context_var = context_slices[context_idx]
2✔
1047
                target_var = target_slices[target_idx]
2✔
1048

1049
                for var in [context_var, target_var]:
2✔
1050
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1051
                        f"If using 'split' sampling strategy for linked context and target sets, "
1052
                        f"the context and target sets must be pandas DataFrames or Series, "
1053
                        f"but got {type(var)}."
1054
                    )
1055

1056
                N_obs = len(context_var)
2✔
1057
                N_obs_target_check = len(target_var)
2✔
1058
                if N_obs != N_obs_target_check:
2✔
1059
                    raise ValueError(
×
1060
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1061
                        f"because they have different numbers of observations: "
1062
                        f"{N_obs} and {N_obs_target_check}"
1063
                    )
1064
                split_seed = seed + split_i if seed is not None else None
2✔
1065
                rng = np.random.default_rng(split_seed)
2✔
1066

1067
                N_context = int(N_obs * split_frac)
2✔
1068
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1069

1070
                context_var = context_var.iloc[idxs_context]
2✔
1071
                target_var = target_var.drop(context_var.index)
2✔
1072

1073
                context_slices[context_idx] = context_var
2✔
1074
                target_slices[target_idx] = target_var
2✔
1075

1076
        # TODO move to method
1077
        if (
2✔
1078
            self.links is not None
1079
            and "gapfill" in context_sampling
1080
            and "gapfill" in target_sampling
1081
        ):
1082
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1083
            # while we have the full context and target data in scope
1084

1085
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1086
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1087
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1088
                f"Number of context sets with 'gapfill' sampling strategy "
1089
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1090
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1091
            )
1092
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1093
                zip(context_gapfill_idxs, target_gapfill_idxs)
1094
            ):
1095
                assert (context_idx, target_idx) in self.links, (
2✔
1096
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1097
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1098
                )
1099

1100
                context_var = context_slices[context_idx]
2✔
1101
                target_var = target_slices[target_idx]
2✔
1102

1103
                for var in [context_var, target_var]:
2✔
1104
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1105
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1106
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1107
                        f"but got {type(var)}."
1108
                    )
1109

1110
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1111
                rng = np.random.default_rng(split_seed)
2✔
1112

1113
                # Keep trying until we get a target set with at least one target point
1114
                keep_searching = True
2✔
1115
                while keep_searching:
2✔
1116
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1117
                    added_mask = (
2✔
1118
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1119
                    )
1120
                    curr_mask = context_var.isnull()
2✔
1121

1122
                    # Mask out added missing values
1123
                    context_var = context_var.where(~added_mask)
2✔
1124

1125
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1126
                    #   when we could just slice the target values here
1127
                    target_mask = added_mask & ~curr_mask
2✔
1128
                    if isinstance(target_var, xr.Dataset):
2✔
1129
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1130
                    else:
1131
                        keep_searching = np.all(target_mask.data == False)
2✔
1132
                    if keep_searching:
2✔
1133
                        continue  # No target points -- use a different `added_mask`
×
1134

1135
                    target_var = target_var.where(
2✔
1136
                        target_mask
1137
                    )  # Only keep target locations
1138

1139
                    context_slices[context_idx] = context_var
2✔
1140
                    target_slices[target_idx] = target_var
2✔
1141

1142
        for i, (var, sampling_strat) in enumerate(
2✔
1143
            zip(context_slices, context_sampling)
1144
        ):
1145
            context_seed = seed + i if seed is not None else None
2✔
1146
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1147
            task[f"X_c"].append(X_c)
2✔
1148
            task[f"Y_c"].append(Y_c)
2✔
1149
        if target_sampling is not None:
2✔
1150
            for j, (var, sampling_strat) in enumerate(
2✔
1151
                zip(target_slices, target_sampling)
1152
            ):
1153
                target_seed = seed + i + j if seed is not None else None
2✔
1154
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1155
                task[f"X_t"].append(X_t)
2✔
1156
                task[f"Y_t"].append(Y_t)
2✔
1157

1158
        if self.aux_at_contexts is not None:
2✔
1159
            # Add auxiliary variable sampled at context set as a new context variable
1160
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1161
            if len(X_c_offgrid) == 0:
2✔
1162
                # No offgrid context sets
1163
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1164
            else:
1165
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1166
            Y_c_aux = (
2✔
1167
                self.sample_offgrid_aux(
1168
                    X_c_offrid_all,
1169
                    self.time_slice_variable(self.aux_at_contexts, date),
1170
                ),
1171
            )
1172
            task["X_c"].append(X_c_offrid_all)
2✔
1173
            task["Y_c"].append(Y_c_aux)
2✔
1174

1175
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1176
            task["Y_t_aux"] = None
2✔
1177
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1178
            # Add auxiliary variable to target set
1179
            if len(task["X_t"]) > 1:
2✔
1180
                raise ValueError(
×
1181
                    "Cannot add auxiliary variable to target set when there "
1182
                    "are multiple target variables (not supported by default `ConvNP` model)."
1183
                )
1184
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1185
                task["X_t"][0],
1186
                self.time_slice_variable(self.aux_at_targets, date),
1187
            )
1188

1189
        return Task(task)
2✔
1190

1191
    def __call__(
2✔
1192
        self,
1193
        date: pd.Timestamp,
1194
        context_sampling: Union[
1195
            str,
1196
            int,
1197
            float,
1198
            np.ndarray,
1199
            List[Union[str, int, float, np.ndarray]],
1200
        ] = "all",
1201
        target_sampling: Optional[
1202
            Union[
1203
                str,
1204
                int,
1205
                float,
1206
                np.ndarray,
1207
                List[Union[str, int, float, np.ndarray]],
1208
            ]
1209
        ] = None,
1210
        split_frac: float = 0.5,
1211
        datewise_deterministic: bool = False,
1212
        seed_override: Optional[int] = None,
1213
    ) -> Union[Task, List[Task]]:
1214
        """Generate a task for a given date (or a list of
1215
        :class:`.data.task.Task` objects for a list of dates).
1216

1217
        There are several sampling strategies available for the context and
1218
        target data:
1219

1220
            - "all": Sample all observations.
1221
            - int: Sample N observations uniformly at random.
1222
            - float: Sample a fraction of observations uniformly at random.
1223
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1224
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1225
                normalised.
1226
            - "split": Split pandas observations into disjoint context and target sets.
1227
                `split_frac` determines the fraction of observations
1228
                to use for the context set. The remaining observations are used
1229
                for the target set.
1230
                The context set and target set must be linked through the ``TaskLoader``
1231
                ``links`` argument. Only valid for pandas data.
1232
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1233
                Randomly samples a missing data (NaN) mask from another timestamp and
1234
                adds it to the context set (i.e. increases the number of NaNs).
1235
                The target set is then true values of the data at the added missing locations.
1236
                The context set and target set must be linked through the ``TaskLoader``
1237
                ``links`` argument. Only valid for xarray data.
1238

1239
        Args:
1240
            date (:class:`pandas.Timestamp`):
1241
                Date for which to generate the task.
1242
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1243
                Sampling strategy for the context data, either a list of
1244
                sampling strategies for each context set, or a single strategy
1245
                applied to all context sets. Default is ``"all"``.
1246
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1247
                Sampling strategy for the target data, either a list of
1248
                sampling strategies for each target set, or a single strategy
1249
                applied to all target sets. Default is ``None``, meaning no target
1250
                data is returned.
1251
            split_frac (float, optional):
1252
                The fraction of observations to use for the context set with
1253
                the "split" sampling strategy for linked context and target set
1254
                pairs. The remaining observations are used for the target set.
1255
                Default is 0.5.
1256
            datewise_deterministic (bool, optional):
1257
                Whether random sampling is datewise deterministic based on the
1258
                date. Default is ``False``.
1259
            seed_override (Optional[int], optional):
1260
                Override the seed for random sampling. This can be used to use
1261
                the same random sampling at different ``date``. Default is
1262
                None.
1263

1264
        Returns:
1265
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1266
                Task object or list of task objects for each date containing
1267
                the context and target data.
1268
        """
1269
        if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1270
            return [
2✔
1271
                self.task_generation(
1272
                    d,
1273
                    context_sampling,
1274
                    target_sampling,
1275
                    split_frac,
1276
                    datewise_deterministic,
1277
                    seed_override,
1278
                )
1279
                for d in date
1280
            ]
1281
        else:
1282
            return self.task_generation(
2✔
1283
                date,
1284
                context_sampling,
1285
                target_sampling,
1286
                split_frac,
1287
                datewise_deterministic,
1288
                seed_override,
1289
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc