• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14112818194

27 Mar 2025 05:21PM UTC coverage: 81.663% (+0.04%) from 81.626%
14112818194

Pull #150

github

web-flow
Merge 55d674cb5 into 38ec5ef26
Pull Request #150: BUG FIX: extent = north_america instead of usa

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.14
/deepsensor/data/loader.py
1
import os
2✔
2
import json
2✔
3
import copy
2✔
4

5
import numpy as np
2✔
6
import xarray as xr
2✔
7
import pandas as pd
2✔
8

9
from typing import List, Tuple, Union, Optional
2✔
10

11
from deepsensor.data.task import Task
2✔
12
from deepsensor.errors import InvalidSamplingStrategyError, SamplingTooManyPointsError
2✔
13

14

15
class TaskLoader:
2✔
16
    """Generates :class:`~.data.task.Task` objects for training, testing, and inference with DeepSensor models.
17

18
    Provides a suite of sampling methods for generating :class:`~.data.task.Task` objects for different kinds of
19
    predictions, such as: spatial interpolation, forecasting, downscaling, or some combination
20
    of these.
21

22
    The behaviour is the following:
23
        - If all data passed as paths, load the data and overwrite the paths with the loaded data
24
        - Either all data is passed as paths, or all data is passed as loaded data (else ``ValueError``)
25
        - If all data passed as paths, the TaskLoader can be saved with the ``save`` method
26
          (using config)
27

28
    Args:
29
        task_loader_ID:
30
            If loading a TaskLoader from a config file, this is the folder the
31
            TaskLoader was saved in (using `.save`). If this argument is passed, all other
32
            arguments are ignored.
33
        context (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
34
            Context data. Can be a single :class:`xarray.DataArray`,
35
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
36
            list/tuple of these.
37
        target (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset`, :class:`pandas.DataFrame`]):
38
            Target data. Can be a single :class:`xarray.DataArray`,
39
            :class:`xarray.Dataset` or :class:`pandas.DataFrame`, or a
40
            list/tuple of these.
41
        aux_at_contexts (Tuple[int, :class:`xarray.DataArray` | :class:`xarray.Dataset`], optional):
42
            Auxiliary data at context locations. Tuple of two elements, where
43
            the first element is the index of the context set for which the
44
            auxiliary data will be sampled at, and the second element is the
45
            auxiliary data, which can be a single :class:`xarray.DataArray` or
46
            :class:`xarray.Dataset`. Default: None.
47
        aux_at_targets (:class:`xarray.DataArray` | :class:`xarray.Dataset`, optional):
48
            Auxiliary data at target locations. Can be a single
49
            :class:`xarray.DataArray` or :class:`xarray.Dataset`. Default:
50
            None.
51
        links (Tuple[int, int] | List[Tuple[int, int]], optional):
52
            Specifies links between context and target data. Each link is a
53
            tuple of two integers, where the first integer is the index of the
54
            context data and the second integer is the index of the target
55
            data. Can be a single tuple in the case of a single link. If None,
56
            no links are specified. Default: None.
57
        context_delta_t (int | List[int], optional):
58
            Time difference between context data and t=0 (task init time). Can
59
            be a single int (same for all context data) or a list/tuple of
60
            ints. Default is 0.
61
        target_delta_t (int | List[int], optional):
62
            Time difference between target data and t=0 (task init time). Can
63
            be a single int (same for all target data) or a list/tuple of ints.
64
            Default is 0.
65
        time_freq (str, optional):
66
            Time frequency of the data. Default: ``'D'`` (daily).
67
        xarray_interp_method (str, optional):
68
            Interpolation method to use when interpolating
69
            :class:`xarray.DataArray`. Default is ``'linear'``.
70
        discrete_xarray_sampling (bool, optional):
71
            When randomly sampling xarray variables, whether to sample at
72
            discrete points defined at grid cell centres, or at continuous
73
            points within the grid. Default is ``False``.
74
        dtype (object, optional):
75
            Data type of the data. Used to cast the data to the specified
76
            dtype. Default: ``'float32'``.
77
    """
78

79
    config_fname = "task_loader_config.json"
2✔
80

81
    def __init__(
2✔
82
        self,
83
        task_loader_ID: Union[str, None] = None,
84
        context: Union[
85
            xr.DataArray,
86
            xr.Dataset,
87
            pd.DataFrame,
88
            str,
89
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
90
        ] = None,
91
        target: Union[
92
            xr.DataArray,
93
            xr.Dataset,
94
            pd.DataFrame,
95
            str,
96
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
97
        ] = None,
98
        aux_at_contexts: Optional[Tuple[int, Union[xr.DataArray, xr.Dataset]]] = None,
99
        aux_at_targets: Optional[
100
            Union[
101
                xr.DataArray,
102
                xr.Dataset,
103
            ]
104
        ] = None,
105
        links: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = None,
106
        context_delta_t: Union[int, List[int]] = 0,
107
        target_delta_t: Union[int, List[int]] = 0,
108
        time_freq: str = "D",
109
        xarray_interp_method: str = "linear",
110
        discrete_xarray_sampling: bool = False,
111
        dtype: object = "float32",
112
    ) -> None:
113
        if task_loader_ID is not None:
2✔
114
            self.task_loader_ID = task_loader_ID
2✔
115
            # Load TaskLoader from config file
116
            fpath = os.path.join(task_loader_ID, self.config_fname)
2✔
117
            with open(fpath, "r") as f:
2✔
118
                self.config = json.load(f)
2✔
119

120
            self.context = self.config["context"]
2✔
121
            self.target = self.config["target"]
2✔
122
            self.aux_at_contexts = self.config["aux_at_contexts"]
2✔
123
            self.aux_at_targets = self.config["aux_at_targets"]
2✔
124
            self.links = self.config["links"]
2✔
125
            if self.links is not None:
2✔
126
                self.links = [tuple(link) for link in self.links]
2✔
127
            self.context_delta_t = self.config["context_delta_t"]
2✔
128
            self.target_delta_t = self.config["target_delta_t"]
2✔
129
            self.time_freq = self.config["time_freq"]
2✔
130
            self.xarray_interp_method = self.config["xarray_interp_method"]
2✔
131
            self.discrete_xarray_sampling = self.config["discrete_xarray_sampling"]
2✔
132
            self.dtype = self.config["dtype"]
2✔
133
        else:
134
            self.context = context
2✔
135
            self.target = target
2✔
136
            self.aux_at_contexts = aux_at_contexts
2✔
137
            self.aux_at_targets = aux_at_targets
2✔
138
            self.links = links
2✔
139
            self.context_delta_t = context_delta_t
2✔
140
            self.target_delta_t = target_delta_t
2✔
141
            self.time_freq = time_freq
2✔
142
            self.xarray_interp_method = xarray_interp_method
2✔
143
            self.discrete_xarray_sampling = discrete_xarray_sampling
2✔
144
            self.dtype = dtype
2✔
145

146
        if not isinstance(self.context, (tuple, list)):
2✔
147
            self.context = (self.context,)
2✔
148
        if not isinstance(self.target, (tuple, list)):
2✔
149
            self.target = (self.target,)
2✔
150

151
        if isinstance(self.context_delta_t, int):
2✔
152
            self.context_delta_t = (self.context_delta_t,) * len(self.context)
2✔
153
        else:
154
            assert len(self.context_delta_t) == len(self.context), (
2✔
155
                f"Length of context_delta_t ({len(self.context_delta_t)}) must be the same as "
156
                f"the number of context sets ({len(self.context)})"
157
            )
158
        if isinstance(self.target_delta_t, int):
2✔
159
            self.target_delta_t = (self.target_delta_t,) * len(self.target)
2✔
160
        else:
161
            assert len(self.target_delta_t) == len(self.target), (
2✔
162
                f"Length of target_delta_t ({len(self.target_delta_t)}) must be the same as "
163
                f"the number of target sets ({len(self.target)})"
164
            )
165

166
        all_paths = self._check_if_all_data_passed_as_paths()
2✔
167
        if all_paths:
2✔
168
            self._set_config()
2✔
169
            self._load_data_from_paths()
2✔
170

171
        self.context = self._cast_to_dtype(self.context)
2✔
172
        self.target = self._cast_to_dtype(self.target)
2✔
173
        self.aux_at_contexts = self._cast_to_dtype(self.aux_at_contexts)
2✔
174
        self.aux_at_targets = self._cast_to_dtype(self.aux_at_targets)
2✔
175

176
        self.links = self._check_links(self.links)
2✔
177

178
        (
2✔
179
            self.context_dims,
180
            self.target_dims,
181
            self.aux_at_target_dims,
182
        ) = self.count_context_and_target_data_dims()
183
        (
2✔
184
            self.context_var_IDs,
185
            self.target_var_IDs,
186
            self.context_var_IDs_and_delta_t,
187
            self.target_var_IDs_and_delta_t,
188
            self.aux_at_target_var_IDs,
189
        ) = self.infer_context_and_target_var_IDs()
190

191
    def _set_config(self):
2✔
192
        """Instantiate a config dictionary for the TaskLoader object."""
193
        # Take deepcopy to avoid modifying the original config
194
        self.config = copy.deepcopy(
2✔
195
            dict(
196
                context=self.context,
197
                target=self.target,
198
                aux_at_contexts=self.aux_at_contexts,
199
                aux_at_targets=self.aux_at_targets,
200
                links=self.links,
201
                context_delta_t=self.context_delta_t,
202
                target_delta_t=self.target_delta_t,
203
                time_freq=self.time_freq,
204
                xarray_interp_method=self.xarray_interp_method,
205
                discrete_xarray_sampling=self.discrete_xarray_sampling,
206
                dtype=self.dtype,
207
            )
208
        )
209

210
    def _check_if_all_data_passed_as_paths(self) -> bool:
2✔
211
        """If all data passed as paths, save paths to config and return True."""
212

213
        def _check_if_strings(x, mode="all"):
2✔
214
            if x is None:
2✔
215
                return None
2✔
216
            elif isinstance(x, (tuple, list)):
2✔
217
                if mode == "all":
2✔
218
                    return all([isinstance(x_i, str) for x_i in x])
2✔
219
                elif mode == "any":
2✔
220
                    return any([isinstance(x_i, str) for x_i in x])
2✔
221
            else:
222
                return isinstance(x, str)
2✔
223

224
        all_paths = all(
2✔
225
            filter(
226
                lambda x: x is not None,
227
                [
228
                    _check_if_strings(self.context),
229
                    _check_if_strings(self.target),
230
                    _check_if_strings(self.aux_at_contexts),
231
                    _check_if_strings(self.aux_at_targets),
232
                ],
233
            )
234
        )
235
        self._is_saveable = all_paths
2✔
236

237
        any_paths = any(
2✔
238
            filter(
239
                lambda x: x is not None,
240
                [
241
                    _check_if_strings(self.context, mode="any"),
242
                    _check_if_strings(self.target, mode="any"),
243
                    _check_if_strings(self.aux_at_contexts, mode="any"),
244
                    _check_if_strings(self.aux_at_targets, mode="any"),
245
                ],
246
            )
247
        )
248
        if any_paths and not all_paths:
2✔
249
            raise ValueError(
×
250
                "Data must be passed either all as paths or all as xarray/pandas objects (not a mix)."
251
            )
252

253
        return all_paths
2✔
254

255
    def _load_data_from_paths(self):
2✔
256
        """Load data from paths and overwrite paths with loaded data."""
257
        loaded_data = {}
2✔
258

259
        def _load_pandas_or_xarray(path):
2✔
260
            # Need to be careful about this. We need to ensure data gets into the right form
261
            #  for TaskLoader.
262
            if path is None:
2✔
263
                return None
2✔
264
            elif path in loaded_data:
2✔
265
                return loaded_data[path]
2✔
266
            elif path.endswith(".nc"):
2✔
267
                data = xr.open_dataset(path)
2✔
268
            elif path.endswith(".csv"):
2✔
269
                df = pd.read_csv(path)
2✔
270
                if "time" in df.columns:
2✔
271
                    df["time"] = pd.to_datetime(df["time"])
2✔
272
                    df = df.set_index(["time", "x1", "x2"]).sort_index()
2✔
273
                else:
274
                    df = df.set_index(["x1", "x2"]).sort_index()
×
275
                data = df
2✔
276
            else:
277
                raise ValueError(f"Unknown file extension for {path}")
×
278
            loaded_data[path] = data
2✔
279
            return data
2✔
280

281
        def _load_data(data):
2✔
282
            if isinstance(data, (tuple, list)):
2✔
283
                data = tuple([_load_pandas_or_xarray(data_i) for data_i in data])
2✔
284
            else:
285
                data = _load_pandas_or_xarray(data)
2✔
286
            return data
2✔
287

288
        self.context = _load_data(self.context)
2✔
289
        self.target = _load_data(self.target)
2✔
290
        self.aux_at_contexts = _load_data(self.aux_at_contexts)
2✔
291
        self.aux_at_targets = _load_data(self.aux_at_targets)
2✔
292

293
    def save(self, folder: str):
2✔
294
        """Save TaskLoader config to JSON in `folder`."""
295
        if not self._is_saveable:
2✔
296
            raise ValueError(
2✔
297
                "TaskLoader cannot be saved because not all data was passed as paths."
298
            )
299

300
        os.makedirs(folder, exist_ok=True)
2✔
301
        fpath = os.path.join(folder, self.config_fname)
2✔
302
        with open(fpath, "w") as f:
2✔
303
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
304

305
    def _cast_to_dtype(
2✔
306
        self,
307
        var: Union[
308
            xr.DataArray,
309
            xr.Dataset,
310
            pd.DataFrame,
311
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame, str]],
312
        ],
313
    ) -> (List, List):
314
        """Cast context and target data to the default dtype.
315

316
        ..
317
            TODO unit test this by passing in a variety of data types and
318
            checking that they are cast correctly.
319

320
        Args:
321
            var : ...
322
                ...
323

324
        Returns:
325
            tuple: Tuple of context data with specified dtype.
326
            tuple: Tuple of target data with specified dtype.
327
        """
328

329
        def cast_to_dtype(var):
2✔
330
            if isinstance(var, xr.DataArray):
2✔
331
                var = var.astype(self.dtype)
2✔
332
                var["x1"] = var["x1"].astype(self.dtype)
2✔
333
                var["x2"] = var["x2"].astype(self.dtype)
2✔
334
            elif isinstance(var, xr.Dataset):
2✔
335
                var = var.astype(self.dtype)
2✔
336
                var["x1"] = var["x1"].astype(self.dtype)
2✔
337
                var["x2"] = var["x2"].astype(self.dtype)
2✔
338
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
339
                var = var.astype(self.dtype)
2✔
340
                # Note: Numeric pandas indexes are always cast to float64, so we have to cast
341
                #   x1/x2 coord dtypes during task sampling
342
            else:
343
                raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
344
            return var
2✔
345

346
        if var is None:
2✔
347
            return var
2✔
348
        elif isinstance(var, (tuple, list)):
2✔
349
            var = tuple([cast_to_dtype(var_i) for var_i in var])
2✔
350
        else:
351
            var = cast_to_dtype(var)
2✔
352

353
        return var
2✔
354

355
    def load_dask(self) -> None:
2✔
356
        """Load any `dask` data into memory.
357

358
        This function triggers the computation and loading of any data that
359
        is represented as dask arrays or datasets into memory.
360

361
        Returns:
362
            None
363
        """
364

365
        def load(datasets):
2✔
366
            if datasets is None:
2✔
367
                return
×
368
            if not isinstance(datasets, (tuple, list)):
2✔
369
                datasets = [datasets]
2✔
370
            for i, var in enumerate(datasets):
2✔
371
                if isinstance(var, (xr.DataArray, xr.Dataset)):
2✔
372
                    var = var.load()
2✔
373

374
        load(self.context)
2✔
375
        load(self.target)
2✔
376
        load(self.aux_at_contexts)
2✔
377
        load(self.aux_at_targets)
2✔
378

379
        return None
2✔
380

381
    def count_context_and_target_data_dims(self):
2✔
382
        """Count the number of data dimensions in the context and target data.
383

384
        Returns:
385
            tuple: context_dims, Tuple of data dimensions in the context data.
386
            tuple: target_dims, Tuple of data dimensions in the target data.
387

388
        Raises:
389
            ValueError: If the context/target data is not a tuple/list of
390
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
391
                        :class:`pandas.DataFrame`.
392
        """
393

394
        def count_data_dims_of_tuple_of_sets(datasets):
2✔
395
            if not isinstance(datasets, (tuple, list)):
2✔
396
                datasets = [datasets]
2✔
397

398
            dims = []
2✔
399
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
400
            for i, var in enumerate(datasets):
2✔
401
                if isinstance(var, xr.Dataset):
2✔
402
                    dim = len(var.data_vars)  # Multiple data variables
2✔
403
                elif isinstance(var, xr.DataArray):
2✔
404
                    dim = 1  # Single data variable
2✔
405
                elif isinstance(var, pd.DataFrame):
2✔
406
                    dim = len(var.columns)  # Assumes all columns are data variables
2✔
407
                elif isinstance(var, pd.Series):
2✔
408
                    dim = 1  # Single data variable
2✔
409
                else:
410
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
411
                dims.append(dim)
2✔
412
            return dims
2✔
413

414
        context_dims = count_data_dims_of_tuple_of_sets(self.context)
2✔
415
        target_dims = count_data_dims_of_tuple_of_sets(self.target)
2✔
416
        if self.aux_at_contexts is not None:
2✔
417
            context_dims += count_data_dims_of_tuple_of_sets(self.aux_at_contexts)
2✔
418
        if self.aux_at_targets is not None:
2✔
419
            aux_at_target_dims = count_data_dims_of_tuple_of_sets(self.aux_at_targets)[
2✔
420
                0
421
            ]
422
        else:
423
            aux_at_target_dims = 0
2✔
424

425
        return tuple(context_dims), tuple(target_dims), aux_at_target_dims
2✔
426

427
    def infer_context_and_target_var_IDs(self):
2✔
428
        """Infer the variable IDs of the context and target data.
429

430
        Returns:
431
            tuple: context_var_IDs, Tuple of variable IDs in the context data.
432
            tuple: target_var_IDs, Tuple of variable IDs in the target data.
433

434
        Raises:
435
            ValueError: If the context/target data is not a tuple/list of
436
                        :class:`xarray.DataArray`, :class:`xarray.Dataset` or
437
                        :class:`pandas.DataFrame`.
438
        """
439

440
        def infer_var_IDs_of_tuple_of_sets(datasets, delta_ts=None):
2✔
441
            """If delta_ts is not None, then add the delta_t to the variable ID."""
442
            if not isinstance(datasets, (tuple, list)):
2✔
443
                datasets = [datasets]
2✔
444

445
            var_IDs = []
2✔
446
            # Distinguish between xr.DataArray, xr.Dataset and pd.DataFrame
447
            for i, var in enumerate(datasets):
2✔
448
                if isinstance(var, xr.DataArray):
2✔
449
                    var_ID = (var.name,)  # Single data variable
2✔
450
                elif isinstance(var, xr.Dataset):
2✔
451
                    var_ID = tuple(var.data_vars.keys())  # Multiple data variables
2✔
452
                elif isinstance(var, pd.DataFrame):
2✔
453
                    var_ID = tuple(var.columns)
2✔
454
                elif isinstance(var, pd.Series):
2✔
455
                    var_ID = (var.name,)
2✔
456
                else:
457
                    raise ValueError(f"Unknown type {type(var)} for context set {var}")
×
458

459
                if delta_ts is not None:
2✔
460
                    # Add delta_t to the variable ID
461
                    var_ID = tuple(
2✔
462
                        [f"{var_ID_i}_t{delta_ts[i]}" for var_ID_i in var_ID]
463
                    )
464
                else:
465
                    var_ID = tuple([f"{var_ID_i}" for var_ID_i in var_ID])
2✔
466

467
                var_IDs.append(var_ID)
2✔
468

469
            return var_IDs
2✔
470

471
        context_var_IDs = infer_var_IDs_of_tuple_of_sets(self.context)
2✔
472
        context_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
473
            self.context, self.context_delta_t
474
        )
475
        target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.target)
2✔
476
        target_var_IDs_and_delta_t = infer_var_IDs_of_tuple_of_sets(
2✔
477
            self.target, self.target_delta_t
478
        )
479

480
        if self.aux_at_contexts is not None:
2✔
481
            context_var_IDs += infer_var_IDs_of_tuple_of_sets(self.aux_at_contexts)
2✔
482
            context_var_IDs_and_delta_t += infer_var_IDs_of_tuple_of_sets(
2✔
483
                self.aux_at_contexts, [0]
484
            )
485

486
        if self.aux_at_targets is not None:
2✔
487
            aux_at_target_var_IDs = infer_var_IDs_of_tuple_of_sets(self.aux_at_targets)[
2✔
488
                0
489
            ]
490
        else:
491
            aux_at_target_var_IDs = None
2✔
492

493
        return (
2✔
494
            tuple(context_var_IDs),
495
            tuple(target_var_IDs),
496
            tuple(context_var_IDs_and_delta_t),
497
            tuple(target_var_IDs_and_delta_t),
498
            aux_at_target_var_IDs,
499
        )
500

501
    def _check_links(self, links: Union[Tuple[int, int], List[Tuple[int, int]]]):
2✔
502
        """Check that the context-target links are valid.
503

504
        Args:
505
            links (Tuple[int, int] | List[Tuple[int, int]]):
506
                Specifies links between context and target data. Each link is a
507
                tuple of two integers, where the first integer is the index of
508
                the context data and the second integer is the index of the
509
                target data. Can be a single tuple in the case of a single
510
                link. If None, no links are specified. Default: None.
511

512
        Returns:
513
            Tuple[int, int] | List[Tuple[int, int]]
514
                The input links, if valid.
515

516
        Raises:
517
            ValueError
518
                If the links are not valid.
519
        """
520
        if links is None:
2✔
521
            return None
2✔
522

523
        assert isinstance(
2✔
524
            links, list
525
        ), f"Links must be a list of length-2 tuples, but got {type(links)}"
526
        assert len(links) > 0, "If links is not None, it must be a non-empty list"
2✔
527
        assert all(
2✔
528
            isinstance(link, tuple) for link in links
529
        ), f"Links must be a list of tuples, but got {[type(link) for link in links]}"
530
        assert all(
2✔
531
            len(link) == 2 for link in links
532
        ), f"Links must be a list of length-2 tuples, but got lengths {[len(link) for link in links]}"
533

534
        # Check that the links are valid
535
        for link_i, (context_idx, target_idx) in enumerate(links):
2✔
536
            if context_idx >= len(self.context):
2✔
537
                raise ValueError(
×
538
                    f"Invalid context index {context_idx} in link {link_i} of {links}: "
539
                    f"there are only {len(self.context)} context sets"
540
                )
541
            if target_idx >= len(self.target):
2✔
542
                raise ValueError(
2✔
543
                    f"Invalid target index {target_idx} in link {link_i} of {links}: "
544
                    f"there are only {len(self.target)} target sets"
545
                )
546

547
        return links
2✔
548

549
    def __str__(self):
2✔
550
        """String representation of the TaskLoader object (user-friendly)."""
551
        s = f"TaskLoader({len(self.context_dims)} context sets, {len(self.target_dims)} target sets)"
×
552
        s += f"\nContext variable IDs: {self.context_var_IDs}"
×
553
        s += f"\nTarget variable IDs: {self.target_var_IDs}"
×
554
        if self.aux_at_targets is not None:
×
555
            s += f"\nAuxiliary-at-target variable IDs: {self.aux_at_target_var_IDs}"
×
556
        return s
×
557

558
    def __repr__(self):
2✔
559
        """Representation of the TaskLoader object (for developers).
560

561
        ..
562
            TODO make this a more verbose version of __str__
563
        """
564
        s = str(self)
×
565
        s += "\n"
×
566
        s += f"\nContext data dimensions: {self.context_dims}"
×
567
        s += f"\nTarget data dimensions: {self.target_dims}"
×
568
        if self.aux_at_targets is not None:
×
569
            s += f"\nAuxiliary-at-target data dimensions: {self.aux_at_target_dims}"
×
570
        return s
×
571

572
    def sample_da(
2✔
573
        self,
574
        da: Union[xr.DataArray, xr.Dataset],
575
        sampling_strat: Union[str, int, float, np.ndarray],
576
        seed: Optional[int] = None,
577
    ) -> (np.ndarray, np.ndarray):
578
        """Sample a DataArray according to a given strategy.
579

580
        Args:
581
            da (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
582
                DataArray to sample, assumed to be sliced for the task already.
583
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
584
                Sampling strategy, either "all" or an integer for random grid
585
                cell sampling.
586
            seed (int, optional):
587
                Seed for random sampling. Default is None.
588

589
        Returns:
590
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
591
                Tuple of sampled target data and sampled context data.
592

593
        Raises:
594
            InvalidSamplingStrategyError:
595
                If the sampling strategy is not valid or if a numpy coordinate
596
                array is passed to sample an xarray object, but the coordinates
597
                are out of bounds.
598
        """
599
        da = da.load()  # Converts dask -> numpy if not already loaded
2✔
600
        if isinstance(da, xr.Dataset):
2✔
601
            da = da.to_array()
2✔
602

603
        if isinstance(sampling_strat, float):
2✔
604
            sampling_strat = int(sampling_strat * da.size)
2✔
605

606
        if isinstance(sampling_strat, (int, np.integer)):
2✔
607
            N = sampling_strat
2✔
608
            rng = np.random.default_rng(seed)
2✔
609
            if self.discrete_xarray_sampling:
2✔
610
                x1 = rng.choice(da.coords["x1"].values, N, replace=True)
×
611
                x2 = rng.choice(da.coords["x2"].values, N, replace=True)
×
612
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2)).data
×
613
            elif not self.discrete_xarray_sampling:
2✔
614
                if N == 0:
2✔
615
                    # Catch zero-context edge case before interp fails
616
                    X_c = np.zeros((2, 0), dtype=self.dtype)
2✔
617
                    dim = da.shape[0] if da.ndim == 3 else 1
2✔
618
                    Y_c = np.zeros((dim, 0), dtype=self.dtype)
2✔
619
                    return X_c, Y_c
2✔
620
                x1 = rng.uniform(da.coords["x1"].min(), da.coords["x1"].max(), N)
2✔
621
                x2 = rng.uniform(da.coords["x2"].min(), da.coords["x2"].max(), N)
2✔
622
                Y_c = da.sel(x1=xr.DataArray(x1), x2=xr.DataArray(x2), method="nearest")
2✔
623
                Y_c = np.array(Y_c, dtype=self.dtype)
2✔
624
            X_c = np.array([x1, x2], dtype=self.dtype)
2✔
625

626
        elif isinstance(sampling_strat, np.ndarray):
2✔
627
            X_c = sampling_strat.astype(self.dtype)
2✔
628
            try:
2✔
629
                Y_c = da.sel(
2✔
630
                    x1=xr.DataArray(X_c[0]),
631
                    x2=xr.DataArray(X_c[1]),
632
                    method="nearest",
633
                    tolerance=0.1,  # Maximum distance from observed point to sample
634
                )
635
            except KeyError:
2✔
636
                raise InvalidSamplingStrategyError(
2✔
637
                    f"Passed a numpy coordinate array to sample xarray object, "
638
                    f"but the coordinates are out of bounds."
639
                )
640
            Y_c = np.array(Y_c, dtype=self.dtype)
2✔
641

642
        elif sampling_strat in ["all", "gapfill"]:
2✔
643
            X_c = (
2✔
644
                da.coords["x1"].values[np.newaxis],
645
                da.coords["x2"].values[np.newaxis],
646
            )
647
            Y_c = da.data
2✔
648
            if Y_c.ndim == 2:
2✔
649
                # returned a 2D array, but we need a 3D array of shape (variable, N_x1, N_x2)
650
                Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
651
        else:
652
            raise InvalidSamplingStrategyError(
×
653
                f"Unknown sampling strategy {sampling_strat}"
654
            )
655

656
        if Y_c.ndim == 1:
2✔
657
            # returned a 1D array, but we need a 2D array of shape (variable, N)
658
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
659

660
        return X_c, Y_c
2✔
661

662
    def sample_df(
2✔
663
        self,
664
        df: Union[pd.DataFrame, pd.Series],
665
        sampling_strat: Union[str, int, float, np.ndarray],
666
        seed: Optional[int] = None,
667
    ) -> (np.ndarray, np.ndarray):
668
        """Sample a DataFrame according to a given strategy.
669

670
        Args:
671
            df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
672
                Dataframe to sample, assumed to be time-sliced for the task
673
                already.
674
            sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
675
                Sampling strategy, either "all" or an integer for random grid
676
                cell sampling.
677
            seed (int, optional):
678
                Seed for random sampling. Default is None.
679

680
        Returns:
681
            Tuple[X_c, Y_c]:
682
                Tuple of sampled target data and sampled context data.
683

684
        Raises:
685
            InvalidSamplingStrategyError:
686
                If the sampling strategy is not valid or if a numpy coordinate
687
                array is passed to sample a pandas object, but the DataFrame
688
                does not contain all the requested samples.
689
        """
690
        df = df.dropna(how="any")  # If any obs are NaN, drop them
2✔
691

692
        if isinstance(sampling_strat, float):
2✔
693
            sampling_strat = int(sampling_strat * df.shape[0])
2✔
694

695
        if isinstance(sampling_strat, (int, np.integer)):
2✔
696
            N = sampling_strat
2✔
697
            rng = np.random.default_rng(seed)
2✔
698
            if N > df.index.size:
2✔
699
                raise SamplingTooManyPointsError(requested=N, available=df.index.size)
2✔
700
            idx = rng.choice(df.index, N, replace=False)
2✔
701
            X_c = df.loc[idx].reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
702
            Y_c = df.loc[idx].values.T
2✔
703
        elif isinstance(sampling_strat, str) and sampling_strat in [
2✔
704
            "all",
705
            "split",
706
        ]:
707
            # NOTE if "split", we assume that the context-target split has already been applied to the df
708
            # in an earlier scope with access to both the context and target data. This is maybe risky!
709
            X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
2✔
710
            Y_c = df.values.T
2✔
711
        elif isinstance(sampling_strat, np.ndarray):
2✔
712
            if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
2✔
713
                raise InvalidSamplingStrategyError(
2✔
714
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
715
                    "but the coordinate array has a different dtype than the DataFrame. "
716
                    f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
717
                )
718
            X_c = sampling_strat.astype(self.dtype)
2✔
719
            try:
2✔
720
                Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
2✔
721
            except KeyError:
2✔
722
                raise InvalidSamplingStrategyError(
2✔
723
                    "Passed a numpy coordinate array to sample pandas DataFrame, "
724
                    "but the DataFrame did not contain all the requested samples.\n"
725
                    f"Indexes: {df.index}\n"
726
                    f"Sampling coords: {X_c}\n"
727
                    "If this is unexpected, check that your numpy sampling array matches "
728
                    "the DataFrame index values *exactly*."
729
                )
730
        else:
731
            raise InvalidSamplingStrategyError(
×
732
                f"Unknown sampling strategy {sampling_strat}"
733
            )
734

735
        if Y_c.ndim == 1:
2✔
736
            # returned a 1D array, but we need a 2D array of shape (variable, N)
737
            Y_c = Y_c.reshape(1, *Y_c.shape)
2✔
738

739
        return X_c, Y_c
2✔
740

741
    def sample_offgrid_aux(
2✔
742
        self,
743
        X_t: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
744
        offgrid_aux: Union[xr.DataArray, xr.Dataset],
745
    ) -> np.ndarray:
746
        """Sample auxiliary data at off-grid locations.
747

748
        Args:
749
            X_t (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
750
                Off-grid locations at which to sample the auxiliary data. Can
751
                be a tuple of two numpy arrays, or a single numpy array.
752
            offgrid_aux (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
753
                Auxiliary data at off-grid locations.
754

755
        Returns:
756
            :class:`numpy:numpy.ndarray`:
757
                [Description of the returned numpy ndarray]
758

759
        Raises:
760
            [ExceptionType]:
761
                [Description of under what conditions this function raises an exception]
762
        """
763
        if "time" in offgrid_aux.dims:
2✔
764
            raise ValueError(
×
765
                "If `aux_at_targets` data has a `time` dimension, it must be sliced before "
766
                "passing it to `sample_offgrid_aux`."
767
            )
768
        if isinstance(X_t, tuple):
2✔
769
            xt1, xt2 = X_t
2✔
770
            xt1 = xt1.ravel()
2✔
771
            xt2 = xt2.ravel()
2✔
772
        else:
773
            xt1, xt2 = xr.DataArray(X_t[0]), xr.DataArray(X_t[1])
2✔
774
        Y_t_aux = offgrid_aux.sel(x1=xt1, x2=xt2, method="nearest")
2✔
775
        if isinstance(Y_t_aux, xr.Dataset):
2✔
776
            Y_t_aux = Y_t_aux.to_array()
×
777
        Y_t_aux = np.array(Y_t_aux, dtype=np.float32)
2✔
778
        if (isinstance(X_t, tuple) and Y_t_aux.ndim == 2) or (
2✔
779
            isinstance(X_t, np.ndarray) and Y_t_aux.ndim == 1
780
        ):
781
            # Reshape to (variable, *spatial_dims)
782
            Y_t_aux = Y_t_aux.reshape(1, *Y_t_aux.shape)
2✔
783
        return Y_t_aux
2✔
784

785
    def time_slice_variable(self, var, date, delta_t=0):
2✔
786
        """Slice a variable by a given time delta.
787

788
        Args:
789
            var (...):
790
                Variable to slice.
791
            delta_t (...):
792
                Time delta to slice by.
793

794
        Returns:
795
            var (...)
796
                Sliced variable.
797

798
        Raises:
799
            ValueError
800
                If the variable is of an unknown type.
801
        """
802
        # TODO: Does this work with instantaneous time?
803
        delta_t = pd.Timedelta(delta_t, unit=self.time_freq)
2✔
804
        if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
805
            if "time" in var.dims:
2✔
806
                var = var.sel(time=date + delta_t)
2✔
807
        elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
808
            if "time" in var.index.names:
2✔
809
                var = var[var.index.get_level_values("time") == date + delta_t]
2✔
810
        else:
811
            raise ValueError(f"Unknown variable type {type(var)}")
×
812
        return var
2✔
813

814
    def task_generation(  # noqa: D102
2✔
815
        self,
816
        date: pd.Timestamp,
817
        context_sampling: Union[
818
            str,
819
            int,
820
            float,
821
            np.ndarray,
822
            List[Union[str, int, float, np.ndarray]],
823
        ] = "all",
824
        target_sampling: Optional[
825
            Union[
826
                str,
827
                int,
828
                float,
829
                np.ndarray,
830
                List[Union[str, int, float, np.ndarray]],
831
            ]
832
        ] = None,
833
        split_frac: float = 0.5,
834
        datewise_deterministic: bool = False,
835
        seed_override: Optional[int] = None,
836
    ) -> Task:
837
        def check_sampling_strat(sampling_strat, set):
2✔
838
            """Check the sampling strategy.
839

840
            Ensure ``sampling_strat`` is either a single strategy (broadcast
841
            to all sets) or a list of length equal to the number of sets.
842
            Convert to a tuple of length equal to the number of sets and
843
            return.
844

845
            Args:
846
                sampling_strat:
847
                    Sampling strategy to check.
848
                set:
849
                    Context or target set to check.
850

851
            Returns:
852
                tuple:
853
                    Tuple of sampling strategies, one for each set.
854

855
            Raises:
856
                InvalidSamplingStrategyError:
857
                    - If the sampling strategy is invalid.
858
                    - If the length of the sampling strategy does not match the number of sets.
859
                    - If the sampling strategy is not a valid type.
860
                    - If the sampling strategy is a float but not in [0, 1].
861
                    - If the sampling strategy is an int but not positive.
862
                    - If the sampling strategy is a numpy array but not of shape (2, N).
863
            """
864
            if sampling_strat is None:
2✔
865
                return None
2✔
866
            if not isinstance(sampling_strat, (list, tuple)):
2✔
867
                sampling_strat = tuple([sampling_strat] * len(set))
2✔
868
            elif isinstance(sampling_strat, (list, tuple)) and len(
2✔
869
                sampling_strat
870
            ) != len(set):
871
                raise InvalidSamplingStrategyError(
2✔
872
                    f"Length of sampling_strat ({len(sampling_strat)}) must "
873
                    f"match number of context sets ({len(set)})"
874
                )
875

876
            for strat in sampling_strat:
2✔
877
                if not isinstance(strat, (str, int, np.integer, float, np.ndarray)):
2✔
878
                    raise InvalidSamplingStrategyError(
2✔
879
                        f"Unknown sampling strategy {strat} of type {type(strat)}"
880
                    )
881
                elif isinstance(strat, str) and strat not in [
2✔
882
                    "all",
883
                    "split",
884
                    "gapfill",
885
                ]:
886
                    raise InvalidSamplingStrategyError(
2✔
887
                        f"Unknown sampling strategy {strat} for type str"
888
                    )
889
                elif isinstance(strat, float) and not 0 <= strat <= 1:
2✔
890
                    raise InvalidSamplingStrategyError(
2✔
891
                        f"If sampling strategy is a float, must be fraction "
892
                        f"must be in [0, 1], got {strat}"
893
                    )
894
                elif isinstance(strat, int) and strat < 0:
2✔
895
                    raise InvalidSamplingStrategyError(
2✔
896
                        f"Sampling N must be positive, got {strat}"
897
                    )
898
                elif isinstance(strat, np.ndarray) and strat.shape[0] != 2:
2✔
899
                    raise InvalidSamplingStrategyError(
2✔
900
                        "Sampling coordinates must be of shape (2, N), got "
901
                        f"{strat.shape}"
902
                    )
903

904
            return sampling_strat
2✔
905

906
        def sample_variable(var, sampling_strat, seed):
2✔
907
            """Sample a variable by a given sampling strategy to get input and
908
            output data.
909

910
            Args:
911
                var:
912
                    Variable to sample.
913
                sampling_strat:
914
                    Sampling strategy to use.
915
                seed:
916
                    Seed for random sampling.
917

918
            Returns:
919
                Tuple[X, Y]:
920
                    Tuple of input and output data.
921

922
            Raises:
923
                ValueError:
924
                    If the variable is of an unknown type.
925
            """
926
            if isinstance(var, (xr.Dataset, xr.DataArray)):
2✔
927
                X, Y = self.sample_da(var, sampling_strat, seed)
2✔
928
            elif isinstance(var, (pd.DataFrame, pd.Series)):
2✔
929
                X, Y = self.sample_df(var, sampling_strat, seed)
2✔
930
            else:
931
                raise ValueError(f"Unknown type {type(var)} for context set " f"{var}")
×
932
            return X, Y
2✔
933

934
        # Check that the sampling strategies are valid
935
        context_sampling = check_sampling_strat(context_sampling, self.context)
2✔
936
        target_sampling = check_sampling_strat(target_sampling, self.target)
2✔
937
        # Check `split_frac
938
        if split_frac < 0 or split_frac > 1:
2✔
939
            raise ValueError(f"split_frac must be between 0 and 1, got {split_frac}")
2✔
940
        if self.links is None:
2✔
941
            b1 = any(
2✔
942
                [
943
                    strat in ["split", "gapfill"]
944
                    for strat in context_sampling
945
                    if isinstance(strat, str)
946
                ]
947
            )
948
            if target_sampling is None:
2✔
949
                b2 = False
2✔
950
            else:
951
                b2 = any(
2✔
952
                    [
953
                        strat in ["split", "gapfill"]
954
                        for strat in target_sampling
955
                        if isinstance(strat, str)
956
                    ]
957
                )
958
            if b1 or b2:
2✔
959
                raise ValueError(
2✔
960
                    "If using 'split' or 'gapfill' sampling strategies, the context and target "
961
                    "sets must be linked with the TaskLoader `links` attribute."
962
                )
963
        if self.links is not None:
2✔
964
            for context_idx, target_idx in self.links:
2✔
965
                context_sampling_i = context_sampling[context_idx]
2✔
966
                if target_sampling is None:
2✔
967
                    target_sampling_i = None
×
968
                else:
969
                    target_sampling_i = target_sampling[target_idx]
2✔
970
                link_strats = (context_sampling_i, target_sampling_i)
2✔
971
                if any(
2✔
972
                    [
973
                        strat in ["split", "gapfill"]
974
                        for strat in link_strats
975
                        if isinstance(strat, str)
976
                    ]
977
                ):
978
                    # If one of the sampling strategies is "split" or "gapfill", the other must
979
                    # use the same splitting strategy
980
                    if link_strats[0] != link_strats[1]:
2✔
981
                        raise ValueError(
2✔
982
                            f"Linked context set {context_idx} and target set {target_idx} "
983
                            f"must use the same sampling strategy if one of them "
984
                            f"uses the 'split' or 'gapfill' sampling strategy. "
985
                            f"Got {link_strats[0]} and {link_strats[1]}."
986
                        )
987

988
        if not isinstance(date, pd.Timestamp):
2✔
989
            date = pd.Timestamp(date)
2✔
990

991
        if seed_override is not None:
2✔
992
            # Override the seed for random sampling
993
            seed = seed_override
×
994
        elif datewise_deterministic:
2✔
995
            # Generate a deterministic seed, based on the date, for random sampling
996
            seed = int(date.strftime("%Y%m%d"))
2✔
997
        else:
998
            # 'Truly' random sampling
999
            seed = None
2✔
1000

1001
        task = {}
2✔
1002

1003
        task["time"] = date
2✔
1004
        task["ops"] = []
2✔
1005
        task["X_c"] = []
2✔
1006
        task["Y_c"] = []
2✔
1007
        if target_sampling is not None:
2✔
1008
            task["X_t"] = []
2✔
1009
            task["Y_t"] = []
2✔
1010
        else:
1011
            task["X_t"] = None
2✔
1012
            task["Y_t"] = None
2✔
1013

1014
        context_slices = [
2✔
1015
            self.time_slice_variable(var, date, delta_t)
1016
            for var, delta_t in zip(self.context, self.context_delta_t)
1017
        ]
1018
        target_slices = [
2✔
1019
            self.time_slice_variable(var, date, delta_t)
1020
            for var, delta_t in zip(self.target, self.target_delta_t)
1021
        ]
1022

1023
        # TODO move to method
1024
        if (
2✔
1025
            self.links is not None
1026
            and "split" in context_sampling
1027
            and "split" in target_sampling
1028
        ):
1029
            # Perform the split sampling strategy for linked context and target sets at this point
1030
            # while we have the full context and target data in scope
1031

1032
            context_split_idxs = np.where(np.array(context_sampling) == "split")[0]
2✔
1033
            target_split_idxs = np.where(np.array(target_sampling) == "split")[0]
2✔
1034
            assert len(context_split_idxs) == len(target_split_idxs), (
2✔
1035
                f"Number of context sets with 'split' sampling strategy "
1036
                f"({len(context_split_idxs)}) must match number of target sets "
1037
                f"with 'split' sampling strategy ({len(target_split_idxs)})"
1038
            )
1039
            for split_i, (context_idx, target_idx) in enumerate(
2✔
1040
                zip(context_split_idxs, target_split_idxs)
1041
            ):
1042
                assert (context_idx, target_idx) in self.links, (
2✔
1043
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1044
                    f"with the `links` attribute if using the 'split' sampling strategy"
1045
                )
1046

1047
                context_var = context_slices[context_idx]
2✔
1048
                target_var = target_slices[target_idx]
2✔
1049

1050
                for var in [context_var, target_var]:
2✔
1051
                    assert isinstance(var, (pd.Series, pd.DataFrame)), (
2✔
1052
                        f"If using 'split' sampling strategy for linked context and target sets, "
1053
                        f"the context and target sets must be pandas DataFrames or Series, "
1054
                        f"but got {type(var)}."
1055
                    )
1056

1057
                N_obs = len(context_var)
2✔
1058
                N_obs_target_check = len(target_var)
2✔
1059
                if N_obs != N_obs_target_check:
2✔
1060
                    raise ValueError(
×
1061
                        f"Cannot split context set {context_idx} and target set {target_idx} "
1062
                        f"because they have different numbers of observations: "
1063
                        f"{N_obs} and {N_obs_target_check}"
1064
                    )
1065
                split_seed = seed + split_i if seed is not None else None
2✔
1066
                rng = np.random.default_rng(split_seed)
2✔
1067

1068
                N_context = int(N_obs * split_frac)
2✔
1069
                idxs_context = rng.choice(N_obs, N_context, replace=False)
2✔
1070

1071
                context_var = context_var.iloc[idxs_context]
2✔
1072
                target_var = target_var.drop(context_var.index)
2✔
1073

1074
                context_slices[context_idx] = context_var
2✔
1075
                target_slices[target_idx] = target_var
2✔
1076

1077
        # TODO move to method
1078
        if (
2✔
1079
            self.links is not None
1080
            and "gapfill" in context_sampling
1081
            and "gapfill" in target_sampling
1082
        ):
1083
            # Perform the gapfill sampling strategy for linked context and target sets at this point
1084
            # while we have the full context and target data in scope
1085

1086
            context_gapfill_idxs = np.where(np.array(context_sampling) == "gapfill")[0]
2✔
1087
            target_gapfill_idxs = np.where(np.array(target_sampling) == "gapfill")[0]
2✔
1088
            assert len(context_gapfill_idxs) == len(target_gapfill_idxs), (
2✔
1089
                f"Number of context sets with 'gapfill' sampling strategy "
1090
                f"({len(context_gapfill_idxs)}) must match number of target sets "
1091
                f"with 'gapfill' sampling strategy ({len(target_gapfill_idxs)})"
1092
            )
1093
            for gapfill_i, (context_idx, target_idx) in enumerate(
2✔
1094
                zip(context_gapfill_idxs, target_gapfill_idxs)
1095
            ):
1096
                assert (context_idx, target_idx) in self.links, (
2✔
1097
                    f"Context set {context_idx} and target set {target_idx} must be linked, "
1098
                    f"with the `links` attribute if using the 'gapfill' sampling strategy"
1099
                )
1100

1101
                context_var = context_slices[context_idx]
2✔
1102
                target_var = target_slices[target_idx]
2✔
1103

1104
                for var in [context_var, target_var]:
2✔
1105
                    assert isinstance(var, (xr.DataArray, xr.Dataset)), (
2✔
1106
                        f"If using 'gapfill' sampling strategy for linked context and target sets, "
1107
                        f"the context and target sets must be xarray DataArrays or Datasets, "
1108
                        f"but got {type(var)}."
1109
                    )
1110

1111
                split_seed = seed + gapfill_i if seed is not None else None
2✔
1112
                rng = np.random.default_rng(split_seed)
2✔
1113

1114
                # Keep trying until we get a target set with at least one target point
1115
                keep_searching = True
2✔
1116
                while keep_searching:
2✔
1117
                    added_mask_date = rng.choice(self.context[context_idx].time)
2✔
1118
                    added_mask = (
2✔
1119
                        self.context[context_idx].sel(time=added_mask_date).isnull()
1120
                    )
1121
                    curr_mask = context_var.isnull()
2✔
1122

1123
                    # Mask out added missing values
1124
                    context_var = context_var.where(~added_mask)
2✔
1125

1126
                    # TEMP: Inefficient to convert all non-targets to NaNs and then remove NaNs
1127
                    #   when we could just slice the target values here
1128
                    target_mask = added_mask & ~curr_mask
2✔
1129
                    if isinstance(target_var, xr.Dataset):
2✔
1130
                        keep_searching = np.all(target_mask.to_array().data == False)
×
1131
                    else:
1132
                        keep_searching = np.all(target_mask.data == False)
2✔
1133
                    if keep_searching:
2✔
1134
                        continue  # No target points -- use a different `added_mask`
×
1135

1136
                    target_var = target_var.where(
2✔
1137
                        target_mask
1138
                    )  # Only keep target locations
1139

1140
                    context_slices[context_idx] = context_var
2✔
1141
                    target_slices[target_idx] = target_var
2✔
1142

1143
        for i, (var, sampling_strat) in enumerate(
2✔
1144
            zip(context_slices, context_sampling)
1145
        ):
1146
            context_seed = seed + i if seed is not None else None
2✔
1147
            X_c, Y_c = sample_variable(var, sampling_strat, context_seed)
2✔
1148
            task[f"X_c"].append(X_c)
2✔
1149
            task[f"Y_c"].append(Y_c)
2✔
1150
        if target_sampling is not None:
2✔
1151
            for j, (var, sampling_strat) in enumerate(
2✔
1152
                zip(target_slices, target_sampling)
1153
            ):
1154
                target_seed = seed + i + j if seed is not None else None
2✔
1155
                X_t, Y_t = sample_variable(var, sampling_strat, target_seed)
2✔
1156
                task[f"X_t"].append(X_t)
2✔
1157
                task[f"Y_t"].append(Y_t)
2✔
1158

1159
        if self.aux_at_contexts is not None:
2✔
1160
            # Add auxiliary variable sampled at context set as a new context variable
1161
            X_c_offgrid = [X_c for X_c in task["X_c"] if not isinstance(X_c, tuple)]
2✔
1162
            if len(X_c_offgrid) == 0:
2✔
1163
                # No offgrid context sets
1164
                X_c_offrid_all = np.empty((2, 0), dtype=self.dtype)
×
1165
            else:
1166
                X_c_offrid_all = np.concatenate(X_c_offgrid, axis=1)
2✔
1167
            Y_c_aux = (
2✔
1168
                self.sample_offgrid_aux(
1169
                    X_c_offrid_all,
1170
                    self.time_slice_variable(self.aux_at_contexts, date),
1171
                ),
1172
            )
1173
            task["X_c"].append(X_c_offrid_all)
2✔
1174
            task["Y_c"].append(Y_c_aux)
2✔
1175

1176
        if self.aux_at_targets is not None and target_sampling is None:
2✔
1177
            task["Y_t_aux"] = None
2✔
1178
        elif self.aux_at_targets is not None and target_sampling is not None:
2✔
1179
            # Add auxiliary variable to target set
1180
            if len(task["X_t"]) > 1:
2✔
1181
                raise ValueError(
×
1182
                    "Cannot add auxiliary variable to target set when there "
1183
                    "are multiple target variables (not supported by default `ConvNP` model)."
1184
                )
1185
            task["Y_t_aux"] = self.sample_offgrid_aux(
2✔
1186
                task["X_t"][0],
1187
                self.time_slice_variable(self.aux_at_targets, date),
1188
            )
1189

1190
        return Task(task)
2✔
1191

1192
    def __call__(
2✔
1193
        self,
1194
        date: pd.Timestamp,
1195
        context_sampling: Union[
1196
            str,
1197
            int,
1198
            float,
1199
            np.ndarray,
1200
            List[Union[str, int, float, np.ndarray]],
1201
        ] = "all",
1202
        target_sampling: Optional[
1203
            Union[
1204
                str,
1205
                int,
1206
                float,
1207
                np.ndarray,
1208
                List[Union[str, int, float, np.ndarray]],
1209
            ]
1210
        ] = None,
1211
        split_frac: float = 0.5,
1212
        datewise_deterministic: bool = False,
1213
        seed_override: Optional[int] = None,
1214
    ) -> Union[Task, List[Task]]:
1215
        """Generate a task for a given date (or a list of
1216
        :class:`.data.task.Task` objects for a list of dates).
1217

1218
        There are several sampling strategies available for the context and
1219
        target data:
1220

1221
            - "all": Sample all observations.
1222
            - int: Sample N observations uniformly at random.
1223
            - float: Sample a fraction of observations uniformly at random.
1224
            - :class:`numpy:numpy.ndarray`, shape (2, N):
1225
                Sample N observations at the given x1, x2 coordinates. Coords are assumed to be
1226
                normalised.
1227
            - "split": Split pandas observations into disjoint context and target sets.
1228
                `split_frac` determines the fraction of observations
1229
                to use for the context set. The remaining observations are used
1230
                for the target set.
1231
                The context set and target set must be linked through the ``TaskLoader``
1232
                ``links`` argument. Only valid for pandas data.
1233
            - "gapfill": Generates a training task for filling NaNs in xarray data.
1234
                Randomly samples a missing data (NaN) mask from another timestamp and
1235
                adds it to the context set (i.e. increases the number of NaNs).
1236
                The target set is then true values of the data at the added missing locations.
1237
                The context set and target set must be linked through the ``TaskLoader``
1238
                ``links`` argument. Only valid for xarray data.
1239

1240
        Args:
1241
            date (:class:`pandas.Timestamp`):
1242
                Date for which to generate the task.
1243
            context_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1244
                Sampling strategy for the context data, either a list of
1245
                sampling strategies for each context set, or a single strategy
1246
                applied to all context sets. Default is ``"all"``.
1247
            target_sampling (str | int | float | :class:`numpy:numpy.ndarray` | List[str | int | float | :class:`numpy:numpy.ndarray`], optional):
1248
                Sampling strategy for the target data, either a list of
1249
                sampling strategies for each target set, or a single strategy
1250
                applied to all target sets. Default is ``None``, meaning no target
1251
                data is returned.
1252
            split_frac (float, optional):
1253
                The fraction of observations to use for the context set with
1254
                the "split" sampling strategy for linked context and target set
1255
                pairs. The remaining observations are used for the target set.
1256
                Default is 0.5.
1257
            datewise_deterministic (bool, optional):
1258
                Whether random sampling is datewise deterministic based on the
1259
                date. Default is ``False``.
1260
            seed_override (Optional[int], optional):
1261
                Override the seed for random sampling. This can be used to use
1262
                the same random sampling at different ``date``. Default is
1263
                None.
1264

1265
        Returns:
1266
            :class:`~.data.task.Task` | List[:class:`~.data.task.Task`]:
1267
                Task object or list of task objects for each date containing
1268
                the context and target data.
1269
        """
1270
        if isinstance(date, (list, tuple, pd.core.indexes.datetimes.DatetimeIndex)):
2✔
1271
            return [
2✔
1272
                self.task_generation(
1273
                    d,
1274
                    context_sampling,
1275
                    target_sampling,
1276
                    split_frac,
1277
                    datewise_deterministic,
1278
                    seed_override,
1279
                )
1280
                for d in date
1281
            ]
1282
        else:
1283
            return self.task_generation(
2✔
1284
                date,
1285
                context_sampling,
1286
                target_sampling,
1287
                split_frac,
1288
                datewise_deterministic,
1289
                seed_override,
1290
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc